#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
from typing import Optional

import pandas as pd

from rich import box
from rich import print as rich_print
from rich.table import Table

from submititnow.jt import utils
from submititnow import cli
import typer


app = typer.Typer()

CUSTOM_HORIZONTALS: box.Box = box.Box(
    """\
 ━━ 
    
 ━━ 
    
 ── 
 ━━ 
    
 ━━ 
"""
)


def df_to_table(
    df: pd.DataFrame,
    rich_table: Table,
    show_index: bool = True,
    index_name: Optional[str] = None,
) -> Table:
    """Convert a pandas.DataFrame obj into a rich.Table obj.

    Most of the logic taken from:
    https://gist.github.com/neelabalan/33ab34cf65b43e305c3f12ec6db05938#file-df_to_table-py-L12

    Args:
        pandas_dataframe (DataFrame): A Pandas DataFrame to be converted to a rich Table.
        rich_table (Table): A rich Table that should be populated by the DataFrame values.
        show_index (bool): Add a column with a row count to the table. Defaults to True.
        index_name (str, optional): The column name to give to the index column. Defaults to None, showing no value.
    Returns:
        Table: The rich Table instance passed, populated with the DataFrame values."""

    if show_index:
        index_name = str(index_name) if index_name else ""
        rich_table.add_column(index_name)

    styles = {
        "Exp ID": "bold dark_blue",
        "Job ID": "bold bright_blue",
        "Job Description": "rosy_brown",
        "Exp Info": "rosy_brown",
    }

    for column in df.columns:
        rich_table.add_column(str(column), style=styles.get(column, None))
    exp_indices = df.columns.get_indexer(["Exp ID", "Exp Info"])
    df_rows = [row for _, row in df.iterrows()]
    for index, df_row in enumerate(df_rows):
        row = [str(index)] if show_index else []
        row += [str(x) for x in df_row]
        if index > 0:
            if df_rows[index]["Exp ID"] != df_rows[index - 1]["Exp ID"]:
                rich_table.add_section()
            else:
                for idx in exp_indices:
                    row[idx] = ""
        rich_table.add_row(*row)

    return rich_table


def stylish_job_status(msg: str):
    if msg.startswith("PENDING"):
        msg_style = "bold yellow"
    elif msg.startswith("RUNNING"):
        msg_style = "bold bright_green"
    elif msg.startswith("FAILED: Out Of Memory"):
        msg_style = "bold dark_red"
    elif msg.startswith("FAILED: Triggered"):
        msg_style = "bold red3"
    elif msg.startswith("CANCELLED"):
        msg_style = "bold indian_red1"
    elif msg.startswith("COMPLETED"):
        msg_style = "bold green4"
    else:
        msg_style = "bold medium_violet_red"

    return f"[{msg_style}]{msg}[/{msg_style}]"


@app.command(name="jobs", help="Show the status of all jobs within an experiment.")
def display_jobs_info(
    exp_name: str = typer.Argument(..., help="The name of the experiment."),
    exp_id: Optional[int] = typer.Argument(None, help="The experiment ID."),
    max_rows: int = typer.Option(
        default=20, help="Max number of rows to display in reverse chronological order."
    ),
):
    exp = utils.JTExp(exp_name)
    df = exp.prepare_job_states_df(max_rows, exp_id)
    df["Job Status"] = df["Job Status"].apply(stylish_job_status)

    # Initiate a Table instance to be modified
    table_title = f":test_tube: [bold yellow]Experiment Dashboard for [hot_pink]{exp_name}[/hot_pink]"
    if exp_id:
        table_title = (
            f"{table_title} [bold yellow]with ID [hot_pink]{exp_id}[/hot_pink]"
        )
    table = Table(
        show_header=True,
        header_style="bold bright_white",
        highlight=True,
        title=table_title,
    )

    # Modify the table instance to have the data from the DataFrame
    table = df_to_table(df, table, show_index=False)

    table.box = CUSTOM_HORIZONTALS
    print()
    rich_print(table)


@app.command(name="err", help="Show the stderr log of a job")
def show_stderr(job_id: str):
    filepath = utils.get_job_filepath(job_id, "err")
    cli.show_file_content(filepath)


@app.command(name="out", help="Show the stdout log of a job")
def show_stdout(job_id: str):
    filepath = utils.get_job_filepath(job_id, "out")
    cli.show_file_content(filepath)


@app.command(name="sh", help="Show the SLURM sbatch shell script of a job")
def show_submission_sh(job_id: str):
    filepath = utils.get_job_filepath(job_id, "sh")
    cli.show_file_content(filepath)


@app.command(name="ls", help="List all experiments.")
def list_experiments():

    exp_names = os.listdir(utils.EXPERIMENTS_ROOT_DIR)
    exp_names = filter(lambda x: utils.JTExp(x).exists(), exp_names)
    exp_names = sorted(exp_names, key=lambda x: os.path.getmtime(utils.JTExp(x).exp_dir), reverse=True)
    table = Table(show_header=True, header_style="bold magenta", highlight=True)
    table.add_column(
        ":test_tube: [bold yellow]Experiments", justify="center", style="turquoise2"
    )
    for exp_name in exp_names:
        table.add_row(exp_name)
    table.box = box.HEAVY_EDGE
    print()
    rich_print(table)


if __name__ == "__main__":
    app()
