# type: ignore[attr-defined]
from typing import Optional

from enum import Enum
from random import choice

import typer
from rich.console import Console

from conplex_dti import cli, version
from conplex_dti.example import hello


class Color(str, Enum):
    white = "white"
    red = "red"
    cyan = "cyan"
    magenta = "magenta"
    yellow = "yellow"
    green = "green"


app = typer.Typer(
    name="conplex-dti",
    help="Adapting protein language models and contrastive learning for DTI prediction.",
    add_completion=False,
)
console = Console()


def version_callback(print_version: bool) -> None:
    """Print the version of the package."""
    if print_version:
        console.print(f"[yellow]conplex-dti[/] version: [bold blue]{version}[/]")
        raise typer.Exit()


@app.command(name="")
def main(
    name: str = typer.Option(..., help="Person to greet."),
    color: Optional[Color] = typer.Option(
        None,
        "-c",
        "--color",
        "--colour",
        case_sensitive=False,
        help="Color for print. If not specified then choice will be random.",
    ),
    print_version: bool = typer.Option(
        None,
        "-v",
        "--version",
        callback=version_callback,
        is_eager=True,
        help="Prints the version of the conplex-dti package.",
    ),
) -> None:
    """Print a greeting with a giving name."""
    if color is None:
        color = choice(list(Color))

    greeting: str = hello(name)
    console.print(f"[bold {color}]{greeting}[/]")


@app.command(name="train")
def train(
    parser.add_argument(
        "--exp-id", required=True, help="Experiment ID", dest="experiment_id"
    )
    parser.add_argument(
        "--config", help="YAML config file", default="configs/default_config.yaml"
    )

    parser.add_argument(
        "--wandb-proj",
        help="Weights and Biases Project",
        dest="wandb_proj",
    )
    parser.add_argument(
        "--task",
        choices=[
            "biosnap",
            "bindingdb",
            "davis",
            "biosnap_prot",
            "biosnap_mol",
            "dti_dg",
        ],
        type=str,
        help="Task name. Could be biosnap, bindingdb, davis, biosnap_prot, biosnap_mol.",
    )

    parser.add_argument(
        "--drug-featurizer", help="Drug featurizer", dest="drug_featurizer"
    )
    parser.add_argument(
        "--target-featurizer", help="Target featurizer", dest="target_featurizer"
    )
    parser.add_argument(
        "--distance-metric",
        help="Distance in embedding space to supervise with",
        dest="distance_metric",
    )
    parser.add_argument("--epochs", type=int, help="number of total epochs to run")
    parser.add_argument("-b", "--batch-size", type=int, help="batch size")
    parser.add_argument(
        "--lr",
        "--learning-rate",
        type=float,
        help="initial learning rate",
        dest="lr",
    )
    parser.add_argument(
        "--clr", type=float, help="initial learning rate", dest="clr"
    )
    parser.add_argument(
        "--r", "--replicate", type=int, help="Replicate", dest="replicate"
    )
    parser.add_argument(
        "--d", "--device", type=int, help="CUDA device", dest="device"
    )
    parser.add_argument(
        "--verbosity", type=int, help="Level at which to log", dest="verbosity"
    )
    parser.add_argument(
        "--checkpoint", default=None, help="Model weights to start from"
    )

):
    cli.train()


if __name__ == "__main__":
    app()
