#!/usr/bin/env python
# -*- coding: utf-8 -*-

'''
Created on January 14th, 2015

@author: Benjamin Yvernault, Electrical Engineering, Vanderbilt University

Script to generate the folder for a NDAR submission using the template image03.
Inputs:
    Information about XNAT: Project/Subjects
    Information outside XNAT: CSV file with the following header subject_id_xnat,GUID,study_subject_id,interview_date,interview_age,gender
    Information for scan: CSV with the following header: xnat_scantype,xnat_series_description,assessor_type_qc,ndar_scantype,image_num_dimensions,slice_acquisition,pulse_sequence,citation,experiment_id,image_file,image_thumbnail_file,data_file2,data_file2_definition
    
Example of values below for T1,fMRI,dti:

xnat_scantype,xnat_series_description,assessor_type_qc,ndar_scantype,image_num_dimensions,slice_acquisition,pulse_sequence,citation,experiment_id,image_file,image_thumbnail_file,data_file2,data_file2_definition
T1,,VBMQA,MR structural,3,,MPRAGE(GE),DOI: 10.2174/1573405054038726,,NIFTI,SNAPSHOTS,,
fMRI,,fMRIQA,fMRI,4,4,BOLD T2* EPI (GE),DOI: 10.1002/jmri.20583,20,NIFTI,SNAPSHOTS,,
dti,,dtiQA_v2,MR diffusion,4,4,DTI Stejskal-Tanner (SE),DOI: 10.1371/journal.pone.0061737,,NIFTI,SNAPSHOTS,"bval,bvec",bval & bvec (.zip)

'''

from dicom.tag import Tag
from dax import XnatUtils
import csv,redcap,sys,os,glob,dicom

########################################## DEFAULT VARIABLES ##########################################
DEFAULT_SCAN_HEADER=['xnat_scantype','xnat_series_description','assessor_type_qc','ndar_scantype','image_num_dimensions','slice_acquisition','pulse_sequence','citation','experiment_id','image_file','image_thumbnail_file','data_file2','data_file2_definition']
DEFAULT_SUBJECT_HEADER=['xnat_subject_id','GUID','study_subject_id','interview_date','interview_age','gender']

########################################## image_03 TEMPLATE ##########################################
Ordered_keys=['subjectkey','src_subject_id','interview_date','interview_age','gender','comments_misc','image_file','image_thumbnail_file','image_description','experiment_id','scan_type','scan_object','image_file_format','data_file2','data_file2_type','image_modality','scanner_manufacturer_pd','scanner_type_pd','scanner_software_versions_pd','magnetic_field_strength','mri_repetition_time_pd','mri_echo_time_pd','flip_angle','acquisition_matrix','mri_field_of_view_pd','patient_position','photomet_interpret','receive_coil','transmit_coil','transformation_performed','transformation_type','image_history','image_num_dimensions','image_extent1','image_extent2','image_extent3','image_extent4','extent4_type','image_extent5','extent5_type','image_unit1','image_unit2','image_unit3','image_unit4','image_unit5','image_resolution1','image_resolution2','image_resolution3','image_resolution4','image_resolution5','image_slice_thickness','image_orientation','qc_outcome','qc_description','qc_fail_quest_reason','decay_correction','frame_end_times','frame_end_unit','frame_start_times','frame_start_unit','pet_isotope','pet_tracer','time_diff_inject_to_image','time_diff_units','pulse_seq','slice_acquisition','software_preproc']

NDAR_csv_dict={'subjectkey':'options',
               'src_subject_id':'options',
               'interview_date':'options',
               'interview_age':'options',
               'gender':'options',
               'comments_misc':'',
               'image_file':'XNAT',
               'image_thumbnail_file':'XNAT',
               'image_description':'XNAT',
               'experiment_id':'options',
               'scan_type':'options',
               'scan_object':'Live',
               'image_file_format':'NIFTI',
               'data_file2':'options',
               'data_file2_type':'options',
               'image_modality':'MRI', 
               'scanner_manufacturer_pd':('0008','0070'),
               'scanner_type_pd':('0008','1090'),
               'scanner_software_versions_pd':('0018','1020'),
               'magnetic_field_strength':('0018','0087'),
               'mri_repetition_time_pd':('0018','0080'),
               'mri_echo_time_pd':('0018','0081'),
               'flip_angle':('2001','1023'),
               'acquisition_matrix':'ds',
               'mri_field_of_view_pd':'',
               'patient_position':('0018','5100'),
               'photomet_interpret':('0028','0004'),
               'receive_coil':'ds',
               'transmit_coil':'ds', 
               'transformation_performed':'No', 
               'transformation_type':'',
               'image_history':'',
               'image_num_dimensions':'options',
               'image_extent1':('0028','0010'),
               'image_extent2':('0028','0011'),
               'image_extent3':'ds', 
               'image_extent4':'ds',  
               'extent4_type':'ds',
               'image_extent5':'',
               'extent5_type':'',
               'image_unit1':'frame number',
               'image_unit2':'frame number',
               'image_unit3':'frame number',
               'image_unit4':'frame number',
               'image_unit5':'',
               'image_resolution1':'ds',
               'image_resolution2':'ds', 
               'image_resolution3':'ds',
               'image_resolution4':'ds', 
               'image_resolution5':'0',
	           'image_slice_thickness':('0018','0088'),
               'image_orientation':'ds', 
               'qc_outcome':'XNAT',
               'qc_description':'options',
               'qc_fail_quest_reason':'',
               'decay_correction':'',
               'frame_end_times':'',
               'frame_end_unit':'',
               'frame_start_times':'',
               'frame_start_unit':'',
               'pet_isotope':'',
               'pet_tracer':'',
               'time_diff_inject_to_image':'',
               'time_diff_units':'',
               'pulse_seq':'options',
               'slice_acquisition':'options',
               'software_preproc':''}

########################################## DICOM FUNCTIONS ##########################################
def getNumSlicesAndVols(ds):
    t=[]
    nFrames = ds[0x0028,0x0008].value
    for i in xrange(0,nFrames):
        t.append(ds[0x5200,0x9230][i][0x2005,0x140f][0][0x2001,0x100a].value)
    nSlices = max(t)
    t=[]
    nVols = nFrames/nSlices
    return nSlices,nVols
    
def getAxis(x,y,z):
    thresh = 0.8;
    axis = '';
    ox='';
    oy='';
    oz='';
    ax = abs(x);
    ay = abs(y);
    az = abs(z);
    
    ox = 'R' if x<0 else 'L'
    oy = 'A' if y<0 else 'P'
    oz = 'F' if z<0 else 'H'

    if (ax>thresh and ax>ay and ax>az):
        axis = ox
    elif (ay>thresh and ay>ax and ay > az):
        axis = oy
    elif (az>thresh and az>ax and az>ay):
        axis = oz
    return axis

def getLabel(ds):
    orientation = ds[0x5200,0x9230][0][0x0020,0x9116][0][0x0020,0x0037].value
    aRow = getAxis(orientation[0],orientation[1],orientation[2])
    aCol = getAxis(orientation[3],orientation[4],orientation[5])
    label=''
    if (aRow=='R' or aRow=='L') and (aCol=='A' or aCol=='P'):
        label='Axial'
    elif (aCol=='R' or aCol=='L') and (aRow=='A' or aRow=='P'):
        label='Axial'
    elif (aRow=='R' or aRow=='L') and (aCol=='H' or aCol=='F'):
        label='Coronal'
    elif (aCol=='R' or aCol=='L') and (aRow=='H' or aRow=='F'):
        label='Coronal'
    elif (aRow=='A' or aRow=='P') and (aCol=='H' or aCol=='F'):
        label='Sagittal'
    elif (aCol=='A' or aCol=='P') and (aRow=='H' or aRow=='F'):
        label='Sagittal'
    return label

def get_row(directory,subject_record,scan_record):
    row=list()
    if not 'DICOM' in scan_record.keys():
		print '  ---> warning: no dicom found'
    else:
        if len(scan_record['DICOM'])>0 and scan_record['DICOM'][0]=='/':
	    	dcmpath=os.path.join(directory,scan_record['DICOM'][1:])
        else:
	    	dcmpath=os.path.join(directory,scan_record['DICOM'])
        if not os.path.isfile(dcmpath):
            print '  ---> warning: dicom missing'
        else:
            ds = dicom.read_file(dcmpath)
            #read the keys in order
            for header in Ordered_keys:
                #DICOM header tuple
                if isinstance(NDAR_csv_dict[header],tuple):
                    t2=Tag(NDAR_csv_dict[header][0],NDAR_csv_dict[header][1])
                    try:
                        val=ds[t2].value
                    except:
                        val=' '
                #Specific value
                elif isinstance(NDAR_csv_dict[header],str):
                    val=get_value(header,subject_record,scan_record,ds)
                else:
                    val=''
                #Add the value to the row
                row.append(str(val))
            
    return row

def get_value(header,subject_record,scan_record,ds):
    val=''
    if header=='acquisition_matrix':
        val=ds[0x5200,0x9230][0][0x2005,0x140f][0][0x0018,0x1310].value
    elif header=='image_resolution1':
        val=ds[0x5200,0x9230][0][0x2005,0x140f][0][0x0028,0x0030].value[0]
    elif header=='image_resolution2':
        val=ds[0x5200,0x9230][0][0x2005,0x140f][0][0x0028,0x0030].value[1]
    elif header=='image_resolution3':
        val=ds[0x5200,0x9230][0][0x2005,0x140f][0][0x0018,0x0050].value
    elif header=='image_resolution4':
        val=ds[0x5200,0x9229][0][0x0018,0x9112][0][0x0018,0x0080].value
    elif header=='image_orientation':
        val=getLabel(ds)
    elif header=='image_extent3':
		nSlices,nVols = getNumSlicesAndVols(ds)
		val=nSlices
    elif header=='image_extent4':
		nSlices,nVols = getNumSlicesAndVols(ds)
		val=nVols
    elif header=='receive_coil':
        val=ds[0x5200,0x9229][0][0x0018,0x9042][0][0x0018,0x1250].value
    elif header=='transmit_coil':
        val=ds[0x5200,0x9229][0][0x0018,0x9049][0][0x0018,0x1251].value
    elif header=='subjectkey' and 'GUID' in subject_record.keys():
        val=subject_record['GUID']
    elif header=='src_subject_id' and 'study_subject_id' in subject_record.keys():
        val=subject_record['study_subject_id']
    elif header=='interview_date' and 'interview_date' in subject_record.keys():
        val=subject_record['interview_date']
    elif header=='interview_age' and 'interview_age' in subject_record.keys():
        val=subject_record['interview_age']
    elif header=='gender' and 'gender' in subject_record.keys():
        val=subject_record['gender']
    elif header=='data_file2_type' and 'data_file2_definition' in scan_record.keys():
        val=scan_record['data_file2_definition']
    elif header=='pulse_seq' and 'pulse_sequence' in scan_record.keys():
        val=scan_record['pulse_sequence']
    elif header=='qc_description' and 'citation' in scan_record.keys():
        val=scan_record['citation']
    elif header=='scan_type' and 'ndar_scantype' in scan_record.keys():
        val=scan_record['ndar_scantype']
    elif header=='extent4_type' :
		if 'ndar_scantype' in scan_record.keys() and scan_record['ndar_scantype']=='MR structural (T1)':
		    val=''
		else:
		    val='Number of Volumes'
    elif header in scan_record.keys():
		val=scan_record[header]
    else:
		val=NDAR_csv_dict[header]
    return val
                    
########################################## CSV FUNCTIONS ##########################################
def read_csv(csvpath,default_header):
    csv_dict=dict()
    with open(csvpath, 'rb') as csvfile:
        spamreader = csv.reader(csvfile, delimiter=',')
        keys=spamreader.next()
        #Check if the keys are the right header, if not, assume there was no header:
        if set(keys)==set(default_header):
            pass
        else:
            csv_dict[keys[0]]=dict(zip(default_header,keys))
            keys=default_header
            
        for row in spamreader:
            if default_header[0]=='xnat_scantype':
                k=row[0]+'-x-'+row[1]
            else:
                k=row[0]
            csv_dict[k]=dict(zip(keys, row))
        
    return csv_dict

########################################## XNAT FUNCTIONS ##########################################
def get_data_xnat(options,directory,subjects):
    scan_records=list()
    scan_info=read_csv(options.scaninfo,DEFAULT_SCAN_HEADER)

    # Connection to Xnat
    try:   
        xnat = XnatUtils.get_interface()
        #list of scans and assessors for the full project
        scan_list=XnatUtils.list_project_scans(xnat,options.project)
        assessor_list=XnatUtils.list_project_assessors(xnat,options.project)
	
        #filter to keep only the scan we need and keep the assessor that we need for the qc
        scan_list=filter_scans([v['xnat_scantype'] for s,v in scan_info.items()],subjects,scan_list)
        assessor_list=filter_assessors(subjects,assessor_list)
        
        print 'INFO: Downloading scans resources from XNAT'
        #For each scan, download the data and write the scan_records for the csv
        for scan in sorted(scan_list,key=lambda k: k['subject_label']):
            if scan['quality']=='usable':
                for type_SD in scan_info.keys():
                    scantype=type_SD.split('-x-')[0]
                    part_of_SD=type_SD.split('-x-')[1]
                    if scan['type'].lower()==scantype.lower() and part_of_SD.lower() in scan['series_description'].lower():
                        print ' subject/session/scan: '+scan['subject_label']+'/'+scan['session_label']+'/'+scan['ID']+' found'
                        scan_dict=scan.copy()  #Copy the dict
                        #create path:
                        scan_path=os.path.join(directory,scan['subject_label']+'-x-'+scan['session_label']+'-x-'+scan['ID'])
                        if not os.path.exists(scan_path):
                            os.makedirs(scan_path)
                        #for each header in the scan_info specific to this type of scan :
                        for header,value in scan_info[type_SD].items():
                            if header in ['image_file','image_thumbnail_file','data_file2']:
                                scan_dict[header]=download_XNAT_file(xnat,directory,scan,value)
                            elif header=='assessor_type_qc':
                                qc_status,qc_reason=get_qc(scan,assessor_list,scan_info[type_SD],scan['subject_label']+'-x-'+scan['session_label']+'-x-'+scan['ID'])
                                scan_dict['qc_outcome']=qc_status
                                scan_dict['qc_fail_quest_reason']=qc_reason
                            else:
                                scan_dict[header]=value
		        #Download the DICOM
		        scan_dict['DICOM']=download_XNAT_file(xnat,directory,scan,'DICOM')
		        scan_records.append(scan_dict)            
    finally:                                        
        xnat.disconnect()
        
    return scan_records

def download_XNAT_file(xnat,directory,scan,resource_label):
    """Download if no file in the res_path, zip if more than one resource_label"""
    #Do we need to zip resource (resource_label containing coma)
    zip_resource=False
    fpath=''
    #String length to substract from the fpath (NDAR want the path to start from the directory you submit data from and not full path)
    string_len=len(directory)
    #res_path
    if ',' in resource_label:
        rfolder='_'.join(resource_label.split(','))
        zip_resource=True
    else:
        rfolder=resource_label
    res_path=os.path.join(directory,scan['subject_label']+'-x-'+scan['session_label']+'-x-'+scan['ID'],rfolder)
    if glob.glob(os.path.join(res_path,'*')): 
        fpath=glob.glob(os.path.join(res_path,'*'))[0][string_len:]
    else:
        ##No file, download the resource data
        if not os.path.exists(res_path):
            os.makedirs(res_path)
        #Select the scan:
        scan_obj=XnatUtils.get_full_object(xnat,scan)
        for res in resource_label.split(','):
            if res in ['bval','bvec']:
                if scan_obj.resource(res.lower()).exists():
				    XnatUtils.download_biggest_resources(scan_obj.resource(res.lower()),res_path)
                elif scan_obj.resource(res.upper()).exists():
                    XnatUtils.download_biggest_resources(scan_obj.resource(res.upper()),res_path)
	        elif scan_obj.resource(res).exists():
	            XnatUtils.download_biggest_resources(scan_obj.resource(res),res_path)
        if zip_resource:
            initDir=os.getcwd()
            #Zip all the files in the directory
            os.chdir(res_path)
            os.system('zip '+rfolder+'.zip * > /dev/null')
	    #return to the initial directory:
	    os.chdir(initDir)
            #Remove files from res_path that are not the zip file:
            for fname in os.listdir(res_path):
                if fname!=rfolder+'.zip':
                    os.remove(os.path.join(res_path,fname))
            
        if glob.glob(os.path.join(res_path,'*')): 
            fpath=glob.glob(os.path.join(res_path,'*'))[0][string_len:]
    
    return fpath

def get_qc(scan,assessor_list,scan_record,label):
    qc_status='questionable'
    qc_reason='Not yet control by user'
    assessor_label='ZALD_TTS-x-'+label+'-x-VBMQA'
    assessor=[a for a in assessor_list if a['label']==assessor_label] 
    if assessor:                       
        status=assessor[0]['qcstatus']
        if status=='Needs QA':
            qc_status='questionable'
            qc_reason='Not yet control by user'
        elif 'pass' in status.lower():
            qc_status='pass'
            qc_reason=''
        elif 'fail' in status.lower():
            qc_status='fail'
            qc_reason=''
    
    return qc_status,qc_reason

def filter_scans(scantypes,subjects,obj_list):
    if not scantypes:
        return list()
    
    if scantypes:
        obj_list=filter(lambda x: x['type'] in scantypes, obj_list)
    if subjects:
        obj_list=filter(lambda x: x['subject_label'] in subjects, obj_list)
    return obj_list

def filter_assessors(subjects,obj_list):    
    if subjects:
        obj_list=filter(lambda x: x['subject_label'] in subjects, obj_list)
    return obj_list

########################################## CHECK OPTIONS ##########################################
def check_options(options):
    #Checked argument values if not:
    if options.subjectinfo and not os.path.exists(os.path.abspath(options.subjectinfo)):
        print "OPTION ERROR: the CSV file for the subjects information "+options.subjectinfo+" does not exist."
        return False
    if options.scaninfo and not os.path.exists(os.path.abspath(options.scaninfo)):
        print "OPTION ERROR: the CSV file for the scans information "+options.scaninfo+" does not exist."
        return False
    if options.directory and not os.path.exists(os.path.dirname(os.path.abspath(options.directory))):
        print 'OPTION ERROR: the directory options: '+options.directory+' can not be created because '+os.path.dirname(os.path.abspath(options.directory))+' does not exist.'
        return False
    if not options.project:
        print 'OPTION ERROR: the XNAT project options '+options.project+' has not been set.'
        return False
    else:
        # Connection to Xnat
        try:   
            xnat = XnatUtils.get_interface()
            #check access
            proj=xnat.select('/project/'+options.project)
            if not proj.exists():
                print'OPTION ERROR: XNAT Project given '+options.project+' does not exist on XNAT.'
                return False
            else:
                if not XnatUtils.list_subjects(xnat,options.project):
                    print" OPTION ERROR: You don't access to the project: "+options.project+"."
                    return False
        finally:                                        
            xnat.disconnect()
    
    return True
########################################## MAIN DISPLAY ##########################################
def Main_display(parser):
    (options,args) = parser.parse_args()
    print '################################################################'
    print '#                           XNATNDAR                           #'
    print '#                                                              #'
    print '# Developed by the masiLab Vanderbilt University, TN, USA.     #'
    print '# If issues, email benjamin.c.yvernault@vanderbilt.edu         #'
    print '#                                                              #'
    print '# Function:                                                    #'
    print '#     prepare a directory for NDAR submission using image03    #'
    print '#     template (download data and generate csv)                #'
    print '#                                                              #'
    print '# Parameters :                                                 #'
    if options=={'directory':None, 'project':None,'scaninfo': None, 'subjectinfo': None, 'continu': False}:
        print '#     No Arguments given                                       #'
        print '#     Use "XnatNDAR -h" to see the options                     #'
        print '################################################################'
        parser.print_help()
        sys.exit()
    else:
        if options.directory:
            print '#     %*s -> %*s#' %(-20,'Submission Folder',-33,get_proper_str(options.directory,True))
        if options.project:
            print '#     %*s -> %*s#' %(-20,'XNAT Project',-33,get_proper_str(options.project))
        if options.scaninfo:
            print '#     %*s -> %*s#' %(-20,'CSV scan',-33,get_proper_str(options.scaninfo,True))
        if options.subjectinfo:
            print '#     %*s -> %*s#' %(-20,'CSV subject',-33,get_proper_str(options.subjectinfo,True))
        if options.continu:
            print '#     %*s -> %*s#' %(-20,'Mode Continue',-33,'on')           
        print '################################################################'

def get_proper_str(str_option,end=False):
    if len(str_option)>32:
        if end:
            return '...'+str_option[-29:]
        else:
            return str_option[:29]+'...'
    else:
        return str_option
    
########################################## OPTIONS ##########################################
def parse_args():
    from optparse import OptionParser
    usage = "usage: %prog [options] \nWhat is the script doing : Generate a CSV file for NDAR submission for the template image_03 only."
    parser = OptionParser(usage=usage)
    #Submission folder:
    parser.add_option("-d","--directory",dest="directory",default=None,
                  help="Directory to store the data for submission. NDAR_submission.csv will be created in it.", metavar="DIR")
    #XNAT information
    parser.add_option("-p", "--project", dest="project",default=None,
                  help="Project ID on Xnat", metavar="PROJECT_ID")
    #Outside information
    parser.add_option("-s", "--scaninfo", dest="scaninfo",default=None,
                  help="CSV file with the following header: xnat_scantype,xnat_series_description,assessor_type_qc,ndar_scantype,image_num_dimensions,slice_acquisition,pulse_sequence,citation,experiment_id,image_file,image_thumbnail_file,data_file2,data_file2_definition. See script header for more information and an example.", metavar="CSV")
    parser.add_option("-S","--subjectinfo", dest="subjectinfo",default=None,
                  help="CSV file with the following header xnat_subject_id,GUID,study_subject_id,interview_date,interview_age,gender", metavar="CSV")
    #options
    parser.add_option("-c","--continue",dest="continu",action="store_true", default=False,
                  help="If the script stopped, use continue to restart the script where it stopped.", metavar="FILEPATH")
    return parser

###################################################################################################
########################################## MAIN FUNCTION ##########################################
###################################################################################################
if __name__ == '__main__':
    parser=parse_args()
    (options,args) = parser.parse_args()
    #############################
    #Main display:
    Main_display(parser)
    #check options:
    run=check_options(options)
    #############################
    
    #############################
    # RUN                       #
    #############################
    if run:
        directory=os.path.abspath(options.directory)
        #Output file:
        OutputCSV=os.path.join(directory,'NDAR_submission.csv')
        previous_row=[]
        previous_GUID=[]
        
        #Continu options
        if options.continu and os.path.exists(OutputCSV):
            with open(OutputCSV,'rb') as csvfile:
                spamreader = csv.reader(csvfile,delimiter=',')
                k=spamreader.next()  #remove the image,03
                k=spamreader.next()  #remove the header line
                for row in spamreader:
                    previous_GUID.append(row[0])
                    previous_row.append(row)
        
        #Read from CSV files
        subject_records=read_csv(options.subjectinfo,DEFAULT_SUBJECT_HEADER)
        
        #Print number of object from CSV
        print 'INFO: Number of Subject found in the csv:'
        print '-------------------------------------------'
        print '| %*s : %*s |' % (-10,'Subjets',-10,str(len(subject_records)))
        print '-------------------------------------------'
	
        #Get Xnat info and download files:
        print 'INFO: Querying XNAT project '+options.project+' to download data.'
        scan_records=get_data_xnat(options,directory,subject_records.keys())
        
		#Print number of scans from xnat
        print 'INFO: Number of Subject/scan found on XNAT for NDAR submission:'
        print '-------------------------------------------'
        print '| %*s : %*s |' % (-10,'Subjets',-10,str(len(set([scan['subject_label'] for scan in scan_records]))))
        print '| %*s : %*s |' % (-10,'Scans',-10,str(len(scan_records)))
        print '-------------------------------------------'

        #Open the CSV file NDAR_submission.csv:
        with open(OutputCSV,'wb') as csvfile:
            spamwriter = csv.writer(csvfile,delimiter=',')
            spamwriter.writerow(['image','03'])
            spamwriter.writerow(Ordered_keys)
	    
            if previous_row:
                for row in previous_row:
                    spamwriter.writerow(row)
	    
            print 'INFO: Writting rows for the csv from DICOM header ...'            
            for scan_record in scan_records:
                if subject_records[scan_record['subject_label']]['GUID'] not in previous_GUID:
                    print ' Subject/session/scan: '+scan_record['subject_label']+'/'+scan_record['session_label']+'/'+scan_record['ID']
                    row=get_row(directory,subject_records[scan_record['subject_label']],scan_record)
                    if row:
                        spamwriter.writerow(row)
    
