#!python

import os, sys, re, argparse, traceback, warnings, textwrap, ConfigParser
import pandas
from collections import OrderedDict
from framed import FBA, load_cbmodel
from framed.community.smetana import *
from framed.community.model import Community
from framed.experimental.medium import minimal_medium
from framed.model.environment import Environment
from shutil import copyfile

status_codes = {0: "Unknown", 1: "Optimal", -1: "Suboptimal", -4: "Infeasible_or_Unbounded", -3: "Infeasible", -2: "Unbounded"}

def row_string(row):
    return "\t".join(("" if x is None else str(x)) for x in row.itervalues()) + "\n"


class Logger(object):
    def __init__(self):
        self._log = ""

    def log(self, *args, **kwargs):
        if 'newline' in kwargs:
            newline = kwargs['newline']
            del kwargs['newline']
        else:
            newline = True

        txt = (args[0].format(*args[1:]) if len(args) > 1 else args[0])
        self._log += txt
        if newline: self._log += "\n"

    def flush(self, error=False):
        self._log += "================================================================="
        stream = sys.stderr if error else sys.stdout
        stream.write(self._log)
        stream.flush()
        self._log = "\n"

def main():
    parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
                                     description="Run SMATANA on list of organisms communities",
                                     epilog=textwrap.dedent('''\
        configuration file (* = optional):
    
        [communities]
        communities             Path to tab separated file with community description. Two columns are required, 
                                community ID (column:id) and list of community species (column:species, separator 
                                <communities_separator>)
        communities_separator   Separator character used to separate community members (default: " ")                                
        exchange_metabolites    Path to tab separated file with list of all exchange metabolites. Two columns are required,
                                metabolite ID (column:id) and metabolite name (column:name)
        models_dir*             Path to directory where SBML models are located. If not provided current folder is searched.
        models_locations*       Path to tab separated file mapping species ID (column:id) in <communities> to 
                                SBML file (column:path) relative <models_dir> path. If not provided model ID is assumed to be 
                                SBML file name
        inorganic_media*        Path to tab separated file with list of metabolites always present in the environment. 
                                Metabolite ID is the only required column (column:id). If not the always-present environment 
                                is assumed to be empty
        minimize_media          After finding minimal media (inorganic always present) perform another optimization to find
                                true media
                                
        [framed]
        flavor                  SBML flavor to use when parsing the models. Allowed values are 'cobra', 'cobra:other', 
                                'seed', 'bigg' and 'fbc2'
        biomass_regex           Regular expression to match biomass reaction (default: "iomass")
        extracellular_comp_id   Extracellular compartment id (default: Extracellular)
        smetana_n_solutions     How many unique solutions to calculate for Metabolite Uptake and Species Coupling scores  (default: 50)
        max_uptake              How much of uptake/exchange should be allowed for active exchange reactions (default: 1000)
                                
    '''))
    parser.add_argument('config', help='Path to configuration file (see epilog)')
    parser.add_argument('output', nargs="?",
                        help='Path to output file (have to include {i} and {total} placeholders. example: "results/output_{i}_{total}.tsv")')
    parser.add_argument('--part', dest='part', type=int, help='Which part is calculated (start from 1)', default=1)
    parser.add_argument('--parts-total', dest='parts_total', type=int, help='Total number of parts', default=1)
    parser.add_argument('--no-warnings', dest='no_warnings', action="store_true", help="Suppress warnings")
    args = parser.parse_args()

    if os.path.dirname(args.output) and not os.path.exists(os.path.dirname(args.output)):
        os.makedirs(os.path.dirname(args.output))

    if args.no_warnings:
        warnings.simplefilter("ignore", UserWarning)

    #
    # Load configuration
    #
    logger = Logger()
    config = ConfigParser.ConfigParser()
    config.read(args.config)
    communities_sep = config.get("communities", "communities_separator").strip('"') if config.has_option("communities", "communities_separator") else ","
    flavor = config.get("framed", "flavor") if config.has_option("framed", "flavor") else "fbc2"
    n_solutions = int(config.get("framed", "smetana_n_solutions")) if config.has_option("framed", "smetana_n_solutions") else 50
    minimize_media = int(config.get("communities", "minimize_media")) if config.has_option("communities", "minimize_media") else 0
    min_mass_weight = int(config.get("framed", "min_mass_weight")) if config.has_option("framed", "min_mass_weight") else 0
    max_uptake = float(config.get("framed", "max_uptake")) if config.has_option("framed", "max_uptake") else 1000.0
    re_biomass = re.compile(config.get("framed", "biomass_regex")) if config.has_option("framed", "biomass_regex") else re.compile("iomass")
    extracellular_comp_id = config.get("framed", "extracellular_comp_id") if config.has_option("framed", "extracellular_comp_id") else "Extracellular"
    models_dir = config.get("communities", "models_dir")
    if config.has_option("communities", "models_locations"):
        models_locations = {r['id']: r['path'] for _, r in
                            pandas.read_table(config.get("communities", "models_locations")).iterrows()}
    else:
        models_locations = {}
    exch_metabolites = {r['id']: r['name'] for _, r in
                        pandas.read_table(config.get("communities", "exchange_metabolites")).iterrows()}
    exch_metabolites_list = sorted(exch_metabolites)
    rxn_inorganic = {r['id'].strip() for _, r in pandas.read_table(config.get("communities", "inorganic_media"), comment="#").iterrows()}
    env_inorganic = Environment.from_reactions(rxn_inorganic, max_uptake=max_uptake)

    #
    # Load communities descriptions
    #
    communities = []
    for _, r in pandas.read_table(config.get("communities", "communities"), comment="#").iterrows():
        community_models = [m_id for m_id in re.split(communities_sep, r['species'].strip())]
        community_model_paths = []
        missing_models = []
        for m_id in community_models:
            if models_locations:
                m_file = models_locations.get(m_id, m_id)
                m_path = os.path.join(models_dir, m_file)
                if m_id not in models_locations or not os.path.exists(m_path):
                    missing_models.append(m_id)

                community_model_paths.append(m_path)

        communities.append({'id': r['id'], 'species': community_models, 'paths': community_model_paths, 'missing_models': missing_models})
    max_community_size = max(len(com["species"]) for com in communities)

    #
    # Prepare template for results table and header
    #
    row_template = OrderedDict(
        [("community_id", ""), ("size", ""), ("mip", ""), ("mro", ""), ("fba_community", ""), ("mmedia_status", ""),
         ("fba_status", "")])
    for i in range(1, max_community_size + 1): row_template["org{}".format(i)] = ""
    for i in range(1, max_community_size + 1): row_template["fba.org{}".format(i)] = ""
    for m_id in exch_metabolites_list: row_template["mmedia.{}".format(m_id)] = 0
    for i in range(1, max_community_size + 1):
        for m_id in exch_metabolites_list: row_template["muscore.org{}.{}".format(i, m_id)] = 0
    for i in range(1, max_community_size + 1):
        for m_id in exch_metabolites_list: row_template["mpscore.org{}.{}".format(i, m_id)] = 0
    for i in range(1, max_community_size + 1):
        for j in range(1, max_community_size + 1): row_template["scscore.org{}.org{}".format(i, j)] = 0

    communities_sample = [communities[i::args.parts_total] for i in xrange(args.parts_total)][args.part - 1]

    #
    # Main loop
    #
    with open(args.output.format(i=args.part, total=args.parts_total), "w") as file:
        # Write header
        missing_communities = [m['id'] for m in communities if len(m['missing_models'])]
        missing_communities_sample = [m['id'] for m in communities_sample if len(m['missing_models'])]
        file.write("\t".join(row_template) + "\n")
        logger.log("Total number of communities: {} (sample: {})", len(communities), len(communities_sample))
        logger.log("Missing models: {} (sample: {}): {}", len(missing_communities), len(missing_communities_sample), ", ".join(missing_communities_sample))
        logger.flush()

        for community_i, community_data in enumerate(communities_sample, start=1):
            try:
                row = row_template.copy()
                community_all_i = next(c_i for c_i, c in enumerate(communities, start=1) if c['id'] == community_data['id'])

                logger.log("{}/{:<5} (all: {}/{:<5}) [{}]: {}", community_i, len(communities_sample), community_all_i,
                           len(communities), community_data['id'], ", ".join(community_data["species"]))

                if len(community_data['missing_models']):
                    logger.log("Skiping because missing models: {}".format(", ".join(community_data['missing_models'])), newline=True)
                    logger.flush(error=True)
                    continue
                #
                # Read SBML files representing community organisms
                #
                models = []
                models_i = {}
                biomass_valid = True
                for i, path in enumerate(community_data['paths'], start=1):
                    logger.log("Reading {}:'{}' SBML file...".format(i, path), newline=True)
                    model = load_cbmodel(path, flavor=flavor)
                    model.id = re.sub("[^A-Za-z0-9]", "_", community_data['species'][i-1])
                    biomass_candidates = [r for r in model.reactions if re_biomass.match(r)]
                    if len(biomass_candidates) != 1:
                        biomass_valid = False
                        break

                    model.biomass_reaction = biomass_candidates[0]
                    models.append(model)
                    models_i[model.id] = i

                if not biomass_valid:
                    if len(biomass_candidates) == 0:
                        logger.log("Biomass not found!")
                    elif len(biomass_candidates) > 1:
                        logger.log("Multiple biomass candidates found: {}".format(", ".join(biomass_candidates)))

                    logger.flush(error=True)
                    continue

                #
                # Print community description
                #
                community = Community(community_data['id'], models, extracellular_compartment_id=extracellular_comp_id,
                                      create_biomass=True, interacting=True)
                com_rxn_inorganic = set(community.merged.get_exchange_reactions()) & rxn_inorganic
                logger.log("Inorganic media ({}):", len(com_rxn_inorganic), newline=False)
                logger.log(", ".join(r for r in com_rxn_inorganic))
                for model_i, model_id in enumerate(community_data['species'], start=1):
                    row_template["org{}".format(model_i)] = model_id
                row['community_id'] = community_data['id']
                row['size'] = len(models)

                #
                # Calculate minimal media for community. Predefined inorganic compounds are always present !
                #
                env_inorganic.apply(community.merged, inplace=True)

                candidates_rxn_minimal = set(community.merged.get_exchange_reactions()) - com_rxn_inorganic
                com_rxn_defined, sol = minimal_medium(community.merged, exchange_reactions=candidates_rxn_minimal,
                                                      validate=True, min_mass_weight=min_mass_weight)
                row["mmedia_status"] = status_codes[sol.status]
                if sol.status != Status.OPTIMAL:
                    logger.log("No defined media available")
                    logger.flush(error=True)
                    file.write(row_string(row))
                    file.flush()
                    continue

                com_rxn_defined = com_rxn_defined | com_rxn_inorganic
                if minimize_media:
                    com_rxn_minimal, sol = minimal_medium(community.merged, exchange_reactions=com_rxn_defined,
                                                      validate=True, min_mass_weight=min_mass_weight)

                    row["mmedia_status"] = status_codes[sol.status]
                    if sol.status != Status.OPTIMAL:
                        logger.log("No minimal media available")
                        logger.flush(error=True)
                        file.write(row_string(row))
                        file.flush()
                        continue
                else:
                    com_rxn_minimal = com_rxn_defined


                env_minimal = Environment.from_reactions(com_rxn_minimal, max_uptake=max_uptake)
                logger.log("SMETANA media ({}'inorganic + {}'minimal ({}'defined)):",
                           len(com_rxn_minimal & com_rxn_inorganic),
                           len(com_rxn_minimal - com_rxn_inorganic),
                           len(com_rxn_defined - com_rxn_inorganic),
                           newline=False)

                for r_i, r in enumerate(com_rxn_minimal, start=1):
                    m = community.merged.get_exchange_reactions()[r][0]
                    orig_m = list({org_m.original_metabolite
                          for org_exch in community.organisms_exchange_reactions.itervalues()
                          for org_m in org_exch.itervalues()
                          if org_m.extracellular_metabolite == m})
                    if len(orig_m):
                        orig_m = orig_m[0]
                    else:
                        pass

                    row_id = "mmedia.{}".format(orig_m)
                    if row_id not in row:
                        raise KeyError("Column '{}' was not found in row".format(row_id))
                    row[row_id] = 1

                    logger.log("{}, ", r, newline=(r_i == len(com_rxn_minimal)))

                #
                # Calculate MIP and MRO
                #
                mip, mip_extras = mip_score(community, env_inorganic, validate=True)
                row["mip"] = mip
                logger.log("MIP: {}", mip)
                mro, mro_extras = mro_score(community, env_inorganic, validate=True)
                row["mro"] = mro
                logger.log("MRO: {}", mro)

                #
                # Apply minimal+inorganic media
                #
                env_minimal.apply(community.merged, inplace=True)
                community_fba = community.copy(create_biomass=False, interacting=False)
                for r_id in com_rxn_minimal: community_fba.merged.reactions[r_id].lb = -1000
                for r_exchanges in community_fba.organisms_exchange_reactions.itervalues():
                    for r_id in r_exchanges: community_fba.merged.reactions[r_id].lb = -1
                sol = FBA(community_fba.merged, get_values=True)
                row["fba_status"] = status_codes[sol.status]
                logger.log("Growth on minimal media (uptakes 1.0/org): {:.1f} ({}) = ",
                           sol.fobj, status_codes[sol.status], newline=False)
                if sol.status == Status.OPTIMAL:
                    row["fba_community"] = sol.fobj
                    for m_i, model_id in enumerate(community.organisms, start=1):
                        b = community_fba.organisms_biomass_reactions[model_id]
                        logger.log("{:.1f}'{} + ", sol.values[b], community_data['species'][i-1],
                                   newline=(len(community.organisms) == m_i))
                        row_id = "fba.org{}".format(models_i[model_id])
                        if row_id not in row:
                            raise KeyError("Column '{}' was not found in row".format(row_id))
                        row[row_id] = sol.values[b]

                #
                # Calculate SMETANA
                #
                smetana, smetana_extras = smetana_score(community, env_minimal, n_solutions=n_solutions, min_mass_weight=min_mass_weight)
                logger.log("Smetana (sum={}): ", sum(s.score for s in smetana))
                if smetana:
                    logger.log(", ".join(str(s) for s in smetana))
                    for m_receiver, dependants in smetana_extras["species_coupling"]["scores"].iteritems():
                        for m_donor, scscore in dependants.iteritems():
                            row_id = "scscore.org{}.org{}".format(models_i[m_receiver], models_i[m_donor])
                            if row_id not in row:
                                raise KeyError("Column '{}' was not found in row".format(row_id))
                            row[row_id] = scscore

                    for model_id, metabolites in smetana_extras["metabolite_uptake"]["scores"].iteritems():
                        for m_id, muscore in metabolites.iteritems():
                            row_id = "muscore.org{}.{}".format(models_i[model_id], m_id)
                            if row_id not in row:
                                raise KeyError("Column '{}' was not found in row".format(row_id))
                            row[row_id] = muscore

                    for model_id, metabolites in smetana_extras["metabolite_production"]["scores"].iteritems():
                        for m_id in metabolites:
                            row_id = "mpscore.org{}.{}".format(models_i[model_id], m_id)
                            if row_id not in row:
                                raise KeyError("Column '{}' was not found in row".format(row_id))
                            row[row_id] = 1

                file.write(row_string(row))
                file.flush()
                logger.flush(error=False)
            except:
                logger.log("-------------------- START_ERROR --------------------")
                logger.log(traceback.format_exc())
                logger.log("-------------------- END_ERROR --------------------")
                logger.flush(error=True)

    out_config = os.path.join(os.path.dirname(args.output), "config.cfg")
    if not os.path.exists(out_config):
        copyfile(args.config, out_config)


if __name__ == "__main__":
    main()