#!/usr/bin/env python3

import os
import shutil
import yaml
import edflow


def create_edflow_project(project_name, replace: bool = False, **kwargs):
    """
    Creates a fresh new project directory with all the needed skeletons which are edflow compatible.

    Parameters
    ----------
    project_name (str) : Name of your new project. This will also be the name of the parent folder.
    replace (bool) : If a old project    shall be replaced in case it exists.
    kwargs : optional file paths for the config, model, dataset and iterator classes as <class>_path="path to file"
    e.g. config_path="config/config.yaml" EXCLUDING the project directory.

    """
    if os.path.exists(project_name) and not replace:
        raise IsADirectoryError(
            f"Directory: {project_name} already exists. Provide a different name for the project "
            "OR set replace to True"
        )
    os.mkdir(project_name)
    parent_source = os.path.join(
        edflow.__file__.replace("__init__.py", ""), "edsetup_files"
    )

    destination_config = os.path.join(
        project_name, kwargs.get("config_path", "config.yaml")
    )
    source_config = os.path.join(parent_source, "config.yaml")

    with open(source_config, "r+") as source_config_file:
        source_config_dict = yaml.load(source_config_file, Loader=yaml.FullLoader)

    path_keys = ["model_path", "dataset_path", "iterator_path"]
    path_defaults = ["model.py", "dataset.py", "iterator.py"]

    destination_training_files = list()
    for key, default in zip(path_keys, path_defaults):
        destination_path = os.path.join(project_name, kwargs.get(key, default))
        source_path = os.path.join(parent_source, default)
        shutil.copy(source_path, destination_path)

        destination_training_files.append(destination_path)

    training_files_to_module = [
        parameter.replace("/", ".").replace("py", "")
        for parameter in destination_training_files
    ]
    training_classes = [
        file_name.replace(project_name, "").strip(".").capitalize()
        for file_name in training_files_to_module
    ]

    full_address_to_class = [
        file + class_name
        for file, class_name in zip(training_files_to_module, training_classes)
    ]
    training_parameters_dict = dict(
        zip(["model", "dataset", "iterator"], full_address_to_class)
    )
    source_config_dict.update(training_parameters_dict)

    with open(destination_config, "w+") as new_config_file:
        yaml.dump(source_config_dict, new_config_file)


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-n", metavar="project_name", type=str, default="edflow_project"
    )
    args = parser.parse_args()

    project_name = args.n
    create_edflow_project(project_name)
