#!/usr/bin/env python3

import argparse
import sys, os

import hyperchamber as hc
import hypergan as hg
import hypergan.cli as cli
from hypergan.gans.aligned_gan import AlignedGAN
from hypergan.gans.autoencoder_gan import AutoencoderGAN
from hypergan.gans.alpha_gan import AlphaGAN

class CommandParser:
    def common(self, parser):
        parser.add_argument('directory', action='store', type=str, help='The location of your data.  Subdirectories are treated as different classes.  You must have at least 1 subdirectory.')
        self.common_flags(parser)

    def common_flags(self, parser):
        parser.add_argument('--size', '-s', type=str, default='64x64x3', help='Size of your data.  For images it is widthxheightxchannels.')
        parser.add_argument('--batch_size', '-b', type=int, default=32, help='Number of samples to include in each batch.  If using batch norm, this needs to be preserved when in server mode')
        parser.add_argument('--config', '-c', action='store', default=None, type=str, help='The configuration file to load.')
        parser.add_argument('--device', '-d', type=str, default='/gpu:0', help='In the form "/gpu:0", "/cpu:0", etc.  Always use a GPU (or TPU) to train')
        parser.add_argument('--format', '-f', type=str, default='png', help='jpg or png')
        parser.add_argument('--crop', dest='crop', action='store_true', help='If your images are perfectly sized you can skip cropping.')
        parser.add_argument('--resize', dest='resize', action='store_true', help='If your images are perfectly sized you can skip resize.')
        parser.add_argument('--align', dest='align', help='Align classes.  Takes a directory of data to align with.')
        parser.add_argument('--use_hc_io', type=bool, default=False, help='Set this to no unless you are feeling experimental.')
        parser.add_argument('--save_every', type=int, default=10000, help='Saves the model every n steps.')
        parser.add_argument('--sample_every', type=int, default=5, help='Saves a sample ever X steps.')
        parser.add_argument('--save_samples', action='store_true', help='Saves samples to the local `samples` directory.')
        parser.add_argument('--sampler', type=str, default='static_batch', help='Select a sampler.  Some choices: static_batch, batch, grid, progressive')
        parser.add_argument('--ipython', type=bool, default=False, help='Enables iPython embedded mode.')
        parser.add_argument('--steps', type=int, default=-1, help='Number of steps to train for.  -1 is unlimited (default)')
        parser.add_argument('--noviewer', dest='viewer', action='store_false', help='Disables the display of samples in a window.')
        parser.add_argument('--autoencode', dest='autoencode', action='store', help='Disables the display of samples in a window.')
        parser.add_argument('--classloss', dest='classloss', action='store_true', help='Enable class loss.  You must have multiple subfolders, one for each class')
        parser.add_argument('--list-templates', '-l', dest='list_templates', action='store_true', help='List available templates.')

    def get_parser(self):
        parser = argparse.ArgumentParser(description='Train, run, and deploy your GANs.', add_help=True)
        subparsers = parser.add_subparsers(dest='method')
        train_parser = subparsers.add_parser('train')
        sample_parser = subparsers.add_parser('sample')
        build_parser = subparsers.add_parser('build')
        new_parser = subparsers.add_parser('new')
        subparsers.required = True
        self.common_flags(parser)
        self.common(sample_parser)
        self.common(train_parser)
        self.common(build_parser)
        self.common(new_parser)

        return parser

args = CommandParser().get_parser().parse_args()

if args.size:
    size = [int(x) for x in args.size.split("x")] + [None, None, None]
else:
    size = [None, None, None]
width = size[0] or 64
height = size[1] or 64
channels = size[2] or 3

config = args.config or 'default'
config_filename = hg.Configuration.find(config+'.json')
config = hc.Selector().load(config_filename)


if args.list_templates:
    configs = hg.Configuration.list()
    for config in configs:
        template = hg.Configuration.load(config+'.json', verbose=False)
        print("%-30s - %-60s" %  (config, str(template.description)+"("+str(template.publication)+")"))
    exit(0)
if not args.align:
    if args.method == 'new':
        gan = hg.GAN(config=config, inputs=None)

    else:
        inputs = hg.inputs.image_loader.ImageLoader(args.batch_size)
        inputs.create(args.directory,
              channels=channels, 
              format=args.format,
              crop=args.crop,
              width=width,
              height=height,
              resize=args.resize)

        gan = hg.GAN(config=config, inputs=inputs)
else:
    inputs = hg.inputs.image_loader.ImageLoader(args.batch_size)
    inputs.create(args.directory,
          channels=channels, 
          format=args.format,
          crop=args.crop,
          width=width,
          height=height,
          resize=args.resize)

    inputs2 = hg.inputs.image_loader.ImageLoader(args.batch_size)
    inputs2.create(args.align,
          channels=channels, 
          format=args.format,
          crop=args.crop,
          width=width,
          height=height,
          resize=args.resize)

    class AlignedInput():
        def __init__(self, input1, input2):
            self.x = input1.x
            self.xa = input1.x
            self.xb = input2.x
            self.y = input1.y

    aligned_input = AlignedInput(inputs, inputs2)
    gan = AlignedGAN(config=config, inputs=aligned_input)


gancli = cli.CLI(gan, args=vars(args))
gancli.run()
