#!/home/travis/miniconda/bin/python

import numpy as np
import h5py
import sys
import os
from scipy.stats import kde
import argparse

# Arg-parser
parser = argparse.ArgumentParser(
    description="Plot property density with gaussian KDE")
parser.add_argument("--fmax", help="Max frequency to plot",
                    type=float, default=None)
parser.add_argument("--cutoff",
                    help=("Property below this value is included in data "
                          "before running Gaussian-KDE"),
                    type=float, default=None)
parser.add_argument('--nbins', type=int, default=100,
                    help=("Number of bins in which data are assigned, "
                          "i.e., determining resolution of plot")),
parser.add_argument("--title", dest="title", default=None, help="Plot title")
parser.add_argument('filenames', nargs='*')
args = parser.parse_args()

#
# Matplotlib setting
#
import matplotlib
matplotlib.use('Agg')            
import matplotlib.pyplot as plt
from matplotlib import rc
rc('text', usetex=True)
rc('font', family='serif')
# rc('font', serif='Times New Roman')
rc('font', serif='Liberation Serif')
# plt.rcParams['pdf.fonttype'] = 42

#
# Initial setting
#
if os.path.isfile(args.filenames[0]):
    f = h5py.File(args.filenames[0])
else:
    print("File %s doens't exist." % args.filenames[0])
    sys.exit(1)

if args.fmax:
    max_freq = args.fmax
else:
    max_freq = None

if args.title:
    title = args.title
else:
    title = None

#
# Data collection
#
freqs = []
mode_prop = []
weights = []

# Lifetime
for w, freq, gamma in zip(f['weight'],
                          f['frequency'][:],
                          f['gamma'][30]):
    tau = 1.0 / np.where(gamma > 0, gamma, -1) / (2 * 2 * np.pi)
    tau = np.where(tau > 0, tau, 0)
    if args.cutoff:
        tau = np.where(tau < args.cutoff, tau, 0)

    if max_freq is None:
        freqs += list(freq)
        mode_prop += list(tau)
        weights += [w] * len(freq)
    else:
        freqs += list(np.extract(freq < max_freq, freq))
        mode_prop += list(np.extract(freq < max_freq, tau))
        weights += [w] * np.sum(freq < max_freq)

x = np.array(freqs)
y = np.array(mode_prop)
z = weights

#
# Running Gaussian-KDE
#
# http://stackoverflow.com/questions/19390320/scatterplot-contours-in-matplotlib
#
nbins = 100
xmax = np.max(x)
ymax = np.max(y)

kernel = kde.gaussian_kde([x, y])
xi, yi = np.mgrid[x.min():x.max():nbins*1j, y.min():y.max():nbins*1j]
zi = kernel(np.vstack([xi.flatten(), yi.flatten()])).reshape(xi.shape)

zi_max = np.max(zi)
indices = []
for i, r_zi in enumerate((zi.T)[::-1]):
    if indices:
        indices.append(nbins - i - 1)
    elif np.max(r_zi) > zi_max / 10:
        indices = [nbins - i - 1]

ynbins = nbins ** 2 // len(indices)
xi, yi = np.mgrid[x.min():x.max():nbins*1j, y.min():y.max():ynbins*1j]
zi = kernel(np.vstack([xi.flatten(), yi.flatten()])).reshape(xi.shape)

#
# Plotting
#
x_cut = []
y_cut = []
z_cut = []
threshold = ymax / nbins * len(indices)
threshold = threshold / nbins * (nbins - 1)
epsilon = 1e-5
for xv, yv, zv in zip(x, y, z):
    if (epsilon < yv and yv < threshold and
        epsilon < xv and xv < xmax - epsilon):
        x_cut.append(xv)
        y_cut.append(yv)
        z_cut.append(zv)

fig = plt.figure()
plt.pcolormesh(xi[:,:nbins], yi[:,:nbins], zi[:,:nbins])
plt.colorbar()

plt.scatter(x_cut, y_cut, c=z_cut, s=0.1)

plt.xlim(xmin=0, xmax=xmax)
plt.ylim(ymin=0, ymax=(np.max(y_cut) + epsilon))
if title:
    plt.title(title, fontsize=20)
plt.xlabel('Phonon frequency (THz)', fontsize=18)
plt.ylabel('Lifetime (ps)', fontsize=18)

# plt.show()
fig.savefig("lifetime.png")
