#!/usr/bin/python

# Matplotlib prints annoying messages
import warnings
warnings.filterwarnings("ignore")

# Numerical stuff
import numpy as np
import matplotlib
from pandas import read_csv # pandas != panda
matplotlib.use('Agg')

# System utilities
import cPickle as pkl
import os
import os.path as op
import sys
import shutil
import csv
import logging
import traceback

# Analyses and Derivatives
import panda.methods.derivatives as der
import panda.methods.viz as viz

# Import pipeline specific utilities
import panda.utils.bids_s3 as bids_s3
from panda.utils import execute_cmd

# Import params (dependent on chan_locs)
from panda.default_config import params


def main(buck, dset, sub, ses, cred):
    # Organize directories
    file_name = ("sub-%s_ses-%s.pkl") % (sub, ses)
    path = ("sub-%s/ses-%s/eeg") % (sub, ses)
    remote_in_dir = ("data/%s") % (dset)
    local_in_dir = ("temp/%s") % (dset)
    dset_out = dset + '_out'
    remote_out_dir = ("data/%s") % (dset_out)
    local_out_dir = ("temp/%s") % (dset_out)
    os.makedirs(local_in_dir + '/' + path)
    os.makedirs(local_out_dir + '/' + path)

    # Set pipeline structure
    fs = params['functions']
    dvs = params['derivatives']
    pts = params['plots']

    # Make directories for qa output
    for i in params['p_global']['plot_folders']:
        p = local_out_dir + '/' + path + '/' \
            + params['p_global']['plot_folders'][i]
        os.makedirs(p)

    # Make directories for derivative output
    for i in params['p_global']['derivatives']:
        p = local_out_dir + '/' + path + '/' \
            + params['p_global']['derivatives'][i]
        os.makedirs(p)

    # if running locally, need to add credentials
    if cred != None:
       # extract credentials
        credfile = open(cred, 'rb')
        reader = csv.reader(credfile)
        rowcounter = 0
        for row in reader:
            if rowcounter == 1:
                public_access_key = str(row[1])
                secret_access_key = str(row[2])
            rowcounter = rowcounter + 1

        # set env vars to current credentials
        os.environ['AWS_ACCESS_KEY_ID'] = public_access_key
        os.environ['AWS_SECRET_ACCESS_KEY'] = secret_access_key
        os.environ['AWS_DEFAULT_REGION'] = 'us-east-1'
     

    # Start
    print
    print "Pipeline start:", "subject", sub,
    print "session", ses, "\n"

    # Get the data
    print "Saving data for subject from S3.", "\n"
    bids_s3.get_data(
       bucket = buck,
       remote_dir = remote_in_dir,
       local = local_in_dir,
       subj = sub,
       ses = ses,
       public=False,
       folder = True)


    # Get the channel location (global for all subjects in a set)
    print "Saving channel locations from S3.", "\n"
    bids_s3.get_data(
        bucket = buck,
        remote_dir = remote_in_dir + '/chan_locs.pkl',
        local = local_in_dir,
        public=True,
        folder = False)


    # Get which chans are eeg / eog
    eog_chans = np.array(params['p_global']['eog_chans']) - 1
    eeg_chans = np.setdiff1d(np.arange(128), eog_chans)
    params['p_global'].update({'eog_chans': eog_chans,
                               'eeg_chans': eeg_chans})

    # Load channel locations
    print "Loading channel locations from disk.", "\n"
    with open(local_in_dir + "/chan_locs.pkl", 'rb') as file_handle:
        chan_locs = pkl.load(file_handle)


    # Load the data
    print "Loading data matrix from disk.", "\n"
    with open(local_in_dir + '/' + path + '/' + file_name, 'rb') as file_handle:
        data = pkl.load(file_handle)

    # Set the channel location parameter
    params['p_global']['inter'].update({'chan_locs': chan_locs})

    # Initial data/metadata tuple
    curr = (data, {'funct': None,
                   'fig_path': None,
                   'out_path': local_out_dir + '/' + path,
                   'eog_in': True})

    # Run all the things
    for f in fs:
        funct = f.__name__
        curr[1]['funct'] = funct
        print 'Running function', funct
        curr = f(*curr, p_global = params['p_global'])
        for d in dvs:
            print '|---> making %s matrix derivative' % (d.__name__)
            d(*curr, p_global = params['p_global'])

        for p in pts:
            print '|---> making %s' % (p.__name__)
            p(*curr, p_global = params['p_global'])

    # Save final result
    with open(local_out_dir + '/' + path + '/' + file_name, 'wb') as file_handle:
        pkl.dump(curr[0].astype(np.float32), file_handle, -1)

    print "Pushing results to S3."
    cmd = "".join(['aws s3 cp ', local_out_dir, ' s3://', buck, '/',
                   remote_out_dir + '/', ' --recursive --acl public-read-write'])

    execute_cmd(cmd)

    print 'Success!'

if __name__ == "__main__":
    cred = None
    if len(sys.argv) == 6: 
        _, buck, dset, sub, ses, cred = sys.argv
    else: 
        _, buck, dset, sub, ses = sys.argv
    try:
        main(buck, dset, sub, ses, cred)
    except Exception as e:
        logging.error(traceback.format_exc())
        shutil.rmtree('temp')
    shutil.rmtree('temp')
