#!/usr/bin/env python

import lpcdbeclient
import numpy as np
from time import strftime
import os, logging
import matplotlib.pyplot as plt
import sklearn.metrics as metrics

logger = logging.getLogger('root')

#____________________________________________________________________
def make_roc(y_pred, y_true, outfile):
    fpr, tpr, threshold = metrics.roc_curve(y_true, y_pred)
    roc_auc = metrics.auc(fpr, tpr)

    plt.title('Receiver Operating Characteristic')
    plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % roc_auc)
    plt.legend(loc = 'lower right')
    plt.plot([0, 1], [0, 1],'r--')
    plt.xlim([0, 1])
    plt.ylim([0, 1])
    plt.ylabel('True Positive Rate')
    plt.xlabel('False Positive Rate')
    plt.savefig(outfile)


#____________________________________________________________________
def main():

    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--max-count', type=int, default=None, help='Maximum number of events to inference')
    parser.add_argument('-m', '--module', type=str, default='test3-module-kl', help='Name of the azureml module to send the image to')
    parser.add_argument('-a', '--address', type=str, default='azurb.fnal.gov', help='Address of the DBE')
    parser.add_argument( '--testfile', type=str, default='/uscms/home/klijnsma/nobackup/data/dbe/test/test_file_0.h5', help='Path to the .h5 file in which the images are stored')
    args = parser.parse_args()

    data = lpcdbeclient.Data.from_file(args.testfile)

    client = lpcdbeclient.Client(
        module_name = args.module,
        address = args.address
        )

    with lpcdbeclient.utils.EnterExitPrint('Sending images for inferencing...', 'Done inferencing'):
        truths, predictions, times = client.inference_data(data, max_count=args.max_count)

    avg_time = np.mean(times) * 1000.
    err_time = np.std(times) * 1000.
    out_string = 'Average inference time: {0:.3f} +/- {1:.3f} ms'.format(avg_time, err_time)
    logger.info(out_string)


    #____________________________________________________________________
    # Save to an output dir

    outdir = 'lpcdbe-output-{0}'.format(strftime('%m%d-%H%M%S'))
    os.makedirs(outdir)

    with open(os.path.join(outdir, 'output.log'), 'w') as f:
        f.write('Ran on {0} at {1}\n'.format(
            os.environ['HOSTNAME'],
            strftime('%Y-%m-%d %H:%M:%S (%Z)')
            ))
        f.write(out_string)

    np.save(os.path.join(outdir, 'truths.npy'), truths)
    np.save(os.path.join(outdir, 'predictions.npy'), predictions)
    np.save(os.path.join(outdir, 'times.npy'), times)

    y_pred = np.array(predictions)[:,0,0]
    y_true = np.array(truths)[:,0,0]
    make_roc(y_pred, y_true, os.path.join(outdir, 'roc.png'))


#____________________________________________________________________
if __name__ == "__main__":
    main()
