#!python
import argparse

import spatialprofilingtoolbox as spt
from spatialprofilingtoolbox.dataset_designs.multiplexed_imaging.halo_cell_metadata_provider import HALOCellMetadata
from spatialprofilingtoolbox.dataset_designs.multiplexed_imaging.halo_cell_metadata_design import HALOCellMetadataDesign

def do_aggregation():
    parser = argparse.ArgumentParser(
        description = ''.join([
            'This script aggregates all input cell data from an spt-pipeline run',
            'into a single batch, and does subsampling.',
        ])
    )
    parser.add_argument('--max-per-sample',
        dest='max_per_sample',
        type=int,
        required=False,
        default=100,
        help='The maximum number of cells to draw from each sample/image.',
    )
    parser.add_argument('--omit-column',
        dest='omit_column',
        type=str,
        required=False,
        default=None,
        help='A data column to omit ',
    )
    args = parser.parse_args()

    parameters = spt.get_config_parameters_from_file()
    dataset_design = HALOCellMetadataDesign(
        elementary_phenotypes_file=parameters['elementary_phenotypes_file'],
    )
    cell_data = HALOCellMetadata(
        input_files_path = parameters['input_path'],
        dataset_design = dataset_design,
        file_manifest_file = parameters['file_manifest_file'],
    )
    cell_data.initialize()
    cell_data.write_subsampled(
    	max_per_sample = args.max_per_sample,
    	outcomes_file = parameters['outcomes_file'],
        omit_column = args.omit_column,
    )

if __name__=='__main__':
    do_aggregation()
