#!python
"""Script for computing the incremntal difference in two covearage reports."""
import argparse
import math
import os
import re
import subprocess
from collections import defaultdict
from difflib import SequenceMatcher
from enum import Enum, unique
from functools import partial
from pathlib import Path
from typing import Callable, List, Optional, Set, Tuple

import numpy as np
from ansiwrap import ansilen
from scipy.optimize import linear_sum_assignment


def git_revision_short_hash() -> str:
    """Get the git hash of the current branch."""
    try:
        return (
            subprocess.check_output(["git", "rev-parse", "--short", "HEAD"])
            .decode("ascii")
            .strip()
        )
    except:
        return "unknown"


def tan_score(perc: float) -> float:
    """Transform `perc` to the tangent scale.

    This score monotonically from 0.0 to 6.955.
    """
    bounded = min(100, max(0, perc))
    return math.tan(bounded / 110 * math.pi / 2)


@unique
class ANSICodes(str, Enum):
    """Codes for the ansi escape sequences for altering text in the terminal."""

    RED = "\033[91m"
    GREEN = "\033[92m"
    YELLOW = "\033[93m"
    BLUE = "\033[94m"
    MAGENTA = "\033[95m"
    CYAN = "\033[96m"
    BOLD = "\033[1m"
    UNDERLINE = "\033[4m"
    ENDC = "\033[0m"


@unique
class AbsoluteGradeThresholds(float, Enum):
    """Thresholds of the absolute coverage into grade levels."""

    A_PLUS = 97.0
    A = 92.0
    B = 83.0
    C = 74.0
    D = 65.0


@unique
class ChangeThresholds(float, Enum):
    """Thresholds of the incremental coverage change into grade levels."""

    A_PLUS = 1.0
    A = 0.5
    B = 0.0
    C = -0.333
    D = -0.666


class FuncCoverage:
    """A summary of the coverage information for a function."""

    def __init__(
        self, path: str, line_num: int, func_name: str, coverage_perc: float
    ) -> None:
        """Create a new functional coverage statement."""
        self.path = path
        self.line_num = line_num
        self.func_name = func_name
        self.coverage_perc = coverage_perc

    @property
    def full_name(self) -> str:
        """Access the fully qualified name of the function."""
        return f"{self.path}:{self.func_name}"

    @staticmethod
    def golang_coverage_line(line: str) -> Optional["FuncCoverage"]:
        """Parses a line as a golang function coverage report."""
        match = re.match(
            r"(?P<path>[^.]+\.go):(?P<line_num>\d+):\s+(?P<func_name>\w+)\s+(?P<coverage_perc>\d+\.\d+)\%",
            line,
        )
        if match:
            return FuncCoverage(
                path=match.group("path"),
                line_num=int(match.group("line_num")),
                func_name=match.group("func_name"),
                coverage_perc=float(match.group("coverage_perc")),
            )
        return None

    def cost(self, other: "FuncCoverage") -> float:
        """Compute a cost between this function coverage and `other`.

        The cost computed here is a heuristic cost that measures the difference
        between two functions based on the function path, name, and location.
        This cost currently consists of a measure of string edit distance between
        the full name of each function (which is scaled to between 0-100) and
        the absolute diffence in line numbers for each function (which is scaled
        by dividing it by 1000).  In this way, line-number is meant to act
        primarily as a tie-breaker for functions with the same full-name within
        a file.
        """
        # Compute a string sequence ratio (in range [0, 1] with 1 corresponding)
        # to a perfect match, reverse it so its a cost, and scale by 100.
        ret = (
            1.0 - SequenceMatcher(None, self.full_name, other.full_name).ratio()
        ) * 100

        # Add the line number difference scaled by 1 / 1000 to act as a tie-breaker.
        ret += abs(self.line_num - other.line_num) / 1000
        return ret

    def color_abs_coverage(self, is_markdown: bool = False) -> str:
        """Make a colored string representing the grade of the coverage percentage."""
        cov = self.coverage_perc
        if is_markdown:
            return f"{cov:.1f}%"
        if cov >= AbsoluteGradeThresholds.A_PLUS:
            color = ANSICodes.GREEN
        elif cov >= AbsoluteGradeThresholds.A:
            color = ANSICodes.CYAN
        elif cov >= AbsoluteGradeThresholds.B:
            color = ANSICodes.BLUE
        elif cov >= AbsoluteGradeThresholds.C:
            color = ANSICodes.YELLOW
        elif cov >= AbsoluteGradeThresholds.D:
            color = ANSICodes.MAGENTA
        else:
            color = ANSICodes.RED
        return f"{color}{cov:.1f}%{ANSICodes.ENDC}"


def load_coverage(
    file_name: str, parse_fn: Callable[[str], Optional[FuncCoverage]]
) -> List[FuncCoverage]:
    """Load a functional coverage report from `file_name` using `parse_fn`.

    Args:
        file_name: the name of the file containing the coverage report.
        parse_fn: the function used to parse each line of the coverage report.

    Returns:
        A list of all successfully parsed functional coverage statements.
    """
    with open(file_name, "r") as file:
        lines = list(file.readlines())
        return [c for c in [parse_fn(line) for line in lines] if c is not None]


def color_change(
    old_cover: FuncCoverage, new_cover: FuncCoverage, is_markdown: bool = False
) -> str:
    """Make a colored string representing the change in coverage percentage."""
    old = old_cover.coverage_perc
    new = new_cover.coverage_perc
    # Here we use a tangent transform of the percentages to give more weight
    # to percentages as we approach 100% coverage.
    score_change = tan_score(new) - tan_score(old)
    end = ANSICodes.ENDC
    if score_change >= ChangeThresholds.A_PLUS:
        color = ANSICodes.GREEN
    elif score_change >= ChangeThresholds.A:
        color = ANSICodes.CYAN
    elif score_change >= ChangeThresholds.B:
        color = ANSICodes.BLUE
    elif score_change >= ChangeThresholds.C:
        color = ANSICodes.YELLOW
    elif score_change >= ChangeThresholds.D:
        color = ANSICodes.MAGENTA
    else:
        color = ANSICodes.RED

    if is_markdown:
        color = ""
        end = ""

    if old > new:
        return f"{color}-{old - new}%{end}"
    return f"{color}+{new - old:.1f}%{end}"


class Assignment:
    """A matching between functional coverage statements."""

    def __init__(self) -> None:
        """Create an empty assignment."""
        self.matches: List[Tuple[FuncCoverage, FuncCoverage]] = []
        self.deleted: List[FuncCoverage] = []
        self.added: List[FuncCoverage] = []

    @staticmethod
    def compute_assignment(
        old_cover: List[FuncCoverage],
        new_cover: List[FuncCoverage],
        match_threshold: float,
        perc_change_threshold: float,
    ) -> "Assignment":
        """Compute an assignment between `old_cover` and `new_cover`.

        This assignment problem is solved by first finding all exact matches
        between the old and new cover and then finding an assignment between
        all remaining coverage lines based on a heuristic cost function by finding
        a linear assignment; see https://en.wikipedia.org/wiki/Assignment_problem .

        Args:
            old_cover: the list of old function coverage statements.
            new_cover: the list of new function coverage statements.
            match_threshold: the cost threshold to match coverage statements.
            perc_change_threshold: the threshold on coverage change to consider
                coverage to be different.

        Returns:
            The assignment between the coverage of `old_cover` and `new_cover`.
        """
        assign = Assignment()

        # Search for any covearge matches with 0 cost. By design of our cost,
        # 0-cost only occurs if two coverage lines are identical, so there is
        # at most one such match of each coverage item. Further, even though it
        # is possible that there may be a better global assignment with respect
        # to the cost matrix, in this problem, identical coverage lines should
        # always be paired since we know those correspond to identical code units.
        matched_old: List[int] = []
        for i, old in enumerate(old_cover):
            for j, new in enumerate(new_cover):
                if old.cost(new) == 0.0:
                    # Only include a functions if it had a non-trivial change in coverage.
                    if (
                        abs(new.coverage_perc - old.coverage_perc)
                        >= perc_change_threshold
                    ):
                        assign.matches.append((old, new))

                    # Remove `new` from the `new_cover` and shortcut the inner loop.
                    # However, we cannot alter the `old_cover` since we are still
                    # looping over it; instead, we record the index of `old` to be
                    # removed after the loop exits.
                    new_cover = new_cover[:j] + new_cover[(j + 1):]
                    matched_old.append(i)
                    break
        old_cover = [old for (i, old) in enumerate(old_cover) if i not in matched_old]

        old_matched_indices: Set[int] = set()
        new_matched_indices: Set[int] = set()

        # Only build an assignement if there still are items in both coverage reports.
        if old_cover and new_cover:
            # Compute a least cost assignment between the old and new coverage reports.
            # Note: row_ind and col_ind will be the same length.
            # See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.linear_sum_assignment.html .
            costs = np.array(
                [[old.cost(new) for new in new_cover] for old in old_cover]
            )
            row_ind, col_ind = linear_sum_assignment(costs)

            # Process the assignment into matches in the old and new covearge reports.
            for i in range(len(row_ind)):
                # Only consider two functions to actually match if their cost was within
                # the matching threshold.
                if costs[row_ind[i], col_ind[i]] <= match_threshold:
                    old_matched_indices.add(row_ind[i])
                    new_matched_indices.add(col_ind[i])
                    old = old_cover[row_ind[i]]
                    new = new_cover[col_ind[i]]
                    # Only include a functions if it had a non-trivial change in coverage.
                    if (
                        abs(new.coverage_perc - old.coverage_perc)
                        >= perc_change_threshold
                    ):
                        assign.matches.append((old, new))

        # Find all functions that were not in the new coverage report (deleted).
        for old_index, cover in enumerate(old_cover):
            if old_index not in old_matched_indices:
                assign.deleted.append(cover)

        # Find all functions that were not in the old coverage report (added).
        for new_index, cover in enumerate(new_cover):
            if new_index not in new_matched_indices:
                assign.added.append(cover)

        return assign

    def as_terminal_report(self) -> str:
        """Create a report for this assignment that can be printed on the terminal."""
        report: List[List[str]] = []
        if len(self.matches) > 0:
            subreport = [f"{ANSICodes.BOLD}COVERAGE CHANGES{ANSICodes.ENDC}:"]
            subreport.extend(
                [fmt_func_coverage(old_cov=o, new_cov=n) for o, n in self.matches]
            )
            report.append(subreport)

        if len(self.deleted) > 0:
            subreport = [f"{ANSICodes.BOLD}DELETED{ANSICodes.ENDC}:"]
            subreport.extend(fmt_func_coverage(old_cov=o) for o in self.deleted)
            report.append(subreport)

        if len(self.added) > 0:
            subreport = [f"{ANSICodes.BOLD}ADDED{ANSICodes.ENDC}:"]
            subreport.extend([fmt_func_coverage(new_cov=n) for n in self.added])
            report.append(subreport)

        return "\n\n".join(["\n".join(sr) for sr in report])

    def as_markdown_report(self) -> str:
        """Create a markdown report for this assignment."""
        report: List[List[str]] = []
        if len(self.matches) > 0:
            subreport = [
                "## Coverage Changes in Existing Code:",
                "",
                "| code unit | old coverage | new coverage | coverage change |",
                "| :-- | :--: | :--: | --: |",
            ]
            subreport.extend(
                [func_coverage_markdown(old_cov=o, new_cov=n) for o, n in self.matches]
            )
            report.append(subreport)

        if len(self.deleted) > 0:
            subreport = [
                "## Deleted Code:",
                "",
                "| code unit | coverage |",
                "| :-- | :--: |",
            ]
            subreport.extend(func_coverage_markdown(old_cov=o) for o in self.deleted)
            report.append(subreport)

        if len(self.added) > 0:
            subreport = [
                "## New Code:",
                "",
                "| code unit | coverage |",
                "| :-- | :--: |",
            ]
            subreport.extend([func_coverage_markdown(new_cov=n) for n in self.added])
            report.append(subreport)

        report_text = "\n\n".join(["\n".join(sr) for sr in report])
        score_card = ScoreCard(self)
        return f"""# Incremental Golang Coverage (commit {git_revision_short_hash()})

{score_card}

{report_text if report_text else "No changes in Golang coverage!"}"""


@unique
class Grade(str, Enum):
    """Grades for coverage."""

    A_PLUS = "💯"
    A = "✅✅✅"
    B = "✅"
    C = "🫤"
    D = "⛔"
    F = "⛔⛔⛔"

    def rank(self):
        """A ranking of grades based on their declared order."""
        return len(list(Grade)) - list(Grade).index(self)

    @staticmethod
    def absolute_grade(coverage_perc: float) -> "Grade":
        """Compute the grade for absolute coverage of a code unit."""
        if coverage_perc >= AbsoluteGradeThresholds.A_PLUS:
            return Grade.A_PLUS
        elif coverage_perc >= AbsoluteGradeThresholds.A:
            return Grade.A
        elif coverage_perc >= AbsoluteGradeThresholds.B:
            return Grade.B
        elif coverage_perc >= AbsoluteGradeThresholds.C:
            return Grade.C
        elif coverage_perc >= AbsoluteGradeThresholds.D:
            return Grade.D
        else:
            return Grade.F

    @staticmethod
    def relative_grade(score_change: float) -> "Grade":
        """Compute the grade for relative score change of a code unit."""
        if score_change >= ChangeThresholds.A_PLUS:
            return Grade.A_PLUS
        elif score_change >= ChangeThresholds.A:
            return Grade.A
        elif score_change >= ChangeThresholds.B:
            return Grade.B
        elif score_change >= ChangeThresholds.C:
            return Grade.C
        elif score_change >= ChangeThresholds.D:
            return Grade.D
        else:
            return Grade.F


class ScoreCard:
    """An overall assessment of the incremental code coverage."""

    def __init__(self, assign: Assignment) -> None:
        """Create a new score card from an assignment."""
        self.scores: List[Tuple[Grade, str]] = []
        if assign.matches:
            self.scores.append(ScoreCard._score_change_absolute(assign.matches))
            self.scores.append(ScoreCard._score_change_relative(assign.matches))
        if assign.added:
            self.scores.append(ScoreCard._score_new_absolute(assign.added))

    def __str__(self) -> str:
        """Pretty print the score card."""
        if not self.scores:
            return ""
        overall_grade = min([score[0] for score in self.scores], key=lambda s: s.rank())
        if overall_grade == Grade.A_PLUS:
            assess = "You are a coverage beast!"
        elif overall_grade == Grade.A:
            assess = "Amazing coverage!"
        elif overall_grade == Grade.B:
            assess = "Good job with covering code."
        elif overall_grade == Grade.C:
            assess = "This PR has reasonable coverage but could be improved."
        elif overall_grade == Grade.D:
            assess = "This PR may require coverage improvements; please review."
        elif overall_grade == Grade.F:
            assess = "The PR appears to have major coverage deficencies; please review."
        nl = "\n"
        return f"""{assess}

Grade breakdown:
{nl.join([f"  * {score[0]}  {score[1]}" for score in self.scores])}"""

    @staticmethod
    def _score_change_absolute(
        changed_cov: List[Tuple[FuncCoverage, FuncCoverage]]
    ) -> Tuple[Grade, str]:
        """Compute a score for the absolute coverage of the changed units."""
        worst_cov = min([c[1].coverage_perc for c in changed_cov])
        return (
            Grade.absolute_grade(worst_cov),
            f"Changed code had least coverage of {worst_cov:.1f}%",
        )

    @staticmethod
    def _score_change_relative(
        changed_cov: List[Tuple[FuncCoverage, FuncCoverage]]
    ) -> Tuple[Grade, str]:
        """Compute a score for the relative coverage of the changed units."""
        worst_score = min(
            [
                (
                    tan_score(c[1].coverage_perc) - tan_score(c[0].coverage_perc),
                    c[0].coverage_perc,
                    c[1].coverage_perc,
                )
                for c in changed_cov
            ]
        )
        return (
            Grade.relative_grade(worst_score[0]),
            f"Changed code had worst coverage change of {worst_score[1]:.1f}% -> {worst_score[2]:.1f}%",
        )

    @staticmethod
    def _score_new_absolute(added_cov: List[FuncCoverage]) -> Tuple[Grade, str]:
        """Compute a score for the absolute coverage of the new units."""
        worst_cov = min([c.coverage_perc for c in added_cov])
        return (
            Grade.absolute_grade(worst_cov),
            f"New code had least coverage of {worst_cov:.1f}%",
        )


def fmt_coverage_change(
    old_cov: Optional[FuncCoverage] = None,
    new_cov: Optional[FuncCoverage] = None,
) -> str:
    """Format a summary of the change in coverage % between `old_cov` and `new_cov`."""
    change_str = ""
    left, sep, right = "", "", ""
    max_len = 16
    if new_cov is not None and old_cov is not None:
        sep = "->"
        change_str = f"{color_change(old_cov, new_cov)}  "

    if old_cov is not None:
        left = old_cov.color_abs_coverage()
    if new_cov is not None:
        right = new_cov.color_abs_coverage()
    extra = " " * (max_len - ansilen(left) - ansilen(sep) - ansilen(right))
    mid = len(extra) // 2
    sep = f"{extra[:mid]}{sep}{extra[mid:]}"
    return f"{change_str}[ {left}{sep}{right} ]  "


def fmt_func_coverage(
    old_cov: Optional[FuncCoverage] = None,
    new_cov: Optional[FuncCoverage] = None,
) -> str:
    """Format a summary of difference between `old_cov` and `new_cov`."""
    old_name = old_cov.full_name if old_cov else None
    new_name = new_cov.full_name if new_cov else None
    names = [old_name, new_name] if old_name != new_name else [new_name]
    fmt_names = [f"{ANSICodes.UNDERLINE}{n}{ANSICodes.ENDC}" for n in names if n]
    left = "  " + " -> ".join(fmt_names)
    right = fmt_coverage_change(old_cov, new_cov)

    cols, _ = os.get_terminal_size()
    fill = " " * (cols - ansilen(left) - ansilen(right))
    return f"{left}{fill}{right}"

def coverage_change_markdown(
    old_cov: Optional[FuncCoverage] = None,
    new_cov: Optional[FuncCoverage] = None,
) -> Tuple[str, str, str]:
    """Format a markdown summary of the coverage change in `old_cov` and `new_cov`."""
    change_str = ""
    left, right = "", ""
    if new_cov is not None and old_cov is not None:
        change_str = f"{color_change(old_cov, new_cov, is_markdown=True)}  "

    if old_cov is not None:
        left = old_cov.color_abs_coverage(is_markdown=True)
    if new_cov is not None:
        right = new_cov.color_abs_coverage(is_markdown=True)
    return change_str, left, right

def func_coverage_markdown(
    old_cov: Optional[FuncCoverage] = None,
    new_cov: Optional[FuncCoverage] = None,
) -> str:
    """Format a markdown summary of difference between `old_cov` and `new_cov`."""
    old_name = old_cov.full_name if old_cov else None
    new_name = new_cov.full_name if new_cov else None
    names = [old_name, new_name] if old_name != new_name else [new_name]
    fmt_names = [f"*{n}*" for n in names if n]
    code_unit = " -> ".join(fmt_names)
    change, old_perc, new_perc = coverage_change_markdown(old_cov, new_cov)
    old_perc = f"~~{old_perc}~~" if old_perc else ""

    row = " | ".join([val for val in [code_unit, old_perc, new_perc, change] if val])
    return f"| {row} |"


@unique
class OutputFormat(Enum):
    """Enumerates the possible formats for the coverage report."""

    TERMINAL = partial(Assignment.as_terminal_report)
    MARKDOWN = partial(Assignment.as_markdown_report)

    @staticmethod
    def parse(s: str) -> "OutputFormat":
        """Parse the string `s` into its corresponding format."""
        return OutputFormat[s.upper()]

    def __call__(self, assignment: Assignment) -> None:
        """Invoke the value's formatting function to format `assignment`."""
        return self.value(assignment)

    def __str__(self) -> str:
        """Make a string representation of this value."""
        return self.name


def hash_hint(filepath: str) -> str:
    """Return the git hash embedded in a file name, if found."""
    filename = Path(filepath).name
    re_pattern = r"([0-9a-fA-F]{40,40})"
    match = re.search(re_pattern, filename)
    if match:
        return match.group(1)
    return ""


def fmt_no_coverage(sum_of_old: List[FuncCoverage], sum_of_new: List[FuncCoverage], output_format: OutputFormat) -> None:
    """Print the change in the number of files that have no coverage."""

    no_cov_old = [k for k in sum_of_old if sum_of_old[k] < 1]
    no_cov_new = [k for k in sum_of_new if sum_of_new[k] < 1]
    no_cov_diff = len(no_cov_new) - len(no_cov_old)
    end = ANSICodes.ENDC
    color = bold = ""
    if output_format == OutputFormat.TERMINAL:
        bold = ANSICodes.BOLD

        # Net change fewer files in new
        if no_cov_diff <= 0:
            color = ANSICodes.GREEN
        else:
            color = ANSICodes.RED

    print(f"No. files with no coverage old: {bold}{len(no_cov_old)}{end}")
    print(f"No. files with no coverage new: {color}{bold}{len(no_cov_new)}{end}")
    print(f"Change in files without cover : {color}{bold}{no_cov_diff:>3}{end}")

def parse_args():
    """Parse the arguments from the CLI."""
    parser = argparse.ArgumentParser(
        description="A script that computes incremental coverage reports.",
    )
    parser.add_argument("old_coverage_file", help="The original coverage report.")
    parser.add_argument("new_coverage_file", help="The new coverage report.")
    parser.add_argument(
        "--match-threshold",
        type=int,
        default=20,
        help="Threshold for determining a match in [0, 100] (default: 20).",
    )
    parser.add_argument(
        "--perc-change-threshold",
        type=float,
        default=0.1,
        help="Percent hreshold to consider coverage different (default: 0.1).",
    )
    parser.add_argument(
        "--output-format",
        type=OutputFormat.parse,
        default=OutputFormat.TERMINAL,
        choices=list(OutputFormat),
        help="Selects the format for the coverage report (default: TERMINAL).",
    )
    parser.add_argument(
        "--show-no-coverage",
        action="store_true",
        help="Display the number of files with no coverage at all.",
    )
    return parser.parse_args()


if __name__ == "__main__":
    # Handle command line arguments.
    args = parse_args()

    # Compute the assigment from the pair of coverage reports.
    old_coverage = load_coverage(
        args.old_coverage_file,
        FuncCoverage.golang_coverage_line,
    )
    new_coverage = load_coverage(
        args.new_coverage_file,
        FuncCoverage.golang_coverage_line,
    )
    assign = Assignment.compute_assignment(
        old_coverage,
        new_coverage,
        args.match_threshold,
        args.perc_change_threshold,
    )

    # Print the assignment to stdout.
    print(args.output_format(assign))

    if args.show_no_coverage:
        sum_of_old = defaultdict(float)
        for cov in old_coverage:
            sum_of_old[cov.path] += cov.coverage_perc

        sum_of_new = defaultdict(float)
        for cov in new_coverage:
            sum_of_new[cov.path] += cov.coverage_perc

        hash_old = hash_hint(args.old_coverage_file)
        hash_new = hash_hint(args.new_coverage_file)
        hash_cmp = ""
        if hash_old and hash_new:
            hash_cmp =  f"{hash_old}..{hash_new}"
        print(f"\nNO COVERAGE: {hash_cmp}")

        fmt_no_coverage(sum_of_old, sum_of_new, args.output_format)
