#!/usr/bin/env python
"""
Analyze pairs of closely-eluting spectra with similar precursor and fragments,
infer precursor and fragment error, and transform those values into values usable
as tolerance parameters by search engines.

If multiple files are specified, they will be processed together.
"""

import argparse
import logging
from datetime import datetime
from parammedic import errorcalc
from parammedic import ms2_io
from parammedic import mzml_io
from parammedic import __version__
import gzip

__author__ = "Damon May"
__copyright__ = "Copyright (c) 2016 Damon May"
__license__ = "Apache 2.0"

logger = logging.getLogger(__name__)


def declare_gather_args():
    """
    Declare all arguments, parse them, and return the args dict.
    Does no validation beyond the implicit validation done by argparse.
    return: a dict mapping arg names to values
    """

    # declare args
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('infiles', type=argparse.FileType('r'), nargs='+',
                        help='input .mzML or .ms2 file(s). If multiple files are specified, \
                        they will be processed together')

    parser.add_argument('--min-precursor-mz', type=float_greaterthanequalto0_type,
                        default=errorcalc.DEFAULT_MIN_MZ_FOR_BIN_PRECURSOR,
                        help='minimum precursor m/z value to use')
    parser.add_argument('--max-precursor-mz', type=float_greaterthanequalto0_type,
                        default=errorcalc.DEFAULT_MAX_MZ_FOR_BIN_PRECURSOR,
                        help='maximum precursor m/z value to use')
    parser.add_argument('--min-frag-mz', type=float_greaterthanequalto0_type,
                        default=errorcalc.DEFAULT_MIN_MZ_FOR_BIN_FRAGMENT,
                        help='minimum fragment m/z value to use')
    parser.add_argument('--max-frag-mz', type=float_greaterthanequalto0_type,
                        default=errorcalc.DEFAULT_MAX_MZ_FOR_BIN_FRAGMENT,
                        help='maximum fragment m/z value to use')
    parser.add_argument('--max-precursor-delta-ppm', type=float_greaterthanequalto0_type,
                        default=errorcalc.DEFAULT_MAX_PRECURSORDIST_PPM,
                        help='maximum ppm distance between precursor m/z values to consider two scans potentially \
                        generated by the same peptide')
    parser.add_argument('--charge', type=int_greaterthanequalto1_type,
                        default=errorcalc.DEFAULT_CHARGE,
                        help='charge state to consider MS/MS spectra from')
    parser.add_argument('--min-scan-frag-peaks', type=int_greaterthanequalto1_type,
                        default=errorcalc.DEFAULT_MIN_SCAN_MS2PEAKS,
                        help='Minimum fragment peaks an MS/MS scan must contain to be considered')
    parser.add_argument('--top-n-frag-peaks', type=int_greaterthanequalto1_type,
                        default=errorcalc.DEFAULT_TOPN_FRAGPEAKS,
                        help='number of most-intense fragment peaks to consider, per MS/MS spectrum')
    parser.add_argument('--min-common-frag-peaks', type=int_greaterthanequalto1_type,
                        default=errorcalc.DEFAULT_MIN_FRAGPEAKS_INCOMMON,
                        help='number of the most-intense peaks that two spectra must share in order to be \
                             potentially generated by the same peptide')
    parser.add_argument('--pair-top-n-frag-peaks', type=int_greaterthanequalto1_type,
                        default=errorcalc.DEFAULT_TOPN_FRAGPEAKS_FOR_ERROR_EST,
                        help='number of fragment peaks per spectrum pair to be used in fragment error estimation')
    parser.add_argument('--max-scan-separation', type=int_greaterthanequalto1_type,
                        default=errorcalc.DEFAULT_MAX_SCANS_BETWEEN_COMPARESCANS,
                        help='maximum number of scans two spectra can be separated by in order to be considered \
                        potentially generated by the same peptide')
    parser.add_argument('--min-peak-pairs', type=int_greaterthanequalto1_type,
                        default=errorcalc.DEFAULT_MIN_PEAKPAIRS_FOR_DISTRIBUTION_FIT,
                        help='minimum number of peak pairs (for precursor or fragment) that must be successfully \
                        paired in order to attempt to estimate error distribution')

    parser.add_argument('--debug', action="store_true", help='Enable debug logging')
    parser.add_argument('--version', action='version', version='%(prog)s {version}'.format(version=__version__))

    return parser.parse_args()


def int_greaterthanequalto1_type(x):
    """
    parameter type that is an integer >= 1
    :param x:
    :return:
    """
    x = int(x)
    if x < 1:
        raise argparse.ArgumentTypeError("Minimum value is 1")
    return x


def float_greaterthanequalto0_type(x):
    """
    parameter type that is a float >= 0
    :param x:
    :return:
    """
    x = float(x)
    if x < 0:
        raise argparse.ArgumentTypeError("Minimum value is 0.0")
    return x


def main():
    """
    Handle arguments, create an ErrorCalculator, process all the spectra, and infer parameters.
    :return:
    """
    args = declare_gather_args()
    if args.max_precursor_mz <= args.min_precursor_mz:
        quit("max-precursor-mz must be > min-precursor-mz")
    if args.max_frag_mz <= args.min_frag_mz:
        quit("max-frag-mz must be > min-frag-mz")

    # logging
    logging.basicConfig(level=logging.INFO,
                        format="%(asctime)s %(levelname)s: %(message)s")
    if args.debug:
        logger.setLevel(logging.DEBUG)
        # any module-specific debugging goes below
        errorcalc.logger.setLevel(logging.DEBUG)

    script_start_time = datetime.now()
    logger.debug("Start time: %s" % script_start_time)

    error_calculator = errorcalc.ErrorCalculator(min_precursor_mz=args.min_precursor_mz,
                                                 max_precursor_mz=args.max_precursor_mz,
                                                 min_frag_mz=args.min_frag_mz,
                                                 max_frag_mz=args.max_frag_mz,
                                                 charge=args.charge,
                                                 min_scan_frag_peaks=args.min_scan_frag_peaks,
                                                 topn_frag_peaks=args.top_n_frag_peaks,
                                                 min_common_frag_peaks=args.min_common_frag_peaks,
                                                 pair_topn_frag_peaks=args.pair_top_n_frag_peaks,
                                                 max_scan_separation=args.max_scan_separation,
                                                 max_precursor_deltappm=args.max_precursor_delta_ppm,
                                                 min_peakpairs=args.min_peak_pairs)

    for infile in args.infiles:
        logger.debug("Processing input file %s..." % infile.name)
        for spectrum in generate_spectra(infile):
            error_calculator.process_spectrum(spectrum)
        # clear the bins so we don't end up using pairs across files
        error_calculator.clear_all_bins()

    # calculate mass error distributions
    (failed_precursor, precursor_message, failed_fragment, fragment_message, precursor_sigma_ppm, frag_sigma_ppm,
     precursor_prediction_ppm, fragment_prediction_th) = \
        error_calculator.calc_masserror_dist()

    logger.debug("End time: %s. Elapsed time: %s" % (datetime.now(), datetime.now() - script_start_time))
    if not failed_precursor:
        logger.debug('precursor ppm standard deviation: %f' % precursor_sigma_ppm)
    if not failed_fragment:
        logger.debug('fragment standard deviation (ppm): %f' % frag_sigma_ppm)
    logger.debug('')

    most_common_charge = 0
    n_with_mostcommon_charge = 0
    for charge in error_calculator.charge_spectracount_map:
        if error_calculator.charge_spectracount_map[charge] > n_with_mostcommon_charge:
            n_with_mostcommon_charge = error_calculator.charge_spectracount_map[charge]
            most_common_charge = charge
    print('most common MS/MS scan charge: %d' % most_common_charge)
    if failed_precursor:
        print('ERROR: failed to calculate precursor error:')
        print(precursor_message)
    else:
        print('precursor ppm standard deviation: %f' % precursor_sigma_ppm)
        print("Precursor error estimate (ppm): %.2f" % precursor_prediction_ppm)
    print('')
    if failed_fragment:
        print('ERROR: failed to calculate fragment error:')
        print(fragment_message)
    else:
        print('fragment standard deviation (ppm): %f' % frag_sigma_ppm)
        print("Fragment bin size estimate (Th): %.4f" % fragment_prediction_th)


def generate_spectra(spectra_file):
    """
    a generator for spectra from a .ms2 or .mzML file
    :param spectra_file:
    :return:
    """
    handle = spectra_file
    if spectra_file.name.endswith('.gz'):
        handle = gzip.open(spectra_file.name)
    if '.ms2' in spectra_file.name:
        io_module = ms2_io
    elif '.mzML' in spectra_file.name:
        io_module = mzml_io
    else:
        raise ValueError('generate_spectra, can\'t determine file type from name. Name=%s' % spectra_file.name)

    for spectrum in io_module.read_ms2_scans(handle):
        yield spectrum

main()
