from __future__ import annotations

from dataclasses import MISSING
from dataclasses import Field
from itertools import count
from pathlib import Path
from textwrap import indent
from types import UnionType
from typing import Any
from typing import Optional
from typing import Type

from yaml import DocumentEndEvent
from yaml import DocumentStartEvent
from yaml import Event
from yaml import MappingEndEvent
from yaml import MappingStartEvent
from yaml import ScalarEvent
from yaml import SequenceEndEvent
from yaml import SequenceStartEvent
from yaml import StreamEndEvent
from yaml import StreamStartEvent
from yaml import parse



def is_dataclass(type_: Type) -> bool:
    return hasattr(type_, "__dataclass_fields__")


def has_default(field: Field) -> bool:
    return field.default is not MISSING or field.default_factory is not MISSING


def get_default(field: Field) -> Any:
    assert has_default(field)
    if field.default is not MISSING:
        return field.default
    else:
        assert field.default_factory is not MISSING
        return field.default_factory()


def type_to_str(type_: Type) -> str:
    if type_ == list and hasattr(type_, "__args__"):
        return f"list[{type_to_str(type_.__args__[0])}]"
    elif type_ == dict and hasattr(type_, "__args__"):
        return (
            f"dict[{type_to_str(type_.__args__[0])}, {type_to_str(type_.__args__[1])}]"
        )

    return type_.__name__


def load(path: Path, type_: Type) -> Any:
    """Reads yaml with the specified type.

    The types available are basic types such as int, str, list, dict, etc., and
    dataclass.

    To represent the case where the notation in yaml is different from the resulting
    one, you can use the metadata in dataclasses.field to set the yaml_type and
    yaml_convert metadata for the fields of the data class. The yaml is parsed and
    validated with the type of that field as yaml_type and The result is the value
    converted by the yaml_convert function.
    """
    with open(path, "r") as f:
        lst = list(parse(f.read()))
        return _load(lst, 0, type_)[1]


TRUE_STRINGS = {"true", "yes", "on", "1"}
FALSE_STRINGS = {"false", "no", "off", "0"}


def _load(event_iter: list[Event], ind: int, type_: Type) -> tuple[int, Any]:
    if isinstance(type_, UnionType):
        args = getattr(type_, "__args__", None)
        assert args is not None
        errors_lst = []
        for type_i in args:
            try:
                return _load(event_iter, ind, type_i)
            except errors.YamlError as e:
                errors_lst.append(e)
        else:
            error_txt_lst = [
                f"Failed to parse yaml. Expected one of {list(map(type_to_str, args))}",
                "Errors for each type:",
            ]
            for type_i, error in zip(args, errors_lst):
                error_txt_lst.append(f"Type: {type_to_str(type_i)}")
                error_txt_lst.append(indent(str(error), " " * 4))
            raise errors.YamlError("\n".join(error_txt_lst))

    e = event_iter[ind]
    ind += 1
    match e:
        case StreamStartEvent():
            if isinstance(event_iter[ind], StreamEndEvent):
                raise errors.YamlTypeError(event_iter[ind].start_mark, "Empty yaml")
            ind, res = _load(event_iter, ind, type_)
            e = event_iter[ind]
            assert isinstance(e, StreamEndEvent)
            ind += 1
        case DocumentStartEvent():
            if isinstance(event_iter[ind], DocumentEndEvent):
                raise errors.YamlTypeError(event_iter[ind].start_mark, "Empty yaml")
            ind, res = _load(event_iter, ind, type_)
            e = event_iter[ind]
            assert isinstance(e, DocumentEndEvent)
            ind += 1
        case MappingStartEvent():
            if not (type_ == dict or is_dataclass(type_)):
                raise errors.YamlTypeError(
                    e.start_mark,
                    f"expected {type_to_str(type_)}, mapping found",
                )

            if type_ == dict:
                args = getattr(type_, "__args__", None)
                assert args is not None
                key_type = args[0]
                value_type = args[1]

                res = {}
                while True:
                    e = event_iter[ind]
                    if isinstance(e, MappingEndEvent):
                        break

                    ind, key = _load(event_iter, ind, key_type)
                    ind, value = _load(event_iter, ind, value_type)
                    res[key] = value

            else:
                # Does not allow for different order of fields
                fields: Optional[dict[str, Field]] = getattr(
                    type_, "__dataclass_fields__", None
                )
                assert fields is not None

                fields_dct = {}
                map_start_evnet = e

                while True:
                    e = event_iter[ind]
                    if isinstance(e, MappingEndEvent):
                        break

                    ind, key = _load(event_iter, ind, str)

                    if key not in fields:
                        raise errors.YamlTypeError(
                            e.start_mark,
                            f"unexpected field {key}",
                        )

                    field = fields[key]

                    if "yaml_type" in field.metadata:
                        field_type = field.metadata["yaml_type"]
                    else:
                        field_type = field.type

                    ind, value = _load(event_iter, ind, field_type)
                    if "yaml_convert" in field.metadata:
                        value = field.metadata["yaml_convert"](value)
                    fields_dct[key] = value

                for name, field in fields.items():
                    if name not in fields_dct:
                        if has_default(field):
                            fields_dct[name] = get_default(field)
                        else:
                            raise errors.YamlTypeError(
                                map_start_evnet.start_mark,
                                f"missing field {name}",
                            )

                res = type_(**fields_dct)

            ind += 1

        case SequenceStartEvent():
            if not (type_ == list or type_ == tuple):
                raise errors.YamlTypeError(
                    e.start_mark,
                    f"expected {type_to_str(type_)}, sequence found",
                )

            args = getattr(type_, "__args__", None)
            assert args is not None

            if type_ == list:
                value_type = args[0]

                res = []
                while True:
                    e = event_iter[ind]
                    if isinstance(e, SequenceEndEvent):
                        break

                    ind, value = _load(event_iter, ind, value_type)
                    res.append(value)

            else:
                res = []
                infinite_length = False
                last_type: Optional[Type] = None
                for i in count():
                    e = event_iter[ind]

                    if (not infinite_length) and args[i] == Ellipsis:
                        assert (
                            last_type is not None
                        ), "Ellipsis must be preceded by a type"
                        infinite_length = True

                    if isinstance(e, SequenceEndEvent):
                        if (not infinite_length) and i != len(args):
                            raise errors.YamlTypeError(
                                e.start_mark,
                                f"expected {len(args)} elements, {i} found",
                            )
                        break

                    if infinite_length:
                        assert last_type is not None
                        ind, value = _load(event_iter, ind, last_type)
                        res.append(value)
                    else:
                        ind, value = _load(event_iter, ind, args[i])
                        res.append(value)
                        last_type = args[i]

            ind += 1

        case ScalarEvent():
            value = getattr(e, "value", None)
            assert value is not None
            if type_ == int:
                try:
                    res = int(value)
                except ValueError:
                    raise errors.YamlTypeError(
                        e.start_mark,
                        f"expected int, {value!r} found",
                    )
            elif type_ == float:
                try:
                    res = float(value)
                except ValueError:
                    raise errors.YamlTypeError(
                        e.start_mark,
                        f"expected float, {value!r} found",
                    )
            elif type_ == str:
                res = value
            elif type_ == bool:
                if value in TRUE_STRINGS:
                    res = True
                elif value in FALSE_STRINGS:
                    res = False
                else:
                    raise errors.YamlTypeError(
                        e.start_mark,
                        f"expected bool, {value!r} found",
                    )
            else:
                raise errors.YamlTypeError(
                    e.start_mark,
                    f"expected {type_to_str(type_)}, scalar found",
                )

        case _:
            assert False, f"unreachable, event: {e}"

    return ind, res
