#!/usr/bin/env python3
import argparse
from enum import Enum
from pathlib import Path
import sys

try:
    from rosbags.highlevel import AnyReader as Reader
except (ImportError, ModuleNotFoundError):
    print("Error: package 'rosbags' not found")
    sys.exit(1)

import motion3d as m3d


def pose_to_m3d(pose):
    translation = [pose.position.x, pose.position.y, pose.position.z]
    quaternion = [pose.orientation.w, pose.orientation.x, pose.orientation.y, pose.orientation.z]
    return m3d.QuaternionTransform(translation, quaternion)


def transform_to_m3d(transform):
    translation = [transform.translation.x, transform.translation.y, transform.translation.z]
    quaternion = [transform.rotation.w, transform.rotation.x, transform.rotation.y, transform.rotation.z]
    return m3d.QuaternionTransform(translation, quaternion)


def msg_to_transform(msg, base_frame=None, child_frame=None):
    if msg.__msgtype__ == 'geometry_msgs/msg/PoseStamped':
        return pose_to_m3d(msg.pose)

    if msg.__msgtype__ == 'geometry_msgs/msg/PoseWithCovarianceStamped':
        return pose_to_m3d(msg.pose.pose)

    if msg.__msgtype__ == 'geometry_msgs/msg/TransformStamped':
        # check arguments
        if base_frame is None or child_frame is None:
            print("Error: base frame and child frame are required for TransformStamped messages")
            sys.exit(1)

        # check frames
        if msg.header.frame_id == base_frame and msg.child_frame_id == child_frame:
            return transform_to_m3d(msg.transform)
        return None

    print(f"Error: message type '{msg.__msgtype__}' is not supported")
    sys.exit(1)


def convert_rosbag_poses(src, topic, output_filename, use_msg_stamp=False, binary=False, **kwargs):
    # open bag
    with Reader([Path(src)]) as reader:
        # check topic and get message type
        if topic in reader.topics and len(reader.topics[topic].connections) > 0:
            connections = reader.topics[topic].connections
        else:
            print(f"Error: topic '{topic}' not found in rosbag")
            sys.exit(1)
        
        # create container
        container = m3d.TransformContainer(True, True)

        # iterate topic
        frame_id = None
        for conn, t, data in reader.messages(connections=connections):
            # deserialize
            msg = reader.deserialize(data, conn.msgtype)

            # transform
            transform = msg_to_transform(msg, **kwargs)
            if transform is None:
                continue

            # stamp
            if use_msg_stamp:
                stamp = m3d.Time.FromNSec(t)
            else:
                stamp = m3d.Time.FromSecNSec(msg.header.stamp.sec, msg.header.stamp.nanosec)

            # append
            container.append(stamp, transform)

            # frame id
            if frame_id is None:
                frame_id = msg.header.frame_id
        
        # check container size
        if container.size() == 0:
            print(f"Error: no transformations found")
            sys.exit(1)

        # create motion data
        motion_data = m3d.MotionData(m3d.TransformType.kQuaternion, container)
        if frame_id is not None:
            motion_data.setFrameId(frame_id)

        # write and check
        file_type = m3d.M3DFileType.kBinary if binary else m3d.M3DFileType.kASCII
        status = m3d.M3DWriter.write(output_filename, motion_data, file_type)
        if status != m3d.M3DIOStatus.kSuccess:
            print(f"Error while writing: {m3d.M3DIOStatus(status)} ({status})")
            sys.exit(1)


def main():
    parser = argparse.ArgumentParser(description="Convert poses from ROS1 or ROS2 rosbag.")
    parser.add_argument('src', type=str, help="Input rosbag")
    parser.add_argument('topic', type=str, help="ROS Topic")
    parser.add_argument('output', type=str, help="Output motion data (*.m3d)")
    parser.add_argument('--msg-stamp', action='store_true', help="Use message stamp instead of header stamp")
    parser.add_argument('--binary', action='store_true', help="Write binary")
    parser.add_argument('--base-frame', type=str, default=None, help="Base frame for TransformStamped messages")
    parser.add_argument('--child-frame', type=str, default=None, help="Child frame for TransformStamped messages")
    args = parser.parse_args()

    convert_rosbag_poses(args.src, args.topic, args.output, use_msg_stamp=args.msg_stamp, binary=args.binary,
                         base_frame=args.base_frame, child_frame=args.child_frame)


if __name__ == '__main__':
    main()
