#!/usr/bin/env python

#    
#    Copyright (C) 2017  Tristram Lett

#    This program is free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation, either version 3 of the License, or
#    (at your option) any later version.

#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.

#    You should have received a copy of the GNU General Public License
#    along with this program.  If not, see <http://www.gnu.org/licenses/>.

from __future__ import division
import os
import sys
import numpy as np
import argparse as ap
import nibabel as nib
import matplotlib.colors as colors

from tfce_mediation.pyfunc import write_colorbar

from tmi_viewer.tv_func import check_byteorder, apply_affine_to_scalar_field, nonempty_coordinate_range, display_matplotlib_luts, get_cmap_array, correct_image, concate_images, draw_outline, autothreshold, make_slice_html, write_dict, add_text_to_img
from tmi_viewer.version import __version__

# hopefully safer loading of mayavi
if 'QT_API' not in os.environ:
	os.environ['QT_API'] = 'pyqt'
try:
	from mayavi import mlab
except:
	print "Trying pyside"
	os.environ['QT_API'] = 'pyside'
	from mayavi import mlab


DESCRIPTION = "Outputs a webpage with whole brain slices from voxel based neuroimages in native coordinates with optional overlaps."

def getArgumentParser(ap = ap.ArgumentParser(description = DESCRIPTION)):
	ap.add_argument("-o", "--outputname",
		help = "[Required] Output basename", 
		nargs = 1,
		type = str,
		metavar = 'str')
	group = ap.add_mutually_exclusive_group(required=True)
	group.add_argument("-iv", "--importvoxelimage",
		help = "Input -iv for each image. The image order depends on the input order. Required: (Image | LUT). Optional: (Lower threshold | high threshold). E.g., -iv MNI_T1_1mm_brain.nii.gz binary_r -iv stats.nii.gz tm-sunset_r .95 1 -iv negstats.nii.gz tm-breeze_r .95 1", 
		nargs = '+',
		type = str,
		action = 'append',
		metavar = '*')
	ap.add_argument("-imo", "--addmaskoutline",
		help = "Import a volume, binarize it, and paint its outline on each png file. Input volume, and optionally the lower threshold.",
		nargs = '+',
		metavar = ['img','lower_threshold'])
	ap.add_argument("-sn", "--slicenumber",
		help = "[Optional] Select the slice number for every imported voxel image (starting at 0). The number of inputs must match '-iv's. If the image is 3D, use 0. E.g., -sn 0 0 4", 
		nargs = '+',
		type = int,
		metavar = 'int')
	ap.add_argument("--sign",
		help = "Set the sign (direction of effect) for input volume. The number of inputs must match '-iv's. E.g., --sign 1 1 -1", 
		nargs = '+',
		type = int,
		choices = [1,-1],
		metavar = 'INT')
	ap.add_argument("--alpha",
		help = "Set the alpha (transparency) for input volume ranging from 0 to 1. The number of inputs must match '-iv's. E.g., --alpha 1 0.5 1", 
		nargs = '+',
		type = float,
		metavar = 'FLOAT')
	ap.add_argument("-ns", "--numslices",
		help = "The number of slices. Default = %(default)s", 
		nargs = 1,
		type = int,
		default = [12],
		metavar = 'int')
	ap.add_argument("-k", "--keepimages",
		help = "Do not cleanup intermediate png files for each slice",
		action = 'store_true')
	ap.add_argument("-sm", "--savevoxelmask",
		help = "Saves a nifti file with the thresholded mask. Must be used with -imo argument",
		action = 'store_true')
	ap.add_argument("-ttype", "--thesholdingtype",
		help = "Method used to set the the lower threshold if thresholds are not supplied (Default is otsu).",
		choices = ['otsu', 'otsu_p', 'li', 'li_p', 'yen', 'yen_p', 'zscore', 'zscore_p'])
	ap.add_argument("--slicesize",
		help = "Default height and width of each output slice. Default is %(default)s pixels (i.e. 400x400)",
		nargs = 1,
		type = int,
		default = [400],
		metavar = 'int')
	ap.add_argument("-hc", "--hardcodecoordinates",
		help = "Writes the slices coordinates to the image file.",
		action = 'store_true')
	ap.add_argument("--interpolation",
		help = "Specify the reslice interpolation (choices are: %(choices)s). The default is nearest_neighbour",
		nargs = 1,
		choices = ['cubic','linear', 'nearest_neighbour'],
		default = ['nearest_neighbour'],
		metavar = 'str')
	ap.add_argument("--zthresh",
		help = "The z value to use for thresholding. Default = %(default)s.",
		nargs = 1,
		type = float,
		default = [2.3264],
		metavar = 'float')

	group.add_argument("--plotluts",
		help = "Plots the avalable lookup tables (LUT), and exits. Note, all LUTs can be reversed by adding '_r' after the name. For exaple, tm-storm_r is the reverse of the LUT tm-storm.",
		action = 'store_true')
	ap.add_argument('-v', '--version', 
		action = 'version', 
		version = '%(prog)s ' + __version__)
	return ap

def run(opts):

	if opts.plotluts:
		display_matplotlib_luts()
		sys.exit()

	if opts.importvoxelimage:
		if opts.outputname is None:
			print "Output name is required"
			sys.exit()
		if opts.thesholdingtype is None:
			threshtype = 'otsu'
		elif opts.thesholdingtype == 'otsu':
			threshtype = 'otsu'
		elif opts.thesholdingtype == 'li':
			threshtype = 'li'
		elif opts.thesholdingtype == 'yen':
			threshtype = 'yen'
		elif opts.thesholdingtype == 'otsu_p':
			threshtype = 'otsu_p'
		elif opts.thesholdingtype == 'li_p':
			threshtype = 'li_p'
		elif opts.thesholdingtype == 'yen_p':
			threshtype = 'yen_p'
		elif opts.thesholdingtype == 'zscore_p':
			threshtype = 'zscore_p'
		else:
			threshtype = 'zscore'

		surf = []
		numslices = opts.numslices[0]
		numpixels = opts.slicesize[0]

		for i, opts_vi in enumerate(opts.importvoxelimage):

			# reorientate input volume
			imgname = opts_vi[0]
			invol = nib.as_closest_canonical(nib.load(imgname))

			if opts.alpha is not None:
				if len(opts.alpha) == len(opts.importvoxelimage):
					c_map = get_cmap_array(opts_vi[1], background_alpha = 0, image_alpha = float(opts.alpha[i]))
				else:
					print "Error: the number of -iv (%d) must match the length of --alpha (%d)" % (len(opts.importvoxelimage),len(opts.alpha))
					quit()
			else:
				c_map = get_cmap_array(opts_vi[1], background_alpha = 0)

			# Load and check data
			data = check_byteorder(invol.get_data())
			if opts.sign is not None:
				if len(opts.sign) == len(opts.importvoxelimage):
					data *= int(opts.sign[i])
				else:
					print "Error: the number of -iv (%d) must match the length of --sign (%d)" % (len(opts.importvoxelimage),len(opts.sign))
					quit()
			if opts.slicenumber is not None:
				if len(opts.slicenumber) == len(opts.importvoxelimage):
					if data.ndim == 4:
						data *= int(opts.slicenumber[i])
				else:
					print "Error: the number of -iv (%d) must match the length of -sn (%d)" % (len(opts.importvoxelimage),len(opts.slicenumber))
					quit()
			else:
				if data.ndim == 4:
					print "Detected %s as a 4D volume. Extracting the first volume." % imgname
					data = data[:,:,:,0]

			# Set thresholds
			if len(opts_vi) == 4:
				lthres = float(opts_vi[2])
				uthres = float(opts_vi[3])
			elif len(opts_vi) == 2:
				lthres, uthres = autothreshold(data, threshtype, opts.zthresh[0])
				print "%s Lower Threshold = %1.2f" % (os.path.basename(opts_vi[0]), lthres)
				print "%s Upper Threshold = %1.2f" % (os.path.basename(opts_vi[0]), uthres)
			else:
				print "Error: there must be two or four inputs with -iv"
				quit()

			# Create surfaces
			x_rng, y_rng, z_rng = nonempty_coordinate_range(data, invol.affine)
			if i == 0:
				x_space = np.round(np.linspace(x_rng[0]+5, x_rng[1]-5, numslices))
				y_space = np.round(np.linspace(y_rng[0]+5, y_rng[1]-5, numslices))
				z_space = np.round(np.linspace(x_rng[0]+1, x_rng[1]-1, numslices))
			scalar_field = apply_affine_to_scalar_field(data, invol.affine)

			surf.append(mlab.pipeline.image_plane_widget(scalar_field,
				plane_orientation = 'x_axes',
				vmin = lthres,
				vmax = uthres,
				plane_opacity = 0,
				name = "X_%s" % imgname,
				slice_index = int(data.shape[1]/2)))
			surf[i].module_manager.scalar_lut_manager.lut.table = c_map
			surf[i].ipw.reslice_interpolate = str(opts.interpolation[0])
			surf[i].scene.x_plus_view()
			rl_cmap = colors.ListedColormap(c_map[:,0:3]/255)
			write_colorbar(np.array((lthres,uthres)), rl_cmap, opts_vi[1])
			os.system("mv %s_colorbar.png %d_colorbar.png" % (opts_vi[1], i))

		count = 0
		for xval in x_space:
			for i in range(len(opts.importvoxelimage)):
				surf[i].ipw.slice_position = xval
				surf[i].scene.x_plus_view()
			surf[0].scene.background = (0,0,0)
			mlab.savefig('X_Slice_%d.png' % count, size = (numpixels,numpixels))
			correct_image('X_Slice_%d.png' % count)
			if opts.hardcodecoordinates:
				add_text_to_img('X_Slice_%d.png' % count, "X = %.0f" % xval)
			count += 1

		count = 0
		for yval in y_space:
			for i in range(len(opts.importvoxelimage)):
				surf[i].ipw.plane_orientation = 'y_axes'
				surf[i].ipw.slice_position = yval
				surf[i].scene.y_plus_view()
			surf[0].scene.y_plus_view()
			surf[0].scene.background = (0,0,0)
			mlab.savefig('Y_Slice_%d.png' % count, size = (numpixels,numpixels))
			correct_image('Y_Slice_%d.png' % count, rotate = 90, flip = True)
			if opts.hardcodecoordinates:
				add_text_to_img('Y_Slice_%d.png' % count, "Y = %.0f" % yval)
			count += 1

		count = 0
		for zval in z_space:
			for i in range(len(opts.importvoxelimage)):
				surf[i].ipw.plane_orientation = 'z_axes'
				surf[i].ipw.slice_position = zval
				surf[i].scene.z_plus_view()
			surf[0].scene.z_plus_view()
			surf[0].scene.background = (0,0,0)
			mlab.savefig('Z_Slice_%d.png' % count, size = (numpixels,numpixels))
			correct_image('Z_Slice_%d.png' % count)
			if opts.hardcodecoordinates:
				add_text_to_img('Z_Slice_%d.png' % count, "Z = %.0f" % zval)
			count+=1
	cleanup = True
	if opts.keepimages:
		cleanup = False

	if opts.addmaskoutline:
		mlab.close(all=True)
		invol = nib.as_closest_canonical(nib.load(opts.addmaskoutline[0]))
		data = check_byteorder(invol.get_data())

		if len(opts.addmaskoutline) == 2:
			lthres = float(opts.addmaskoutline[1])
		elif len(opts.addmaskoutline) == 1:
			lthres, _ = autothreshold(data, threshtype, opts.zthresh[0])
		else:
			print "Error"

		data[data <= lthres] = 0
		data[data > lthres] = 1

		if opts.savevoxelmask:
			nib.save(nib.Nifti1Image(data.astype(np.float32, order = "C"), affine=invol.affine),
				"%s_mask.nii.gz" % opts.outputname[0])

		scalar_field = apply_affine_to_scalar_field(data, invol.affine)
		c_map = get_cmap_array('tm-white', 0)

		mask = mlab.pipeline.image_plane_widget(scalar_field,
				plane_orientation='x_axes',
				vmin = 0,
				vmax = 1,
				plane_opacity = 0,
				name = "Mask",
				slice_index = int(data.shape[1]/2))
		mask.module_manager.scalar_lut_manager.lut.table = c_map
		mask.scene.background = (0,0,0)
		mask.scene.x_plus_view()

		count = 0
		for xval in x_space:
			mask.ipw.slice_position = xval
			mask.scene.x_plus_view()
			mask.scene.background = (0,0,0)
			mlab.savefig('X_mask_%d.png' % count, size = (numpixels,numpixels))
			correct_image('X_mask_%d.png' % count)
			draw_outline('X_Slice_%d.png' % count,'X_mask_%d.png' % count)
			count += 1
		count = 0
		for yval in y_space:
			mask.ipw.plane_orientation = 'y_axes'
			mask.ipw.slice_position = yval
			mask.scene.y_plus_view()
			mask.scene.background = (0,0,0)
			mlab.savefig('Y_mask_%d.png' % count, size = (numpixels,numpixels))
			correct_image('Y_mask_%d.png' % count, rotate = 90, flip = True)
			draw_outline('Y_Slice_%d.png' % count,'Y_mask_%d.png' % count)
			count += 1
		count = 0
		for zval in z_space:
			mask.ipw.plane_orientation = 'z_axes'
			mask.ipw.slice_position = zval
			mask.scene.z_plus_view()
			mask.scene.background = (0,0,0)
			mlab.savefig('Z_mask_%d.png' % count, size = (numpixels,numpixels))
			correct_image('Z_mask_%d.png' % count)
			draw_outline('Z_Slice_%d.png' % count,'Z_mask_%d.png' % count)
			count += 1

	coordinates = np.stack((x_space, y_space, z_space)).T
	concate_images('X_Slice', numslices,
		clean = cleanup,
		numpixels = numpixels)
	concate_images('Y_Slice', numslices,
		clean = cleanup,
		numpixels = numpixels)
	concate_images('Z_Slice', numslices,
		clean = cleanup,
		numpixels = numpixels)

	write_coordinates = True
	if opts.hardcodecoordinates:
		write_coordinates = False

	make_slice_html("%s.html" % opts.outputname[0], 
		coordinates,
		opts.importvoxelimage,
		write_coordinates = write_coordinates)

	write_dict(".%s.html/settings" % opts.outputname[0], opts)

if __name__ == "__main__":
	parser = getArgumentParser()
	opts = parser.parse_args()
	run(opts)
