#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import print_function

import os

from labellio import Classifier, Config, ImageLoader, Label


def images(image_dir):
    for base, _, files in os.walk(image_dir):
        for f in files:
            yield os.path.join(base, f)


def exec_batch(batch, classifier, label):
    paths, data = zip(*batch)
    for i, output in enumerate(classifier.forward_iter(data)):
        print("{0}\t{1}\t{2}".format(paths[i], label.label(output.best).encode("utf-8"), output.values))


def main(args):
    model_dir = args.model_dir
    image_dir = args.image_dir
    batch_size = args.batch_size

    config = Config(model_dir)
    label = Label(config)
    loader = ImageLoader(config)
    classifier = Classifier(config, **args.mode)

    batch = []
    for image in images(image_dir):
        batch.append((image, loader.load(image)))

        if 0 < batch_size <= len(batch):
            exec_batch(batch, classifier, label)
            batch = []

    if len(batch) > 0:
        exec_batch(batch, classifier, label)


def mode_type(select):
    if select == "auto":
        return dict()
    elif select == "gpu":
        return dict(gpu=True)
    elif select == "cpu":
        return dict(cpu=True)
    else:
        raise RuntimeError("Invalid Argument")


def get_parser():
    import argparse

    parser = argparse.ArgumentParser(description="Labellio Classifier")
    parser.add_argument("model_dir", help="Path to a model directory which is exported from Labellio. (Please extract the archive before you use it.)")
    parser.add_argument("image_dir", help="Path to an image directory. The directory should contain only images.")
    parser.add_argument("--batch-size", type=int, default=0, help="Batch size. (default=0: Retrieve all images at once.)")
    parser.add_argument("--mode", type=mode_type, default=mode_type("auto"),
                        help="Calculation Mode (gpu, cpu, auto, default=auto: gpu auto-detect)")

    return parser


if __name__ == "__main__":
    main(get_parser().parse_args())
