#!python

import asyncio
import argparse
from mt import cv, np
from mt.base import filetype

def get_image(imm):
    '''Produces BGR image for display using OpenCV.'''
    if imm.pixel_format in ['gray', 'bgr']:
        return imm.image

    if imm.pixel_format == 'rgb':
        return np.ascontiguousarray(np.flip(imm.image, axis=-1))

    if imm.pixel_format == 'rgba':
        h, w = imm.image.shape[:2]
        image = np.zeros((h*2, w*2, 3), dtype=np.uint8)
        image[h:h*2,:w,0] = imm.image[:,:,2]
        image[h:h*2,:w,1] = imm.image[:,:,1]
        image[h:h*2,:w,2] = imm.image[:,:,0]

        image[:h,w:w*2,0] = imm.image[:,:,3]

        image[:h,:w,0] = np.round(imm.image[:,:,2].astype(float)*imm.image[:,:,3].astype(float)/255).astype(np.uint8)
        image[:h,:w,1] = np.round(imm.image[:,:,1].astype(float)*imm.image[:,:,3].astype(float)/255).astype(np.uint8)
        image[:h,:w,2] = np.round(imm.image[:,:,0].astype(float)*imm.image[:,:,3].astype(float)/255).astype(np.uint8)
        return image

    if imm.pixel_format == 'bgra':
        h, w = imm.image.shape[:2]
        image = np.zeros((h*2, w*2, 3), dtype=np.uint8)
        image[h:h*2,:w,0] = imm.image[:,:,0]
        image[h:h*2,:w,1] = imm.image[:,:,1]
        image[h:h*2,:w,2] = imm.image[:,:,2]

        image[:h,w:w*2,0] = imm.image[:,:,3]

        image[:h,:w,0] = np.round(imm.image[:,:,0].astype(float)*imm.image[:,:,3].astype(float)/255).astype(np.uint8)
        image[:h,:w,1] = np.round(imm.image[:,:,1].astype(float)*imm.image[:,:,3].astype(float)/255).astype(np.uint8)
        image[:h,:w,2] = np.round(imm.image[:,:,2].astype(float)*imm.image[:,:,3].astype(float)/255).astype(np.uint8)
        return image

    raise ValueError("Imm with pixel format '{}' is not supported.".format(imm.pixel_format))

def view(image, max_width=640, as_ansi=True):
    '''Displays a BGR image.'''
    if max_width < image.shape[1]:
        height = image.shape[0]*max_width//image.shape[1]
        image = cv.resize(image, dsize=(max_width, height))
    if as_ansi:
        img2 = cv.cvtColor(image, cv.COLOR_BGR2RGB)
        print(cv.to_ansi(img2))
    else:
        cv.namedWindow('image')
        print("Press any key to exit.")
        cv.imshow('image', image)
        cv.waitKey(0)

async def main(args, context_vars: dict = {}):
    if args.imm_file.endswith('.imm'):
        imm = await cv.immload_asyn(args.imm_file, context_vars=context_vars)
        print("Image path: {}".format(args.imm_file))
        print("Pixel format: {}".format(imm.pixel_format))
        print("Resolution: {}x{}".format(imm.image.shape[1], imm.image.shape[0]))
        print("Meta:")
        print(imm.meta)
        view(get_image(imm), max_width=args.max_width, as_ansi=not args.use_highgui)
    else:
        if args.imm_file.lower().endswith('.jp2'):
            image_type = 'jpx'
        elif args.imm_file.lower().endswith('.jpg'):
            image_type = 'jpg'
        elif args.imm_file.lower().endswith('.jpeg'):
            image_type = 'jpg'
        else:
            image_type_detail = await filetype.image_match_asyn(args.imm_file, context_vars=context_vars)
            image_type = None if image_type_detail is None else image_type_detail.extension
        if image_type is None:
            print("Not an image file: {}".format(args.imm_file))
        else:
            print("Image path: {}".format(args.imm_file))
            print("File type: {}".format(image_type))
            image = await cv.imload(args.imm_file, flags=cv.IMREAD_UNCHANGED, context_vars=context_vars)
            if image_type.startswith('jp'):
                if image.shape[2] == 4:
                    pixel_format = 'bgra'
                elif image.shape[2] == 3:
                    pixel_format = 'bgr'
                elif image.shape[2] == 1:
                    pixel_format = 'gray'
                else:
                    pixel_format = None
            elif image_type == 'png':
                if image.shape[2] == 4:
                    pixel_format = 'rgba'
                elif image.shape[2] == 3:
                    pixel_format = 'rgb'
                elif image.shape[2] == 1:
                    pixel_format = 'gray'
                else:
                    pixel_format = None
            else:
                pixel_format = None
            print("Resolution: {}x{}".format(image.shape[1], image.shape[0]))
            if pixel_format is not None:
                imm = cv.Image(image, pixel_format=pixel_format)
                image = get_image(imm)
            view(image, max_width=args.max_width, as_ansi=not args.use_highgui)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Tool to view an image with metadata (IMM) file.")
    parser.add_argument('--max_width', type=int, default=640,
                        help="The maximum width to view. Default is 640.")
    parser.add_argument('-X', '--use_highgui', action='store_true',
                        help="Uses OpenCV's highgui module to display the image.")
    parser.add_argument('imm_file', type=str,
                        help="The file to view.")
    args = parser.parse_args()
    asyncio.run(main(args, context_vars={'async': True}))
