#!/usr/bin/env python

from pyPheWAS.pyPhewasCorev2 import *
import pandas as pd
import numpy as np
import argparse
import time
import os.path as osp
from pathlib import Path
import math


def parse_args():
    parser = argparse.ArgumentParser(description="pyPheWAS Date to Age Conversion Tool")

    parser.add_argument('--phenotype', required=True, type=str, help='Name of the phenotype file (e.g. icd9_data.csv)')
    parser.add_argument('--group', required=True, type=str, help ='Name of the group file (e.g. groups.csv)')
    parser.add_argument('--phenotypeout', required=True, type=str, help='Name of the output file (original phenotype data + event ages)')
    parser.add_argument('--eventcolumn', required=True, type=str, help='Name of the event column in the phenotype file')
    parser.add_argument('--etype', required=True, type=str, help='Type of event data (CPT or ICD)')
    parser.add_argument('--path', required=False, default='.', type=str, help='Path to all input files and destination of output files')
    parser.add_argument('--precision', required=False, default=5, type=int, help='Decimal precision of age in the output file (default: 5)')
    parser.add_argument('--dob_column', required=False, default='DOB',type=str, help='Name of the birth date column in the group file (default: DOB)')

    args = parser.parse_args()
    return args

"""
Print Start Message
"""
start = time.time()
print_start_msg()
print('\nconvertEventToAge: Date to Age Conversion Tool\n')


"""
Retrieve and validate all arguments.
"""
args = parse_args()

kwargs = {'phenotype': args.phenotype,
          'group': args.group,
          'path': Path(args.path),
          'phenotypeout': args.phenotypeout,
          'eventcolumn': args.eventcolumn,
          'precision': args.precision,
          'dob_column': args.dob_column,
          'etype':args.etype
}

# Assert that valid files were given
assert kwargs['phenotype'].endswith('.csv'), "%s is not a valid phenotype file, must be a .csv file" % (kwargs['phenotype'])
assert kwargs['group'].endswith('.csv'), "%s is not a valid group file, must be a .csv file" % (kwargs['group'])
assert kwargs['phenotypeout'].endswith('.csv'), "%s is not a valid output file, must be a .csv file" % (kwargs['phenout'])

# Assert that valid event type was given
assert kwargs['etype'] in ['CPT','ICD'], "%s is not a valid data type. Must be CPT or ICD" % (kwards['type'])

# Print Arguments
display_kwargs(kwargs)
# Make all arguments local variables
locals().update(kwargs)

# Fill paths
phenotype = path / phenotype
group = path / group
phenotypeout = path / phenotypeout

# Assert that all files exist
assert osp.exists(phenotype), "%s does not exist" % phenotype
assert osp.exists(group), "%s does not exist" % group


"""
Read Files
"""
print('Reading input files')
group_df = pd.read_csv(group)
phen = pd.read_csv(phenotype)
out_cols = list(phen.columns)


"""
Convert Specified Event to Age
"""
print('Starting conversion')
group_df['nDOB'] = pd.to_datetime(group_df[dob_column], infer_datetime_format=True)
phen['nEvent_date'] = pd.to_datetime(phen[eventcolumn], infer_datetime_format=True)
df = pd.merge(group_df, phen, on='id')

df['AgeAt'+etype] = (df['nEvent_date'] - df['nDOB']).astype('timedelta64[D]')/365.2425
df['AgeAt'+etype] = df['AgeAt'+etype].round(precision)

neg_mask = df['AgeAt'+etype] < 0.0
if np.any(neg_mask):
    error_subs = df[neg_mask].drop_duplicates(subset='id')[['id']]
    nsub = error_subs.shape[0]
    psub = float(nsub) / float(df.drop_duplicates(subset='id').shape[0]) * 100.0
    print('\nWARNING -- %d events from %d subjects have negative ages' %(sum(neg_mask), nsub))
    print('Removing %d (%.2f%%) subjects and saving to %s' % (nsub, psub, path / 'age_calc_error_*.csv'))
    error_subs['REMOVE'] = 1
    # remove subjects from group
    group_df = group_df.merge(error_subs, on='id', how='left')
    tmp = group_df[group_df['REMOVE'] == 1].copy()
    tmp.drop(columns=['REMOVE','nDOB'], inplace=True)
    f = path / 'age_calc_error_group.csv'
    tmp.to_csv(f, index=False)
    # remove subjects from ICD 
    df = df.merge(error_subs, on='id', how='left')
    tmp = df[df['REMOVE'] == 1].copy()
    tmp.drop(columns=['REMOVE','nDOB','nEvent_date'], inplace=True)
    f = path / 'age_calc_error_phen.csv'
    tmp.to_csv(f, index=False)
    df = df[df['REMOVE'] != 1]
    # Save new group
    f_basename = group.stem + '__fixed.csv'
    f = path / f_basename
    print('Saving new group file (%d subjects removed) to %s\n' %(nsub, f))
    group_df = group_df[group_df['REMOVE'] !=1]
    group_df.drop(columns=['REMOVE','nDOB'], inplace=True)
    group_df.to_csv(f, index=False)
    
    

out_cols.append('AgeAt'+etype)

print('Saving %s data to %s' % (etype,phenotypeout))
df.to_csv(phenotypeout, index=False, columns=out_cols)


interval = time.time() - start
hour = math.floor(interval/3600.0)
minute = math.floor((interval - hour*3600)/60)
second = math.floor(interval - hour*3600 - minute*60)

if hour > 0:
    time_str = '%dh:%dm:%ds' %(hour,minute,second)
elif minute > 0:
    time_str = '%dm:%ds' % (minute, second)
else:
    time_str = '%ds' % second

print('convertEventToAge Complete [Runtime: %s]' %time_str)