#!/usr/bin/env python
import ast
import argparse
import pprint

import libem
from libem.match.parameter import tools


def main():
    parser = argparse.ArgumentParser(description="Libem CLI tool")
    parser.add_argument('e1', type=str, help='First entity')
    parser.add_argument('e2', type=str, help='Second entity')

    # todo:
    #  consolidate and reuse the calibration options and parsing
    #  with the ones in benchmark/run.py and benchmark/util.py
    # calibrate any parameters
    parser.add_argument('-c', '--calibrate', metavar="KEY=VALUE",
                        type=parse_key_value_pair, nargs='+',
                        help="Key-value pairs separated by '=', "
                             "e.g., key1=value1 key2=value2")
    # shortcuts for calibration
    parser.add_argument("-m", "--model", dest='model', nargs='?',
                        help="The LLM to use.",
                        type=str, default=libem.parameter.model())
    parser.add_argument("--browse", dest='browse',
                        help="Enable the browse tool.",
                        action='store_true', default=False)
    parser.add_argument("--cot", dest='cot',
                        help="Enable chain of thought.",
                        action='store_true', default=False)
    parser.add_argument("--confidence", dest='confidence',
                        help="Report confidence score.",
                        action='store_true', default=False)
    parser.add_argument("-g", "--guess", dest='guess',
                        help="Match by guessing.",
                        action='store_true', default=False)
    # output options
    parser.add_argument("-d", "--debug", dest='debug',
                        help="Enable debug mode.",
                        action='store_true', default=False)
    parser.add_argument("-s", "--stats", dest='stats',
                        help="Print stats from trace.",
                        action='store_true', default=False)

    args = parser.parse_args()

    entity1 = args.e1
    entity2 = args.e2

    configs = args.calibrate or []
    if args.model:
        configs.append(('libem.match.parameter.model', args.model))
    if args.cot:
        configs.append(('libem.match.parameter.cot', True))
    if args.confidence:
        configs.append(('libem.match.parameter.confidence', True))
    if args.browse:
        configs.append(('libem.match.parameter.tools', tools() + ['libem.browse']))
    if args.guess:
        configs.append(('libem.match.parameter.guess', True))
    if args.debug:
        libem.debug_on()

    if configs:
        libem.calibrate(dict(configs))

    with libem.trace as t:
        result = libem.match(entity1, entity2)
    print("Match result:", result)

    if args.stats:
        pp = pprint.PrettyPrinter(sort_dicts=False)
        pp.pprint(t.stats())


def parse_key_value_pair(arg_value):
    """ Parse a key-value pair, separated by '=' """
    if '=' not in arg_value:
        raise argparse.ArgumentTypeError("Key-value pairs must be separated by '='")
    key, value = arg_value.split('=', 1)
    try:
        # Safely evaluate the value part to handle data structures like lists
        value = ast.literal_eval(value)
    except (SyntaxError, ValueError):
        # If evaluation fails, keep the value as a string
        pass
    return key, value


if __name__ == '__main__':
    main()
