#!/bin/bash

# Helper script to connect to remote managed instance with SSM and start SSH port forwarding
# Every time it generates a new SSH key at ~/.ssh/sagemaker-ssh-gw and transfers the public part
# to the instance via S3 by executing the remote SSM command

# Syntax:
# sm-connect-ssh-proxy <instance_id> <s3_ssh_authorized_keys_path> <extra_ssh_args>

set -e

INSTANCE_ID="$1"
SSH_AUTHORIZED_KEYS_PATH="$2"
shift
shift
PORT_FWD_ARGS=$*

echo "$(date -Iseconds) sm-connect-ssh-proxy: Connecting to: $INSTANCE_ID"
echo "$(date -Iseconds) sm-connect-ssh-proxy: Extra args: $PORT_FWD_ARGS"

instance_status=$(aws ssm describe-instance-information --filters Key=InstanceIds,Values="$INSTANCE_ID" --query 'InstanceInformationList[0].PingStatus' --output text)

echo "Instance status: $instance_status"

if [[ "$instance_status" != "Online" ]]; then
  echo "Error: Instance is offline."
  exit 1
fi

# TODO: make it possible to override the default (also helps avoid race conditions)
SSH_KEY=~/.ssh/sagemaker-ssh-gw

echo "Generating $SSH_KEY keypair with ECDSA and uploading public key to $SSH_AUTHORIZED_KEYS_PATH"

echo 'yes' | ssh-keygen -t ecdsa -q -f "${SSH_KEY}" -N '' >/dev/null
aws s3 cp "${SSH_KEY}.pub" "${SSH_AUTHORIZED_KEYS_PATH}"

CURRENT_REGION=$(aws configure list | grep region | awk '{print $2}')
echo "Will use AWS Region: $CURRENT_REGION"

AWS_CLI_VERSION=$(aws --version)
echo "AWS CLI version (should be v2): $AWS_CLI_VERSION"

# TODO: consider moving to start-session with AWS-StartNonInteractiveCommand

echo "Running SSM commands at region ${CURRENT_REGION} to copy public key to ${INSTANCE_ID}"
send_command=$(aws ssm send-command \
    --region "${CURRENT_REGION}" \
    --instance-ids "${INSTANCE_ID}" \
    --document-name "AWS-RunShellScript" \
    --comment "Copy public key for SSH helper" \
    --timeout-seconds 30 \
    --parameters "commands=[
        'mkdir -p /etc/ssh/authorized_keys.d/',
        'aws s3 cp --recursive \"${SSH_AUTHORIZED_KEYS_PATH}\" /etc/ssh/authorized_keys.d/',
        'ls -la /etc/ssh/authorized_keys.d/',
        'cat /etc/ssh/authorized_keys.d/* > /etc/ssh/authorized_keys',
        'ls -la /etc/ssh/authorized_keys'
      ]" \
    --no-cli-pager --no-paginate \
    --output json)

json_value_regexp='s/^[^"]*".*": \"\(.*\)\"[^"]*/\1/'

send_command=$(echo "$send_command" | python -m json.tool)
command_id=$(echo "$send_command" | grep "CommandId" | sed -e "$json_value_regexp")
echo "Got command ID: $command_id"

# Wait a little bit to prevent strange InvocationDoesNotExist error
sleep 5

for i in $(seq 1 15); do
  # Switch to unicode for AWS CLI to properly parse output
  export LC_CTYPE=en_US.UTF-8
  command_output=$(aws ssm get-command-invocation \
      --instance-id "${INSTANCE_ID}" \
      --command-id "${command_id}" \
      --no-cli-pager --no-paginate \
      --output json)
  command_output=$(echo "$command_output" | python -m json.tool)
  command_status=$(echo "$command_output" | grep '"Status":' | sed -e "$json_value_regexp")
  output_content=$(echo "$command_output" | grep '"StandardOutputContent":' | sed -e "$json_value_regexp")
  error_content=$(echo "$command_output" | grep '"StandardErrorContent":' | sed -e "$json_value_regexp")

  echo "Command status: $command_status"
  if [[ "$command_status" != "Pending" && "$command_status" != "InProgress" ]]; then
    echo "Command output: $output_content"
    if [[ "$error_content" != "" ]]; then
      echo "Command error: $error_content"
    fi
    break
  fi
  sleep 1
done

if [[ "$command_status" != "Success" ]]; then
  echo "Error: Command didn't finish successfully in time"
  exit 2
fi

echo "$(date -Iseconds) sm-connect-ssh-proxy: Starting SSH over SSM proxy"

# We don't use AWS-StartPortForwardingSession feature of SSM here, because we need port forwarding in both directions
#  with -L and -R parameters of SSH. This is useful for forwarding the PyCharm license server, which needs -R option.
#  SSM allows only forwarding of ports from the server (equivalent to the -L option).
proxy_command="aws ssm start-session\
 --reason 'Local user started SageMaker SSH Helper'\
 --region '${CURRENT_REGION}'\
 --target '${INSTANCE_ID}'\
 --document-name AWS-StartSSHSession\
 --parameters portNumber=%p"

# shellcheck disable=SC2086
ssh -4 -o User=root -o IdentityFile="${SSH_KEY}" -o IdentitiesOnly=yes \
  -o ProxyCommand="$proxy_command" \
  -o ServerAliveInterval=15 -o ServerAliveCountMax=3 \
  -o PasswordAuthentication=no \
  -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \
  $PORT_FWD_ARGS "$INSTANCE_ID"
