#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import print_function

import os

from labellio import Classifier, Config, ImageLoader, Label


def images(image_dir):
    for base, _, files in os.walk(image_dir):
        for f in files:
            yield os.path.join(base, f)


def exec_batch(batch, classifier, label):
    paths, data = zip(*batch)
    for i, output in enumerate(classifier.forward_iter(data)):
        print("{0}\t{1}\t{2}".format(paths[i], label.label(output.best).encode("utf-8"), output.values))


def main(args):
    model_dir = args.model_dir
    image_dir = args.image_dir
    batch_size = args.batch_size

    config = Config(model_dir)
    label = Label(config)
    loader = ImageLoader(config)
    classifier = Classifier(config)

    batch = []
    for image in images(image_dir):
        batch.append((image, loader.load(image)))

        if 0 < batch_size <= len(batch):
            exec_batch(batch, classifier, label)

    if len(batch) > 0:
        exec_batch(batch, classifier, label)


def get_parser():
    import argparse

    parser = argparse.ArgumentParser(description="Labellio Classifier")
    parser.add_argument("model_dir", help="Path to a model directory which is exported from Labellio. (Please extract the archive before you use it.)")
    parser.add_argument("image_dir", help="Path to an image directory. The directory should contain only images.")
    parser.add_argument("--batch-size", type=int, default=0, help="Batch size. (default=0: Retrieve all images at once.)")

    return parser


if __name__ == "__main__":
    main(get_parser().parse_args())
