#!/usr/bin/env python3
import argparse
import os
import warnings

try:
    import matplotlib
    import matplotlib.pyplot as plt
    from mpl_toolkits import mplot3d
    if os.name == 'posix' and 'DISPLAY' not in os.environ:
        print("No display found, using 'Agg' as matplotlib backend.")
        matplotlib.use('Agg')
except (ImportError, ModuleNotFoundError):
    print("Error: package 'matplotlib' not found")
    import sys
    sys.exit(1)

import numpy as np

import motion3d as m3d


def plot_positions(ax, positions, *args, **kwargs):
    min_positions = np.min(positions, axis=0)
    max_positions = np.max(positions, axis=0)
    mid_positions = (max_positions + min_positions) / 2
    max_range = np.max((max_positions - min_positions) / 2)

    ax.scatter3D(positions[:, 0], positions[:, 1], positions[:, 2], *args, **kwargs)
    ax.set_xlim(mid_positions[0] - max_range, mid_positions[0] + max_range)
    ax.set_ylim(mid_positions[1] - max_range, mid_positions[1] + max_range)
    ax.set_zlim(mid_positions[2] - max_range, mid_positions[2] + max_range)


def plot_orientations(ax, positions1, positions2, step, *args, **kwargs):
    for index in range(0, len(positions1), step):
        ax.plot([positions1[index, 0], positions2[index, 0]],
                [positions1[index, 1], positions2[index, 1]],
                [positions1[index, 2], positions2[index, 2]], '-', *args, **kwargs)


def plot_poses(filename, pos_step=1, axes_step=0, axes_length=1.0, use_origin=False, skip_show=False):
    # load data
    motion_data, _ = m3d.M3DReader.read(filename, unsafe=True)

    # convert to poses
    poses = motion_data.getTransforms()
    poses.removeStamps_()
    poses.normalized_()
    poses.asType_(m3d.TransformType.kMatrix)
    poses.asPoses_()

    # transform to origin
    if use_origin:
        origin = motion_data.getOrigin()
        if origin is None:
            warnings.warn("No origin given.")
        else:
            poses.changeFrame_(origin.normalized().inverse())

    # create figure
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    # create plot data
    plot_data = {'pos': [], 'ax': [], 'ay': [], 'az': []}
    for p in poses:
        matrix = p.getMatrix()
        plot_data['pos'].append(matrix @ np.array([0.0, 0.0, 0.0, 1.0]))
        if axes_step > 0 and axes_length > 0:
            plot_data['ax'].append(matrix @ np.array([axes_length, 0.0, 0.0, 1.0]))
            plot_data['ay'].append(matrix @ np.array([0.0, axes_length, 0.0, 1.0]))
            plot_data['az'].append(matrix @ np.array([0.0, 0.0, axes_length, 1.0]))
    plot_data = {k: np.array(v) for k, v in plot_data.items()}

    # plot data
    plot_positions(ax, plot_data['pos'][::pos_step])
    if axes_step > 0 and axes_length > 0:
        plot_orientations(ax, plot_data['pos'], plot_data['ax'], step=pos_step * axes_step, c='r')
        plot_orientations(ax, plot_data['pos'], plot_data['ay'], step=pos_step * axes_step, c='g')
        plot_orientations(ax, plot_data['pos'], plot_data['az'], step=pos_step * axes_step, c='b')

    # axes labels
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_zlabel('z')

    # show figure
    if not skip_show:
        if os.name == 'posix' and 'DISPLAY' not in os.environ:
            print("No display found, skipping show.")
        else:
            plt.show()


def main():
    parser = argparse.ArgumentParser(description="Plot poses stored in m3d format.")
    parser.add_argument('data', type=str, help="Motion data (*.m3d)")
    parser.add_argument('--pos-step', type=int, default=1, help="Step size between position points (default: 1)")
    parser.add_argument('--axes-step', type=int, default=0, help="Step size between orientation axes, use 0 to hide axes (default: 0)")
    parser.add_argument('--axes-length', type=float, default=1.0, help="Length of orientation axes (default: 1.0)")
    parser.add_argument('--origin', action='store_true', help="Transform poses to origin")
    parser.add_argument('--skip-show', action='store_true', help=argparse.SUPPRESS)
    args = parser.parse_args()

    plot_poses(args.data, pos_step=args.pos_step, axes_step=args.axes_step, axes_length=args.axes_length, 
               use_origin=args.origin, skip_show=args.skip_show)


if __name__ == '__main__':
    main()
