#!/usr/bin/python -u

"""Time-warps a speech parameter sequence based on a reference."""

# Copyright 2014, 2015 Matt Shannon

# This file is part of mcd.
# See `License` for details of license and warranty.


from mcd import util
from mcd import dtw
import mcd.metrics_fast as mt

import os
import sys
import argparse
import math
import numpy as np

def main(rawArgs):
    parser = argparse.ArgumentParser(
        description=(
            'Time-warps a speech parameter sequence based on a reference.'
            ' Dynamic time warping (DTW) is used to compute the time warping.'
            ' By default a speech parameter sequence consisting of three'
            ' portions (mgc,lf0,bap) is warped to match the timing of the'
            ' reference speech parameter sequence.'
            ' Only the first portion is used when computing the warping.'
        ),
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument(
        '--exts', dest='exts', default='mgc,lf0,bap', metavar='EXTLIST',
        help=(
            'file extensions added to uttId to get file containing speech'
            ' parameters'
        )
    )
    parser.add_argument(
        '--param_orders', dest='paramOrders', default='40,1,5',
        metavar='ORDERLIST',
        help='orders of the parameter files (mgc,lf0,bap)'
    )
    parser.add_argument(
        dest='natDir', metavar='NATDIR',
        help='directory containing natural speech parameters'
    )
    parser.add_argument(
        dest='synthDir', metavar='SYNTHDIR',
        help='directory containing synthetic speech parameters'
    )
    parser.add_argument(
        dest='outDir', metavar='OUTDIR',
        help='directory to output warped speech parameters to'
    )
    parser.add_argument(
        dest='uttIds', metavar='UTTID', nargs='+',
        help='utterance ids (ext will be appended to these)'
    )
    args = parser.parse_args(rawArgs[1:])

    costFn = mt.logSpecDbDist

    paramOrders = [
        int(paramOrderStr)
        for paramOrderStr in args.paramOrders.split(',')
    ]
    assert paramOrders
    mgcParamOrder = paramOrders[0]

    exts = args.exts.split(',')
    assert len(exts) == len(paramOrders)
    mgcExt = exts[0]

    minCostTot = 0.0
    framesTot = 0
    for uttId in args.uttIds:
        print 'processing', uttId
        nat = util.readParamFileAlt(
            os.path.join(args.natDir, uttId+'.'+mgcExt),
            paramOrder=mgcParamOrder
        )
        synth = util.readParamFileAlt(
            os.path.join(args.synthDir, uttId+'.'+mgcExt),
            paramOrder=mgcParamOrder
        )
        # ignore 0th cepstral component
        nat = nat[:, 1:]
        synth = synth[:, 1:]

        minCost, path = dtw.dtw(nat, synth, costFn)
        frames = len(nat)

        minCostTot += minCost
        framesTot += frames

        print 'MCD = %f (%d frames)' % (minCost / frames, frames)

        pathCosts = [ costFn(nat[i], synth[j]) for i, j in path ]
        synthIndexSeq = dtw.projectPathBestCost(path, pathCosts)
        assert len(synthIndexSeq) == len(nat)

        uniqueFrames = len(set(synthIndexSeq))
        repeatedFrames = len(synthIndexSeq) - uniqueFrames
        droppedFrames = len(synth) - uniqueFrames
        assert len(synth) - droppedFrames + repeatedFrames == len(nat)
        print ('warping %s frames -> %s frames (%s repeated, %s dropped)' %
               (len(synth), len(nat), repeatedFrames, droppedFrames))
        print

        for paramOrder, ext in zip(paramOrders, exts):
            synthFull = util.readParamFileAlt(
                os.path.join(args.synthDir, uttId+'.'+ext),
                paramOrder=paramOrder
            )
            synthFullWarped = dtw.warpGeneral(synthFull, synthIndexSeq)
            util.writeParamFile(
                synthFullWarped,
                os.path.join(args.outDir, uttId+'.'+ext),
                paramOrder=paramOrder
            )

    print 'overall MCD = %f (%d frames)' % (minCostTot / framesTot, framesTot)

if __name__ == '__main__':
    main(sys.argv)
