#!/usr/bin/env python
from __future__ import print_function
import argparse
import getpass
import logging
import netifaces
import os
import serial
import serial.tools.list_ports
import sys
import time
from collections import defaultdict
from cloudmanager_micropython_esp8266 import FIRMWARE_FILE, determine_default_serial_port


FLASH_SIZE = '32m'
FLASH_MODE_HUZZAH = """Put your device into flash mode, using the following steps
   * Press and hold the flash button on the top of the board
   * Press and release the reset button on the top of the board
   * Release the flash button on the top of the board

On completion of these steps the red led on the board should be lit.
"""
BAUD_RATES = ['9600', '115200', '460800', '921600']

logger = logging.getLogger('flash_esp_image')


class NoResponse(Exception):
    pass


class FlashFailed(Exception):
    pass


def get_service_addresses():
    listen_addresses = []
    for interface in netifaces.interfaces():
        if interface in ['docker0']:
            continue
        addresses = netifaces.ifaddresses(interface).get(netifaces.AF_INET, [])
        # Netifaces is returning funny values for ipv6 addresses, disabling for now
        # addresses += netifaces.ifaddresses(interface).get(netifaces.AF_INET6, [])
        if not addresses:
            continue
        for address in addresses:
            if 'peer' in address.keys():
                continue
            if address['addr'] in ['::1', '127.0.0.1']:
                continue
            listen_addresses.append(address['addr'])
    return listen_addresses


def send_settings_to_repl(setting_dict, baudrate=115200):
    ser = serial.Serial(setting_dict['port'])
    ser.baudrate = int(baudrate)
    ser.bytesize = serial.EIGHTBITS  # number of bits per bytes
    ser.parity = serial.PARITY_NONE  # set parity check: no parity
    ser.stopbits = serial.STOPBITS_ONE  # number of stop bits
    ser.timeout = 1  # non-block read
    ser.xonxoff = False  # disable software flow control
    ser.rtscts = False  # disable hardware (RTS/CTS) flow control
    ser.dsrdtr = False  # disable hardware (DSR/DTR) flow control
    ser.writeTimeout = 2  # timeout for write

    # line = ser.readlines()
    ser.write('\3\4\r')

    time.sleep(.4)
    line = ser.readline()
    logger.debug(line, end='')
    count = 0
    ser.write('\r')
    while b'>>>' not in line:
        if b'Fatal exception' in line:
            raise FlashFailed('Flashing of board failed')
        # if count % 10:
        #     ser.write('\r')
        # time.sleep(.1)
        line = ser.readline()
        print(line, end='')
        count += 1
        if count > 60:
            raise NoResponse('Board configuraton failed')

    line = ser.readline()
    while b'>>>' in line:
        line = ser.readline()
    print(line, end='')

    if not setting_dict['wifi_ssid']:
        return

    ser.write('import os\r')
    print(ser.readline(), end='')
    ser.write("os.mkdir('etc')\r")
    print(ser.readline(), end='')
    ser.write('from bootconfig.config import set\r')
    print(ser.readline(), end='')
    if setting_dict.get('wifi_ssid', None):
        ser.write("set('{0}', '{1}')\r".format('wifi_ssid', setting_dict['wifi_ssid']))
        print(ser.readline(), end='')
    if setting_dict.get('wifi_password', None):
        ser.write("set('{0}', '{1}')\r".format('wifi_password', setting_dict['wifi_password']))
        print(ser.readline(), end='')
    if setting_dict.get('cloudmanager_server', None):
        ser.write("set('{0}', '{1}')\r".format('redis_server', setting_dict['cloudmanager_server']))
        print(ser.readline(), end='')
    if setting_dict.get('cloudmanager_port', None):
        ser.write("set('{0}', '{1}')\r".format('redis_port', setting_dict['cloudmanager_port']))
        print(ser.readline(), end='')
    if setting_dict.get('name', None):
        ser.write("set('{0}', '{1}')\r".format('name', setting_dict['name']))
        print(ser.readline(), end='')
    if 'bootconfig' in setting_dict['services']:
        ser.write('import bootconfig.service\r')
        print(ser.readline(), end='')
        ser.write('bootconfig.service.autostart()\r')
        print(ser.readline(), end='')
    if 'cloudclient' in setting_dict['services']:
        ser.write('import redis_cloudclient.service\r')
        print(ser.readline(), end='')
        ser.write('redis_cloudclient.service.autostart()\r')
        print(ser.readline(), end='')
    if 'webrepl' in setting_dict['services']:
        ser.write('import webrepl_setup\r')
        print(ser.readline(), end='')
        ser.write('e\r')
        print(ser.readline(), end='')
        time.sleep(.5)
        print(ser.readline(), end='')
        ser.write(args.webrepl_password + '\r')
        print(ser.readline(), end='')
        time.sleep(.5)
        print(ser.readline(), end='')
        ser.write(args.webrepl_password + '\r')
        print(ser.readline(), end='')
        time.sleep(.5)
        print(ser.readline(), end='')
        ser.write('n\r')
        time.sleep(.5)
        print(ser.readline(), end='')
    if setting_dict['wifi_ssid'] and setting_dict['wifi_password']:
        ser.write('from bootconfig.device import configure_device\r')
        print(ser.readline(), end='')
        ser.write('configure_device()\r')
        print(ser.readline(), end='')

    time.sleep(1)
    print('Resetting board')
    ser.write('\3\4\r')
    print(ser.readline(), end='')
    # ser.write('import machine\r')
    # print(ser.readline(), end='')
    # ser.write('machine.reset()\r')
    # line = ser.readline()
    # ser.write("print('Resetting')\r")
    line = ser.readline()
    print(line, end='')
    count = 0
    while 'Cloud' not in line and 'WebREPL' not in line and 'MicroPython' not in line and '>>>' not in line:
        time.sleep(.1)
        line = ser.readline()
        count += 1
        if count > 50:
            break
    print(ser.read_all(), end='')


def header(message):
    print('*' * 80)
    print(message)
    print('*' * 80)


def prompt(message):
    if sys.version_info.major == 2:
        return raw_input(message)
    return input(message)


def erase_flash():
    command = 'esptool.py --port {port} --baud 9600 erase_flash'.format(**vars(args))
    logger.debug(command)
    result = os.popen(command).read()
    # os.system(command)
    # time.sleep(1)
    logger.debug(result)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--port', default=determine_default_serial_port(), help='Serial port the esp8266 device is connected to')
    parser.add_argument('--baud', default='115200', choices=BAUD_RATES, help='Serial baud rate, choices: 9600, 115200 460800, 921600')
    parser.add_argument('--board', default='nodemcu', choices=['nodemcu', 'wemosd1', 'huzzah'], help='Type of board being flashed')
    parser.add_argument('--flash_mode', default='qio', choices=['qio', 'dio'], help='Flash mode')
    parser.add_argument('--skip_erase', dest='erase', default=True, action='store_false', help='Skip the flash erase step')
    parser.add_argument('--firmware_file', default=FIRMWARE_FILE, help='Firmware image file to flash')
    settings = parser.add_argument_group('Settings')
    settings.add_argument('--wifi_ssid', default=None, help='Connect to a specific wifi network')
    settings.add_argument('--wifi_password', default=None, help='WPA passphrase for wifi')
    settings.add_argument('--cloudmanager_server', default=None, help='Cloudmanager server to connect to')
    settings.add_argument('--cloudmanager_port', default=None, help='Cloudmanager port number')
    settings.add_argument('--name', default=None, help='The name of the board')
    settings.add_argument(
        '--enable_service', dest='services', nargs='?', action='append',
        choices=['bootconfig', 'cloudclient', 'webrepl'],
        help='Services to enable, default=[cloudclient]'
    )
    settings.add_argument(
        '--disable_service', dest='disable_services', nargs='?', action='append',
        choices=['bootconfig', 'cloudclient', 'webrepl'],
        help='Services to disable'
    )
    args = parser.parse_args()

    args.flash_size = 'detect'

    if args.board in ['huzzah']:
        print(FLASH_MODE_HUZZAH)
        input('Press enter to continue')

    if not args.services:
        args.services = []
        if args.wifi_ssid:
            args.services.append('cloudclient')

    if args.disable_services:
        for service in args.disable_services:
            args.services.remove(service)

    if 'webrepl' in args.services:
        args.webrepl_password = getpass.getpass('WebRepl Password: ')

    if 'cloudclient' in args.services:
        if args.cloudmanager_server is None:
            args.cloudmanager_server = get_service_addresses()[0]
        if args.cloudmanager_port is None:
            args.cloudmanager_port = 18266

    if args.erase:
        header('Erasing flash')
        erase_flash()
        print('Done\n')

    header('Flashing board')
    command = 'esptool.py --port {port} --baud {baud} write_flash --verify --flash_size={flash_size} --flash_mode={flash_mode} 0 {firmware_file}'.format(**vars(args))
    logger.debug(command)
    os.system(command)

    arguments = defaultdict(lambda: None)
    arguments.update(vars(args))
    attempt = 1
    for baudrate in [args.baud, args.baud, '115200', '115200', '9600', '9600'] + BAUD_RATES:
        try:
            header('Sending settings at %s baud' % baudrate)
            send_settings_to_repl(arguments, baudrate=baudrate)
            sys.exit(0)
        except NoResponse:
            # prompt('### Press the reset button on your board and press [enter] ###')
            continue
        except FlashFailed:
            attempt += 1
            print('\n### Flashing failed with flash mode set to %s' % arguments['flash_mode'], end='')
            arguments['flash_mode'] = 'qio' if arguments['flash_mode'] == 'dio' else 'dio'
            print(', retrying with flash mode set to %s ###\n' % arguments['flash_mode'])
            header('Flashing board, attempt #%d' % attempt)
            command = 'esptool.py --port {port} --baud {baud} write_flash --verify --flash_size={flash_size} --flash_mode={flash_mode} 0 {firmware_file}'.format(**arguments)
            logger.debug(command)
            os.system(command)
        print('')
