#! /usr/bin/env python
import os
import click
from tqdm import tqdm

@click.group()
def main():
    pass


@main.command()
@click.option('-i', '--input-path')
@click.option('-o', '--output-path', default='')
def video2frame(input_path, output_path):
    import cv2
    # import moviepy.editor as mpy 
    if output_path == '':
        output_path = os.path.splitext(input_path)[0]
    os.makedirs(output_path, exist_ok=True)
    cap = cv2.VideoCapture(input_path)
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    for i in tqdm(range(frame_count), total=frame_count):
        ret, frame = cap.read()
        cv2.imwrite(f'{output_path}/{i:05d}.png', frame)

    print(f'Save to {output_path}')

@main.command()
@click.option('-i', '--input-path')
@click.option('-o', '--output-path', default='')
def frame2video(input_path, output_path):
    import moviepy.editor as mpy 
    if output_path == '':
        output_path = input_path + '.mp4'
    
    input_frames = [os.path.join(input_path, fname) for fname in sorted(os.listdir(input_path))]
    mpy.ImageSequenceClip(input_frames, fps=24).write_videofile(output_path)
    print(f'Save to {output_path}')


@main.command(help="将图片名称按照数字顺序重命名")
@click.option('-i', '--input-path', default='.')
@click.option('-o', '--output-path', default='.')
def rename_frames(input_path, output_path):
    imgs = sorted(os.listdir(input_path), key=lambda x: int(x.split('.')[0]))
    os.makedirs(output_path, exist_ok=True)
    for i, img in enumerate(imgs):
        src = os.path.join(input_path, img)
        ext = img.split('.')[-1]
        dst = os.path.join(output_path, f"{i:05d}.{ext}")
        os.system(f"cp {src} {dst}")
    
    
@main.command()
@click.option('-i', '--input-path')
@click.option('-o', '--output-path', default='')
def video2audio(input_path, output_path):
    import moviepy.editor as mpy 
    if output_path == '':
        output_path = os.path.splitext(input_path)[0] + '.mp3'
    mpy.VideoFileClip(input_path).audio.write_audiofile(output_path)


@main.command()
@click.argument('input-path')
def detect_reid_pkl(input_path):
    import pandas as pd
    from pprint import pprint
    pkl = pd.read_pickle(input_path)
    sum = 0
    for key, value in pkl.items():
        sum += len(value)
        print(key, len(value))
        if len(value) > 0:
            pprint(value[0])
    print("Total: ", sum)

if __name__ == '__main__':
    main()