#! /usr/bin/env python
##########################################################################
# NSAp - Copyright (C) CEA, 2013 - 2016
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

# System import
from __future__ import print_function
import os
import glob
import shutil
import argparse
from datetime import datetime
import json
from pprint import pprint
import textwrap
from argparse import RawTextHelpFormatter

# Bredala import
try:
    import bredala
    bredala.USE_PROFILER = False
except:
    pass

# Pyfreesurfer import
from pyfreesurfer import __version__ as version
from pyfreesurfer.wrapper import FSWrapper
from pyfreesurfer import DEFAULT_FREESURFER_PATH

# Third party import
import matplotlib.pyplot as plt


# Parameters to keep trace
__hopla__ = ["runtime", "inputs", "outputs"]


# Script documentation
DOC = """
Freesurfer Euler
~~~~~~~~~~~~~~~~

Compute the Eurler number in irder to QC automatically the dataset.

From: Neuroimage. 2018 Apr 1;169:407-418. doi: 10.1016/j.neuroimage.2017.12.059.
Epub 2017 Dec 24. Quantitative assessment of structural image quality.
Rosen AFG, et al.

Command:

python $HOME/git/pyfreesurfer/pyfreesurfer/scripts/pyfreesurfer_euler \
    -v 2 \
    -c /i2bm/local/freesurfer/SetUpFreeSurfer.sh \
    -d /neurospin/senior/nsap/data/V4/freesurfer \
    -o /neurospin/senior/nsap/data/V4/qc/freesurfer
"""


def is_file(filearg):
    """ Type for argparse - checks that file exists but does not open.
    """
    if not os.path.isfile(filearg):
        raise argparse.ArgumentError(
            "The file '{0}' does not exist!".format(filearg))
    return filearg


def is_directory(dirarg):
    """ Type for argparse - checks that directory exists.
    """
    if not os.path.isdir(dirarg):
        raise argparse.ArgumentError(
            "The directory '{0}' does not exist!".format(dirarg))
    return dirarg


def get_cmd_line_args():
    """
    Create a command line argument parser and return a dict mapping
    <argument name> -> <argument value>.
    """
    parser = argparse.ArgumentParser(
        prog="python pyfreesurfer_stats",
        description=textwrap.dedent(DOC),
        formatter_class=RawTextHelpFormatter)

    # Required arguments
    required = parser.add_argument_group("required arguments")
    required.add_argument(
        "-d", "--fsdir",
        required=True, metavar="PATH", type=is_directory,
        help="the FreeSurfer processing home directory.")
    required.add_argument(
        "-o", "--outdir",
        metavar="PATH", required=True, type=is_directory,
        help="the euler number output directory.")

    # Optional arguments
    parser.add_argument(
        "-v", "--verbose",
        type=int, choices=[0, 1, 2], default=0,
        help="increase the verbosity level: 0 silent, [1, 2] verbose.")
    parser.add_argument(
        "-c", "--config", dest="fsconfig",
        metavar="FILE", type=is_file,
        help="the FreeSurfer configuration file.")

    # Create a dict of arguments to pass to the 'main' function
    args = parser.parse_args()
    if args.fsconfig is None:
        args.fsconfig = DEFAULT_FREESURFER_PATH

    return args


"""
Parse the command line.
"""
args = get_cmd_line_args()
tool = "pyfreesurfer_euler"
timestamp = datetime.now().isoformat()
tool_version = version
freesurfer_config = args.fsconfig
freesurfer_version = FSWrapper([], freesurfer_config).version
params = locals()
runtime = dict([(name, params[name])
               for name in ("freesurfer_config", "tool", "tool_version",
                            "freesurfer_version", "timestamp")])
if args.verbose > 0:
    print("[info] Start FreeSurfer Euler number...")
    print("[info] Directory: {0}.".format(args.fsdir))
fsdir = args.fsdir
outdir = args.outdir
params = locals()
inputs = dict([(name, params[name]) for name in ("fsdir", "outdir")])
outputs = {}


"""
Compute the eurler number for all the subjects
"""
dataset = {}
for hemi in ("lh", "rh"):
    surfs = glob.glob(os.path.join(
        fsdir, "*", "surf", "{0}.orig.nofix".format(hemi)))
    for path in surfs:
        sid = path.split(os.sep)[-3]
        output_file = os.path.join(outdir, "{0}_{1}euler".format(sid, hemi))
        outputs.setdefault("subject_euler_files", []).append(output_file)
        cmd = ["mris_euler_number", path]
        process = FSWrapper(cmd, shfile=freesurfer_config)
        process()
        reporting = process.stderr
        with open(output_file, "wt") as open_file:
            open_file.write(reporting)
        if sid not in dataset:
            dataset[sid] = {}
        euler_number = int(reporting.split("\n")[0].rsplit("=", 1)[1].split("-->")[0])
        if hemi in dataset[sid]:
            raise ValueError("Subject appears multiple times: {0}.".format(
                path))
        dataset[sid][hemi] = euler_number
euler_summary = os.path.join(outdir, "euler_summary.csv")
outputs["euler_summray_file"] = euler_summary
header = ["paraticipant_id", "lh_euler_number", "rh_euler_number",
          "average_euler_number"]
with open(euler_summary, "wt") as open_file:
    open_file.write(",".join(header))
    for sid, struct in dataset.items():
        open_file.write("\n")
        open_file.write(",".join([
            sid, str(struct["lh"]), str(struct["rh"]),
            str((struct["lh"] + struct["rh"]) / 2.)]))
good_data = []
good_data_scores = []
all_scores = []
all_subjects = []
default_threshold = -217
for sid, scores in dataset.items():
    score = (scores["lh"] + scores["rh"]) / 2.
    if score > default_threshold:
        good_data.append(sid)
        good_data_scores.append(score)
    all_scores.append(score)
    all_subjects.append(sid)
extra_subjects = list(set(all_subjects) - set(good_data))
outputs["thres"] = default_threshold
outputs["valid_subjects"] = good_data
ax = plt.subplot(121)
plt.xlabel("Smarts")
plt.ylabel("Probability")
plt.title("Euler Numbers: all socres")
plt.hist(all_scores, 50, facecolor='r', alpha=0.75)
plt.grid(True)
ax = plt.subplot(122)
plt.xlabel("Smarts")
plt.ylabel("Probability")
plt.title("Euler Numbers: under default threshold")
plt.hist(good_data_scores, 50, facecolor='g', alpha=0.75)
plt.grid(True)
euler_hist_file = os.path.join(outdir, "euler_hist.png")
plt.savefig(euler_hist_file, format="png")
plt.close()
outputs["euler_hist_file"] = euler_hist_file


"""
Update the outputs and save them and the inputs in a 'logs' directory.
"""
logdir = os.path.join(outdir, "logs")
if not os.path.isdir(logdir):
    os.mkdir(logdir)
for name, final_struct in [("inputs", inputs), ("outputs", outputs),
                           ("runtime", runtime)]:
    log_file = os.path.join(logdir, "{0}.json".format(name))
    with open(log_file, "wt") as open_file:
        json.dump(final_struct, open_file, sort_keys=True, check_circular=True,
                  indent=4)
if args.verbose > 1:
    print("[final]")
    pprint(outputs)
