#!/Users/boydb1/anaconda/bin/python

#==============================================================================
# Xnatmirror
# Only copies the default XNAT datatypes, Project/Subject/MR_Session/MR_Scan.
#
#==============================================================================

import os
from pyxnat import Interface
import sys
from datetime import datetime

# Define attributes to be copied
PROJ_ATTRS = [
'xnat:projectData/secondary_ID',
'xnat:projectData/name',
'xnat:projectData/description',
'xnat:projectData/keywords',
'xnat:projectData/PI/title',
'xnat:projectData/PI/firstname',
'xnat:projectData/PI/lastname',
'xnat:projectData/PI/institution',
'xnat:projectData/PI/department']

SUBJ_ATTRS = [
'xnat:subjectData/group',
'xnat:subjectData/src',
'xnat:subjectData/investigator/firstname',
'xnat:subjectData/investigator/lastname',
'xnat:subjectData/demographics[@xsi:type=xnat:demographicData]/dob',
'xnat:subjectData/demographics[@xsi:type=xnat:demographicData]/yob',
'xnat:subjectData/demographics[@xsi:type=xnat:demographicData]/age',
'xnat:subjectData/demographics[@xsi:type=xnat:demographicData]/gender',
'xnat:subjectData/demographics[@xsi:type=xnat:demographicData]/handedness',
'xnat:subjectData/demographics[@xsi:type=xnat:demographicData]/ses',
'xnat:subjectData/demographics[@xsi:type=xnat:demographicData]/education',
'xnat:subjectData/demographics[@xsi:type=xnat:demographicData]/educationDesc',
'xnat:subjectData/demographics[@xsi:type=xnat:demographicData]/race',
'xnat:subjectData/demographics[@xsi:type=xnat:demographicData]/ethnicity',
'xnat:subjectData/demographics[@xsi:type=xnat:demographicData]/weight',
'xnat:subjectData/demographics[@xsi:type=xnat:demographicData]/height',
'xnat:subjectData/demographics[@xsi:type=xnat:demographicData]/gestational_age',
'xnat:subjectData/demographics[@xsi:type=xnat:demographicData]/post_menstrual_age',
'xnat:subjectData/demographics[@xsi:type=xnat:demographicData]/birth_weight']

MR_EXP_ATTRS = [
'xnat:experimentData/date',
'xnat:experimentData/visit_id',
'xnat:experimentData/time',
'xnat:experimentData/note',
'xnat:experimentData/investigator/firstname',
'xnat:experimentData/investigator/lastname',
'xnat:imageSessionData/scanner/manufacturer',
'xnat:imageSessionData/scanner/model',
'xnat:imageSessionData/operator',
'xnat:imageSessionData/dcmAccessionNumber',
'xnat:imageSessionData/dcmPatientId',
'xnat:imageSessionData/dcmPatientName',
'xnat:imageSessionData/session_type',
'xnat:imageSessionData/modality',
'xnat:imageSessionData/UID',
'xnat:mrSessionData/coil',
'xnat:mrSessionData/fieldStrength',
'xnat:mrSessionData/marker',
'xnat:mrSessionData/stabilization']

MR_SCAN_ATTRS = [
'xnat:imageScanData/type',
'xnat:imageScanData/UID',
'xnat:imageScanData/note',
'xnat:imageScanData/quality',
'xnat:imageScanData/condition',
'xnat:imageScanData/series_description',
'xnat:imageScanData/documentation',
'xnat:imageScanData/frames',
'xnat:imageScanData/scanner/manufacturer',
'xnat:imageScanData/scanner/model',
'xnat:mrScanData/parameters/flip',
'xnat:mrScanData/parameters/orientation',
'xnat:mrScanData/parameters/tr',
'xnat:mrScanData/parameters/ti',
'xnat:mrScanData/parameters/te',
'xnat:mrScanData/parameters/sequence',
'xnat:mrScanData/parameters/imageType',
'xnat:mrScanData/parameters/scanSequence',
'xnat:mrScanData/parameters/seqVariant',
'xnat:mrScanData/parameters/scanOptions',
'xnat:mrScanData/parameters/acqType',
'xnat:mrScanData/parameters/pixelBandwidth',
'xnat:mrScanData/parameters/voxelRes/x',
'xnat:mrScanData/parameters/voxelRes/y',
'xnat:mrScanData/parameters/voxelRes/z',
'xnat:mrScanData/parameters/fov/x',
'xnat:mrScanData/parameters/fov/y',
'xnat:mrScanData/parameters/matrix/x',
'xnat:mrScanData/parameters/matrix/y',
'xnat:mrScanData/parameters/partitions',
'xnat:mrScanData/fieldStrength',
'xnat:mrScanData/marker',
'xnat:mrScanData/stabilization',
'xnat:mrScanData/coil']

SC_SCAN_ATTRS = [
'xnat:imageScanData/type',
'xnat:imageScanData/UID',
'xnat:imageScanData/note',
'xnat:imageScanData/quality',
'xnat:imageScanData/condition',
'xnat:imageScanData/series_description',
'xnat:imageScanData/documentation',
'xnat:imageScanData/frames',
'xnat:imageScanData/scanner/manufacturer',
'xnat:imageScanData/scanner/model']

#==============================================================================
# Functions

def check_attributes(src_obj, dest_obj, dtype):
    
    if dtype == 'xnat:projectData':
        attr_list = PROJ_ATTRS
    elif dtype == 'xnat:subjectData':
        attr_list =  SUBJ_ATTRS
    elif dtype == 'xnat:mrSessionData':
        attr_list = MR_EXP_ATTRS
    elif dtype == 'xnat:mrScanData':
        attr_list =  MR_SCAN_ATTRS
    elif dtype == 'xnat:scScanData':
        attr_list =  SC_SCAN_ATTRS
    else:
        print 'DEBUG:Unknown Type:%s' % dtype
        return
        
    for a in attr_list:
        src_v = src_obj.attrs.get(a)
        src_v = src_v.replace("\\", "|")
        dest_v = dest_obj.attrs.get(a)
        if src_v != dest_v: 
            print 'Attribute mismatch, setting again:%s:src_v=%s,dest_v=%s' % (a, src_v, dest_v)
            dest_obj.attrs.set(a, src_v)

def copy_attrs(src_obj, dest_obj, attr_list):
    """ Copies list of attributes form source to destination"""
    src_attrs = src_obj.attrs.mget(attr_list)
    src_list = dict(zip(attr_list, src_attrs))

    # NOTE: For some reason need to set te again b/c a bug somewhere sets te to sequence name
    te_key = 'xnat:mrScanData/parameters/te'
    if te_key in src_list: src_list[te_key] = src_obj.attrs.get(te_key)

    dest_obj.attrs.mset(src_list)
    return 0

def copy_attributes(src_obj, dest_obj):
    src_type = src_obj.datatype()
    #dest_type = dest_obj.datatype() 
    # TODO: confirm same type
    
    if src_type == 'xnat:projectData':
        copy_attrs(src_obj, dest_obj, PROJ_ATTRS)
    elif src_type == 'xnat:subjectData':
        copy_attrs(src_obj, dest_obj, SUBJ_ATTRS)
    elif src_type == 'xnat:mrSessionData':
        copy_attrs(src_obj, dest_obj, MR_EXP_ATTRS)
    else: # xnat:mrScanData
        copy_attrs(src_obj, dest_obj, MR_SCAN_ATTRS)

def subj_compare(item1, item2):
    return cmp(item1.label(), item2.label())

def copy_file(src_f, dest_r, cache_d):
    f_label=src_f.label()
    loc_f = cache_d+'/'+f_label
    
    try:
        # Download file
        if os.path.exists(loc_f) == False:		
            src_f.get(loc_f)
        
        # Get File Attributes
        f_in_attrs = src_f.attributes()
        f_content = f_in_attrs.get('file_content')
        f_format = f_in_attrs.get('file_format')
        f_tags = f_in_attrs.get('file_tags')
    
        # Upload File
        if f_format and f_content and not f_tags:         # format & content
            dest_r.file(f_label).put(loc_f, f_format, f_content)
        elif f_format and not f_content and not f_tags:   # format only
            dest_r.file(f_label).put(loc_f,f_format)
        elif f_format and f_content and f_tags:           # format, content, & tags
            dest_r.file(f_label).put(loc_f,f_format,f_content)
        else:                                             # none
            dest_r.file(f_label).put(loc_f)
    
        # Delete local copy
        os.remove(loc_f)
    except:
        print "ERROR:failed to copy file:%s" % (f_label)
        pass

def copy_resource(src_r, dest_r, cache_d):
    try:
        # Download zip of resource    
        print '        Downloading resource:...'
        cache_z = src_r.get(cache_d, extract=False)

        # Upload zip of resource
        print '        Uploading resource:...'
        dest_r.put_zip(cache_z, extract=True)
    
        # Delete cached zip
        os.remove(cache_z)
    
    except IndexError:
        print "ERROR:failed to copy: %s. Increase XNAT session timeout and try again." % (cache_z)
        raise

def is_empty_resource(res):
    f_count = 0 
    for f_in in res.files().fetchall('obj'):
        f_count += 1
        break
    return (f_count == 0)
    
def parse_args():
    from argparse import ArgumentParser
    parser = ArgumentParser()
    parser.add_argument('project', help="XNAT Project Name to look for Subjects")
    parser.add_argument('directory', help="Directory to temporarily hold data during processing")
    parser.add_argument('-cf', help="Check Files of Existing Resources, copy any not found.", action='store_true', default=False)
    parser.add_argument('-ca', help="Check Attributes of Existing Data, recopy any that don't match", action='store_true', default=False)
    return parser.parse_args()

# ******************************************************************************



#==============================================================================
# Main
if __name__ == '__main__':
    args = parse_args()
    SRC_PROJECT = args.project
    DEST_PROJECT = SRC_PROJECT
    CHECK_ATTRS = args.ca
    CHECK_FILES = args.cf
    SRC_CACHEDIR = args.directory+'/src/'+SRC_PROJECT
    DEST_CACHEDIR = args.directory+'/dest/'+DEST_PROJECT
else:
    sys.exit(1)

try:
    # Environs
    DEST_XNAT_USER = os.environ['DEST_XNAT_USER']
    DEST_XNAT_PASS = os.environ['DEST_XNAT_PASS']
    DEST_XNAT_HOST = os.environ['DEST_XNAT_HOST']
    SRC_XNAT_HOST  = os.environ['SRC_XNAT_HOST']
    SRC_XNAT_USER  = os.environ['SRC_XNAT_USER']
    SRC_XNAT_PASS  = os.environ['SRC_XNAT_PASS']

except KeyError as e:
    print "You must set the environment variable %s" % str(e)
    sys.exit(1)

print 'Src XNAT Host:\t%s' % SRC_XNAT_HOST
print 'Src XNAT Proj:\t%s' % SRC_PROJECT
print 'Dest XNAT Host:\t%s' % DEST_XNAT_HOST
print 'Dest XNAT User:\t%s' % DEST_XNAT_USER
print 'Dest XNAT Proj:\t%s' % DEST_PROJECT

# print current time
print '\n\nTime: ', str(datetime.now()), '========================================================'

# Create pyxnat Interfaces
src_xnat = Interface(
	SRC_XNAT_HOST,
	SRC_XNAT_USER,
	SRC_XNAT_PASS)
dest_xnat = Interface(
	DEST_XNAT_HOST,
	DEST_XNAT_USER,
	DEST_XNAT_PASS)

# Clear caches	
#print 'Clearing caches...'
#dest_xnat.cache.clear()
#src_xnat.cache.clear()

# Create destination project
print 'Getting projects...'
proj = src_xnat.select('/project/'+SRC_PROJECT)
p_16 = dest_xnat.select.project(DEST_PROJECT)
if not p_16.exists():
    p_16.create()
    #copy_attributes(proj, p_16)
#elif CHECK_ATTRS: check_attributes(proj, p_16, proj.datatype())

# Process each subject of project
print 'Getting and sorting subject list...'
all_subjects = src_xnat.select('/project/'+SRC_PROJECT+'/subjects/*')
subj_i = 0
for subject in sorted(all_subjects, cmp=subj_compare):    
    subj_i += 1
    subject_label = subject.label()    

    print "Processing subject %s (%d)..." % (subject_label, subj_i)
    s_16 = p_16.subject(subject_label)
    if not s_16.exists(): 
        s_16.create()    
        copy_attributes(subject, s_16)
    elif CHECK_ATTRS: check_attributes(subject, s_16, subject.datatype())
    
    # Process each experiment of subject
    for experiment in subject.experiments().fetchall('obj'):
        exp_label = experiment.label()
        
        if experiment.datatype() != 'xnat:mrSessionData':
            print('Skipping, experiment is not MR Session')
            continue

        print "Processing experiment %s:%s..." % (subject_label, exp_label)
        
        e_16 = s_16.experiment(exp_label)
        if not e_16.exists(): 
            e_16.create()
            copy_attributes(experiment, e_16)
        elif CHECK_ATTRS: check_attributes(experiment, e_16, experiment.datatype())
        
        # Process each scan of experiment
        for scan in experiment.scans().fetchall('obj'):
            scan_label = scan.label()     
            
            if scan_label == '99':
                print('Skipping:'+scan_label)
                continue       
           
            print "    Processing scan %s:%s:%s..." % (subject_label, exp_label, scan_label)
            scan_16 = e_16.scan(scan_label)
            scan_type = scan.attrs.get('xnat:imageScanData/type').lower()
            
            if not scan_16.exists(): 
                scan_16.create()
            elif CHECK_ATTRS: 
                check_attributes(scan, scan_16, 'xnat:mrScanData')
                            
            # Process each resource of scan
            for res in scan.resources().fetchall('obj'):
                res_label = res.label()
                
                if 'NIFTI' not in res_label.upper() and \
                'BVEC' not in res_label.upper() and \
                'BVAL' not in res_label.upper():
                    print('Skipping:'+res_label)
                    continue
                
                print "        Processing resource:%s..." % (res_label)
                
                if res_label == 'NIfTI':
                    r_16 = scan_16.resource('NIFTI')
                else:
                    r_16 = scan_16.resource(res_label)
                                
                # Create cache dir
                cache_dir = DEST_CACHEDIR+'/%s/%s/%s/%s' % (subject_label, exp_label, scan_label, res_label) 
                if not os.path.exists(cache_dir): os.makedirs(cache_dir)
                
                # Prepare resource and check for empty
                is_empty = False
                if not r_16.exists():
                    r_16.create()
                    is_empty = True
                elif is_empty_resource(r_16):
                    is_empty = True
                                        
                if is_empty:
                    copy_count = 0
                    for f in res.files():
                        print '        Copying file: %s...' % f.label()
                        copy_count += 1
                        copy_file(f, r_16, cache_dir) 
                    print '        Finished copying resource, %d files copied' % copy_count
                elif CHECK_FILES:
                    # Copy files that don't exist already
                    copy_count = 0
                    for f in res.files():
                        f_label = f.label()
                        if not r_16.file(f_label).exists():
                            print '        Copying file: %s...' % f_label
                            copy_count += 1
                            copy_file(f, r_16, cache_dir) 
                    print '        Finished checking resource, %d new files copied' % copy_count

# Wrap up						
print 'DONE'
