#!/usr/bin/env python3
import argparse
from enum import auto, Enum
import sys
import warnings

import numpy as np

import motion3d as m3d


class KittiFrame(Enum):
    CAM0 = auto()
    CAM1 = auto()
    CAM2 = auto()
    CAM3 = auto()
    VELO = auto()


def read_calib_file(filepath):
    """
    Read in a calibration file and parse into a dictionary.

    Note: The implementation is adopted from pykitti.utils.read_calib_file():
          https://github.com/utiasSTARS/pykitti/blob/master/pykitti/utils.py
    """
    data = {}
    with open(filepath, 'r') as f:
        for line in f.readlines():
            key, value = line.split(':', 1)
            # The only non-float values in these files are dates, which
            # we don't care about anyway
            try:
                data[key] = np.array([float(x) for x in value.split()])
            except ValueError:
                pass
    return data


def load_kitti_odometry_calib(calib_filename, frame=None):
    """
    Load and compute intrinsic and extrinsic calibration parameters.

    Note: The implementation is adopted from pykitti.odometry._get_file_lists():
          https://github.com/utiasSTARS/pykitti/blob/master/pykitti/odometry.py
    """
    # Load the calibration file
    filedata = read_calib_file(calib_filename)

    # Create 3x4 projection matrices
    P_rect_00 = np.reshape(filedata['P0'], (3, 4))
    P_rect_10 = np.reshape(filedata['P1'], (3, 4))
    P_rect_20 = np.reshape(filedata['P2'], (3, 4))
    P_rect_30 = np.reshape(filedata['P3'], (3, 4))

    # Compute the rectified extrinsics from cam0 to camN
    T1 = np.eye(4)
    T1[0, 3] = P_rect_10[0, 3] / P_rect_10[0, 0]
    T2 = np.eye(4)
    T2[0, 3] = P_rect_20[0, 3] / P_rect_20[0, 0]
    T3 = np.eye(4)
    T3[0, 3] = P_rect_30[0, 3] / P_rect_30[0, 0]

    # Compute the velodyne to rectified camera coordinate transforms
    # T_cam0_velo is the frame transformation cam0 -> velodyne
    T_cam0_velo = np.reshape(filedata['Tr'], (3, 4))
    T_cam0_velo = np.vstack([T_cam0_velo, [0, 0, 0, 1]])
    T_cam1_velo = T1.dot(T_cam0_velo)
    T_cam2_velo = T2.dot(T_cam0_velo)
    T_cam3_velo = T3.dot(T_cam0_velo)

    # Select coordinate frame
    if frame == KittiFrame.CAM0:
        return m3d.MatrixTransform(T_cam0_velo, unsafe=True).inverse()
    elif frame == KittiFrame.CAM1:
        return m3d.MatrixTransform(T_cam1_velo, unsafe=True).inverse()
    elif frame == KittiFrame.CAM2:
        return m3d.MatrixTransform(T_cam2_velo, unsafe=True).inverse()
    elif frame == KittiFrame.CAM3:
        return m3d.MatrixTransform(T_cam3_velo, unsafe=True).inverse()
    elif frame == KittiFrame.VELO:
        return m3d.MatrixTransform()
    else:
        raise NotImplementedError(f"Coordinate frame '{frame}' not implemented")


def convert_kitti_odometry_poses(poses_filename, output_filename, times_filename=None,
                                 calib_filename=None, frame=None, 
                                 binary=False, precision=16, normalize=False):
    # read data
    poses_data = np.loadtxt(poses_filename)
    times_data = np.loadtxt(times_filename) if times_filename is not None else None

    # check size
    assert(poses_data.shape[1] == 12)
    if times_data is not None:
        assert(poses_data.shape[0] == times_data.shape[0])

    # create container
    tdata = m3d.TransformContainer(has_stamps=times_filename is not None, has_poses=True)
    for row in range(poses_data.shape[0]):
        transform = m3d.MatrixTransform(poses_data[row, :], unsafe=True)
        if normalize:
            transform.normalized_()
        if times_data is None:
            tdata.append(transform)
        else:
            timestamp = m3d.Time.FromNSec(np.uint64(times_data[row] * 1e9))
            tdata.append(timestamp, transform)

    # create motion data
    motion_data = m3d.MotionData(m3d.TransformType.kMatrix, tdata)

    # add origin
    if calib_filename is not None:
        if frame is None:
            warnings.warn("No frame given, cannot load calibration.")
        else:
            calib = load_kitti_odometry_calib(calib_filename, frame)
            motion_data.setOrigin(calib)
            motion_data.setFrameId(frame.name.lower())

    # write and check
    file_type = m3d.M3DFileType.kBinary if binary else m3d.M3DFileType.kASCII
    status = m3d.M3DWriter.write(output_filename, motion_data, file_type, precision=precision)
    if status != m3d.M3DIOStatus.kSuccess:
        print(f"Error while writing: {m3d.M3DIOStatus(status)} ({status})")
        sys.exit(1)


class ParseEnum(argparse.Action):
    def __init__(self, option_strings, enum_type, *args, **kwargs):
        self._enum_type = enum_type
        kwargs['choices'] = [f.name for f in enum_type]
        if 'default' not in kwargs:
            kwargs['default'] = None
        super(ParseEnum, self).__init__(option_strings=option_strings, *args, **kwargs)

    def __call__(self, parser, namespace, values, option_string=None):
        if isinstance(values, (list, tuple)):
            value = str(values[0])
        else:
            value = str(values)
        try:
            enum_value = self._enum_type[value]
            setattr(namespace, self.dest, enum_value)
        except KeyError:
            parser.error("Input {} is not a field of enum {}".format(values, self._enum_type))


def main():
    parser = argparse.ArgumentParser(description="Convert KITTI odometry poses.")
    parser.add_argument('poses', type=str, help="Input poses (*.txt)")
    parser.add_argument('output', type=str, help="Output motion data (*.m3d)")
    parser.add_argument('--times', type=str, default=None, help="Timestamps file (*.txt)")
    parser.add_argument('--calib', type=str, default=None, help="Calibration file (*.txt)")
    parser.add_argument('--frame', action=ParseEnum, enum_type=KittiFrame, default=None,
                        help="Coordinate frame (usually CAM0)")
    parser.add_argument('--binary', action='store_true', help="Write binary")
    parser.add_argument('--precision', type=int, default=16, help="ASCII precision")
    parser.add_argument('--normalize', action='store_true', help="Normalize transformations before saving")
    args = parser.parse_args()

    convert_kitti_odometry_poses(args.poses, args.output, times_filename=args.times,
                                 calib_filename=args.calib, frame=args.frame,
                                 binary=args.binary, precision=args.precision,
                                 normalize=args.normalize)


if __name__ == '__main__':
    main()
