#!/usr/bin/env python
# tool to automate various AWS commands
import datetime as dt
import os
import subprocess
import sys
import time

import pytz

from ncluster import aws_util as u
from ncluster.aws_backend import INSTANCE_INFO

VERBOSE = False


def vprint(*args):
  if VERBOSE:
    print(*args)

def toseconds(dt_):
  """Converts datetime object to seconds."""
  return time.mktime(dt_.utctimetuple())


def ls(fragment=''):
  print(f"https://console.aws.amazon.com/ec2/v2/home?region={u.get_region()}")

  stopped_instances = u.lookup_instances(fragment, valid_states=['stopped'])
  stopped_names = list(u.get_name(i) for i in stopped_instances)
  if stopped_names:
    print("ignored stopped instances: ", ", ".join(stopped_names))

  instances = u.lookup_instances(fragment)
  print('-' * 80)
  print(
    f"{'name':15s} {'hours_live':>10s} {'cost_in_$':>10s} {'instance_type':>15s} {'ip_address':>15s} "
    f"{'key/owner':>15s}")
  print('-' * 80)
  for instance in instances[::-1]:
    # current time in UTC zone (default AWS)
    now_time = dt.datetime.utcnow().replace(tzinfo=pytz.utc)
    launch_time = instance.launch_time
    elapsed_sec = toseconds(now_time) - toseconds(launch_time)
    elapsed_hours = elapsed_sec / 3600
    instance_type = instance.instance_type
    if instance_type in INSTANCE_INFO:
      cost = INSTANCE_INFO[instance_type]['cost'] * elapsed_hours
    else:
      cost = -1
    print(f"{u.get_name(instance):15s} {elapsed_sec / 3600:10.1f} {cost:10.0f} {instance_type[:5]:>15s} "
          f"{instance.public_ip_address:>15s} {instance.key_name[9:]:>15s}")

  # list spot requests, ignore active ones since they show up already
  client = u.get_ec2_client()
  spot_requests = []
  for request in client.describe_spot_instance_requests()['SpotInstanceRequests']:
    state = request['State']
    # TODO(y) also ignore state == 'fulfilled'?
    if state == 'cancelled' or state == 'closed' or state == 'active':
      continue
    
    launch_spec = request['LaunchSpecification']
    spot_requests.append(launch_spec['InstanceType'])
  if spot_requests:
    print(f"Pending spot instances: {','.join(spot_requests)}")
   #   client.cancel_spot_instance_requests(SpotInstanceRequestIds=[request['SpotInstanceRequestId']])


def etchosts():
  instances = u.lookup_instances()
  instance_tuples = [(u.get_name(i), i.public_ip_address) for i in instances]
  print('-' * 80)
  print("paste following into your /etc/hosts")
  print('-' * 80)
  for name, ip in sorted(instance_tuples):
    print(f"{ip} {name}")


def _user_keypair_check(instance):
  launching_user = instance.key_name[len(u.get_prefix())+1:]
  current_user = os.environ['USER']
  assert launching_user == current_user, f"Set USER={launching_user} to connect to this machine"


def ssh(fragment=''):
  instances = u.lookup_instances(fragment)
  instance = instances[0]
  print(f"Found {len(instances)} instances matching {fragment}, connecting to most recent  {u.get_name(instance)} "
        f"launched by {instance.key_name}")

  _user_keypair_check(instance)
  cmd = f"ssh -i {u.get_keypair_fn()} -o StrictHostKeyChecking=no ubuntu@{instance.public_ip_address}"
  print(cmd)
  os.system(cmd)


def kill(fragment='', stop_instead_of_kill=False):
  instances = u.lookup_instances(fragment)
  for i in instances:
    state = i.state['Name']
    print(u.get_name(i), i.instance_type, i.key_name,
          state if state == 'stopped' else '')

  action = ('stopping' if stop_instead_of_kill else 'terminating')
  if not instances:
    print("no instances found, quitting")
    return

  answer = input(f"{len(instances)} instances found, {action} in {u.get_region()}? (y/N) ")

  ec2_client = u.get_ec2_client()
  if answer.lower() == "y":
    instance_ids = [i.id for i in instances]
    if stop_instead_of_kill:
      response = ec2_client.stop_instances(InstanceIds=instance_ids)
    else:
      response = ec2_client.terminate_instances(InstanceIds=instance_ids)

    assert u.is_good_response(response), response
    print(f"{action}: success")
  else:
    print("Didn't get y, doing nothing")


def start(fragment=''):
  instances = u.lookup_instances(fragment, valid_states=['stopped'])
  for i in instances:
    print(u.get_name(i), i.instance_type, i.key_name)

  if not instances:
    print("no stopped instances found, quitting")
    return

  answer = input(f"{len(instances)} instances found, start in {u.get_region()}? (y/N) ")

  if answer.lower() == "y":
    for i in instances:
      print(f"starting {u.get_name(i)}")
      i.start()
  else:
    print("Didn't get y, doing nothing")
    return

  print("Warning, need to manually mount efs on instance: ")
  print_efs_mount_command()


def mosh(fragment=''):
  instances = u.lookup_instances(fragment)
  instance = instances[0]
  _user_keypair_check(instance)
  print(f"Found {len(instances)} instances matching {fragment}, connecting to most recent  {u.get_name(instance)}")

  cmd = f"mosh --ssh='ssh -i {u.get_keypair_fn()} -o StrictHostKeyChecking=no' ubuntu@{instance.public_ip_address}"
  print(cmd)
  os.system(cmd)


def print_efs_mount_command():
  region = u.get_region()
  efs_id = u.get_efs_dict()[u.get_prefix()]
  dns = f"{efs_id}.efs.{region}.amazonaws.com"
  print('sudo mkdir -p /ncluster')
  print(f"sudo mount -t nfs -o nfsvers=4.1,rsize=1048576,wsize=1048576,hard,timeo=600,retrans=2 {dns}:/ /ncluster")


def efs():
  print_efs_mount_command()
  print()
  print()

  efs_client = u.get_efs_client()
  response = efs_client.describe_file_systems()
  assert u.is_good_response(response), response

  for efs_response in response['FileSystems']:
    #  {'CreationTime': datetime.datetime(2017, 12, 19, 10, 3, 44, tzinfo=tzlocal()),
    # 'CreationToken': '1513706624330134',
    # 'Encrypted': False,
    # 'FileSystemId': 'fs-0f95ab46',
    # 'LifeCycleState': 'available',
    # 'Name': 'nexus01',
    # 'NumberOfMountTargets': 0,
    # 'OwnerId': '316880547378',
    # 'PerformanceMode': 'generalPurpose',
    # 'SizeInBytes': {'Value': 6144}},
    efs_id = efs_response['FileSystemId']
    tags_response = efs_client.describe_tags(FileSystemId=efs_id)
    assert u.is_good_response(tags_response)
    key = u.get_name(tags_response.get('Tags', ''))
    print("%-16s %-16s" % (efs_id, key))
    print('-' * 40)

    # list mount points
    response = efs_client.describe_mount_targets(FileSystemId=efs_id)
    ec2 = u.get_ec2_resource()
    if not response['MountTargets']:
      print("<no mount targets>")
    else:
      for mount_response in response['MountTargets']:
        subnet = ec2.Subnet(mount_response['SubnetId'])
        zone = subnet.availability_zone
        state = mount_response['LifeCycleState']
        id_ = mount_response['MountTargetId']
        ip = mount_response['IpAddress']
        print('%-16s %-16s %-16s %-16s' % (zone, ip, id_, state,))


def terminate_tmux():
  """
  Script to clean-up tmux sessions
  """

  def run_shell(cmd):
    """Runs shell command, returns list of outputted lines
    with newlines stripped"""
    #  print(cmd)
    p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
    (stdout, stderr) = p.communicate()
    stdout = stdout.decode('ascii')
    lines = stdout.split('\n')
    stripped_lines = []
    for l in lines:
      stripped_line = l.strip()
      if l:
        stripped_lines.append(stripped_line)
    return stripped_lines

  for line in run_shell('tmux ls'):
    session_name = line.split(':', 1)[0]
    
    if session_name == 'tensorboard' or session_name == 'jupyter' or session_name == 'dropbox':
      print("Skipping "+session_name)
      continue
    print("Killing "+session_name)
    run_shell('tmux kill-session -t ' + session_name)


def main():
  print(f"Region ({u.get_region()}) $USER ({u.get_username()}) account ({u.get_account_number()})")
  if len(sys.argv) < 2:
    mode = 'ls'
  else:
    mode = sys.argv[1]

  if len(sys.argv) < 3:
    fragment = ''
  else:
    fragment = sys.argv[2]

  if mode == 'ls':
    ls(fragment)
  elif mode == 'ssh':
    ssh(fragment)
  elif mode == 'mosh':
    mosh(fragment)
  elif mode == 'kill':
    kill(fragment)
  elif mode == 'stop':
    kill(fragment, stop_instead_of_kill=True)
  elif mode == 'start':
    start(fragment)
  elif mode == 'efs':
    efs()
  elif mode == 'reboot':
    #    instance.reboot()
    pass

  elif mode == '/etc/hosts':
    etchosts()
  elif mode == 'terminate_tmux':
    terminate_tmux()
    
  else:
    assert False, "Unknown mode " + mode


if __name__ == '__main__':
  main()
