#!/usr/bin/python3
from samson.utilities.cli import start_repl, HASHES, CURVES, PKI
import argparse
import sys


parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest='command')

hash_parser = subparsers.add_parser('hash')
hash_parser.description = f"Available hash types: {', '.join([k for k,v in HASHES.items()])}"

hash_parser.add_argument('type')
hash_parser.add_argument('text', nargs="?")
hash_parser.add_argument('--args', nargs="?")

pki_parser = subparsers.add_parser('pki')
pki_parser.formatter_class = argparse.RawDescriptionHelpFormatter
pki_parser.description = f"""Available PKI types: {', '.join([k for k,v in PKI.items()])}

Available curves: {', '.join([k for k,v in CURVES.items()])}
"""


pki_parser.add_argument('action')
pki_parser.add_argument('type')
pki_parser.add_argument('--args', nargs="?")
pki_parser.add_argument('filename', nargs="?")
pki_parser.add_argument('--pub', action='store_true')
pki_parser.add_argument('--encoding', nargs="?")



arguments = parser.parse_args()

def try_parse(val):
    try:
        val = int(val)
    except ValueError as _:
        pass
    
    return val


if __name__ == '__main__':
    if len(sys.argv) == 1:
        start_repl()
    else:
        if arguments.command == 'hash':
            hash_cls = HASHES[arguments.type.lower()]

            text = arguments.text

            if not text:
                text = sys.stdin.read()

            if arguments.args:
                dict_args = {k: int(v) for k,v in [arg.split('=') for arg in arguments.args.split(',')]}
            else:
                dict_args = {}

            hash_obj = hash_cls(**dict_args)
            print(hash_obj.hash(text.encode('utf-8')).hex().decode())

        elif arguments.command == 'pki':
            pki_cls = PKI[arguments.type.lower()]

            if arguments.args:
                dict_args = {k: try_parse(v) for k,v in [arg.split('=') for arg in arguments.args.split(',')]}
            else:
                dict_args = {}

            if "curve" in dict_args:
                dict_args["G"] = CURVES[dict_args["curve"]].G
                del dict_args["curve"]


            if arguments.action == 'generate':
                pki_obj = pki_cls(**dict_args)
                
                if arguments.pub:
                    if arguments.encoding:
                        print(pki_obj.export_public_key(encoding=arguments.encoding).decode())
                    else:
                        print(pki_obj.export_public_key().decode())
                else:
                    if arguments.encoding:
                        print(pki_obj.export_private_key(encoding=arguments.encoding).decode())
                    else:
                        print(pki_obj.export_private_key().decode())
                    
                    
            elif arguments.action == 'parse':
                if arguments.filename:
                    with open(arguments.filename, 'rb') as f:
                        key_to_parse = f.read()

                else:
                    key_to_parse = sys.stdin.read().encode('utf-8')

                print(pki_cls.import_key(key_to_parse))
    