#!/usr/bin/env python
from pyPheWAS.pyPhewasCorev2 import print_start_msg, display_kwargs
import math
import pandas as pd
import argparse
from pathlib import Path
import time

def parse_args():
    parser = argparse.ArgumentParser(description="pyPheWAS Phenotype Assignment Tool")

    parser.add_argument('--phenotype', required=True, type=str, help='Name of the input phenotype file')
    parser.add_argument('--group', required=False, default='', type=str, help='Name of the group file to add target variable map to')
    parser.add_argument('--groupout', required=True, type=str, help ='Name of the output group file')
    parser.add_argument('--phenotypeout', required=False, default='', type=str, help ='Name of the output phenotype file containing only selected case and control subjects') 
    parser.add_argument('--case_codes', required=True, type=str, help='Case ICD codes (filename or comma-separated list)')
    parser.add_argument('--ctrl_codes', required=False, default='', type=str,  help='Control ICD codes (filename or comma-separated list)')
    parser.add_argument('--exclude_codes', required=False, default='', type=str,  help='Exclusion ICD codes (filename, comma-separated list, or \"phewas_map\")')
    parser.add_argument('--code_freq', required=True, type=str, help='Minimum frequency of codes (If 2 comma-separated values are given and ctrl_codes is given, 2nd argument is applied to controls)')
    parser.add_argument('--thresh_type', required=False, type=str,default='absolute', help='Code frequency threshold type (absolute [default] or daily)')
    parser.add_argument('--response_var', required=False, type=str,default='target', help='Name of response variable (e.g. target, genotype, response [default=target])')
    parser.add_argument('--path', required=False, default='.', type=str, help='Path to all input files and destination of output files (default = current directory)')

    args = parser.parse_args()
    return args



def parse_codes(codes, codes_path, code_type):
    # reads codes from file (if needed) and converts into a normal list
    if codes.endswith('.csv') | codes.endswith('.txt'):
        print('Reading %s group codes from file' % code_type)
        with open(codes_path/codes,'r') as code_f:
            data = code_f.readlines()
        # remove white space and reformat into an array
        codes = set()
        [codes.update( x.replace(" ","").replace("\n","").split(',') ) for x in data]
        codes.discard('')
    else:
        codes = set(codes.replace(" ","").replace("\n","").split(','))
    return list(codes)



def export_phewas_exclusions(case_codes, phecode_lists, icd_list, path, groupname):
    # reformat phecode list for export
    phecode_lists[0]['ICD_TYPE'] = 9
    phecode_lists[1]['ICD_TYPE'] = 10
    phecode_lists[0].rename(columns={'ICD9 String':'ICD_String'}, inplace=True)
    phecode_lists[1].rename(columns={'ICD10 String':'ICD_String'}, inplace=True)
    case_phecodes = pd.concat(phecode_lists)
    icd_list[0]['ICD_TYPE'] = 9
    icd_list[1]['ICD_TYPE'] = 10
    icd_list[0].rename(columns={'ICD9 String':'ICD_String'}, inplace=True)
    icd_list[1].rename(columns={'ICD10 String':'ICD_String'}, inplace=True)
    ex_codes = pd.concat(icd_list)
    with open(path / (groupname +'_exclusion_list.csv'), 'w+') as out:
        out.write('Case ICD codes\n')
        out.write('\n'.join(case_codes) + '\n')
        out.write('\n--------------------------------------------------\n')
        out.write('Corresponding Case PheCodes\n')
        case_phecodes.to_csv(out, index=False, columns=['ICD_TYPE','PheCode','Phenotype','Excl. Phecodes', 'Excl. Phenotypes'])
        out.write('\n--------------------------------------------------\n')
        out.write('Related ICD Codes (Exclusions)\n')
        ex_codes.to_csv(out, index=False, columns=['ICD_TYPE','ICD_CODE','ICD_String','PheCode','Phenotype'])
    return



def get_exclusion_from_phewas(case_codes, path, groupname):
    # get related ICD codes from phewas map and save to a file
    print("\nGetting case-related ICD codes from PheWAS map...")
    from pyPheWAS.pyPhewasCorev2 import icd9_codes, icd10_codes
    # merge case code list with phewas maps to get phecode lists
    icd9_phecodes = icd9_codes[ icd9_codes.ICD_CODE.isin(case_codes) ].drop_duplicates(subset="PheCode")
    icd10_phecodes = icd10_codes[ icd10_codes.ICD_CODE.isin(case_codes) ].drop_duplicates(subset="PheCode")
    # from phecode list, get all phecode exclusion lists
    m9 = pd.isna(icd9_codes["PheCode"])
    for _, data in icd9_phecodes.iterrows():
        bounds = data['Excl. Phecodes'].replace(" ","").split('-')
        m9 |= (icd9_codes['PheCode'] >= bounds[0]) & (icd9_codes['PheCode'] <= bounds[1])
    m10 = pd.isna(icd10_codes["PheCode"])
    for _, data in icd10_phecodes.iterrows():
        bounds = data['Excl. Phecodes'].replace(" ","").split('-')
        m10 |= (icd10_codes['PheCode'] >= bounds[0]) & (icd10_codes['PheCode'] <= bounds[1])
    ex9 = icd9_codes[m9].drop_duplicates(subset=['ICD_CODE','PheCode'])
    ex10 = icd10_codes[m10].drop_duplicates(subset=['ICD_CODE','PheCode'])
    print("Saving related ICD codes to %s/exclusion_list.csv" %path)
    export_phewas_exclusions(case_codes, [icd9_phecodes, icd10_phecodes], [ex9, ex10], path, groupname)
    # turn that into an ICD code list
    ex9 = set( ex9['ICD_CODE'] )
    ex10 = set( ex10['ICD_CODE'] )
    return list( ex9.union(ex10) )



def find_and_count(df, codes):
    # finds records of the ICD codes provided in codes ('gen') & 
    # counts the total record count for each subject (specified by id)
    codes = [x+'\Z' for x in codes] # append \Z to force regex to find exact match
    df['gen'] = df['ICD_CODE'].str.match('|'.join(codes)) # find all ICD code matches
    df['gen'] = df['gen']*1 # convert to integer
    df['genc'] = df.groupby('id')['gen'].transform('sum') # count all instances
    return



"""
Print Start Message
"""
start = time.time()
print_start_msg()
print('\ncreatePhenotypeFile: Response Variable Assignment Tool\n')


"""
Retrieve and validate all arguments.
"""
args = parse_args()

kwargs = {
    'phenotype':args.phenotype,
    'group':args.group,
    'groupout':args.groupout,
    'phenotypeout':args.phenotypeout,
    'case_codes':args.case_codes,
    'ctrl_codes': args.ctrl_codes,
    'exclude_codes': args.exclude_codes,
    'code_freq':args.code_freq,
    'thresh_type':args.thresh_type,
    'response_var':args.response_var,
    'path':Path(args.path),
}

# Assert that files are valid
assert kwargs['phenotype'].endswith('.csv'), "%s is not a valid phenotype file, must be a .csv file" % (kwargs['phenotype'])
assert kwargs['groupout'].endswith('.csv'), "%s is not a valid output file, must be a .csv file" % (kwargs['groupout'])
if kwargs['group']:
    assert kwargs['group'].endswith('.csv'), "%s is not a valid output file, must be a .csv file" % (kwargs['group'])
if kwargs['phenotypeout']:
    assert kwargs['phenotypeout'].endswith('.csv'), "%s is not a valid output file, must be a .csv file" % (kwargs['phenotypeout'])

# Assert that options are valid
assert kwargs['thresh_type'] in ['absolute','daily'], "%s is not a valid threshold type" % kwargs['thresh_type']

# Print Arguments
display_kwargs(kwargs)
# Make all arguments local variables
locals().update(kwargs)

# Fill paths
phenotype = path / phenotype 
groupout = path / groupout
if group:
    group = path / group
if phenotypeout:
    phenotypeout = path / phenotypeout

# Assert that all files exist
assert phenotype.exists(), "%s does not exist" % phenotype
if group:
    assert group.exists(), "%s does not exist" % group
if case_codes.endswith('.csv') | case_codes.endswith('.txt'):
    assert (path/case_codes).exists(), "%s does not exist" % (path / case_codes)
if ctrl_codes.endswith('.csv') | ctrl_codes.endswith('.txt'):
    assert (path/ctrl_codes).exists(), "%s does not exist" % (path / ctrl_codes)
if exclude_codes.endswith('.csv') | exclude_codes.endswith('.txt'):
    assert (path/exclude_codes).exists(), "%s does not exist" % (path / exclude_codes)

# Read group file
if group:
    print('Reading group data from file')
    group_data = pd.read_csv(group)

# Make code frequency an integer
code_freq = code_freq.replace(" ","").split(',')
code_freq = [int(x) for x in code_freq]

"""
Parse codes
"""

case_codes = parse_codes(case_codes, path, "case")

if ctrl_codes:
    ctrl_codes = parse_codes(ctrl_codes, path, "control")

if exclude_codes:
    if exclude_codes == "phewas_map":
        exclude_codes = get_exclusion_from_phewas(case_codes, path, groupout.stem)
    else:
        exclude_codes = parse_codes(exclude_codes, path, "exclusion")
    # check for overlap between control list and exclusion list
    if ctrl_codes: 
        overlap = set(ctrl_codes).intersection( set(exclude_codes) )
        assert len(overlap) == 0, "Provided control code list overlap with exclusion code list (%s)" % ','.join(overlap)



"""
Find cases
"""
print('\nReading ICD data from file')
phen = pd.read_csv(phenotype,dtype={'ICD_CODE':str})

# for daily threshold, filter phenotype data so that there is only one instance of an ICD code 
# recorded within a single day for each subject
if thresh_type == 'daily':
    phen.drop_duplicates(subset=['id','ICD_CODE','AgeAtICD'], inplace=True)

print('Finding cases with ICD code(s): %s' % '|'.join(case_codes))
find_and_count(phen, case_codes)
case_ids = set( phen[phen['genc'] >= code_freq[0]]["id"] )

if len(case_ids) == 0:
    assert sum(phen['gen']) > 0, "No case subjects found (supplied case ICD codes not found)"
    assert False, "No case subjects found (supplied frequency threshold was not met)"

print('Found %d case subjects' % len(case_ids)) # drop case subjects from phen df
phen = phen[phen['genc'] < code_freq[0]]

"""
Process Exclusions
"""
# Inadequate threshold exclusion
ambiguous = len( set( phen[phen['genc'] > 0]["id"] ) )
phen = phen[phen['genc'] == 0]
print("\nDropping %d ambiguous subjects (have case codes but don't meet frequency threshold)" % ambiguous)
print("%d subjects left in potential control pool" % phen.drop_duplicates(subset='id').shape[0])

# Related or user-specified code exclusions
if exclude_codes:
    print('\nExcluding subjects with ICD code(s): %s' % '|'.join(exclude_codes))
    find_and_count(phen, exclude_codes)
    # drop excluded subjects from phen df
    excluded = len( set( phen[phen['genc'] > 0]["id"] ) )
    phen = phen[phen['genc'] == 0]
    print("%d subjects excluded" % excluded)
    print("%d subjects left in potential control pool\n" % phen.drop_duplicates(subset='id').shape[0])

assert phen.shape[0] > 0, "No subjects remaining for control group"

"""
Find controls
"""
if ctrl_codes:
    print('\nFinding controls with ICD code(s): %s' % '|'.join(ctrl_codes))
    find_and_count(phen, ctrl_codes)
    cf = code_freq[0] if len(code_freq) == 1 else code_freq[1]
    ctrl_ids = set( phen[phen['genc'] >= cf]["id"] )
else:
    ctrl_ids = set(phen["id"])

if len(ctrl_ids) == 0:
    assert sum(phen['gen']) > 0, "No control subjects found (supplied control ICD codes not found)"
    assert False, "No control subjects found (supplied frequency threshold was not met)"

print('Found %d control subjects' % len(ctrl_ids))

"""
Save Output
"""
print("\nFormatting outputs...")

# build output DF from sets of case & control IDs
tmp_df = pd.DataFrame(case_ids)
tmp_df.rename(columns={0:'id'}, inplace=True)
tmp_df[response_var] = 1

out_df = pd.DataFrame(ctrl_ids)
out_df.rename(columns={0:'id'}, inplace=True)
out_df[response_var] = 0

out_df = pd.concat([tmp_df, out_df], ignore_index=True)

if group:
    print('Merging response variable assignment with provided group file')
    out_df = pd.merge(out_df, group_data, how='inner',on='id', suffixes=('','_old'))

num_case = out_df[out_df[response_var]==1].shape[0]
num_ctrl = out_df[out_df[response_var]==0].shape[0]
print('Cases: %d\tControls: %d' %(num_case, num_ctrl))

print('Saving group file to %s' % groupout)
out_df.to_csv(groupout, index=False)

# optionally, save new phenotype file with just the selected subjects
if phenotypeout:
    print('\nMerging response variable assignment with phenotype file')
    phen = pd.read_csv(phenotype, dtype={'ICD_CODE':str}) # fresh read of phen file
    phen = pd.merge(out_df["id"], phen, how='inner', on='id')
    print('Saving new phenotype file to %s' % phenotypeout)
    phen.to_csv(phenotypeout, index=False)



"""
Calculate runtime
"""
interval = time.time() - start
hour = math.floor(interval/3600.0)
minute = math.floor((interval - hour*3600)/60)
second = math.floor(interval - hour*3600 - minute*60)

if hour > 0:
    time_str = '%dh:%dm:%ds' %(hour,minute,second)
elif minute > 0:
    time_str = '%dm:%ds' % (minute, second)
else:
    time_str = '%ds' % second

print('createPhenotypeFile Complete [Runtime: %s]' %time_str)
