#!python
# -*- coding: utf-8 -*-
#
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tf_slim as slim
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.tools import optimize_for_inference_lib

import argparse
import json
import nazoru.core as nazoru
import numpy as np
import os
import sys
import tensorflow.compat.v1 as tf
import zipfile

FLAGS = None

def load_data(kanas):

  def get_ndjson_path(zip_path):
    dir_path, filename = os.path.split(zip_path)
    body, ext = os.path.splitext(filename)
    return dir_path, os.path.join(dir_path, '%s.ndjson' % (body))

  data = []
  labels = []
  dir_path, ndjson_path = get_ndjson_path(FLAGS.stroke_data)
  if not os.path.exists(ndjson_path):
    with zipfile.ZipFile(FLAGS.stroke_data, 'r') as f:
      f.extractall(dir_path)
  with open(ndjson_path) as f:
    for line in f.readlines():
      line = json.loads(line)
      if not line['kana'].lower() in kanas: continue
      keydowns = [(e[0], e[2]) for e in line['events'] if e[1] == 'down']
      # TODO (tushuhei) Relax this condition to accept shorter strokes.
      if len(keydowns) < 5: continue  # Ignore if too short.
      data.append(keydowns)
      labels.append(kanas.index(line['kana'].lower()))
  labels = np.eye(len(kanas))[labels]  # Idiom to one-hot encode.
  return data, labels


def convert_data_to_tensors(x, y):
  inputs = tf.constant(x, dtype=tf.float32)
  inputs.set_shape([None] + list(x[0].shape))
  outputs = tf.constant(y)
  outputs.set_shape([None] + list(y[0].shape))
  return inputs, outputs


def save_inference_graph(checkpoint_dir, input_shape, num_classes, conv_defs,
    output_graph, optimized_output_graph,
    input_node_name=nazoru.lib.INPUT_NODE_NAME,
    output_node_name=nazoru.lib.OUTPUT_NODE_NAME):
  with tf.Graph().as_default():
    inputs = tf.placeholder(
        tf.float32, shape=[1] + list(input_shape), name=input_node_name)
    outputs, _ = nazoru.nazorunet(
        inputs, num_classes=num_classes, conv_defs=conv_defs, is_training=False,
        dropout_keep_prob=1, scope=nazoru.lib.SCOPE, reuse=tf.AUTO_REUSE)
    sv = tf.train.Supervisor(logdir=checkpoint_dir)
    with sv.managed_session() as sess:
      output_graph_def = graph_util.convert_variables_to_constants(
          sess, sess.graph.as_graph_def(), [output_node_name])
      # TODO (tushuhei) Maybe we don't need to export unoptimized graph.
      with gfile.FastGFile(output_graph, 'wb') as f:
        f.write(output_graph_def.SerializeToString())

      output_graph_def = optimize_for_inference_lib.optimize_for_inference(
          output_graph_def, [input_node_name], [output_node_name],
          dtypes.float32.as_datatype_enum)

      with gfile.FastGFile(optimized_output_graph, 'wb') as f:
        f.write(output_graph_def.SerializeToString())

def export_saved_model_from_pb(saved_model_dir, graph_name,
    input_node_name=nazoru.lib.INPUT_NODE_NAME,
    output_node_name=nazoru.lib.OUTPUT_NODE_NAME):
  builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
  with tf.gfile.GFile(graph_name, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
  sigs = {}

  with tf.Session(graph=tf.Graph()) as sess:
    tf.import_graph_def(graph_def, name='')
    g = tf.get_default_graph()

    sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
        tf.saved_model.signature_def_utils.predict_signature_def(
            {input_node_name: g.get_tensor_by_name('%s:0' % (input_node_name))},
            {output_node_name: g.get_tensor_by_name('%s:0' % (output_node_name))
              })
    builder.add_meta_graph_and_variables(
        sess, [tag_constants.SERVING], signature_def_map=sigs)

  builder.save()


def main(_):
  kanas = nazoru.lib.KANAS
  keydown_strokes, labels = load_data(kanas)
  x = np.array([
    nazoru.lib.keydowns2image(
      keydown_stroke,
      directional_feature=not FLAGS.no_directional_feature,
      temporal_feature=not FLAGS.no_temporal_feature,
      scale=FLAGS.image_size,
      stroke_width=FLAGS.stroke_width)
    for keydown_stroke in keydown_strokes])
  train_x, train_t, val_x, val_t, test_x, test_t = nazoru.lib.split_data(
      x, labels, 0.2, 0.2)
  conv_defs = [
    nazoru.Conv(kernel=[3, 3], stride=2, depth=32),
    nazoru.DepthSepConv(kernel=[3, 3], stride=1, depth=64),
    nazoru.DepthSepConv(kernel=[3, 3], stride=2, depth=128),
    nazoru.DepthSepConv(kernel=[3, 3], stride=1, depth=128),
  ]

  tf.disable_eager_execution()
  with tf.Graph().as_default():
    tf.logging.set_verbosity(tf.logging.INFO)

    train_x, train_t = convert_data_to_tensors(train_x, train_t)
    val_x, val_t = convert_data_to_tensors(val_x, val_t)

    # Make the model.
    train_logits, train_endpoints = nazoru.nazorunet(
        train_x, num_classes=len(kanas), conv_defs=conv_defs,
        dropout_keep_prob=FLAGS.dropout_keep_prob, scope=nazoru.lib.SCOPE,
        reuse=tf.AUTO_REUSE, is_training=True)
    val_logits, val_endpoints = nazoru.nazorunet(
        val_x, num_classes=len(kanas), conv_defs=conv_defs,
        dropout_keep_prob=1, scope=nazoru.lib.SCOPE, reuse=tf.AUTO_REUSE,
        is_training=True)

    # Add the loss function to the graph.
    tf.losses.softmax_cross_entropy(train_t, train_logits)
    train_total_loss = tf.losses.get_total_loss()
    tf.summary.scalar('train_total_loss', train_total_loss)

    val_total_loss = tf.losses.get_total_loss()
    tf.summary.scalar('val_total_loss', val_total_loss)

    train_accuracy = tf.reduce_mean(tf.cast(tf.equal(
      tf.argmax(train_logits, axis=1), tf.argmax(train_t, axis=1)), tf.float32))
    tf.summary.scalar('train_accuracy', train_accuracy)

    val_accuracy = tf.reduce_mean(tf.cast(tf.equal(
      tf.argmax(val_logits, axis=1), tf.argmax(val_t, axis=1)), tf.float32))
    tf.summary.scalar('val_accuracy', val_accuracy)

    # Specify the optimizer and create the train op:
    optimizer = tf.train.AdamOptimizer(learning_rate=0.005)
    train_op = slim.learning.create_train_op(train_total_loss, optimizer)

    def train_step_fn(sess, *args, **kwargs):
      total_loss, should_stop = slim.learning.train_step(sess, *args, **kwargs)
      if train_step_fn.step % FLAGS.n_steps_to_log == 0:
        val_acc = sess.run(val_accuracy)
        tf_logging.info('step: %d, validation accuracy: %.3f' % (
          train_step_fn.step, val_acc))
      train_step_fn.step += 1
      return [total_loss, should_stop]
    train_step_fn.step = 0

    # Run the training inside a session.
    final_loss = slim.learning.train(
        train_op,
        logdir=FLAGS.checkpoint_dir,
        number_of_steps=FLAGS.n_train_steps,
        train_step_fn=train_step_fn,
        save_summaries_secs=5,
        log_every_n_steps=FLAGS.n_steps_to_log)

    save_inference_graph(checkpoint_dir=FLAGS.checkpoint_dir,
        input_shape=train_x.shape[1:], num_classes=len(kanas),
        conv_defs=conv_defs, output_graph=FLAGS.output_graph,
        optimized_output_graph=FLAGS.optimized_output_graph)
    export_saved_model_from_pb(saved_model_dir=FLAGS.saved_model_dir,
        graph_name=FLAGS.optimized_output_graph)

if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument(
      '--image_size',
      type=int,
      default=16,
      help='Size of the square canvas to render input strokes.'
  )
  parser.add_argument(
      '--stroke_width',
      type=int,
      default=2,
      help='Stroke width of rendered strokes.'
  )
  parser.add_argument(
      '--n_train_steps',
      type=int,
      default=500,
      help='Number of training steps to run.'
  )
  parser.add_argument(
      '--n_steps_to_log',
      type=int,
      default=10,
      help='How often to print log during traing.'
  )
  parser.add_argument(
      '--checkpoint_dir',
      type=str,
      default=os.path.join(os.sep, 'tmp', 'nazorunet'),
      help='Where to save checkpoint files.'
  )
  parser.add_argument(
      '--output_graph',
      type=str,
      default='nazoru_custom.pb',
      help='Where to save the trained graph.'
  )
  parser.add_argument(
      '--optimized_output_graph',
      type=str,
      default='optimized_nazoru_custom.pb',
      help='Where to save the trained graph optimized for inference.'
  )
  parser.add_argument(
      '--saved_model_dir',
      type=str,
      default='nazoru_saved_model_custom',
      help='Where to save the exported graph.'
  )
  parser.add_argument(
      '--dropout_keep_prob',
      type=float,
      default=0.8,
      help='The percentage of activation values that are retained in dropout.'
  )
  parser.add_argument(
      'stroke_data',
      type=str,
      help='Path to zipped stroke data to input. You can download the ' +
        'default stroke data at ' +
        'https://github.com/google/mozc-devices/mozc-nazoru/data/strokes.zip.'
  )
  parser.add_argument(
      '--no_directional_feature',
      action='store_true',
      help='Not to use directional feature.'
  )
  parser.add_argument(
      '--no_temporal_feature',
      action='store_true',
      help='Not to use temporal feature.'
  )

  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
