#!/usr/bin/env python3
import pymysql
from prettytable import PrettyTable


def connect_db(db):
    return pymysql.connect(host=db['host'],
                           port=db['port'],
                           user=db['user'],
                           password=db['password'],
                           database=db['database'])


def getTableList(conn):
    cursor = conn.cursor()
    cursor.execute("SHOW TABLES;")
    return [t[0] for t in cursor]
    cursor.close()


def getDatabaseSchema(conn):
    result = {}
    tableList = getTableList(conn)
    for t in tableList:
        result[t] = getTableSchema(t, conn)
    return result


def getTableSchema(table_name, conn):
    cursor = conn.cursor()
    cursor.execute('desc ' + table_name)
    columns = cursor.fetchall()

    result = {}
    for column in columns:
        (col_name, col_type, col_nullable, col_key, col_default, extra) = column
        result[col_name] = column
    return result


class TableDiff:
    def __init__(self, tableName, table1, table2):
        self.tableName = tableName
        self.table1 = table1
        self.table2 = table2

    def get_defalut_clause(self, default_value):
        if (default_value is None):
            return ""
        if (default_value is ""):
            return "DEFAULT ''"
        return "DEFAULT " + default_value

    def column_info(self, col):
        if col is None:
            return "None"
        (col_name, col_type, col_nullable,
         col_key, col_default, extra) = col
        nullable_clause = " " if col_nullable == "YES" else " NOT NULL "
        return '{} {} {} {}'.format(col_type, nullable_clause, self.get_defalut_clause(col_default), extra)

    def printDiff(self):
        diffTable = PrettyTable(["col_name", "d1", "d2"])
        col_names = sorted(
            set(list(self.table1.keys()) + list(self.table2.keys())))
        diffcount = 0
        for column_name in col_names:
            col1 = self.column_info(self.table1.get(column_name))
            col2 = self.column_info(self.table2.get(column_name))
            if (col1 != col2):
                diffcount += 1
                [col1, col2] = [highlight(col1), highlight(col2)]
                diffTable.add_row(
                    [column_name, col1, col2])
        print('========================= {} has {} different columns ==========================='.format(
            self.tableName, diffcount))
        if (diffcount > 0):
            print(diffTable)


def highlight(s):
    return '\033[31m' + s + '\033[0m'


def diffSchema(schema1, schema2):
    keys = sorted(set(list(schema1.keys()) + list(schema2.keys())))
    diffTable = PrettyTable(["table_name", "d1", "d2"])
    for table_name in keys:
        if (schema1.get(table_name) is None) or (schema2.get(table_name) is None):
            diffTable.add_row([table_name, (schema1.get(table_name) is None),
                               (schema2.get(table_name) is None)])
        else:
            diff = TableDiff(table_name, schema1.get(
                table_name), schema2.get(table_name))
            diff.printDiff()

    print(diffTable)


db1 = {
    'host': 'mysql-prod.blacklake.tech',
    'port': 3306,
    'database': 'manufacture',
    'user': 'manufacture_user',
    'password': 'sd__wnf2JIUEHNsfwF3_s2f'
}

db2 = {
    'host': '10.1.30.10',
    'port': 3306,
    'user': 'manufacture_dev_user',
    'password': 'sdfiws_Sdf2sdfj8fsk',
    'database': 'manufacture_dev'
}


conn1 = connect_db(db1)
conn2 = connect_db(db2)

schema1 = getDatabaseSchema(conn1)
schema2 = getDatabaseSchema(conn2)

diffSchema(schema1, schema2)
