#!/usr/bin/env python
import matplotlib
import os

import matplotlib.pyplot as plt

import argparse
import datetime
import json
import logging
import time

from copy import deepcopy

import numpy as np
import scipy.special

from tart.operation import settings

from tart_tools import api_handler
from tart_tools import api_imaging
from tart.imaging import elaz

import dask.array as da
from dask.distributed import Client, progress

from disko import DiSkO, get_source_list, TelescopeOperator, vis_to_real, MultivariateGaussian, create_fov


logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler()) # Add other handlers if you're using this as a library
logger.setLevel(logging.INFO)


def handle_bayes(ARGS):
    sphere = create_fov(ARGS.nside, ARGS.fov, ARGS.arcmin)

    if ARGS.file:
        logger.info("Getting Data from file: {}".format(ARGS.file))
        # Load data from a JSON file
        with open(ARGS.file, 'r') as json_file:
            calib_info = json.load(json_file)

        info = calib_info['info']
        ant_pos = calib_info['ant_pos']
        config = settings.from_api_json(info['info'], ant_pos)

        flag_list = [] # [4, 5, 14, 22]

        original_positions = deepcopy(config.get_antenna_positions())

        gains_json = calib_info['gains']
        gains = np.asarray(gains_json['gain'])
        phase_offsets = np.asarray(gains_json['phase_offset'])
        config = settings.from_api_json(info['info'], ant_pos)
    
        measurements = []
        for d in calib_info['data']:
            vis_json, source_json = d
            cv, timestamp = api_imaging.vis_calibrated(vis_json, config, gains, phase_offsets, flag_list)
            src_list = elaz.from_json(source_json, 0.0)
        disko = DiSkO.from_cal_vis(cv)
    else:
        logger.info("Getting Data from MS file: {}".format(ARGS.ms))
        disko = DiSkO.from_ms(ARGS.ms, ARGS.nvis, res_arcmin=sphere.res_arcmin, channel=ARGS.channel, field_id=ARGS.field)
        # Convert from reduced Julian Date to timestamp.
        timestamp = disko.timestamp
        src_list = None
        
    if not ARGS.show_sources:
        src_list = None

    time_repr = "{:%Y_%m_%d_%H_%M_%S_%Z}".format(timestamp)

    # Processing
    
    real_vis = vis_to_real(disko.vis_arr)
    
    
    ##
    #
    # Do the inference, get the SVD.
    #
    ##
    
    to = TelescopeOperator(disko, sphere)
    
    if ARGS.uv:
        to.plot_uv(ARGS.title)

    if ARGS.prior is not None:
        prior = MultivariateGaussian.from_hdf5(ARGS.prior)
    else:
        prior = to.get_prior() # in the image space.
    
    # Transform to the natural basis.
    prior =  prior.linear_transform(to.Vh)
    
    
    n_v = real_vis.shape[0]
    
    # TODO create a proper covariance that ensures the real and imaginary components are linked.
    if ARGS.sigma_v is None:
        diag = np.diag(disko.rms**2)
        logger.info("Using measurement set sigma {}".format(np.percentile(disko.rms, [5,50,95])))
    else:
        diag = np.diag(np.ones(n_v // 2)*(ARGS.sigma_v)**2)
        logger.info("Using supplied sigma {}".format(ARGS.sigma_v))
    
    sigma_vis = np.block([[diag, 0.5*diag],[0.5*diag, diag]]) # .rechunk('auto')

    # now invert sigma_vis
    sigma_precision = MultivariateGaussian.sp_inv(sigma_vis)
    del sigma_vis
    
    # Pull the block from the natural_prior that is the range_space prior
    if True:
        prior_r = prior.block(0,to.rank)
        prior_n = prior.block(to.rank,to.n_s)

        A_r = to.A_r
        V = to.V
        
        del to
        posterior_r = prior_r.bayes_update(sigma_precision, real_vis, A_r)
        posterior_n = prior_n
        
        del A_r
        del sigma_precision
        del prior_r
        del prior_n
        del prior

        posterior = MultivariateGaussian.outer(posterior_r, posterior_n)
        
        del posterior_r
        del posterior_n
        
        logger.info("Transforming posterior")
        
        posterior = posterior.linear_transform(V)
                
        del V
    else:
        posterior = to.sequential_inference(prior, real_vis, sigma_precision)
        del to
        del sigma_precision
        del prior

    
    
    # Now save the files.
    if ARGS.posterior is not None:
        posterior.to_hdf5(ARGS.posterior)

    def path(ending, image_title):
        os.makedirs(ARGS.dir, exist_ok=True)
        fname = '{}.{}'.format(image_title, ending)
        return os.path.join(ARGS.dir, fname)

    def save_images(image_title, source_list):
        # Save as a FITS file
        if ARGS.FITS:
            sphere.to_fits(fname=path('fits', image_title), fov=ARGS.fov, info=disko.info)
        
        if ARGS.SVG:
            fname = path('svg', image_title)
            sphere.to_svg(fname=fname, show_grid=True, src_list=source_list, fov=ARGS.fov, title=image_title, show_cbar=True)
            logger.info("Generating {}".format(fname))
        if ARGS.PNG:
            fname = path('png', image_title)
            sphere.plot(plt, source_list)
            plt.title(image_title)
            plt.tight_layout()
            plt.savefig(fname, dpi=300)
            plt.close()
            logger.info("Generating {}".format(fname))
        if ARGS.PDF:
            fname = path('pdf', image_title)
            sphere.plot(plt, source_list)
            plt.title(image_title)
            plt.savefig(fname, dpi=600)
            plt.close()
            logger.info("Generating {}".format(fname))

    if ARGS.PDF or ARGS.PNG or ARGS.SVG or ARGS.FITS: 


        if ARGS.mu:
            logger.info("Computing pixels")
            tic = time.perf_counter()    
            #mu_positive = np.array(da.clip(posterior.mu, 0, None))
            logger.info(f"    Took {time.perf_counter() - tic:0.4f} seconds")
            stat = sphere.set_visible_pixels(np.array(posterior.mu), scale=False)
            stat['sigma-v'] = ARGS.sigma_v
            logger.info(json.dumps(stat, sort_keys=True))
            save_images('{}_{}_mu'.format(ARGS.title, time_repr), source_list=src_list)

        if ARGS.var:
            tic = time.perf_counter()    
            logger.info("Computing variance...")
            variance = np.array(posterior.variance())
            logger.info(f"    Took {time.perf_counter() - tic:0.4f} seconds")
            sphere.set_visible_pixels(variance, scale=False)
            save_images('{}_{}_var'.format(ARGS.title, time_repr), source_list=None)
        
        if ARGS.pcf:
            tic = time.perf_counter()    
            logger.info("Computing point covariance...")
            
            brightest_pixel = np.argmax(posterior.mu)
            pix_cov=np.array(posterior.sigma()[brightest_pixel,:])
            logger.info(f"    Took {time.perf_counter() - tic:0.4f} seconds")

            sphere.set_visible_pixels(pix_cov, scale=False)
            save_images('{}_{}_pcf'.format(ARGS.title, time_repr), source_list=None)

        for i in range(ARGS.nsamples):
            sphere.set_visible_pixels(posterior.sample(), scale=False)
            save_images(image_title = '{}_{}_s{:0>5}'.format(ARGS.title, time_repr, i), source_list=None)


if __name__ == '__main__':
    
    np.random.seed(42)
    
    parser = argparse.ArgumentParser(description='DiSkO: Bayesian inference of a posterior sky', 
                                    formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--ms', required=False, default=None, help="visibility file")
    parser.add_argument('--file', required=False, default=None, help="Snapshot observation saved JSON file (visiblities, positions and more).")
    parser.add_argument('--channel', type=int, default=0, help="Use this frequency channel.")
    parser.add_argument('--field', type=int, default=0, help="Use this FIELD_ID from the measurement set.")

    parser.add_argument('--dir', required=False, default='.', help="Output directory.")
    parser.add_argument('--nvis', type=int, default=1000, help="Number of visibilities to use.")
    parser.add_argument('--arcmin', type=float, default=None, help="Highest allowed res of the sky in arc minutes.")

    parser.add_argument('--fov', type=float, default=180.0, help="Field of view in degrees")
    parser.add_argument('--nside', type=int, default=None, help="Healpix nside parameter for display purposes only.")

    parser.add_argument('--sigma-v', type=float, default=None, help="Diagonal components of the visibility covariance. If not supplied use measurement set values")

    parser.add_argument('--PNG', action="store_true", help="Generate a PNG format image.")
    parser.add_argument('--PDF', action="store_true", help="Generate a PDF format image.")
    parser.add_argument('--SVG', action="store_true", help="Generate a SVG format image.")
    parser.add_argument('--FITS', action="store_true", help="Generate a FITS format image.")
    parser.add_argument('--show-sources', action="store_true", help="Show known sources on images (only works on PNG & SVG).")

    parser.add_argument('--prior', type=str, default=None, help="Load the from an HDF5 file.")
    parser.add_argument('--posterior', type=str, default=None, help="Store the posterior in HDF5 format file.")

    parser.add_argument('--uv', action="store_true", help="Plot the UV coverage.")
    parser.add_argument('--mu', action="store_true", help="Save the mean image.")
    parser.add_argument('--pcf', action="store_true", help="Save the point covariance function image.")
    parser.add_argument('--var', action="store_true", help="Save the pixel variance image.")
    parser.add_argument('--nsamples', type=int, default=0, help="Number of samples to save from the posterior.")

    parser.add_argument('--title', required=False, default="disko", help="Prefix the output files.")

    source_json = None


    log_fmt = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
    logging.basicConfig(format=log_fmt, level=logging.INFO)

    root = logging.getLogger()
    
    fh = logging.FileHandler('disko.log')
    #fh.setLevel(logging.INFO)
    
    # create console handler and set level to debug
    ch = logging.StreamHandler()
    #ch.setLevel(logging.INFO)

    # create formatter
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

    # add formatter to ch
    fh.setFormatter(formatter)

    # add ch to logger
    #root.addHandler(ch)
    root.addHandler(fh)

    #client = Client()

    handle_bayes(parser.parse_args())

    #client.close()
    #local_cluster.close()
