#!python

import subprocess, sys, os, argparse, boto3
from objects.exceptions import CommandError

#sys.tracebacklimit=0

def get_aws_session(profile_name):
	if profile_name:
		return boto3.Session(profile_name=profile_name, region_name='eu-west-1')
	else:
		return boto3.Session(region_name='eu-west-1')

def get_aws_subaccount_credentials(args, conn_client):
	desired_account_name = args.environment_name.split("-")[0]+"-"+args.environment_name.split("-")[1]
	conn = get_aws_session(profile_name=desired_account_name).client('sts')


	return conn

def get_instances_list(args):

	if args.profile:
		conn_eb = get_aws_session(args.profile).client('elasticbeanstalk')
	else:
		conn_eb = get_aws_subaccount_credentials(args, 'elasticbeanstalk')

	try:
		instances_list = [ instance['Id'] for instance in conn_eb.describe_environment_resources(EnvironmentName=args.environment_name)['EnvironmentResources']['Instances'] ]
	except Exception as e:
		print(e)
		exit()

	return instances_list

def get_input(output, default):

	result = str(input(output + ': ')).strip() or default
	return result


def prompt_for_instance_in_list(instances_list, default=1):
	for x in range(0, len(instances_list)):
		print(str(x + 1) + ')', instances_list[x])

	while True:
		try:
			choice = int(get_input('(default is ' + str(default)+')', default))
			if not (0 < choice <= len(instances_list)):
				raise ValueError  # Also thrown by non int numbers
			else:
				break
		except ValueError:
			print('Sorry, that is not a valid choice. Please choose a number between 1 and ' + str(len(instances_list)) + '.')

	return choice - 1

def get_ssh_key_from_s3(args, instance):
	print("INFO: Downloading s3://ssh-key-"+instance['KeyName']+"/"+instance['KeyName'])

	if args.profile:
		conn_s3 = get_aws_session(args.profile).client('s3')
	else:
		conn_s3 = get_aws_subaccount_credentials(args, 's3')

	conn_s3.download_file("ssh-key-"+instance['KeyName'], instance['KeyName'], os.path.expanduser("~") + "/.ssh/" + instance['KeyName'])
	s = '400'
	os.chmod(os.path.expanduser("~") + "/.ssh/" + instance['KeyName'], int(s, base=8))

def ssh_into_instance(args, instance_id, custom_ssh=None, command=None):

	if args.profile:
		conn_ec2 = get_aws_session(args.profile).client('ec2')
	else:
		conn_ec2 = get_aws_subaccount_credentials(args, 'ec2')

	instance = conn_ec2.describe_instances(InstanceIds=[instance_id])['Reservations'][0]['Instances'][0]

	keypair_file = "~/.ssh/"+instance['KeyName']

	# Download the ssh key from s3 in case it is not present at "~/.ssh/...."
	if not os.path.isfile(os.path.expanduser("~") + "/.ssh/" + instance['KeyName']):
		get_ssh_key_from_s3(args, instance)

	try:
		ip = instance['PrivateIpAddress']
	except KeyError:
		# Now allows access to private subnet
		if 'PrivateIpAddress' in instance and 'PrivateDnsName' in instance:
			ip = instance['PrivateDnsName']
		else:
			raise

	user = 'ec2-user'

	# do ssh
	try:
		ssh_command = ['ssh', '-o IdentitiesOnly=yes', '-i', keypair_file, user + '@' + ip]

		print('INFO: Running ' + ' '.join(ssh_command))
		returncode = subprocess.call(ssh_command)
		if returncode != 0:
			raise CommandError('An error occurred while running: ' + ssh_command[0] + '.')
	except OSError:
		raise

def main(args):

	# Get the instances list
	instances_list = get_instances_list(args)

	# Prompt the user to chose an instance
	instance_chosen = prompt_for_instance_in_list(instances_list)

	# Open the ssh connection using the internal IP
	ssh_into_instance(args, instances_list[instance_chosen])


### Init ###

if 'ssh' in [arguments for arguments in sys.argv] or len(sys.argv) == 1:

	parser = argparse.ArgumentParser(add_help=False)

	parser.add_argument('ssh')

	# group_role_profile = parser.add_mutually_exclusive_group(required=True)
	# group_role_profile.add_argument('--profile',help='--profile <application-[production|staging]>')
	# group_role_profile.add_argument('--role',help='--role [Administrators|Developers]>')
	parser.add_argument('--profile',help='--profile <application-[production|staging]>', required=True)
	parser.add_argument('environment_name', help='<application-[production|staging]-[web|activejobs|cronjobs]>')

	args = parser.parse_args()

	main(args)

else:
	print('Forwarding command to the official awsebcli..')

	aws_eb_cli = ['eb']

	for eb_arguments in sys.argv[1:]:
		aws_eb_cli.append(eb_arguments)

	subprocess.call(aws_eb_cli)




