#!/usr/bin/python -u

"""Computes the MCD DTW metric for two sequences of mel cepstra."""

# Copyright 2014, 2015 Matt Shannon

# This file is part of mcd.
# See `License` for details of license and warranty.


from mcd import util
from mcd import dtw
import mcd.metrics_fast as mt

import os
import sys
import argparse
import math
import numpy as np

def main(rawArgs):
    parser = argparse.ArgumentParser(
        description=(
            'Computes the MCD DTW metric for two sequences of mel cepstra.'
            ' Mel cepstral distortion (MCD) is a measure of the difference'
            ' between two sequences of mel cepstra.'
            ' This utility computes the MCD between two sequences allowing for'
            ' possible differences in timing.'
            ' Specifically it uses dynamic time warping (DTW) to compute the'
            ' minimum MCD that can be obtained by "aligning" the two sequences'
            ' subject to certain constraints on the form of the alignment.'
        ),
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument(
        '--ext', dest='ext', default='mgc', metavar='EXT',
        help=(
            'file extension added to uttId to get file containing speech'
            ' parameters'
        )
    )
    parser.add_argument(
        '--param_order', dest='paramOrder', default=40, type=int,
        metavar='ORDER',
        help='parameter order of the cepstral files'
    )
    parser.add_argument(
        dest='natDir', metavar='NATDIR',
        help='directory containing natural speech parameters'
    )
    parser.add_argument(
        dest='synthDir', metavar='SYNTHDIR',
        help='directory containing synthetic speech parameters'
    )
    parser.add_argument(
        dest='uttIds', metavar='UTTID', nargs='+',
        help='utterance ids (ext will be appended to these)'
    )
    args = parser.parse_args(rawArgs[1:])

    costFn = mt.logSpecDbDist

    minCostTot = 0.0
    framesTot = 0
    for uttId in args.uttIds:
        print 'processing', uttId
        nat = util.readParamFileAlt(
            os.path.join(args.natDir, uttId+'.'+args.ext),
            paramOrder=args.paramOrder
        )
        synth = util.readParamFileAlt(
            os.path.join(args.synthDir, uttId+'.'+args.ext),
            paramOrder=args.paramOrder
        )
        # ignore 0th cepstral component
        nat = nat[:, 1:]
        synth = synth[:, 1:]

        minCost, path = dtw.dtw(nat, synth, costFn)
        frames = len(nat)

        minCostTot += minCost
        framesTot += frames

    print 'overall MCD = %f (%d frames)' % (minCostTot / framesTot, framesTot)

if __name__ == '__main__':
    main(sys.argv)
