#!/usr/bin/env python

#    
#    Copyright (C) 2017  Tristram Lett

#    This program is free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation, either version 3 of the License, or
#    (at your option) any later version.

#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.

#    You should have received a copy of the GNU General Public License
#    along with this program.  If not, see <http://www.gnu.org/licenses/>.

from __future__ import division
import os
import sys
import numpy as np
import argparse as ap
import nibabel as nib
import matplotlib.colors as colors

from tfce_mediation.pyfunc import write_colorbar

from tmi_viewer.tv_func import check_byteorder, apply_affine_to_scalar_field, nonempty_coordinate_range, display_matplotlib_luts, get_cmap_array, correct_image, concate_images, draw_outline, autothreshold, make_slice_html, write_dict, add_text_to_img
from tmi_viewer.version import __version__

# hopefully safer loading of mayavi
if 'QT_API' not in os.environ:
	os.environ['QT_API'] = 'pyqt'
try:
	from mayavi import mlab
except:
	print "Trying pyside"
	os.environ['QT_API'] = 'pyside'
	from mayavi import mlab


DESCRIPTION = "Outputs a webpage with whole brain slices from voxel based neuroimages in native coordinates with optional overlaps."

def getArgumentParser(ap = ap.ArgumentParser(description = DESCRIPTION)):
	ap.add_argument("-o", "--outputname",
		help = "[Required] Output basename", 
		nargs = 1,
		type = str,
		metavar = 'str')
	group = ap.add_mutually_exclusive_group(required=True)
	group.add_argument("-ivb", "--importvoxelbase",
		help = "[Required] Base overlay. Input {img | lut | low thr| high thr }", 
		nargs = '+',
		type = str,
		metavar = '*')
	ap.add_argument("-iv1", "--importvoxel1",
		help = "[Optional] First overlay. Input {img | lut | low thr| high thr }", 
		nargs = '+',
		type = str,
		metavar = '*')
	ap.add_argument("-iv2", "--importvoxel2",
		help = "[Optional] Second overlay. Input {img | lut | low thr| high thr }", 
		nargs = '+',
		type = str,
		metavar = '*')
	ap.add_argument("-imo", "--addmaskoutline",
		help = "Import a volume, binarize it, and paint its outline on each png file. Input volume and lower threshold",
		nargs = '+',
		metavar = ['img','lower_threshold'])

	ap.add_argument("-ns", "--numslices",
		help = "The number of slices. Default = %(default)s", 
		nargs = 1,
		type = int,
		default = [12],
		metavar = 'int')
	ap.add_argument("-k", "--keepimages",
		help = "Do not cleanup intermediate png files for each slice",
		action = 'store_true')
	ap.add_argument("-sm", "--savevoxelmask",
		help = "Saves a nifti file with the thresholded mask. Must be used with -imo argument",
		action = 'store_true')
	ap.add_argument("-ttype", "--thesholdingtype",
		help = "Method used to set the the lower threshold if thresholds are not supplied (Default is Oshu).",
		choices = ['otsu', 'otsu_p', 'li', 'li_p', 'yen', 'yen_p', 'zscore', 'zscore_p'])
	ap.add_argument("--slicesize",
		help = "Default height and width of each output slice. Default is %(default)s pixels (i.e. 400x400)",
		nargs = 1,
		type = int,
		default = [400],
		metavar = 'int')
	ap.add_argument("-hc", "--hardcodecoordinates",
		help = "Writes the slices coordinates to the image file.",
		action = 'store_true')
	ap.add_argument("--interpolation",
		help = "Specify the reslice interpolation (choices are: %(choices)s). The default is nearest_neighbour",
		nargs = 1,
		choices = ['cubic','linear', 'nearest_neighbour'],
		default = ['nearest_neighbour'],
		metavar = 'str')
	ap.add_argument("--zthresh",
		help = "The z value to use for thresholding. Default = %(default)s.",
		nargs = 1,
		type = float,
		default = [2.3264],
		metavar = 'float')

	group.add_argument("--plotluts",
		help = "Plots the avalable lookup tables, and exits.",
		action = 'store_true')
	ap.add_argument('-v', '--version', 
		action = 'version', 
		version = '%(prog)s ' + __version__)
	return ap

def run(opts):

	if opts.plotluts:
		display_matplotlib_luts()
		sys.exit()

	if opts.importvoxelbase:
		if opts.outputname is None:
			print "Output name is required"
			sys.exit()
		if opts.thesholdingtype is None:
			threshtype = 'otsu'
		elif opts.thesholdingtype == 'otsu':
			threshtype = 'otsu'
		elif opts.thesholdingtype == 'li':
			threshtype = 'li'
		elif opts.thesholdingtype == 'yen':
			threshtype = 'yen'
		elif opts.thesholdingtype == 'otsu_p':
			threshtype = 'otsu_p'
		elif opts.thesholdingtype == 'li_p':
			threshtype = 'li_p'
		elif opts.thesholdingtype == 'yen_p':
			threshtype = 'yen_p'
		elif opts.thesholdingtype == 'zscore_p':
			threshtype = 'zscore_p'
		else:
			threshtype = 'zscore'

		imgname = opts.importvoxelbase[0]
		invol = nib.as_closest_canonical(nib.load(imgname))
		c_map = get_cmap_array(opts.importvoxelbase[1], 0)
		numslices = opts.numslices[0]
		numpixels = opts.slicesize[0]
		data = check_byteorder(invol.get_data())
#		data = data[::int(np.sign(invol.affine[0,0])),::int(np.sign(invol.affine[1,1])),::int(np.sign(invol.affine[2,2]))] # why??????

		if len(opts.importvoxelbase) == 4:
			lthres = float(opts.importvoxelbase[2])
			uthres = float(opts.importvoxelbase[3])
		elif len(opts.importvoxelbase) == 2:
			lthres, uthres = autothreshold(data, threshtype, opts.zthresh[0])
			print "%s Lower Threshold = %1.2f" % (os.path.basename(opts.importvoxelbase[0]), lthres)
			print "%s Upper Threshold = %1.2f" % (os.path.basename(opts.importvoxelbase[0]), uthres)
		else:
			print "Error"

		x_rng,  y_rng, z_rng = nonempty_coordinate_range(data, invol.affine)
		scalar_field = apply_affine_to_scalar_field(data, invol.affine)

		surfb = mlab.pipeline.image_plane_widget(scalar_field,
			plane_orientation = 'x_axes',
			vmin = lthres,
			vmax = uthres,
			plane_opacity = 0,
			name = "X_%s" % imgname,
			slice_index = int(data.shape[1]/2))
		surfb.module_manager.scalar_lut_manager.lut.table = c_map
		surfb.ipw.reslice_interpolate = str(opts.interpolation[0])
		surfb.scene.x_plus_view()
		rl_cmap = colors.ListedColormap(c_map[:,0:3]/255)
		write_colorbar(np.array((lthres,uthres)), rl_cmap, opts.importvoxelbase[1])
		os.system("mv %s_colorbar.png 0_colorbar.png" % opts.importvoxelbase[1])

		if opts.importvoxel1:

			imgname = opts.importvoxel1[0]
			invol = nib.as_closest_canonical(nib.load(imgname))
			c_map = get_cmap_array(opts.importvoxel1[1], 0)
			data = check_byteorder(invol.get_data())

			if len(opts.importvoxel1) == 4:
				lthres = float(opts.importvoxel1[2])
				uthres = float(opts.importvoxel1[3])
			elif len(opts.importvoxel1) == 2:
				lthres, uthres = autothreshold(data, threshtype, opts.zthresh[0])
				print "%s Lower Threshold = %1.2f" % (os.path.basename(opts.importvoxel1[0]), lthres)
				print "%s Upper Threshold = %1.2f" % (os.path.basename(opts.importvoxel1[0]), uthres)
			else:
				print "Error"

			scalar_field = apply_affine_to_scalar_field(data, invol.affine)
			surf1 = mlab.pipeline.image_plane_widget(scalar_field,
				plane_orientation='x_axes',
				vmin = lthres,
				vmax = uthres,
				plane_opacity = 0,
				name = "X_%s" % imgname,
				slice_index = int(data.shape[1]/2))
			surf1.module_manager.scalar_lut_manager.lut.table = c_map
			surf1.ipw.reslice_interpolate = str(opts.interpolation[0])
			surf1.scene.x_plus_view()
			rl_cmap = colors.ListedColormap(c_map[:,0:3]/255)
			write_colorbar(np.array((lthres,uthres)), rl_cmap, opts.importvoxel1[1])
			os.system("mv %s_colorbar.png 1_colorbar.png" % opts.importvoxel1[1])

		if opts.importvoxel2:

			imgname = opts.importvoxel2[0]
			invol = nib.as_closest_canonical(nib.load(imgname))
			c_map = get_cmap_array(opts.importvoxel2[1], 0)
			data = check_byteorder(invol.get_data())

			if len(opts.importvoxel2) == 4:
				lthres = float(opts.importvoxel2[2])
				uthres = float(opts.importvoxel2[3])
			elif len(opts.importvoxel2) == 2:
				lthres, uthres = autothreshold(data, threshtype, opts.zthresh[0])
				print "%s Lower Threshold = %1.2f" % (os.path.basename(opts.importvoxel2[0]), lthres)
				print "%s Upper Threshold = %1.2f" % (os.path.basename(opts.importvoxel2[0]), uthres)
			else:
				print "Error"

			scalar_field = apply_affine_to_scalar_field(data, invol.affine)
			surf2 = mlab.pipeline.image_plane_widget(scalar_field,
				plane_orientation='x_axes',
				vmin = lthres,
				vmax = uthres,
				plane_opacity = 0,
				name = "X_%s" % imgname,
				slice_index = int(data.shape[1]/2))
			surf2.module_manager.scalar_lut_manager.lut.table = c_map
			surf2.ipw.reslice_interpolate = str(opts.interpolation[0])
			surf2.scene.x_plus_view()
			rl_cmap = colors.ListedColormap(c_map[:,0:3]/255)
			write_colorbar(np.array((lthres,uthres)), rl_cmap, opts.importvoxel2[1])
			os.system("mv %s_colorbar.png 2_colorbar.png" % opts.importvoxel2[1])



		count = 0
		for xval in np.round(np.linspace(x_rng[0]+1, x_rng[1]-5, numslices)):
			surfb.ipw.slice_position = xval
			if opts.importvoxel1:
				surf1.ipw.slice_position = xval
				surf1.scene.x_plus_view()
			if opts.importvoxel2:
				surf2.ipw.slice_position = xval
				surf2.scene.x_plus_view()
			surfb.scene.x_plus_view()
			surfb.scene.background = (0,0,0)
			mlab.savefig('X_Slice_%d.png' % count, size = (numpixels,numpixels))
			correct_image('X_Slice_%d.png' % count)
			if opts.hardcodecoordinates:
				add_text_to_img('X_Slice_%d.png' % count, "X = %.0f" % xval)
			count += 1

		count = 0
		for yval in np.round(np.linspace(y_rng[0]+1, y_rng[1]-1, numslices)):
			surfb.ipw.plane_orientation = 'y_axes'
			surfb.ipw.slice_position = yval
			if opts.importvoxel1:
				surf1.ipw.plane_orientation = 'y_axes'
				surf1.ipw.slice_position = yval
				surf1.scene.y_plus_view()
			if opts.importvoxel2:
				surf2.ipw.plane_orientation = 'y_axes'
				surf2.ipw.slice_position = yval
				surf2.scene.y_plus_view()
			surfb.scene.y_plus_view()
			surfb.scene.background = (0,0,0)
			mlab.savefig('Y_Slice_%d.png' % count, size = (numpixels,numpixels))
			correct_image('Y_Slice_%d.png' % count, rotate = 90, flip = True)
			if opts.hardcodecoordinates:
				add_text_to_img('Y_Slice_%d.png' % count, "Y = %.0f" % yval)
			count += 1

		count = 0
		for zval in np.round(np.linspace(x_rng[0]+1, x_rng[1]-1, numslices)):
			surfb.ipw.plane_orientation = 'z_axes'
			surfb.ipw.slice_position = zval
			if opts.importvoxel1:
				surf1.ipw.plane_orientation = 'z_axes'
				surf1.ipw.slice_position = zval
				surf1.scene.z_plus_view()
			if opts.importvoxel2:
				surf2.ipw.plane_orientation = 'z_axes'
				surf2.ipw.slice_position = zval
				surf2.scene.z_plus_view()
			surfb.scene.z_plus_view()
			surfb.scene.background = (0,0,0)
			mlab.savefig('Z_Slice_%d.png' % count, size = (numpixels,numpixels))
			correct_image('Z_Slice_%d.png' % count)
			if opts.hardcodecoordinates:
				add_text_to_img('Z_Slice_%d.png' % count, "Z = %.0f" % zval)
			count+=1
	cleanup = True
	if opts.keepimages:
		cleanup = False

	if opts.addmaskoutline:
		mlab.close(all=True)
		invol = nib.load(opts.addmaskoutline[0])

		data = check_byteorder(invol.get_data())

		if len(opts.addmaskoutline) == 2:
			lthres = float(opts.addmaskoutline[1])
		elif len(opts.addmaskoutline) == 1:
			lthres, _ = autothreshold(data, threshtype, opts.zthresh[0])
		else:
			print "Error"

		data[data <= lthres] = 0
		data[data > lthres] = 1

		if opts.savevoxelmask:
			nib.save(nib.Nifti1Image(data.astype(np.float32, order = "C"), affine=invol.affine),
				"%s_mask.nii.gz" % opts.outputname[0])

		scalar_field = apply_affine_to_scalar_field(data, invol.affine)
		c_map = get_cmap_array('tm-white', 0)

		mask = mlab.pipeline.image_plane_widget(scalar_field,
				plane_orientation='x_axes',
				vmin = 0,
				vmax = 1,
				plane_opacity = 0,
				name = "Mask",
				slice_index = int(data.shape[1]/2))
		mask.module_manager.scalar_lut_manager.lut.table = c_map
		mask.scene.background = (0,0,0)
		mask.scene.x_plus_view()

		count = 0
		for xval in np.round(np.linspace(x_rng[0]+1, x_rng[1]-5, numslices)):
			mask.ipw.slice_position = xval
			mask.scene.x_plus_view()
			mask.scene.background = (0,0,0)
			mlab.savefig('X_mask_%d.png' % count, size = (numpixels,numpixels))
			correct_image('X_mask_%d.png' % count)
			draw_outline('X_Slice_%d.png' % count,'X_mask_%d.png' % count)
			count += 1
		count = 0
		for yval in np.round(np.linspace(y_rng[0]+1, y_rng[1]-1, numslices)):
			mask.ipw.plane_orientation = 'y_axes'
			mask.ipw.slice_position = yval
			mask.scene.y_plus_view()
			mask.scene.background = (0,0,0)
			mlab.savefig('Y_mask_%d.png' % count, size = (numpixels,numpixels))
			correct_image('Y_mask_%d.png' % count, rotate = 90, flip = True)
			draw_outline('Y_Slice_%d.png' % count,'Y_mask_%d.png' % count)
			count += 1
		count = 0
		for zval in np.round(np.linspace(x_rng[0]+1, x_rng[1]-1, numslices)):
			mask.ipw.plane_orientation = 'z_axes'
			mask.ipw.slice_position = zval
			mask.scene.z_plus_view()
			mask.scene.background = (0,0,0)
			mlab.savefig('Z_mask_%d.png' % count, size = (numpixels,numpixels))
			correct_image('Z_mask_%d.png' % count)
			draw_outline('Z_Slice_%d.png' % count,'Z_mask_%d.png' % count)
			count += 1

	coordinates = np.stack((np.round(np.linspace(x_rng[0]+1, x_rng[1]-5,numslices)), 
		np.round(np.linspace(y_rng[0]+1, y_rng[1]-1, numslices)),
		np.round(np.linspace(z_rng[0]+1, z_rng[1]-1, numslices)))).T
	concate_images('X_Slice', numslices,
		clean = cleanup,
		numpixels = numpixels)
	concate_images('Y_Slice', numslices,
		clean = cleanup,
		numpixels = numpixels)
	concate_images('Z_Slice', numslices,
		clean = cleanup,
		numpixels = numpixels)

	write_coordinates = True
	if opts.hardcodecoordinates:
		write_coordinates = False

	make_slice_html("%s.html" % opts.outputname[0], 
		coordinates,
		opts.importvoxel1,
		opts.importvoxel2,
		write_coordinates = write_coordinates)

	write_dict(".%s.html/settings" % opts.outputname[0], opts)

if __name__ == "__main__":
	parser = getArgumentParser()
	opts = parser.parse_args()
	run(opts)
