#! /usr/bin/env python
##########################################################################
# NSAp - Copyright (C) CEA, 2013-2016
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

# System import
from __future__ import print_function
import os
import json
from pprint import pprint
import argparse
from datetime import datetime
import shutil
import textwrap
from argparse import RawTextHelpFormatter

# Bredala import
try:
    import bredala
    bredala.USE_PROFILER = False
    bredala.register("pyfreesurfer.hcp",
                     names=["prefreesurfer_hcp", "freesurfer_hcp",
                            "postfreesurfer_hcp"])
    # bredala.register("pyfreesurfer.wrapper",
    #                  names=["HCPWrapper.__init__"])
except:
    pass

# Pyfreesurfer import
from pyfreesurfer import __version__ as version
from pyfreesurfer.wrapper import HCPWrapper
from pyfreesurfer import DEFAULT_FREESURFER_PATH
from pyfreesurfer import DEFAULT_FSL_PATH
from pyfreesurfer import DEFAULT_WORKBENCH_PATH
from pyfreesurfer.hcp import prefreesurfer_hcp
from pyfreesurfer.hcp import freesurfer_hcp
from pyfreesurfer.hcp import postfreesurfer_hcp


# Parameters to keep trace
__hopla__ = ["runtime", "inputs", "outputs"]


# Script documentation
DOC = """
Run the HCP prefreesurfer, freesurfer and postfreesurfer scripts.

Requirements for this module:
  installed versions of:
    - FSL (version 5.0.6),
    - FreeSurfer (version 5.3.0-HCP),
    - gradunwarp (HCP version 1.0.2) if doing gradient distortion correction

  environment:
    - FSLDIR
    - FREESURFER_HOME
    - HCPPIPEDIR
    - CARET7DIR
    - PATH (to be able to find gradient_unwarp.py)

The primary purposes of the PreFreeSurfer Pipeline
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

1. To average any image repeats (i.e. multiple T1w or T2w images available)
2. To create a native, undistorted structural volume space for the subject
        * Subject images in this native space will be distortion corrected
          for gradient and b0 distortions and rigidly aligned to the axes
          of the MNI space. "Native, undistorted structural volume space"
          is sometimes shortened to the "subject's native space" or simply
          "native space".
3. To provide an initial robust brain extraction
4. To align the T1w and T2w structural images (register them to the native
   space)
5. To perform bias field correction
6. To register the subject's native space to the MNI space

The primary purposes of the FreeSurfer Pipeline
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

1. Make Spline Interpolated Downsample to 1mm
2. Initial recon-all steps (with flags that are part of "-autorecon1", with the
   exception of -skullstrip)
3. Generate brain mask
4. Call recon-all to run most of the "-autorecon2" stages, but turning off
   smooth2, inflate2, curvstats, and segstats stages.
5. High resolution white matter and fine tune T2w to T1w registration.
6. Intermediate Recon-all Steps
7. High resolution pial matter (adjusts the pial surface based on the the T2w
   image)
8. Final recon-all steps

The primary purposes of the PostFreeSurfer Pipeline
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

1. Conversion of FreeSurfer Volumes and Surfaces to NIFTI and GIFTI and Create
   Caret Files and Registration
2. Create FreeSurfer ribbon file at full resolution
3. Myelin Mapping


Command
~~~~~~~

python $HOME/git/pyfreesurfer/pyfreesurfer/scripts/pyfreesurfer_hcp \
    --path /volatile/nsap/hcp \
    --subject 994273 \
    --t1 /tmp/T1w_MPR1/994273_3T_T1w_MPR1.nii.gz \
    --t2 /tmp/T2w_SPC1/994273_3T_T2w_SPC1.nii.gz \
    --fmapmag /tmp/T1w_MPR1/994273_3T_FieldMap_Magnitude.nii.gz \
    --fmapphase /tmp/T1w_MPR1/994273_3T_FieldMap_Phase.nii.gz \
    --gdcoeffs /tmp/coeff_SC72C_Skyra.grad \
    --hcpdir /volatile/git/Pipelines \
    --wbcommand /usr/bin \
    --verbose 1
"""


def is_file(filearg):
    """ Type for argparse - checks that file exists but does not open.
    """
    if filearg == "NONE":
        return filearg
    elif not os.path.isfile(filearg):
        raise argparse.ArgumentError(
            "The file '{0}' does not exist!".format(filearg))
    return filearg


def is_directory(dirarg):
    """ Type for argparse - checks that directory exists.
    """
    if not os.path.isdir(dirarg):
        raise argparse.ArgumentError(
            "The directory '{0}' does not exist!".format(dirarg))
    return dirarg


def get_cmd_line_args():
    """
    Create a command line argument parser and return a dict mapping
    <argument name> -> <argument value>.
    """
    parser = argparse.ArgumentParser(
        prog="python pyfreesurfer_hcp",
        description=textwrap.dedent(DOC),
        formatter_class=RawTextHelpFormatter)

    # Required arguments
    required = parser.add_argument_group("required arguments")
    required.add_argument(
        "--path",
        required=True, type=is_directory,
        help="path to study data folder (~ to FreeSurfer home directory). "
             "Used with --subject input to create full path to root directory "
             "for all outputs generated as /path/subject.")
    required.add_argument(
        "--subject",
        required=True,
        help="subject ID. Used with --path input to create full path to root "
             "directory for all outputs generated as path/subject.")
    required.add_argument(
        "--t1",
        type=is_file,  nargs="+",
        help="list of full paths to T1-weighted structural images for the "
             "subject.")
    required.add_argument(
        "--t2",
        type=is_file, nargs="+",
        help="list of full paths to T2-weighted structural images for the "
             "subject.")
    required.add_argument(
        "--hcpdir",
        required=True,
        help="the path to the HCP project containing the HCP scripts.")

    # Optional arguments
    parser.add_argument(
        "--field",
        choices=("3T", "7T"), default="3T",
        help="subject ID. Used with --path input to create full path to root "
             "directory for all outputs generated as path/subject.")
    parser.add_argument(
        "--brainsize",
        type=int, default=150,
        help="brain size estimate in mm, 150 for humans.")
    parser.add_argument(
        "--fmapmag",
        type=is_file, default="NONE",
        help="Siemens Gradient Echo Fieldmap magnitude file.")
    parser.add_argument(
        "--fmapphase",
        type=is_file, default="NONE",
        help="Siemens Gradient Echo Fieldmap phase file.")
    parser.add_argument(
        "--fmapgeneralelectric",
        default="NONE",
        help="General Electric Gradient Echo Field Map file."
             "Two volumes in one file: 1. field map in deg, 2. magnitude.")
    parser.add_argument(
        "--echodiff",
        default=2.46,
        help="delta TE in ms for field map or 'NONE' if not used.")
    parser.add_argument(
        "--SEPhaseNeg",
        default="NONE",
        help="for spin echo field map, path to volume with a negative phase"
             "encoding direction (LR in HCP data), set to 'NONE' if not using "
             "Spin Echo Field Maps.")
    parser.add_argument(
        "--SEPhasePos",
        default="NONE",
        help="for spin echo field map, path to volume with a positive phase"
             "encoding direction (RL in HCP data), set to 'NONE' if not using "
             "Spin Echo Field Maps.")
    parser.add_argument(
        "--echospacing", dest="echospacing", default="NONE",
        help="echo Spacing or Dwelltime of Spin Echo Field Map or 'NONE' if not"
             "used.")
    parser.add_argument(
        "--seunwarpdir",
        default="NONE",
        help="phase encoding direction of the spin echo field map. (Only "
             "applies when using a spin echo field map.).")
    parser.add_argument(
        "--t1samplespacing", default=0.0000074,
        help="T1 image sample spacing, 'NONE' if not used.")
    parser.add_argument(
        "--t2samplespacing", default=0.0000021,
        help="T2 image sample spacing, 'NONE' if not used.")
    parser.add_argument(
        "--unwarpdir",
        choices=("z", "y", "x"), default="z",
        help="readout direction of the T1w and T2w images (Used with either a "
             "gradient echo field map or a spin echo field map).")
    parser.add_argument(
        "--gdcoeffs",
        default="NONE", type=is_file,
        help="file containing gradient distortion coefficients. "
             "Set to 'NONE' to turn off.")
    parser.add_argument(
        "--avgrdcmethod", default="SiemensFieldMap",
        help="averaging and readout distortion correction method or 'NONE' if "
             "not used.")
    parser.add_argument(
        "--topupconfig",
        default="NONE",
        help="configuration file for topup or 'NONE' if not used.")
    parser.add_argument(
        "--fsldir",
        type=is_file,
        help="the path to the FSL 'fsl.sh' configuration file.")
    parser.add_argument(
        "--fsconfig",
        type=is_file,
        help="the path to the FreeSurfer configuration file.")
    parser.add_argument(
        "--wbcommand",
        type=is_directory,
        help="the path containing the wbcommand.")
    parser.add_argument(
        "--verbose",
        type=int, choices=[0, 1, 2], default=0,
        help="increase the verbosity level: 0 silent, [1, 2] verbose.")
    parser.add_argument(
        "--erase",
        action="store_true",
        help="if activated, clean the result folder.")
    parser.add_argument(
        "--nopre",
        action="store_true",
        help="if activated, do not perform the prefreesurfer step.")
    parser.add_argument(
        "--t1wdir",
        type=is_directory,
        help="if nopre, the prefreesurfer T1w folder.")

    # Create a dict of arguments to pass to the 'main' function
    args = parser.parse_args()
    if args.fsconfig is None:
        args.fsconfig = DEFAULT_FREESURFER_PATH
    if args.fsldir is None:
        args.fsldir = DEFAULT_FSL_PATH
    if args.wbcommand is None:
        args.wbcommand = DEFAULT_WORKBENCH_PATH

    return args


"""
Parse the command line.
"""
args = get_cmd_line_args()
tool = "pyfreesurfer_hcp"
if args.field != "3T":
    raise ValueError("{0} field not yet supported.".format(args.field))
timestamp = datetime.now().isoformat()
tool_version = version
freesurfer_config = args.fsconfig
fsl_config = args.fsldir
wbcommand = args.wbcommand
wrapper = HCPWrapper(
    env={"PATH": os.environ["PATH"],
         "HCPPIPEDIR": args.hcpdir,
         "CARET7DIR": wbcommand},
    fslconfig=fsl_config,
    fsconfig=freesurfer_config)
freesurfer_version = wrapper.freesurfer_version()
fsl_version = wrapper.fsl_version()
gradunwarp_version = wrapper.gradunwarp_version()
wbcommand_version = wrapper.wbcommand_version()
hcp_version = wrapper.hcp_version()
params = locals()
runtime = dict([(name, params[name])
               for name in ("freesurfer_config", "fsl_config", "tool",
                            "tool_version", "freesurfer_version",
                            "fsl_version", "gradunwarp_version",
                            "wbcommand_version", "hcp_version", "timestamp")])
if args.verbose > 0:
    version_error = "<NOT INSTALLED>"
    print("[info] Start FreeSurfer HCP recon_all...")
    print("[info] Directory: {0}.".format(args.path))
    print("[info] Subject: {0}.".format(args.subject))
    print("[info] T1w: {0}.".format(args.t1))
    print("[info] T2w: {0}.".format(args.t2))
    print("[info] FSL version: {0}.".format(fsl_version or version_error))
    print("[info] FreeSurfer version: {0}.".format(
        freesurfer_version or version_error))
    print("[info] GradUnWarp version: {0}.".format(
        gradunwarp_version or version_error))
    print("[info] WBCommand version: {0}.".format(
        wbcommand_version or version_error))
    print("[info] HCP version: {0}.".format(hcp_version or version_error))
workdir = args.path
subject = args.subject
subjectdir = os.path.join(workdir, subject)
t1files = args.t1
t2files = args.t2
magfile = args.fmapmag
phasefile = args.fmapphase
nopre = args.nopre
t1wdir = args.t1wdir
params = locals()
inputs = dict([(name, params[name])
               for name in ("subject", "t1files", "subjectdir", "t2files",
                            "workdir", "nopre", "t1wdir")])
outputs = None
if args.erase and os.path.isdir(subjectdir):
    shutil.rmtree(subjectdir)
if not os.path.isdir(subjectdir):
    os.mkdir(subjectdir)


"""
Perform or not the PreFreeSurfer pipeline.
"""
if nopre:

    """
    Copy the input T1w/MNINonLinear PreFreeSurfer output directories.
    """
    t1w_folder = os.path.join(subjectdir, "T1w")
    if not os.path.isdir(t1w_folder):
        shutil.copytree(t1wdir, t1w_folder)
    mninonlinear_folder = os.path.join(subjectdir, "MNINonLinear")
    mninonlineardir = os.path.join(os.path.dirname(t1wdir), "MNINonLinear")
    if not os.path.isdir(mninonlinear_folder):
        shutil.copytree(mninonlineardir, mninonlinear_folder)


    """
    Restore the PreFreeSurfer pipeline outputs.
    """
    t1_img = os.path.join(t1w_folder, "T1w_acpc_dc_restore.nii.gz")
    t1_img_brain = os.path.join(t1w_folder, "T1w_acpc_dc_restore_brain.nii.gz")
    t2_img = os.path.join(t1w_folder, "T2w_acpc_dc_restore.nii.gz")

    
    """
    Remove FreeSurfer folder if created.
    """
    fssubjdir = os.path.join(t1w_folder, subject)
    if os.path.isdir(fssubjdir):
        shutil.rmtree(fssubjdir)

else:

    """
    Copy/rename the input images in the HCP working directory.
    """
    t1_basename = "{0}_{1}_T1w_MPR{2}"
    t2_basename = "{0}_{1}_T2w_SPC{2}"
    mag_basename = "{0}_{1}_FieldMap_Magnitude"
    phase_basename = "{0}_{1}_T1w_FieldMap_Phase"
    for files, basename in [(t1files, t1_basename), (t2files, t2_basename),
                            (magfile, mag_basename), (phasefile, phase_basename)]:
        if not isinstance(files, list):
            files = [files]
        for cnt, path in enumerate(files):
            if path == "NONE":
                continue
            extensions = os.path.basename(path).split(".")[1:]
            extension = "." + ".".join(extensions)
            out_file = os.path.join(
                subjectdir,
                basename.format(subject, args.field, cnt + 1) + extension)
            shutil.copy2(path, out_file)


    """
    PreFreeSurfer Pipeline.
    """

    t1w_folder, t1_img, t1_img_brain, t2_img = prefreesurfer_hcp(
        path=args.path,
        subject=args.subject,
        t1=args.t1,
        t2=args.t2,
        fmapmag=args.fmapmag,
        fmapphase=args.fmapphase,
        hcpdir=args.hcpdir,
        brainsize=args.brainsize,
        fmapgeneralelectric=args.fmapgeneralelectric,
        echodiff=args.echodiff,
        SEPhaseNeg=args.SEPhaseNeg,
        SEPhasePos=args.SEPhasePos,
        echospacing=args.echospacing,
        seunwarpdir=args.seunwarpdir,
        t1samplespacing=args.t1samplespacing,
        t2samplespacing=args.t2samplespacing,
        unwarpdir=args.unwarpdir,
        gdcoeffs=args.gdcoeffs,
        avgrdcmethod=args.avgrdcmethod,
        topupconfig=args.topupconfig,
        wbcommand=wbcommand,
        fslconfig=fsl_config,
        fsconfig=freesurfer_config)


"""
FreeSurfer Pipeline
"""

freesurfer_hcp(
    subject=args.subject,
    t1w_folder=t1w_folder,
    t1_img=t1_img,
    t1_img_brain=t1_img_brain,
    t2_img=t2_img,
    hcpdir=args.hcpdir,
    wbcommand=wbcommand,
    fslconfig=fsl_config,
    fsconfig=freesurfer_config)


"""
PostFreeSurfer Pipeline
"""

postfreesurfer_hcp(
    path=args.path,
    subject=args.subject,
    hcpdir=args.hcpdir,
    wbcommand=args.wbcommand,
    fslconfig=fsl_config,
    fsconfig=freesurfer_config)


"""
Update the outputs and save them and the inputs in a 'logs' directory.
"""

logdir = os.path.join(subjectdir, "logs")
if not os.path.isdir(logdir):
    os.mkdir(logdir)
params = locals()
outputs = dict([(name, params[name]) for name in ("t1w_folder", )])
for name, final_struct in [("inputs", inputs), ("outputs", outputs),
                           ("runtime", runtime)]:
    log_file = os.path.join(logdir, "{0}.json".format(name))
    with open(log_file, "wt") as open_file:
        json.dump(final_struct, open_file, sort_keys=True, check_circular=True,
                  indent=4)
if args.verbose > 1:
    print("[final]")
    pprint(outputs)
