#!/usr/bin/python3

# Warning
# This script is extremely hacky and specific purpose
# It tries to emulate the old pdf2set code relatively closely
# Error checking is limited.  May eat household pets.

import os
import sys
import numpy
import scipy

from megdata import BTIPDF, BTIConfigFile, BTI_CHANTYPE_MEG, \
                    BTI_CHANTYPE_EEG, BTI_CHANTYPE_REFERENCE, \
                    BTI_CHANTYPE_DERIVED

if len(sys.argv) < 2:
    print("Usage: %s FILENAME [OUTPUTFILENAME]" % (os.path.basename(sys.argv[0])))
    sys.exit(1)

pdf_fd = os.open(sys.argv[1], os.O_RDONLY)
config_fd = os.open(os.path.join(os.path.dirname(sys.argv[1]), 'config'), os.O_RDONLY)

outputfilename = None

# If we've been given an outputfile, get that name
if len(sys.argv) > 2:
    outputfilename = sys.argv[2]
    if os.path.exists(outputfilename):
        print("Error: %s already exists" % outputfilename)
        sys.exit(1)

pdf = BTIPDF.from_fd(pdf_fd)
config = BTIConfigFile.from_fd(config_fd)

# We can only work this out using the filename not the contents of the file, if
# we can't work it out from the path, not just stick something in
# Traditionally this was pulled out of the 4D database
try:
    pat, study, session, num, fname = os.path.realpath(sys.argv[1]).split(os.path.sep)[-5:]
    session = session.replace('%', '/').replace('@', ' ')
    description = ';'.join([pat, study, session, num, fname])
except Exception as e:
    description = 'PatientID;StudyID;SessionDate;SessionNum;'
    description += os.path.basename(sys.argv[1])

try:
    if outputfilename is None:
        outfile = sys.stdout
    else:
        outfile = open(outputfilename, 'w')

    outfile.write("MSI.FileType: 1\n")
    outfile.write("MSI.FileDescriptor: %s\n" % description)
    outfile.write("MSI.TotalChannels: %d\n" % pdf.hdr.total_chans)
    outfile.write("MSI.TotalEpochs: %d\n" % pdf.hdr.total_epochs)
    outfile.write("MSI.SamplePeriod: %f\n" % pdf.hdr.sample_period)
    outfile.write("MSI.SampleFrequency: %f\n" % (1.0/pdf.hdr.sample_period))
    outfile.write("MSI.FirstLatency: %f\n" % pdf.slice_to_latency(0))
    DATATYPES = {1: 'SHORT', 2: 'LONG', 3: 'FLOAT', 4: 'DOUBLE'}
    outfile.write("MSI.Format: %s\n" % DATATYPES[pdf.hdr.data_format])
    outfile.write("MSI.SlicesPerEpoch: %d\n" % pdf.epochs[0].pts_in_epoch)

    # Build up our channel information
    chan_names = []
    chan_scales = []
    chan_upb = []
    chan_gains = []
    meg_chans = []
    meg_indices = []
    ref_chans = []
    ref_indices = []
    eeg_chans = []
    eeg_indices = []
    deriv_chans = []
    deriv_indices = []
    trigger_idx = None
    response_idx = None

    # For historical reasons, these indices are 1-indexed.
    pos = 1
    for chan in pdf.channels[0:pdf.hdr.total_chans]:
        # Find the relevant channel in the config file
        config_chan = None
        for cc in config.channels:
            if chan.chan_no == cc.hdr.chan_no:
                config_chan = cc
                break
        if not config_chan:
            raise Exception("Could not find channel number %d" % chan.chan_no)

        chan_names.append(cc.hdr.name)
        chan_upb.append(cc.hdr.units_per_bit)
        # Hack to be identical with old C++ code
        if chan.scale == 1.0:
            chan_scales.append(1)
        else:
            chan_scales.append(chan.scale)

        if cc.hdr.gain == 1.0:
            chan_gains.append(1)
        else:
            chan_gains.append(cc.hdr.gain)

        if cc.hdr.ctype == BTI_CHANTYPE_MEG:
            meg_chans.append(cc)
            meg_indices.append(pos)
        elif cc.hdr.ctype == BTI_CHANTYPE_EEG:
            eeg_chans.append(cc)
            eeg_indices.append(pos)
        elif cc.hdr.ctype == BTI_CHANTYPE_REFERENCE:
            ref_chans.append(cc)
            ref_indices.append(pos)
        elif cc.hdr.ctype == BTI_CHANTYPE_DERIVED:
            deriv_chans.append(cc)
            deriv_indices.append(pos)

        if cc.hdr.name == 'TRIGGER':
            trigger_idx = pos
        if cc.hdr.name == 'RESPONSE':
            response_idx = pos

        pos += 1

    outfile.write("MSI.ChannelOrder: %s\n" % (','.join(chan_names)))
    outfile.write("MSI.ChannelScale: \t%s\n" % ('\t'.join([str(s) for s in chan_scales])))
    outfile.write("MSI.ChannelGain: \t%s\n" % ('\t'.join([str(s) for s in chan_gains])))
    outfile.write("MSI.ChannelUnitsPerBit: \t%s\n" % ('\t'.join(['%.6e' % s for s in chan_upb])))

    if len(meg_chans) > 0:
        outfile.write("MSI.MegChanCount: %d\n" % len(meg_chans))
        outfile.write("MSI.MegChanNames: %s\n" % (','.join([c.hdr.name for c in meg_chans])))
        outfile.write("MSI.MegChanIndex: %s\n" % (','.join([str(i) for i in meg_indices])))

    if len(eeg_chans) > 0:
        outfile.write("MSI.EegChanCount: %d\n" % len(eeg_chans))
        outfile.write("MSI.EegChanNames: %s\n" % (','.join([c.hdr.name for c in eeg_chans])))
        outfile.write("MSI.EegChanIndex: %s\n" % (','.join([str(i) for i in eeg_indices])))

    if len(ref_chans) > 0:
        outfile.write("MSI.RefChanCount: %d\n" % len(ref_chans))
        outfile.write("MSI.RefChanNames: %s\n" % (','.join([c.hdr.name for c in ref_chans])))
        outfile.write("MSI.RefChanIndex: %s\n" % (','.join([str(i) for i in ref_indices])))

    if len(deriv_chans) > 0:
        outfile.write("MSI.DerChanCount: %d\n" % len(deriv_chans))
        outfile.write("MSI.DerChanNames: %s\n" % (','.join([c.hdr.name for c in deriv_chans])))
        outfile.write("MSI.DerChanIndex: %s\n" % (','.join([str(i) for i in deriv_indices])))

    if trigger_idx is not None:
        outfile.write("MSI.TriggerIndex: %d\n" % trigger_idx)

    if response_idx is not None:
        outfile.write("MSI.ResponseIndex: %d\n" % response_idx)

    if len(meg_chans) > 0:
        outfile.write("MSI.Meg_Position_Information.Begin\n")
        for c in meg_chans:
            # Write out channel info
            outfile.write("%s" % (c.hdr.name))
            for loop in c.chan.loops:
                outfile.write("\t%.7f\t%.7f\t%.7f" % (loop.position[0, 0], loop.position[0, 1], loop.position[0, 2]))
                outfile.write("\t%.7f\t%.7f\t%.7f" % (loop.orientation[0, 0], loop.orientation[0, 1], loop.orientation[0, 2]))
            outfile.write("\n")

        outfile.write("MSI.Meg_Position_Information.End\n")

        # This is our own addition as with the default layout we haven't a clue about
        # the loop radii
        outfile.write("MSI.Meg_Loop_Radius.Begin\n")
        for c in meg_chans:
            # Channel radii
            outfile.write("%s" % (c.hdr.name))
            for loop in c.chan.loops:
                outfile.write("\t%.7f" % (loop.radius))
            outfile.write("\n")
        outfile.write("MSI.Meg_Loop_Radius.End\n")

    # This is our own addition too as by default we know nothing about the reference channels
    if len(ref_chans) > 0:
        outfile.write("MSI.Ref_Position_Information.Begin\n")
        for c in ref_chans:
            # Write out channel info
            outfile.write("%s" % (c.hdr.name))
            for loop in c.chan.loops:
                outfile.write("\t%.7f\t%.7f\t%.7f" % (loop.position[0, 0], loop.position[0, 1], loop.position[0, 2]))
                outfile.write("\t%.7f\t%.7f\t%.7f" % (loop.orientation[0, 0], loop.orientation[0, 1], loop.orientation[0, 2]))
            outfile.write("\n")

        outfile.write("MSI.Ref_Position_Information.End\n")

        # This is our own addition as with the default layout we haven't a clue about
        # the loop radii
        outfile.write("MSI.Ref_Loop_Radius.Begin\n")
        for c in ref_chans:
            # Channel radii
            outfile.write("%s" % (c.hdr.name))
            for loop in c.chan.loops:
                outfile.write("\t%.7f" % (loop.radius))
            outfile.write("\n")
        outfile.write("MSI.Ref_Loop_Radius.End\n")

    if trigger_idx is not None:
        # Find trigger events - our indices are 1-indexed
        trig_chan = pdf.read_raw_data(None, trigger_idx - 1)

    # Write the event codes
    event_times = []
    event_codes = []

    outfile.write("MSI.TotalEvents: %d\n" % len(pdf.epochs))
    if len(pdf.epochs) > 0:
        trg_slice = pdf.latency_to_slice(0)
        for e in range(len(pdf.epochs)):
            # Assumes all epochs are the same length; fairly sure this is a 4D limitation anyways
            # Slices are written 1-indexed (hysterical raisins)
            sl = int((e * pdf.epochs[0].pts_in_epoch) + trg_slice)
            event_times.append( sl + 1 )
            # For historical reasons, event codes only get written for certain bit arrangements
            # We stick with this to be compatible as we never use the event field anyways, only
            # the trigger and group ones
            ev_val = int(trig_chan[sl])
            code = 0
            if (ev_val & 0xff00):
                code = (ev_val & 0x00ff)

            event_codes.append( code )

        outfile.write("MSI.Events: " + ','.join([str(x) for x in event_times]) + "\n")
        outfile.write("MSI.EventCodes: " + ','.join([str(x) for x in event_codes]) + "\n")

    if trigger_idx is not None:
        # We already have the trigger channel data loaded
        # Use the diff function to find changes
        pts = numpy.where(scipy.diff(trig_chan) != 0)[0]
        # The points we want to read are one beyond the diff
        pts += 1
        # Read the pts
        trigs = trig_chan[pts]
        # Now, only take changes which don't result in a point of 0
        pts   = pts[numpy.where(trigs != 0)]
        trigs = trigs[numpy.where(trigs != 0)]
        # and divide into group codes < 256 and trigger codes > 256
        grp_pts    = pts[numpy.where(trigs < 256)]
        grp_codes  = trigs[numpy.where(trigs < 256)]

        trig_pts   = pts[numpy.where(trigs >= 256)]
        trig_codes = trigs[numpy.where(trigs >= 256)]

        # When we write out the trigger points, they're 1-indexes for historical reasons
        # so add 1 again
        grp_pts  += 1
        trig_pts += 1

        outfile.write("MSI.TrigEventCount: %d\n" % len(trig_pts))
        if len(trig_pts) > 0:
            outfile.write('MSI.TrigEvents: ' + ','.join([str(x) for x in list(trig_pts)]) + '\n')
            outfile.write('MSI.TrigEventCodes: ' + ','.join([str(x) for x in list(trig_codes)]) + '\n')

        outfile.write("MSI.GroupEventCount: %d\n" % len(grp_pts))
        if len(grp_pts) > 0:
            outfile.write('MSI.GroupEvents: ' + ','.join([str(x) for x in list(grp_pts)]) + '\n')
            outfile.write('MSI.GroupEventCodes: ' + ','.join([str(x) for x in list(grp_codes)]) + '\n')
    else:
        outfile.write("MSI.TrigEventCount: 0\n")
        outfile.write("MSI.GroupEventCount: 0\n")

    if response_idx is not None:
        # Find trigger events - our indices are 1-indexed
        resp_chan = pdf.read_raw_data(None, response_idx - 1)
        # Use the diff function to find changes
        pts = numpy.where(scipy.diff(resp_chan) != 0)[0]
        # The points we want to read are one beyond the diff
        pts += 1
        # Read the pts
        resps = resp_chan[pts]
        # Now, only take changes which don't result in a point of 0
        resp_pts   = pts[numpy.where(resps != 0)]
        resp_codes = resps[numpy.where(resps != 0)]

        # When we write out the trigger points, they're 1-indexes for historical reasons
        # so add 1 again
        resp_pts  += 1

        outfile.write("MSI.RespEventCount: %d\n" % len(resp_pts))
        if len(resp_pts) > 0:
            outfile.write('MSI.RespEvents: ' + ','.join([str(x) for x in list(resp_pts)]) + '\n')
            outfile.write('MSI.RespEventCodes: ' + ','.join([str(x) for x in list(resp_codes)]) + '\n')
    else:
        outfile.write("MSI.RespEventCount: 0\n")

    outfile.close()

except Exception as e:
    print("Error: %s" % e)

    # Delete the output file if we didn't manage to fully create it
    outfile.close()

    if outputfilename is not None:
        os.unlink(outputfilename)
