#!python
import sys

import pydis


def separate_int_flag(value):
    flags = []
    for flag in list(type(value)):
        if flag != 0 and flag in value:
            flags.append(flag)

    return flags


def get_action_string(action):
    actions = ('INV', 'R', 'W', 'RW', 'CR', 'CW', 'RCW', 'CRW')
    return actions[action]


def print_instruction(instruction):
    print('== [    BASIC ] ============================================================================================'
          '=')
    print(f'   MNEMONIC: {instruction.mnemonic} [ENC: {instruction.encoding.name}, MAP: {instruction.opcode_map.name},'
          f'OPC: {instruction.opcode:02X}]')
    print(f'     LENGTH: {instruction.length:2}')
    print(f'        SSZ: {instruction.stack_width:2}')
    print(f'       EOSZ: {instruction.operand_width:2}')
    print(f'       EASZ: {instruction.address_width:2}')
    print(f'   CATEGORY: {instruction.meta.category.name}')
    print(f'    ISA-SET: {instruction.meta.isa_set.name}')
    print(f'    ISA-EXT: {instruction.meta.isa_ext.name}')
    print(f' EXCEPTIONS: {instruction.meta.exception_class.name}')

    if instruction.attributes != pydis.InstructionAttribute.NoAttributes:
        print(' ATTRIBUTES: ', end='')
        for attribute in separate_int_flag(instruction.attributes):
            print(attribute.name, end=' ')
        print()  # end the current line

    if instruction.operands:
        print()
        print_operands(instruction)

    if instruction.encoding in (pydis.InstructionEncoding.XOP, pydis.InstructionEncoding.VEX,
                                pydis.InstructionEncoding.EVEX, pydis.InstructionEncoding.MVEX):
        print_avx_info(instruction)

    pydis.default_formatter.print_segment_registers = True
    pydis.default_formatter.print_operand_sizes = True

    print()
    print('== [   DISASM ] ============================================================================================'
          '=')
    print(f'  {instruction}')


def print_operands(instruction):
    print('== [ OPERANDS ] ============================================================================================'
          '=')
    print('##       TYPE  VISIBILITY  ACTION      ENCODING   SIZE  NELEM  ELEMSZ   ELEMTYPE                        VALU'
          'E')
    print('--  ---------  ----------  ------  ------------   ----  -----  ------   --------  --------------------------'
          '-')

    for i, operand in enumerate(instruction.operands):
        print(f'{i:02}  {operand.type.name:>9}  {operand.visibility.name:>10}  {get_action_string(operand.action):>6}  '
              f'{operand.encoding.name:>12}  {operand.size:5}   {operand.element_count:4}  {operand.element_size:7}  '
              f'{operand.element_type:>8}', end='')

        if operand.type == pydis.OperandType.Register:
            print(f'  {str(operand.register):>27}')
        elif operand.type == pydis.OperandType.Memory:
            print(f'  TYPE  ={operand.memory.type.name:>20}')
            print(f'  {"SEG  ":>85} ={str(operand.memory.segment):>20}')
            print(f'  {"BASE ":>85} ={str(operand.memory.base):>20}')
            print(f'  {"INDEX":>85} ={str(operand.memory.index):>20}')
            print(f'  {"SCALE":>85} ={str(operand.memory.scale):>20}')
            print(f'  {"DISP ":>85} =  0x{operand.memory.displacement:016X}')
        elif operand.type == pydis.OperandType.Immediate:
            if operand.immediate.is_signed:
                relative = 'R' if operand.immediate.is_relative else '_'
                print(f'  (S {relative} {operand.immediate.size:2} 0x{operand.immediate.value:016X}')
        else:
            print()  # end the current line

        print('--  ---------  ----------  ------  ------------   ----  -----  ------   --------  ----------------------'
              '-----')


def print_avx_info(instruction):
    static = ' (static)' if instruction.avx.broadcast.is_static else ''
    print('== [      AVX ] ==========================================================================================='
          '=')
    print(f'  VECTORLEN: {instruction.avx.vector_length:03}')
    print(f'  BROADCAST: {instruction.avx.broadcast.mode.name}{static}')

    if instruction.encoding == pydis.InstructionEncoding.EVEX:
        sae = 'Y' if instruction.avx.has_sae else 'N'
        control_mask = ' (control-mask)' if instruction.avx.mask.is_control_mask else ''
        print(f'   ROUNDING: {instruction.avx.rounding.name}')
        print(f'        SAE: {sae}')
        print(f'       MASK: {instruction.avx.mask.register} [{instruction.avx.mask.mode.name}:>5]{control_mask}')
    elif instruction.encoding == pydis.InstructionEncoding.MVEX:
        sae = 'Y' if instruction.avx.has_sae else 'N'
        eviction_hint = 'Y' if instruction.avx.has_eviction_hint else 'N'
        print(f'   ROUNDING: {instruction.avx.rounding.name}')
        print(f'        SAE: {sae}')
        print(f'       MASK: {instruction.avx.mask.register} [MERGE]')
        print(f'         EH: {eviction_hint}')
        print(f'    SWIZZLE: {instruction.avx.swizzle}')
        print(f'    CONVERT: {instruction.avx.conversion}')
    else:
        print()


def print_usage():
    print(f'Usage: {sys.argv[0]} -[real|16|32|64] [hexbytes]', file=sys.stderr)


def main():
    if len(sys.argv) < 3:
        print_usage()
        return pydis.Status.InvalidParameter

    if sys.argv[1] == '-real':
        mode = pydis.MachineMode.Real16
        width = pydis.AddressWidth.Width16
    elif sys.argv[1] == '-16':
        mode = pydis.MachineMode.LongCompat16
        width = pydis.AddressWidth.Width16
    elif sys.argv[1] == '-32':
        mode = pydis.MachineMode.LongCompat32
        width = pydis.AddressWidth.Width32
    elif sys.argv[1] == '-64':
        mode = pydis.MachineMode.Long64
        width = pydis.AddressWidth.Width64
    else:
        print_usage()
        return pydis.Status.InvalidParameter

    try:
        input_bytes = bytes.fromhex(''.join(sys.argv[2:]))
    except ValueError:
        print('Invalid hex value', file=sys.stderr)
        return pydis.Status.InvalidParameter

    if len(input_bytes) > pydis.MaxInstructionLength:
        print(f'Maximum number of {pydis.MaxInstructionLength} bytes exceeded', file=sys.stderr)
        return pydis.Status.InvalidParameter

    try:
        instruction = next(pydis.decode(input_bytes, mode=mode, address_width=width))
    except Exception as e:
        print(f'Failed to decode instruction: {e}')
        return pydis.Status.InvalidParameter

    print_instruction(instruction)


if __name__ == '__main__':
    sys.exit(main())
