#!/usr/bin/env python3
from __future__ import print_function
import subprocess
import argparse
import botocore
import getpass
import shutil
import random
import socket
import base64
import errno
import boto3
import time
import json
import sys
import os
import re

raw_input = input
supports_color = hasattr(sys.stdout, 'isatty') and sys.stdout.isatty()
STYLES = {
    'error': '\033[1;31m',
    'success': '\033[1;37m',
    'dbname': '\033[1;37m',
    'blue': '\033[1;34m',
    'green': '\033[1;32m',
    'red': '\033[1;31m',
    'normal': '\033[0m',
    'bold': '\033[1m'
}


def print_(msg, file=sys.stderr, style=None, end="\n"):
    if supports_color and style in STYLES:
        print(STYLES[style] + msg, '\033[0m', file=file, end=end)
    else:
        print(msg, file=file, end=end)


_config = None
_config_updated = False
_args_config = None


def get_argv_config_dict():
    res = {}
    for item in sys.argv:
        parts = item.split('=')
        if len(parts) == 2:
            res[parts[0]] = parts[1]
    return res


def get_config(k, defaultval, fromstr=str, required=False, guess=lambda: ''):
    global _config, _args_config, _config_updated

    if _config is None:
        config_path = os.path.expanduser("~/.getbox")
        if os.path.exists(config_path):
            with open(config_path) as f:
                _config = json.loads(f.read())
        else:
            _config = {}

        _args_config = get_argv_config_dict()

    if k in _args_config:
        res = fromstr(_args_config[k])
    else:
        res = _config.get(k, defaultval)

    if required and not res:
        g = guess()
        s = raw_input('Required config parameter %s not set. Please input %s [%s]> ' % (k, k, g))
        if g and not s.strip():
            res = fromstr(g)
        else:
            res = fromstr(s)
        if not res:
            print_("%s is required, exiting" % k)
            sys.exit(1)
        _config[k] = res
        _config_updated = True

    return res


def guess_subnets():
    client = boto3.client('ec2')
    try:
        r = client.describe_subnets()
        return ' '.join([(s['AvailabilityZone'] + ':' + s['SubnetId']) for s in r['Subnets'][:3]])
    except botocore.exceptions.ClientError:
        return ''


def guess_keyname():
    client = boto3.client('ec2')
    try:
        r = client.describe_key_pairs()
        return r['KeyPairs'][0]['KeyName']
    except botocore.exceptions.ClientError:
        return ''


def guess_ami():
    client = boto3.client('ec2')
    images = client.describe_images(Owners=['099720109477'],
                                    Filters=[{'Name': 'architecture', 'Values': ['x86_64']},
                                             {'Name': 'virtualization-type', 'Values': ['hvm']},
                                             {'Name': 'root-device-type', 'Values': ['ebs']},
                                             {'Name': 'sriov-net-support', 'Values': ['simple']}])['Images']
    candidates = [i for i in images if i['Name'].startswith('ubuntu/images/hvm-ssd/ubuntu-bionic-18.04-amd64-server-')]
    return max(candidates, key=lambda x: x['CreationDate'])['ImageId']


def get_latest_dlami_id():
    client = boto3.client('ec2')
    images = client.describe_images(Owners=['898082745236'],
                                    Filters=[{'Name': 'architecture', 'Values': ['x86_64']},
                                             {'Name': 'virtualization-type', 'Values': ['hvm']},
                                             {'Name': 'root-device-type', 'Values': ['ebs']}])['Images']
    candidates = [i for i in images if 'Ubuntu' in i['Name']]
    return max(candidates, key=lambda x: x['CreationDate'])['ImageId']


def subnets_fromstr(s):
    res = {}
    for item in s.split():
        zone, subnet = item.split(':')
        res[zone] = subnet
    return res


CONFIG_EXTRA_TAGS = get_config('extra_tags', {})
CONFIG_OWNER_TAG_NAME = get_config('owner_tag_name', 'owner')
CONFIG_SSH_USER = get_config('ssh_user', 'ubuntu')
CONFIG_SSH_WAIT_TIMEOUT_SECONDS = get_config('ssh_wait_timeout_seconds', 600)
CONFIG_ROOT_SIZE_GB = get_config('root_size_gb', 80, fromstr=int)
CONFIG_AMI_ID = get_config('ami_id', '', required=True, guess=guess_ami)
CONFIG_CUDA_AMI_ID = get_config('cuda_ami_id', None)
CONFIG_SSH_KEY_NAME = get_config('ssh_key_name', '', required=True, guess=guess_keyname)
CONFIG_IAM_ROLE_ARN = get_config('iam_role_arn', '', required=True)
CONFIG_SSH_FLAGS = get_config('ssh_flags', '-A')
CONFIG_SUBNETS = get_config('subnets', {}, required=True, guess=guess_subnets, fromstr=subnets_fromstr)
CONFIG_SECURITY_GROUPS = get_config('security_groups', [])
CONFIG_SPOT_PRICE = get_config('spot_price', None)
CONFIG_EBS_IOPS = get_config('ebs_iops', None, int)
CONFIG_EBS_VOLTYPE = get_config('ebs_volume_type', 'gp2')
CONFIG_EXTRA_INIT_COMMANDS = get_config('extra_init_commands', [])


def maybe_save_config():
    global _config, _config_updated
    config_path = os.path.expanduser("~/.getbox")
    if not os.path.exists(config_path):
        print_("Saving config to %s" % config_path)
        with open(config_path, 'w') as f:
            json.dump(_config, f)
    elif _config_updated:
        print_("Overwrite config at %s? [y/N] " % config_path, end='')
        if raw_input('') in ('Y', 'y'):
            with open(config_path, 'w') as f:
                json.dump(_config, f)


maybe_save_config()

#
# Generated with
#   curl  "https://instaguide.io/data/instances/regionCode=us-west-2&tenancyCode=shared&platform=linux.json" | jq -r '.rows|.[]|[.instanceType,.storage]|@csv' | sort | perl -ne 's/"//g; @f = split ","; if ($f[1] =~ /(\d+) x (\d+)/) { print $f[0]." ".$1*$2."/".$1.";" } else { print $f[0]." 0;"}' | fold -s  ; echo
ephemeral_info = """
a1.2xlarge 0;a1.4xlarge 0;a1.large 0;a1.medium 0;a1.xlarge 0;c1.medium
350/1;c1.xlarge 1680/4;c3.2xlarge 160/2;c3.4xlarge 320/2;c3.8xlarge
640/2;c3.large 32/2;c3.xlarge 80/2;c4.2xlarge 0;c4.4xlarge 0;c4.8xlarge
0;c4.large 0;c4.xlarge 0;c5.18xlarge 0;c5.2xlarge 0;c5.4xlarge 0;c5.9xlarge
0;c5d.18xlarge 1800/1;c5d.2xlarge 200/1;c5d.4xlarge 400/1;c5d.9xlarge
900/1;c5d.large 50/1;c5d.xlarge 100/1;c5.large 0;c5n.18xlarge 0;c5n.2xlarge
0;c5n.4xlarge 0;c5n.9xlarge 0;c5n.large 0;c5n.xlarge 0;c5.xlarge 0;cc2.8xlarge
3360/4;cr1.8xlarge 240/2;d2.2xlarge 12000/6;d2.4xlarge 24000/12;d2.8xlarge
48000/24;d2.xlarge 6000/3;f1.16xlarge 3760/4;f1.2xlarge 470/1;f1.4xlarge
940/1;g2.2xlarge 60/1;g2.8xlarge 240/2;g3.16xlarge 0;g3.4xlarge 0;g3.8xlarge
0;g3s.xlarge 0;h1.16xlarge 16000/8;h1.2xlarge 2000/1;h1.4xlarge
4000/2;h1.8xlarge 8000/4;hs1.8xlarge 48000/24;i2.2xlarge 1600/2;i2.4xlarge
3200/4;i2.8xlarge 6400/8;i2.xlarge 800/1;i3.16xlarge 15200/8;i3.2xlarge
1900/1;i3.4xlarge 3800/2;i3.8xlarge 7600/4;i3.large 475/1;i3.metal
15200/8;i3.xlarge 950/1;m1.large 840/2;m1.medium 410/1;m1.xlarge
1680/4;m2.2xlarge 850/1;m2.4xlarge 1680/2;m2.xlarge 420/1;m3.2xlarge
160/2;m3.large 32/1;m3.medium 4/1;m3.xlarge 80/2;m4.10xlarge 0;m4.16xlarge
0;m4.2xlarge 0;m4.4xlarge 0;m4.large 0;m4.xlarge 0;m5.12xlarge 0;m5.24xlarge
0;m5.2xlarge 0;m5.4xlarge 0;m5a.12xlarge 0;m5a.24xlarge 0;m5a.2xlarge
0;m5a.4xlarge 0;m5a.large 0;m5a.xlarge 0;m5d.12xlarge 1800/2;m5d.24xlarge
3600/4;m5d.2xlarge 300/1;m5d.4xlarge 600/2;m5d.large 75/1;m5d.xlarge
150/1;m5.large 0;m5.xlarge 0;p2.16xlarge 0;p2.8xlarge 0;p2.xlarge 0;p3.16xlarge
0;p3.2xlarge 0;p3.8xlarge 0;p3dn.24xlarge 1800/2;r3.2xlarge 160/1;r3.4xlarge
320/1;r3.8xlarge 640/2;r3.large 32/1;r3.xlarge 80/1;r4.16xlarge 0;r4.2xlarge
0;r4.4xlarge 0;r4.8xlarge 0;r4.large 0;r4.xlarge 0;r5.12xlarge 0;r5.24xlarge
0;r5.2xlarge 0;r5.4xlarge 0;r5a.12xlarge 0;r5a.24xlarge 0;r5a.2xlarge
0;r5a.4xlarge 0;r5a.large 0;r5a.xlarge 0;r5d.12xlarge 1800/2;r5d.24xlarge
3600/4;r5d.2xlarge 300/1;r5d.4xlarge 600/2;r5d.large 75/1;r5d.xlarge
150/1;r5.large 0;r5.xlarge 0;t1.micro 0;t2.2xlarge 0;t2.large 0;t2.medium
0;t2.micro 0;t2.nano 0;t2.small 0;t2.xlarge 0;t3.2xlarge 0;t3.large 0;t3.medium
0;t3.micro 0;t3.nano 0;t3.small 0;t3.xlarge 0;x1.16xlarge 1920/1;x1.32xlarge
3840/2;x1e.16xlarge 1920/1;x1e.2xlarge 240/1;x1e.32xlarge 3840/2;x1e.4xlarge
480/1;x1e.8xlarge 960/1;x1e.xlarge 120/1;z1d.12xlarge 1800/2;z1d.2xlarge
300/1;z1d.3xlarge 450/1;z1d.6xlarge 900/1;z1d.large 75/1;z1d.xlarge 150/1
"""


def get_ami_id(instance_type):
    if instance_type.split('.')[0] in ('g2', 'g3', 'p2', 'p3'):
        if CONFIG_CUDA_AMI_ID:
            if CONFIG_CUDA_AMI_ID.lower() in ('dlami', 'latest_dlami'):
                return get_latest_dlami_id()
            else:
                return CONFIG_CUDA_AMI_ID
    else:
        if CONFIG_AMI_ID.lower() in ('dlami', 'latest_dlami'):
            return get_latest_dlami_id()
        else:
            return CONFIG_AMI_ID


def all_instance_types():
    info = dict(tuple(s.strip().split()) for s in ephemeral_info.strip().split(";"))
    return info.keys()


def get_ephemeral_info(type_):
    info = dict(tuple(s.strip().split()) for s in ephemeral_info.split(";"))

    if info[type_] == '0':
        return 0, 0
    elif '/' not in info[type_]:
        return int(info[type_]), 1
    else:
        return map(int, info[type_].split('/'))


INIT_EPHEMERALS = """#!/bin/bash
set -e

yum install -y mdadm || true

test -e /etc/apt/apt.conf.d/10periodic && echo 'APT::Periodic::Unattended-Upgrade "0";' >>/etc/apt/apt.conf.d/10periodic

umount /media/ephemeral{0..29} 2>&1 || true

EPHEMERALS=`curl http://169.254.169.254/latest/meta-data/block-device-mapping/ | grep ephemeral || true`

EPHEMERAL_DEVICES=''
for d in $EPHEMERALS ; do
    path=`curl -w "\n" http://169.254.169.254/latest/meta-data/block-device-mapping/$d 2>/dev/null | sed -e 's|^sd|/dev/xvd|'`
    if [ -e "$path" ]; then
        EPHEMERAL_DEVICES+=" $path"
    else
        path=$(echo $d | sed -e 's|^ephemeral|/dev/nvme|')
        if [ -e "${path}n1" ]; then
            EPHEMERAL_DEVICES+=" ${path}n1"
        fi
    fi
done


if [ "$EPHEMERAL_DEVICES" != "" ]; then

    while true; do
        if [ -e /etc/init.d/docker ] ; then
            /etc/init.d/docker stop  || true
            killall docker-containerd || true
        fi

        if pvcreate -y --dataalignment 1m $EPHEMERAL_DEVICES ; then
            break
        else
            sleep 1
        fi
    done

    vgcreate VolGroup01 $EPHEMERAL_DEVICES
    lvcreate -i $(echo ${EPHEMERAL_DEVICES} | wc -w) -l +100%FREE VolGroup01 -n ephemeralvol
fi
"""

BASIC_INIT = INIT_EPHEMERALS + """
pvcreate -y --dataalignment 1m /dev/xvdg
vgcreate VolGroup00 /dev/xvdg
lvcreate -i 1 -l +100%FREE VolGroup00 -n datavol
"""

EBS_MIRROR_INIT = BASIC_INIT + """
echo "y" | mdadm --create --verbose /dev/md0 --level=mirror --raid-devices=2 /dev/VolGroup01/ephemeralvol --write-mostly /dev/VolGroup00/datavol
mkfs.ext4 /dev/md0
mount /dev/md0 /mnt

if [ -e /etc/init.d/docker ] ; then
    /etc/init.d/docker start
fi
"""

EBS_MIRROR_ATTACH_INIT = INIT_EPHEMERALS + """
mdadm -S /dev/md127 || true

# check if dm-0 is indeed the volume we need
[ $(dmsetup info /dev/dm-0 | grep Name | awk '{print $2}') == VolGroup00-datavol ]

mdadm --assemble --run /dev/md0 /dev/dm-0
mount /dev/md0 /mnt
mdadm --manage /dev/md0 -a /dev/VolGroup01/ephemeralvol

if [ -e /etc/init.d/docker ] ; then
    /etc/init.d/docker start
fi
"""

EBS_FORMAT_INIT = BASIC_INIT + """
mkfs.ext4 /dev/VolGroup00/datavol
mount /dev/VolGroup00/datavol /mnt/

if [ "$EPHEMERAL_DEVICES" != "" ]; then
    mkdir /scratch
    mkfs.ext4 /dev/VolGroup01/ephemeralvol
    mount /dev/VolGroup01/ephemeralvol /scratch/
fi

if [ -e /etc/init.d/docker ] ; then
    /etc/init.d/docker start
fi
"""

EBS_ATTACH_INIT = INIT_EPHEMERALS + """
while true  ; do
if [ -b /dev/VolGroup00/datavol ] ; then
    mount /dev/VolGroup00/datavol /mnt/
    break
else
    sleep 1
fi
done

if [ "$EPHEMERAL_DEVICES" != "" ]; then
    mkdir /scratch
    mkfs.ext4 /dev/VolGroup01/ephemeralvol
    mount /dev/VolGroup01/ephemeralvol /scratch/
fi

if [ -e /etc/init.d/docker ] ; then
    /etc/init.d/docker start
fi
"""

COMPLETION_SCRIPT = """
_getbox()
{
    local cur
    INSTANCE_TYPES='""" + " ".join(all_instance_types()) + """'

    COMPREPLY=()

    _get_comp_words_by_ref cur

    prev="${COMP_WORDS[COMP_CWORD-1]}"

    if [ $COMP_CWORD -eq 1 ]; then
        COMPREPLY=( $(compgen -W "get kill keep unkeep ssh list rename" -- ${cur}) )
        return 0
    fi

    if [ "$prev" == "get" ]; then
        COMPREPLY=( $(compgen -W "$INSTANCE_TYPES" -- ${cur}) )
        return 0
    fi

    if [ "$prev" == "ssh" ]; then
        COMPREPLY=( $(getbox ls --names-only --only instances | grep "${cur}") )
        return 0
    fi
} &&
complete -F _getbox getbox
"""


class Spinner:
    def __init__(self, message):
        self.spinner = 0
        self.message = message

    def tick(self):
        spinner_frames = '|/-\\'
        self.spinner = (self.spinner + 1) % len(spinner_frames)
        print_(self.message + spinner_frames[self.spinner] + "\r", end='')

    def done(self, done_msg='done'):
        print_(self.message + " " + done_msg + "\n", end='')


def table_head(*columns):
    print_(' ', end='')
    for text, width, style in columns:
        print_(text[:width].ljust(width), end="", style='bold')
    print_(' ')
    return columns


def table_row(columns, *args):
    print_(' ', end='')
    for col, t in zip(columns, args):
        text, width, style = col
        print_(t[:width].ljust(width), end="", style='green')
    print_(' ')


def table_footer():
    print_('')


def kill_volume(tags, ebs_id_prefix):
    client = boto3.client('ec2')
    vol = get_detached_ebs_vol(client, tags, ebs_id_prefix)

    t = volume_table_head()
    table_row(t, *format_volume_row(vol))

    print_("Delete this volume? [y/N] ", end='')
    if raw_input('') in ('Y', 'y'):
        client.delete_volume(VolumeId=vol['VolumeId'])
        sys.exit(0)
    else:
        print_("Cancelled")
        sys.exit(1)


def find_instance(client, tags, partId):
    client = boto3.client('ec2')

    response = client.describe_instances(Filters=[
        {
            'Name': 'tag:' + CONFIG_OWNER_TAG_NAME,
            'Values': [tags[CONFIG_OWNER_TAG_NAME]]
        }
    ])

    for rsr in response['Reservations']:
        for inst in rsr['Instances']:
            if not tags_match(inst['Tags'], tags):
                continue

            if inst['State']['Name'] != 'running':
                continue

            name = instance_name(inst)

            if partId in inst['InstanceId'] or (name and partId in name):
                return inst

    print_('Instance id not found: %s. Try getbox list to see the list of instances' % partId, style='error')
    sys.exit(1)


def kill_instance(tags, partId):
    client = boto3.client('ec2')

    inst = find_instance(client, tags, partId)

    t = instance_table_head()
    table_row(t, *instance_table_row(inst))

    print_("Terminate this instance? [y/N] ", end='')
    if raw_input('') in ('Y', 'y'):
        client.terminate_instances(InstanceIds=[inst['InstanceId']])
        sys.exit(0)
    else:
        print_("Cancelled")
        sys.exit(1)


def rename_instance(tags, partId, name):
    client = boto3.client('ec2')

    inst = find_instance(client, tags, partId)

    t = instance_table_head()
    table_row(t, *instance_table_row(inst))

    print_("Rename this instance? [y/N] ", end='')
    if raw_input('') in ('Y', 'y'):
        instanceId = inst['InstanceId']
        instance_tags = [x for x in inst['Tags'] if x['Key'] != 'Name']
        instance_tags.append({'Key': 'Name', 'Value': name})

        client.create_tags(Resources=[instanceId], Tags=instance_tags)

        sys.exit(0)
    else:
        print_("Cancelled")
        sys.exit(1)


def get_instance_ssh_address(inst):
    return inst['PublicDnsName'] or inst['PrivateIpAddress']


def ssh(tags, partId):
    client = boto3.client('ec2')

    response = client.describe_instances(Filters=[
        {
            'Name': 'tag:' + CONFIG_OWNER_TAG_NAME,
            'Values': [tags[CONFIG_OWNER_TAG_NAME]]
        }
    ])

    for rsr in response['Reservations']:
        for inst in rsr['Instances']:
            if not tags_match(inst['Tags'], tags):
                continue

            if inst['State']['Name'] != 'running':
                continue

            name = instance_name(inst)

            if partId in inst['InstanceId'] or (name and partId in name):
                dns = get_instance_ssh_address(inst)
                key = inst['KeyName']
                try:
                    keys = subprocess.check_output(["ssh-add", "-l"]).decode().splitlines()
                except subprocess.CalledProcessError:
                    # ssh-add -l returns error if no identities are loaded
                    keys = []

                if not any((key in k) for k in keys):
                    print_('WARNING: You don\' seem to have key %s in your ssh agent' % key, style='bold')

                if hasattr(sys.stdout, 'isatty') and sys.stdout.isatty():
                    os.execv('/bin/bash', ['/bin/bash', '-c', 'ssh %s %s@%s' % (CONFIG_SSH_FLAGS, CONFIG_SSH_USER, dns)])
                else:
                    print_('%s@%s' % (CONFIG_SSH_USER, dns), file=sys.stdout)
                sys.exit(0)

    print_('Instance id not found: %s. Try getbox list to see the list of instances' % partId, style='error')
    sys.exit(1)


def tags_match(Tags, tagdict):
    tt = {t['Key']: t['Value'] for t in Tags if t['Key'] in tagdict}
    return tt == tagdict


def keep_ebs(tags, ebs_id_prefixes, keep=True):
    not_found = set(ebs_id_prefixes)

    client = boto3.client('ec2')
    response = client.describe_volumes(Filters=[
        {
            'Name': 'tag:' + CONFIG_OWNER_TAG_NAME,
            'Values': [tags[CONFIG_OWNER_TAG_NAME]]
        }
    ])

    for vol in response['Volumes']:
        if not tags_match(vol['Tags'], tags):
            continue
        for pref in ebs_id_prefixes:
            if vol['VolumeId'].startswith(pref):
                not_found.remove(pref)

                for a in vol['Attachments']:
                    client.modify_instance_attribute(InstanceId=a['InstanceId'],
                                                     Attribute='blockDeviceMapping',
                                                     BlockDeviceMappings=[
                                                     {
                                                         'DeviceName': a['Device'],
                                                         'Ebs': {'DeleteOnTermination': not keep}
                                                     }])
                if keep:
                    print_("Volume %s will be kept around after instance termination" % vol['VolumeId'])
                else:
                    print_("Volume %s will be deleted after instance termination" % vol['VolumeId'])

    if not_found:
        print_('Couldn\'t find attached EBS with id %s' % repr(not_found), style='error')
        sys.exit(1)


def volume_table_head():
    return table_head(('VolumeId', 20, None),
                      ('Size', 10, None),
                      ('Type', 10, None),
                      ('Delete on termination', 27, None),
                      ('State', 60, None))


def format_volume_row(vol):
    attachment = ','.join(a['InstanceId'] for a in vol['Attachments'])
    delete_on_term = any(a['DeleteOnTermination'] for a in vol['Attachments'])

    if vol['State'] == 'in-use':
        state = 'in-use(%s)' % (attachment)
    else:
        state = vol['State']

    return (
        vol['VolumeId'],
        str(vol['Size']),
        vol['VolumeType'],
        'No' if not delete_on_term else 'Yes',
        state
    )


def list_volumes(tags, names_only=False):
    client = boto3.client('ec2')

    response = client.describe_volumes(Filters=[
        {
            'Name': 'tag:' + CONFIG_OWNER_TAG_NAME,
            'Values': [tags[CONFIG_OWNER_TAG_NAME]]}
    ])

    if response['Volumes']:
        if not names_only:
            t = volume_table_head()

        for vol in response['Volumes']:
            if not tags_match(vol['Tags'], tags):
                continue

            if not names_only:
                table_row(t, *format_volume_row(vol))
            else:
                print_(format_volume_row(vol)[0], file=sys.stdout)

        if not names_only:
            table_footer()


def ebs_desc(inst, block_device_name='/dev/sdg'):
    ebs_desc = []

    for m in inst['BlockDeviceMappings']:
        if m['DeviceName'] == block_device_name:
            volumeId = m['Ebs']['VolumeId']
            delete_on_term = m['Ebs']['DeleteOnTermination']
            if delete_on_term:
                ebs_desc.append('%s(delete)' % volumeId)
            else:
                ebs_desc.append('%s(keep)' % volumeId)

    return ','.join(ebs_desc) if ebs_desc else 'None'


def instance_table_head():
    return table_head(('Id', 20, None),
                      ('Type', 10, None),
                      ('Started', 18, None),
                      ('EBS', 30, None),
                      ('Address', 53, None))


def instance_name(inst):
    for t in inst['Tags']:
        if t['Key'] == 'Name':
            return t['Value']


def instance_table_row(inst):
    return (instance_name(inst) or inst['InstanceId'],
            inst['InstanceType'],
            inst['LaunchTime'].strftime('%Y-%m-%d %H:%M'),
            ebs_desc(inst),
            get_instance_ssh_address(inst))


def list_instances(tags, block_device_name='/dev/sdg', names_only=False):
    client = boto3.client('ec2')
    response = client.describe_instances(Filters=[
        {
            'Name': 'tag:' + CONFIG_OWNER_TAG_NAME,
            'Values': [tags[CONFIG_OWNER_TAG_NAME]]
        }
    ])

    if not names_only:
        t = instance_table_head()

    instances = [inst for rsr in response['Reservations']
                 for inst in rsr['Instances']]

    for inst in sorted(instances, key=lambda x: x['LaunchTime']):
        if not tags_match(inst['Tags'], tags):
            continue

        if inst['State']['Name'] != 'running':
            continue

        if not names_only:
            table_row(t, *instance_table_row(inst))
        else:
            print_(instance_table_row(inst)[0], file=sys.stdout)

    if not names_only:
        table_footer()


def wait_for_ssh(client, dnsName):
    spin = Spinner("Waiting for ssh %s.. " % dnsName)
    spin.tick()
    t_start = time.time()

    while True:
        if time.time() - t_start > CONFIG_SSH_WAIT_TIMEOUT_SECONDS:
            spin.done()
            print_("Failed to connect to %s" % dnsName, style='error')
            sys.exit(1)

        try:
            sock = socket.create_connection((dnsName, 22), timeout=1)
            sock.close()
            spin.done()
            return
        except socket.timeout as e:
            time.sleep(1)
            spin.tick()

        except socket.error as e:
            if e.errno != errno.ECONNREFUSED:
                spin.done()
                print_("Failed to connect: %s" % e, style='error')
                sys.exit(1)
            else:
                time.sleep(1)
                spin.tick()


def wait_for_init(client, instanceId):
    spin = Spinner("Waiting for instance %s to start.. " % instanceId)
    spin.tick()
    while True:
        response = client.describe_instances(Filters=[{'Name': 'instance-id', 'Values': [instanceId]}])
        for rsr in response['Reservations']:
            for inst in rsr['Instances']:
                st = inst['State']['Name']
                if st == 'pending':
                    spin.tick()
                    time.sleep(1)
                    break
                elif st == 'running':
                    spin.done()
                    return get_instance_ssh_address(inst)
                elif st == 'terminated':
                    spin.done("failed")
                    sys.exit(1)
                else:
                    spin.done("failed")
                    print_('Unexpected state while waiting for instance start: %s' % st)


def get_detached_ebs_vol(client, tags, ebs_id_prefix):
    response = client.describe_volumes(Filters=[
        {
            'Name': 'tag:' + CONFIG_OWNER_TAG_NAME,
            'Values': [tags[CONFIG_OWNER_TAG_NAME]]
        }
    ])

    for vol in response['Volumes']:
        if not tags_match(vol['Tags'], tags):
            continue

        if vol['VolumeId'].startswith(ebs_id_prefix):
            if vol['State'] == 'available':
                return vol
            else:
                print_('EBS volume %s exists, but it needs to be detached (currently %s)' % (vol['VolumeId'], vol['State']), style='error')
                sys.exit(1)
    else:
        print_('Couldn\'t find EBS with id %s' % ebs_id_prefix, style='error')
        sys.exit(1)


def get_inst(instance_type, tags, ebs_id_prefix, volume_size=None,
             mirror=None, name=None, keep=False, root_ebs_id=None):
    client = boto3.client('ec2')

    if ebs_id_prefix:
        vol = get_detached_ebs_vol(client, tags, ebs_id_prefix)
        vol_tags = {t['Key']: t['Value'] for t in vol['Tags']}

        # We store instance type in the volume metadata, so if it was not
        # specified explicitly, just take it from there. Same with the mirror
        # flag.
        if vol_tags.get('getbox'):
            tag_metadata = json.loads(vol_tags.get('getbox'))
            if ('instance_type' in tag_metadata) and instance_type is None:
                instance_type = tag_metadata['instance_type']
                print_("The volume was last used with instance type %s, using the same type" % instance_type)
            if ('mirror' in tag_metadata) and mirror is None:
                mirror = tag_metadata['mirror']

    if instance_type is None:
        instance_type = 'm4.4xlarge'
        print_("Using default instance type %s" % instance_type)

    launchSpec = {
        'ImageId': get_ami_id(instance_type),
        'KeyName': CONFIG_SSH_KEY_NAME,
        'InstanceType': instance_type,
        'SubnetId': random.choice(list(CONFIG_SUBNETS.values())),
        'EbsOptimized': is_ebs_optimized_type(instance_type),
        'IamInstanceProfile': {
            'Arn': CONFIG_IAM_ROLE_ARN.replace(':role/', ':instance-profile/'),
        },
        'Monitoring': {
            'Enabled': True
        },
        'SecurityGroupIds': CONFIG_SECURITY_GROUPS
    }

    try:
        size_e, num_e = get_ephemeral_info(instance_type)
    except KeyError:
        print_("Unknown instance type: %s" % (instance_type,), style='error')
        sys.exit(1)

    if mirror:
        if volume_size is not None:
            print_("Cannot specify size in mirror mode, it is chosen automatically to match ephemeral storage capacity", style='error')
            sys.exit(1)

        if size_e == 0:
            print_("Cannot use mirror mode for instance type %s as it has no ephemeral drives" % instance_type, style='error')
            sys.exit(1)

        volume_size = size_e

    def build_userdata(init):
        from email.mime.multipart import MIMEMultipart
        from email.mime.text import MIMEText
        combined_message = MIMEMultipart()
        contents = (init + "\n" + "\n".join(CONFIG_EXTRA_INIT_COMMANDS)).encode()
        filename = 'init.sh'
        sub_message = MIMEText(contents, 'text/x-shellscript', sys.getdefaultencoding())
        sub_message.add_header('Content-Disposition', 'attachment; filename="%s"' % (filename,))
        combined_message.attach(sub_message)
        return str(combined_message)

    if ebs_id_prefix:
        # attach existing
        init = EBS_MIRROR_ATTACH_INIT if mirror else EBS_ATTACH_INIT
        ebs_vol_id, zone = vol['VolumeId'], vol['AvailabilityZone']
        launchSpec['UserData'] = build_userdata(init)
        launchSpec['SubnetId'] = CONFIG_SUBNETS[zone]
    else:
        # create new EBS
        ebs_vol_id = None
        init = EBS_MIRROR_INIT if mirror else EBS_FORMAT_INIT
        launchSpec['UserData'] = build_userdata(init)
        launchSpec['BlockDeviceMappings'] = [
            {
                'DeviceName': '/dev/sdg',
                'Ebs': {
                    'VolumeSize': volume_size or 100,
                    'DeleteOnTermination': not keep,
                    'VolumeType': CONFIG_EBS_VOLTYPE
                }
            }
        ]

        if CONFIG_EBS_IOPS:
            launchSpec['BlockDeviceMappings'][0]['Ebs']['Iops'] = CONFIG_EBS_IOPS

    # skip /dev/sdg since we're using it for EBS (not tested on d2.8xl)
    drive_names = ['/dev/sd' + chr(ord('b') + i) for i in range(25) if i != 5]

    if 'BlockDeviceMappings' not in launchSpec:
        launchSpec['BlockDeviceMappings'] = []

    # root volume
    launchSpec['BlockDeviceMappings'].append({
        'DeviceName': '/dev/sda1',
        'Ebs': {'VolumeSize': CONFIG_ROOT_SIZE_GB,
                'DeleteOnTermination': True,
                'VolumeType': 'gp2'
                }
    })

    for n in range(num_e):
        if 'BlockDeviceMappings' not in launchSpec:
            launchSpec['BlockDeviceMappings'] = []

        launchSpec['BlockDeviceMappings'].append({
            'DeviceName': drive_names[n],
            'VirtualName': 'ephemeral%d' % n
        })

    do_get_spot(client, launchSpec, tags, name, ebs_vol_id, mirror, instance_type)


def do_get_spot(client, launchSpec, tags, name, ebs_vol_id, mirror, instance_type):
    vol_tags = [{'Key': k, 'Value': v} for k, v in tags.items()]
    if mirror is not None:
        vol_tags.append({'Key': 'mirror', 'Value': mirror})
    vol_tags.append({'Key': 'instance_type', 'Value': instance_type})
    if name:
        vol_tags.append({'Key': 'Name', 'Value': name})

    instance_tags = [{'Key': k, 'Value': v} for k, v in tags.items()]
    if name:
        instance_tags.append({'Key': 'Name', 'Value': name})

    instanceMarketOptions = {
        'MarketType': 'spot',
        'SpotOptions': {
            'SpotInstanceType': 'one-time',
            'InstanceInterruptionBehavior': 'terminate'
        }
    }

    if CONFIG_SPOT_PRICE:
        instanceMarketOptions['SpotOptions']['MaxPrice'] = str(CONFIG_SPOT_PRICE)

    response = client.run_instances(
        TagSpecifications=[
            {'ResourceType': 'instance', 'Tags': instance_tags},
            {'ResourceType': 'volume', 'Tags': vol_tags}
        ],
        InstanceMarketOptions=instanceMarketOptions,
        MinCount=1,
        MaxCount=1,
        **launchSpec
    )

    instanceId = response['Instances'][0]['InstanceId']

    dnsName = wait_for_init(client, instanceId)

    if ebs_vol_id is not None:
        print_("Attaching volume %s" % ebs_vol_id)
        client.attach_volume(VolumeId=ebs_vol_id, InstanceId=instanceId, Device='/dev/sdg')

    wait_for_ssh(client, dnsName)
    ssh(get_tags(), instanceId)


def install_check():
    v = boto3.__version__.split('.')
    version = list(map(int, v[:2]))
    if version < [1, 4]:
        print_("boto3>=1.4 recommended, install it with pip install --update boto3")


def install_binary(dst='/usr/local/bin/getbox'):
    try:
        shutil.copyfile(sys.argv[0], dst)
        os.chmod(dst, 0o755)
        return dst
    except IOError as e:
        if e.errno in (13, 1):
            print_('Cannot write to %s, try running this as root' % dst, style='error')
            sys.exit(1)
        else:
            raise


def install_comp():
    try:
        dst = subprocess.check_output(['bash',
                                       '-c',
                                       'echo $(brew --prefix)/etc/bash_completion.d/'],
                                      stderr=open(os.devnull)).strip()

        if not os.path.exists(dst):
            print_('Failed to install bash completion, %s does not exist' % dst, style='info')
        else:
            with open(os.path.join(dst.decode(), 'getbox'), 'w') as f:
                f.write(COMPLETION_SCRIPT)
    except IOError as e:
        if e.errno == 1:
            print_('Cannot write to %s, try running this as root' % dst, style='error')
            sys.exit(0)
        else:
            raise


def print_usage(f=sys.stdout):
    print_("\n          getbox: easy sandbox manager for ec2/spot. Usage: getbox <cmd> [parameters]", file=f)

    def cmd(name, descr, example):
        print_(name.ljust(15), end='', style='bold')
        print_(descr.ljust(62), end='')
        if example:
            print_(example)
        else:
            print_('')

    print_(" " * (86) + "(Example)")

    cmd("get", "Get an ec2 instance and ssh into it.", "getbox get r3.8xlarge")
    cmd("", "...optionally, pass EBS volume id to attach.", "getbox get vol-321")
    cmd("", "...or both", "getbox get r3.8xlarge vol-321")
    cmd("", "...you can also name it", "getbox get r3.8xlarge --name mytestbox")
    cmd("", "...or specify EBS volume size", "getbox get r3.8xlarge 200GB")
    cmd("list", "List your instances and volumes", "getbox list")
    cmd("kill", "Kill an instance or (unattached) volume", "getbox kill 123")
    cmd("", "", "getbox kill vol-456")
    cmd("ssh", "Log into the instance", "getbox ssh 123")
    cmd("keep", "Disable delete-on-termination on an attached EBS volume", "getbox keep vol-123")
    cmd("unkeep", "Enable delete-on-termination on an attached EBS volume", "getbox unkeep vol-123")
    cmd("rename", "(Re)name an instance", "getbox rename i-123 mytestbox")
    cmd("install", "Install getbox to /usr/local/bin and add bash autocomplete", "getbox install")

    print_('')
    print_("Commands that take instance or volume id can take incomplete id as well.")
    print_('')
    print_("ssh", end='', style='bold')
    print_("and ", end='')
    print_("get", end='', style='bold')
    print_("commands can be used to log into the instance in a bash compatible shell, try: ssh $(getbox get)")
    print_("ssh", end='', style='bold')
    print_("command will just log you into the instance if used from tty: getbox ssh")
    print_("...but if it detects that its output was redirected, it'll just return login strings so you can do: rsync -av . $(getbox ssh):")
    print_('')


def get_owner():
    return getpass.getuser()


def get_tags():
    tt = {k: v for k, v in CONFIG_EXTRA_TAGS.items()}
    tt[CONFIG_OWNER_TAG_NAME] = get_owner()
    return tt


def check_creds():
    sess = boto3._get_default_session()

    if hasattr(sess, 'get_credentials'):
        creds = sess.get_credentials()
        if not creds:
            print_('Failed to get AWS credentials', style='error')
            sys.exit(1)

    if get_owner() in ('root', 'ec2-user', 'nobody'):
        print_('Since we use your system username to tag instances, you can\'t run this as %s' % get_owner())
        sys.exit(1)


def is_ebs_optimized_type(s):
    no_ebs_opt = ['cc2', 'cg1', 't2', 'g2',
                  'r3', 'i2', 'hi1', 'hs1', 'm3',
                  'c3', 'm1', 'c1', 'm2', 't1',
                  'cr1', 'f1']

    for f in no_ebs_opt:
        if s.startswith(f):
            return False
    else:
        return True


def valid_instance_type(s):
    families = [
        'a1',
        'm1', 'm2', 'm3', 'm4', 'm5', 'm5d',
        'r3', 'r4', 'r5', 'r5d', 'r5a',
        'i2', 'i3',
        'hs1',
        'hi1',
        'x1', 'x1e',
        'g2', 'g3', 'g3s',
        'p2', 'p3', 'p3dn',
        't1', 't2', 't3',
        'c3', 'c4', 'c5', 'c5d', 'c1', 'cc2', 'cg1', 'cr1', 'c5n',
        'f1',
        'd2',
        'z1d'
    ]
    for f in families:
        if s.startswith(f + '.'):
            return True
    else:
        return False


def parse_size(s):
    m = re.match('^(\d+)(G|GB)$', s)
    if m:
        return int(m.group(1))

    m = re.match('^(\d+)(T|TB)$', s)
    if m:
        return int(m.group(1)) * 1000


def main():
    parser = argparse.ArgumentParser(add_help=False)
    parser.add_argument('--mirror', help='mirror mode', action='store_true', default=None)
    parser.add_argument('--root', help='root volume id', type=str)
    parser.add_argument('--name', help='name', type=str)
    parser.add_argument('--only', help='list objects', type=str, default='')
    parser.add_argument('--names-only', help='list only object names', action='store_true', default=False)
    parser.add_argument('cmd', type=str, nargs='?')
    parser.add_argument('args', nargs='*')
    parser.print_usage = print_usage
    p = parser.parse_known_args()[0]

    if p.cmd is None:
        print_usage()
    elif p.cmd == 'install':
        install_check()
        bin_path = install_binary()
        install_comp()
        print_('Installed at %s' % bin_path, style='success')
    elif p.cmd == 'get':
        check_creds()
        type_ = None

        vol_prefix = None

        volume_size = None
        keep = False

        for a in p.args:
            if a.startswith('vol-'):
                vol_prefix = a
            elif valid_instance_type(a):
                type_ = a
            elif parse_size(a):
                volume_size = parse_size(a)
            elif a == 'keep':
                keep = True
            else:
                if '=' not in a:
                    print_('Unknown option: %s' % a, style='error')
                    sys.exit(1)

        get_inst(type_, get_tags(), vol_prefix,
                 volume_size=volume_size,
                 mirror=p.mirror, name=p.name,
                 keep=keep)
    elif p.cmd in ('name', 'rename'):
        if len(p.args) != 2:
            print_('Usage: rename <id-or-name> <name>', style='error')
            sys.exit(1)
        rename_instance(get_tags(), p.args[0], p.args[1])
    elif p.cmd == 'ssh':
        check_creds()
        if len(p.args) != 1:
            print_('Usage: ssh <instance-id>', style='error')
            sys.exit(1)
        else:
            ssh(get_tags(), p.args[0])
    elif p.cmd == 'kill':
        check_creds()
        if len(p.args) != 1:
            print_('Usage: kill <instance-or-volume-id>', style='error')
            sys.exit(1)
        else:
            vol_prefix = None
            for a in p.args:
                if a.startswith('vol-'):
                    vol_prefix = a

            if vol_prefix is None:
                kill_instance(get_tags(), p.args[0])
            else:
                kill_volume(get_tags(), vol_prefix)
    elif p.cmd == 'keep':
        check_creds()
        if len(p.args) < 1:
            print_('Usage: keep <ebs-id>', style='error')
            sys.exit(1)
        else:
            keep_ebs(get_tags(), p.args[0:], True)
    elif p.cmd in ('unkeep', 'dontkeep'):
        check_creds()
        if len(p.args) < 1:
            print_('Usage: unkeep <ebs-id>', style='error')
            sys.exit(1)
        else:
            keep_ebs(get_tags(), p.args[0:], False)
    elif p.cmd in ('list', 'ls'):
        check_creds()
        if p.only:
            obj_types = p.only.split(',')
        else:
            obj_types = ('instances', 'volumes')
        if 'instances' in obj_types:
            list_instances(get_tags(), names_only=p.names_only)
        if 'volumes' in obj_types:
            list_volumes(get_tags(), names_only=p.names_only)
    else:
        print_usage()


if __name__ == '__main__':
    try:
        main()
    except botocore.exceptions.ClientError as e:
        print_(str(e))
        sys.exit(1)