#!/usr/bin/python -u

"""Computes the MCD metric for two sequences of mel cepstra."""

# Copyright 2014, 2015 Matt Shannon

# This file is part of mcd.
# See `License` for details of license and warranty.


from mcd import util
import mcd.metrics_fast as mt

import os
import sys
import argparse
import re
import math
import numpy as np

def main(rawArgs):
    parser = argparse.ArgumentParser(
        description=(
            'Computes the MCD metric for two sequences of mel cepstra.'
            ' Mel cepstral distortion (MCD) is a measure of the difference'
            ' between two sequences of mel cepstra.'
            ' This utility computes the MCD between two sequences of equal'
            ' length that are assumed to already be "aligned" in terms of'
            ' their timing.'
            ' Optionally certain segments (e.g. silence) can be omitted from'
            ' the calculation.'
        ),
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument(
        '--ext', dest='ext', default='mgc', metavar='EXT',
        help=(
            'file extension added to uttId to get file containing speech'
            ' parameters'
        )
    )
    parser.add_argument(
        '--param_order', dest='paramOrder', default=40, type=int,
        metavar='ORDER',
        help='parameter order of the cepstral files'
    )
    parser.add_argument(
        '--remove_segments', dest='removeSegments', default=None,
        metavar='LABELREGEX',
        help='regex of segment labels to remove'
    )
    parser.add_argument(
        '--labels_dir', dest='labelsDir', default=None,
        metavar='LABELSDIR',
        help=(
            'directory containing phone-level label files (used for segment'
            ' removal)'
        )
    )
    parser.add_argument(
        '--frame_period', dest='framePeriod', default=0.005, type=float,
        metavar='FRAMEPERIOD',
        help='frame period in seconds (used for segment removal)'
    )
    parser.add_argument(
        dest='natDir', metavar='NATDIR',
        help='directory containing natural speech parameters'
    )
    parser.add_argument(
        dest='synthDir', metavar='SYNTHDIR',
        help='directory containing synthetic speech parameters'
    )
    parser.add_argument(
        dest='uttIds', metavar='UTTID', nargs='+',
        help='utterance ids (ext will be appended to these)'
    )
    args = parser.parse_args(rawArgs[1:])
    assert (args.removeSegments is None) == (args.labelsDir is None)
    if args.removeSegments is not None:
        print ('NOTE: removing segments matching regex \'%s\' using labels in'
               ' %s' % (args.removeSegments, args.labelsDir))
    reRemoveSegments = (None if args.removeSegments is None
                        else re.compile(args.removeSegments))

    costFn = mt.logSpecDbDist

    costTot = 0.0
    framesTot = 0
    for uttId in args.uttIds:
        print 'processing', uttId
        nat = util.readParamFileAlt(
            os.path.join(args.natDir, uttId+'.'+args.ext),
            paramOrder=args.paramOrder
        )
        synth = util.readParamFileAlt(
            os.path.join(args.synthDir, uttId+'.'+args.ext),
            paramOrder=args.paramOrder
        )
        # ignore 0th cepstral component
        nat = nat[:, 1:]
        synth = synth[:, 1:]

        assert len(nat) == len(synth)

        if reRemoveSegments is None:
            cost = sum([
                costFn(natFrame, synthFrame)
                for natFrame, synthFrame in zip(nat, synth)
            ])
            frames = len(nat)
        else:
            alignment = util.readLabFile(
                os.path.join(args.labelsDir, uttId+'.lab'),
                framePeriod=args.framePeriod
            )
            alignmentInclude = [
                (startTime, endTime, not reRemoveSegments.search(label))
                for startTime, endTime, label in alignment
            ]
            includeFrames = list(util.expandAlignment(alignmentInclude))
            assert len(includeFrames) == len(nat)
            costs = [
                costFn(natFrame, synthFrame)
                for natFrame, synthFrame, includeFrame in zip(nat, synth,
                                                              includeFrames)
                if includeFrame
            ]
            cost = sum(costs)
            frames = len(costs)

        costTot += cost
        framesTot += frames

    print 'overall MCD = %f (%d frames)' % (costTot / framesTot, framesTot)

if __name__ == '__main__':
    main(sys.argv)
