#!/usr/bin/env python

## IMPORTS ##
import curses
import time
import sys
import os
from sqlalchemy import *

import sqlcrush
from sqlcrush import database
from sqlcrush import user_input

## GLOBALS ##
x = 1
last_x = 0
term_size_change = False
option_window_open = False

## FUNCTIONS ##
#initialize the curses window and return scr
def init_scr():

    scr = curses.initscr()

    curses.noecho()
    curses.cbreak()
    curses.curs_set(0)
    curses.halfdelay(1)
    scr.keypad(True)
    scr.clear()

    return scr

#user scr to terminate the window and revert back to terminal
def term_scr(scr):

    curses.nocbreak()
    scr.keypad(False)
    curses.echo()
    curses.endwin()

#returns the number of columns or rows
def get_scr_dim(scr):
    return scr.getmaxyx()

#returns True if there has been a change in the window size, otherwise False
def check_term_size_change(scr, scr_dim):

    change = False

    if scr_dim != scr.getmaxyx():
        change = True

    return change

def open_top_bar(scr_dim):

    scr_top = curses.newwin(4, scr_dim[1], 0, 0)

    return scr_top

def open_front_main(scr_dim):

    scr_front_main = curses.newwin(scr_dim[0]-4, scr_dim[1], 4, 0)

    if scr_dim[1] > 64:
        scr_front_main.addstr(2, 2, "HELP:")
        scr_front_main.addstr(3, 2, "Make sure to open the database from the command line")
        scr_front_main.addstr(4, 2, "e.g. sqlcrush -t sqlite -d demo.sqlite3")
        scr_front_main.addstr(5, 2, "e.g. sqlcrush -t postgres -d demo -u johnsmith -h localhost")
        scr_front_main.addstr(5, 2, "e.g. sqlcrush -t mysql -d dev -u john -p pass -h devc -s /tmp/mysql.sock")
        scr_front_main.addstr(7, 2, "[h] Toggle help window")
        scr_front_main.addstr(8, 2, "[Arrows] Move around database")
        scr_front_main.addstr(9, 2, "[<>] Move between headers for each table")
        scr_front_main.addstr(10, 2, "[delete] Move back to table select")
        scr_front_main.addstr(11, 2, "[f] Find in database")
        scr_front_main.addstr(12, 2, "[n] New database entry or new execution")
        scr_front_main.addstr(13, 2, "[u] Update database entry")
        scr_front_main.addstr(14, 2, "[d] Delete database entry, cell or execution")
        scr_front_main.addstr(15, 2, "[k] Query mode")
        scr_front_main.addstr(16, 2, "[q] Quit")
        scr_front_main.addstr(17, 2, "[s] Save database")

        if scr_dim[0] > 25:
            scr_front_main.addstr(19, 2, "Thank you for using SQLcrush. This was made to allow SQL")
            scr_front_main.addstr(20, 2, "database manipulation to be done right in the console.")

    else:
        scr_front_main.addstr(2, 2, "HELP:")
        scr_front_main.addstr(3, 2, "Open from shell")
        scr_front_main.addstr(4, 2, "e.g. sqlcrush demo.sqlite3")
        scr_front_main.addstr(6, 2, "[h] Toggle help window")
        scr_front_main.addstr(7, 2, "[Arrows] Move around database")
        scr_front_main.addstr(8, 2, "[<>] Move between headers for each table")
        scr_front_main.addstr(9, 2, "[delete] Move back to table select")
        scr_front_main.addstr(10, 2, "[f] Find in database")
        scr_front_main.addstr(11, 2, "[n] New database entry or new execution")
        scr_front_main.addstr(12, 2, "[u] Update database entry")
        scr_front_main.addstr(13, 2, "[d] Delete database entry, cell or execution")
        scr_front_main.addstr(14, 2, "[k] Query mode")
        scr_front_main.addstr(15, 2, "[q] Quit")
        scr_front_main.addstr(16, 2, "[s] Save database")

    scr_front_main.border(0)

    return scr_front_main

def open_show_left(scr_dim):

    scr_show_left = curses.newwin(scr_dim[0]-4-3, 16, 4, 0)

    scr_show_left.border(0)

    return scr_show_left

def open_show_main(scr_dim):

    scr_show_main = curses.newwin(scr_dim[0]-4-3, scr_dim[1]-16, 4, 16)

    scr_show_main.border(0)

    return scr_show_main

def open_query_main(scr_dim):

    scr_query_main = curses.newwin(scr_dim[0]-4-3, scr_dim[1], 4, 0)

    scr_query_main.border(0)

    return scr_query_main

def open_bottom_bar(scr_dim):

    scr_bottom = curses.newwin(3, scr_dim[1], scr_dim[0]-3, 0)

    return scr_bottom

# refreshes each of the main windows
def refresh_windows(current_screen, scr_top, scr_front_main, scr_show_left, scr_show_main, scr_bottom, scr_query_main):

    if current_screen == 1:
        scr_top.refresh()
        scr_front_main.refresh()
    elif current_screen == 2:
        scr_top.refresh()
        scr_show_left.refresh()
        scr_show_main.refresh()
        scr_bottom.refresh()
    elif current_screen == 3:
        scr_top.refresh()
        scr_query_main.refresh()
        scr_bottom.refresh()
    else:
        scr_top.refresh()

# sets up .sqlcrush folder
def create_environment():

    root_path = os.path.expanduser("~")

    if not os.path.exists(root_path + "/.sqlcrush"):
        try:
            os.makedirs(root_path + "/.sqlcrush", exist_ok=True)
        except:
            os.makedirs(root_path + "/.sqlcrush")
        try:
            os.system("cp " + os.path.dirname(sqlcrush.__file__) + "/saved_databases " + root_path + "/.sqlcrush/saved_databases")
        except:
            pass

    if not os.path.isfile(root_path + "/.sqlcrush/saved_databases"):
        try:
            os.system("cp " + os.getcwd() + "/sqlcrush/saved_databases " + root_path + "/.sqlcrush/saved_databases")
        except:
            f = open(root_path + "/.sqlcrush/saved_databases", "a")
            auto_content = "# SQLcrush saved databases\n# Edit manually or save opened database via the app\n# Format:\n# db_short_name database_type:///username:password@host/dbname current_working_directory (for sqlite3)\n# e.g. dev postgresql://johnsmith:test123@localhost/dev_db\n# e.g. devtest sqlite:///devtest.db /home/johnsmith/dbfiles/test1"
            f.write(auto_content)
            f.close()

## WORKFLOW ##
create_environment()
database.print_intro()

scr = init_scr()
scr_dim = get_scr_dim(scr)

# declarations
cursor_main = [0, 0, 0, 0]
cursor_sub = [0, 0, 0, 0]
open_window = 0
header_list = ["Struct", "Browse", "Execute"]
help_screen = 0
find_list = []
shown_tables = []
table_executions = {}
dbname = 0
user = 0
password = 0
host = 0
port = 0
socket = 0
database_type = 0
saved_database = 0
current_real_database = 0
database_dir = 0
query_state = 0
show_query = 0
updatedData = 1                 # if 1, gets new database information

# Based on what the user inputs
if len(sys.argv) == 1:
    current_database = 0
    current_screen = 1
elif len(sys.argv) >= 2:
    arguments = len(sys.argv)
    for n in range(arguments):
        if sys.argv[n] == "-d":
            dbname = sys.argv[n+1]
            current_real_database = dbname
            current_database = dbname
        if sys.argv[n] == "-u":
            user = sys.argv[n+1]
        if sys.argv[n] == "-pd":
            password = sys.argv[n+1]
        if sys.argv[n] == "-h":
            host = sys.argv[n+1]
        if sys.argv[n] == "-p":
            port = sys.argv[n+1]
        if sys.argv[n] == "-o":
            saved_database = sys.argv[n+1]
        if sys.argv[n] == "-s":
            socket = sys.argv[n+1]
            if socket[0] != "/":
                socket = "/" + socket
        if sys.argv[n] == "-t":
            database_type = sys.argv[n+1]
            database_type = database_type.lower()
            if database_type == "sqlite3":
                database_type = "sqlite"
            if database_type == "postgres":
                database_type = "postgresql"
            if database_type == "mariadb" or database_type == "maria":
                database_type = "mysql"
else:
    current_database = 0
    current_screen = 1

# attempts to open the database
n = 0
while n < 3:
    try:
        open_database = database.connect_database(n, current_real_database, dbname, user, host, password, port, socket, database_type, saved_database)
        if open_database == 0:
            current_screen = 1
            current_database = 0
            n = 3
        else:
            current_screen = 2
            n = 3
            current_database = dbname
            if saved_database != 0:
                root_path = os.path.expanduser("~")
                f = open(root_path + "/.sqlcrush/saved_databases", "r")
            
                saved_dbs = f.readlines()
                f.close()
                length_save_name = len(saved_database)

                # if the user can open it, sets up the relevant info to be used
                for line in saved_dbs:
                    if saved_database == line.split(" ")[0]:
                        sql_connect = line.split(" ")[1]

                database_type = sql_connect.split("//")[0][:-1]
                current_database = sql_connect.split("/")[-1]
                current_real_database = current_database
                if len(sql_connect.split(" ") == 3):
                    database_dir = sql_connect.split(" ")[2]
                else:
                    database_dir = 0
    except:
        # tries to open database 3 times
        n = n + 1
        if n == 3:
            current_screen = 1
            current_database = 0
#main loop
while x != ord("q"):        # quit on [q]
    try:
        # check to see if there is a change in terminal size
        term_size_change = check_term_size_change(scr, scr_dim)
        if term_size_change == True:
            term_scr(scr)
            scr = init_scr()
            scr_dim = get_scr_dim(scr)
            term_size_change == False
        scr_dim = get_scr_dim(scr)

        scr_top = open_top_bar(scr_dim)
        scr_front_main = open_front_main(scr_dim)
        scr_show_left = open_show_left(scr_dim)
        scr_show_main = open_show_main(scr_dim)
        scr_bottom = open_bottom_bar(scr_dim)
        scr_query_main = open_query_main(scr_dim)

        scr.refresh()

        scr_top.addstr(0, 0, "SQLcrush v0.1.4 - by coffeeandscripts")

        if current_database == 0 or current_database == "0":
            scr_top.addstr(1, 0, "No open database")
        else:
            scr_top.addstr(1, 0, str(database_type) + " - " + str(current_real_database))
            
            # for each of the tables, adds up the number of executions
            execution_length = 0
            for table in table_executions:
                for execution in table_executions[table]:
                    execution_length = execution_length + 1

            scr_top.addstr(2, 0, str(execution_length) + " changes made")

            if current_database != 0 and current_database != 1:

                # tries to open the database again 3 times max
                n = 0
                if updatedData == 1:
                    while n < 3:
                        try:
                            open_database = database.connect_database(n, current_real_database, dbname, user, host, password, port, socket, database_type, saved_database)
                            n = 3
                            updatedData = 0
                        except:
                            n = n + 1
                            if n == 3:
                                #if there is a failure, it exits quickly after closing ncurses
                                term_scr(scr)
                                print("Critical failure")
                                time.sleep(1)
                                sys.exit()

                # use [enter] to switch windows
                if x == 9:          # [enter]
                    if open_window == 0:
                        open_window = 1
                    else:
                        open_window = 0
                scr_show_left.addstr(0, 2, str(current_real_database)[0:12], curses.A_REVERSE)

                # when window 1 is open, lists the headers
                n = 0
                for header in header_list:
                    if open_window == 1:
                        if cursor_main[2] == n:
                            scr_show_main.addstr(0, 2+n*8, str(header), curses.A_REVERSE)
                        else:
                            scr_show_main.addstr(0, 2+n*8, str(header))
                    else:
                        scr_show_main.addstr(0, 2+n*8, str(header))
                    n = n + 1
                
                # get the table MetaData from SQLalchemy
                all_tables = MetaData()
                all_tables.reflect(open_database)

                # lists the tables
                n = 2
                p = 0
                shown_tables = []
                for table in all_tables.tables.values():
                    shown_tables.append(table.name)
                    table_print = table.name
                    if table_print not in table_executions:
                        table_executions[table_print] = []
                    if n - 2 >= scr_dim[0] - 10:
                        continue
                    if cursor_main[1] >= p + 1:
                        p = p + 1
                        continue
                    if len(str(table_print)) >= 15:
                        if cursor_main[0] + cursor_main[1] == p + 1:
                            scr_show_left.addstr(n, 1, table_print[0:11] + "...", curses.A_REVERSE)
                        else:
                            scr_show_left.addstr(n, 1, table_print[0:11] + "...")
                    else:
                        if cursor_main[0] + cursor_main[1] == p + 1:
                            scr_show_left.addstr(n, 1, table_print, curses.A_REVERSE)
                        else:
                            scr_show_left.addstr(n, 1, table_print)
                    n = n + 1
                    p = p + 1

                # lists the columns for the selected table
                if open_window == 1:
                    # set the open table based on cursor_main
                    if show_query != 1:
                        table_inspector = inspect(open_database)
                        open_table = table_inspector.get_columns(shown_tables[cursor_main[0] + cursor_main[1] - 1])
                        n = 0
                        pk = 0
                        columns = []
                        for c in open_table:
                            column = []
                            column.append(n)
                            try:
                                column.append(c['name'])
                            except:
                                column.append("-")
                            try:
                                column.append(c['type'])
                            except:
                                column.append(0)
                            try:
                                column.append(c['nullable'])
                            except:
                                column.append(0)
                            column.append(c['default'])
                            # sometimes primary_key doesn't come up and so the first column will be assigned as pk
                            try:
                                column.append(c['primary_key'])
                            except:
                                if pk == 0:
                                    column.append(1)
                                    pk = 1
                                else:
                                    column.append(0)
                            # currently not used but could be relevant in future
                            try:
                                column.append(c['foreign_key'])
                            except:
                                column.append(0)
                            columns.append(column)
                            n = n + 1

                        # organise the presentation for the headers of the columns
                        if cursor_main[2] == 0:
                            if scr_dim[1] > 12 + 6:
                                scr_show_main.addstr(1, 2, "ID:")
                            if scr_dim[1] > 12 + 30:
                                scr_show_main.addstr(1, 6, "Name:")
                            if scr_dim[1] > 12 + 42:
                                scr_show_main.addstr(1, 30, "Type:")
                            if scr_dim[1] > 12 + 52:
                                scr_show_main.addstr(1, 42, "NotNull:")
                            if scr_dim[1] > 12 + 64:
                                scr_show_main.addstr(1, 52, "Default:")
                            n = 0
                            p = 0

                            # printing of the columns based on scr_dim and cursor_sub
                            for column in columns:
                                if n >= scr_dim[0] - 10:
                                    continue
                                if cursor_sub[1] >= p + 1:
                                    p = p + 1
                                    continue
                                # sets the printing length based on width of columns
                                id_print = str(column[0]) + " "
                                name_print = str(column[1]) + " "
                                if len(name_print) >= 24:
                                    name_print = name_print[0:20] + ".."

                                type_print = str(column[2]) + " "
                                if len(type_print) >= 12:
                                    type_print = type_print[0:8] + ".."

                                notnull_print = str(column[3]) + " "
                                if len(notnull_print) >= 10:
                                    notnull_print = notnull_print[0:6] + ".."

                                default_print = str(column[4]) + " "
                                if len(default_print) >= 11:
                                    default_print = default_print[0:7] + ".."

                                while len(id_print) < 3:
                                    id_print = " " + id_print
                                while len(name_print) < 24:
                                    name_print = name_print + " "
                                while len(type_print) < 12:
                                    type_print = type_print + " "
                                while len(notnull_print) < 10:
                                    notnull_print = notnull_print + " "
                                while len(default_print) < 12:
                                    default_print = default_print + " "

                                if cursor_sub[0] + cursor_sub[1] == p + 1:
                                    if scr_dim[1] > 12 + 6:
                                        scr_show_main.addstr(2+n, 3, id_print, curses.A_REVERSE)
                                    if scr_dim[1] > 12 + 30:
                                        scr_show_main.addstr(2+n, 6, name_print, curses.A_REVERSE)
                                    if scr_dim[1] > 12 + 42:
                                        scr_show_main.addstr(2+n, 30, type_print, curses.A_REVERSE)
                                    if scr_dim[1] > 12 + 52:
                                        scr_show_main.addstr(2+n, 42, notnull_print, curses.A_REVERSE)
                                    if scr_dim[1] > 12 + 64:
                                        scr_show_main.addstr(2+n, 52, default_print, curses.A_REVERSE)
                                else:
                                    if column[5] == 1 or column[5] == True:
                                        if scr_dim[1] > 12 + 6:
                                            scr_show_main.addstr(2+n, 3, id_print, curses.A_BLINK)
                                        if scr_dim[1] > 12 + 30:
                                            scr_show_main.addstr(2+n, 6, name_print, curses.A_BLINK)
                                        if scr_dim[1] > 12 + 42:
                                            scr_show_main.addstr(2+n, 30, type_print, curses.A_BLINK)
                                        if scr_dim[1] > 12 + 52:
                                            scr_show_main.addstr(2+n, 42, notnull_print, curses.A_BLINK)
                                        if scr_dim[1] > 12 + 64:
                                            scr_show_main.addstr(2+n, 52, default_print, curses.A_BLINK)
                                    else:
                                        if scr_dim[1] > 12 + 6:
                                            scr_show_main.addstr(2+n, 3, id_print)
                                        if scr_dim[1] > 12 + 30:
                                            scr_show_main.addstr(2+n, 6, name_print)
                                        if scr_dim[1] > 12 + 42:
                                            scr_show_main.addstr(2+n, 30, type_print)
                                        if scr_dim[1] > 12 + 52:
                                            scr_show_main.addstr(2+n, 42, notnull_print)
                                        if scr_dim[1] > 12 + 64:
                                            scr_show_main.addstr(2+n, 52, default_print)
                                n = n + 1
                                p = p + 1

                    # printing of the table in the browser
                    if cursor_main[2] == 1:
                        # get a large dictionary where key is the selected table and list of lists
                        if show_query != 1:
                            current_table = database.get_table(shown_tables[cursor_main[0] + cursor_main[1] - 1], open_database, database_dir)
                        else:
                            with open_database.connect() as conn:
                                current_query_table = conn.execute(new_user_query).fetchall()
                                columns_temp = conn.execute(new_user_query).keys()
                                current_table = {"query":current_query_table}
                            columns = []
                            counter = 0
                            pk = 1
                            for column in columns_temp:
                                columns.append([counter, column, 0, 0, 0, pk, 0])
                                counter = counter + 1
                                pk = 0
                        n = 0
                        m = 0
                        # printing the column titles
                        for column in columns:
                            # checks cursor_sub for the horizontal position
                            if cursor_sub[3] >= m + 1:
                                m = m + 1
                                continue
                            if 2+12*n < scr_dim[1] - 16 - 12:
                                if len(str(column[1])) >= 12:
                                    short_printing = str(column[1])[0:9] + ".."
                                    if cursor_sub[0] + cursor_sub[1] == 1:
                                        if cursor_sub[2] == n+1 or cursor_sub[2] == 0:
                                            scr_show_main.addstr(1, 2+12*n, short_printing, curses.A_REVERSE)
                                        else:
                                            scr_show_main.addstr(1, 2+12*n, short_printing)
                                    else:
                                        scr_show_main.addstr(1, 2+12*n, short_printing)
                                else:
                                    full_printing = str(column[1])
                                    while len(full_printing) < 11:
                                        full_printing = " " + full_printing
                                    if cursor_sub[0] + cursor_sub[1] == 1:
                                        if cursor_sub[2] == n+1 or cursor_sub[2] == 0:
                                            scr_show_main.addstr(1, 2+12*n, full_printing, curses.A_REVERSE)
                                        else:
                                            scr_show_main.addstr(1, 2+12*n, full_printing)
                                    else:
                                        scr_show_main.addstr(1, 2+12*n, full_printing)
                                n = n + 1
                        n = 0
                        p = 0
                        find_counter = 0
                        entry_list = [0]
                        # printing the table itself
                        if show_query != 1:
                            table_name = shown_tables[cursor_main[0] + cursor_main[1] - 1]
                        elif show_query == 1:
                            table_name = "query"
                        for entry in current_table[table_name]:
                            m = 0
                            q = 0
                            entry_list.append(entry[m])
                            # check if a find query has been made and if not empty, skips entry if not in list
                            if find_list != []:
                                if int(p) not in find_list:
                                    p = p + 1
                                    continue
                                else:
                                    find_counter = find_counter + 1
                            # limits number shown based on scr_dim
                            if n >= scr_dim[0] - 10:
                                continue
                            # limits what is shown based on cursor_sub vertically
                            if cursor_sub[1] >= p + 1:
                                p = p + 1
                                continue
                            for column in columns:
                                # limits what is shown based on cursor_sub horizontally
                                if cursor_sub[3] >= q + 1:
                                    q = q + 1
                                    continue
                                if 2+12*m < scr_dim[1] - 16 - 12:
                                    if len(str(entry[q])) >= 12:
                                        short_printing = str(entry[q])[0:9].replace('\n', ' ') + ".."
                                        full_printing = str(entry[q]).replace('\n', ' ')
                                        if (cursor_sub[0] + cursor_sub[1] == p + 2 and find_counter == 0) or cursor_sub[0] + cursor_sub[1] == find_counter + 1:
                                            if cursor_sub[2] == m+1 or cursor_sub[2] == 0:
                                                scr_show_main.addstr(n+2, 2+12*m, short_printing, curses.A_REVERSE)
                                                if cursor_sub[2] != 0:
                                                    if len(full_printing) > scr_dim[1] - 10:
                                                        scr_bottom.addstr(0, 1, full_printing[0:scr_dim[1]-2] + "...")
                                                    else:
                                                        scr_bottom.addstr(0, 1, full_printing)
                                            else:
                                                scr_show_main.addstr(n+2, 2+12*m, short_printing)
                                        elif cursor_sub[0] + cursor_sub[1] == 1:
                                            if cursor_sub[2] == m+1:
                                                scr_show_main.addstr(n+2, 2+12*m, short_printing, curses.A_REVERSE)
                                            else:
                                                scr_show_main.addstr(n+2, 2+12*m, short_printing)
                                        else:
                                            scr_show_main.addstr(n+2, 2+12*m, str(short_printing))
                                    else:
                                        full_printing = str(entry[q]).replace('\n', ' ')
                                        while len(full_printing) < 11:
                                            full_printing = " " + full_printing
                                        if (cursor_sub[0] + cursor_sub[1] == p + 2 and find_counter == 0) or cursor_sub[0] + cursor_sub[1] == find_counter + 1:
                                            if cursor_sub[2] == m+1 or cursor_sub[2] == 0:
                                                scr_show_main.addstr(n+2, 2+12*m, full_printing, curses.A_REVERSE)
                                                if cursor_sub[2] != 0:
                                                    scr_bottom.addstr(0, 1, str(entry[q]).replace('\n', ' '))
                                            else:
                                                scr_show_main.addstr(n+2, 2+12*m, full_printing)
                                        elif cursor_sub[0] + cursor_sub[1] == 1:
                                            if cursor_sub[2] == m+1:
                                                scr_show_main.addstr(n+2, 2+12*m, full_printing, curses.A_REVERSE)
                                            else:
                                                scr_show_main.addstr(n+2, 2+12*m, full_printing)
                                        else:
                                            scr_show_main.addstr(n+2, 2+12*m, full_printing)
                                    m = m + 1
                                    q = q + 1
                            n = n + 1
                            p = p + 1
                        columns.append("Blank")     # added to make room for counting
                    # showing the executions on the specific table
                    elif cursor_main[2] == 2:
                        scr_show_main.addstr(1, 1, "New SQL execution:")
                        executions_list = table_executions[str(shown_tables[cursor_main[0] + cursor_main[1] - 1])]
                        n = 0
                        p = 0
                        # present the executions that have been made in reverse order
                        for execution in reversed(executions_list):
                            if len(execution) > scr_dim[1] - 18:
                                execution_print = execution[0:scr_dim[1]-21] + ".."
                            else:
                                execution_print = execution
                            if n >= scr_dim[0] - 10:
                                continue
                            if cursor_sub[1] >= p + 1:
                                p = p + 1
                                continue
                            if cursor_sub[0] + cursor_sub[1] == p + 1:
                                scr_show_main.addstr(n+2, 1, str(execution_print), curses.A_REVERSE)
                                scr_bottom.addstr(0, 1, str(execution)[0:scr_dim[1] - 2])
                            else:
                                scr_show_main.addstr(n+2, 1, str(execution_print))
                            n = n + 1
                            p = p + 1
                # when a table hasn't been selected
                elif open_window == 0:
                    scr_show_main.addstr(2, 3, "Press ENTER to view/edit")
        
        # sets the open_list that is passed to some functions
        if open_window == 0 and current_screen == 2:
            open_list = shown_tables
        elif open_window == 1 and current_screen == 2 and show_query != 1:
            open_list = columns
        else:
            pass


        # query editor ????? SET CURSOR_SUB TO 0000
        if current_screen == 3:
            if query_state == 0:
                scr_query_main.addstr(1, 1, "Favourite Queries: [enter] Execute [n] New", curses.A_REVERSE)
                fav_queries = database.favourite_queries()

                # lists the tables
                n = 2
                p = 0
                for query in fav_queries:
                    if n - 2 >= scr_dim[0] - 10:
                        continue
                    if cursor_main[1] >= p + 1:
                        p = p + 1
                        continue
                    if len(query) >= scr_dim[1]-2:
                        if cursor_main[0] + cursor_main[1] == p + 1:
                            scr_query_main.addstr(n, 1, query[0:scr_dim[1]-6] + "...", curses.A_REVERSE)
                        else:
                            scr_query_main.addstr(n, 1, query[0:scr_dim[1]-6] + "...")
                    else:
                        if cursor_main[0] + cursor_main[1] == p + 1:
                            scr_query_main.addstr(n, 1, query, curses.A_REVERSE)
                        else:
                            scr_query_main.addstr(n, 1, query)
                    n = n + 1
                    p = p + 1
            elif query_state == 1:
                new_user_query = database.new_user_query(scr_dim, scr_query_main, open_database)
                if new_user_query != "000":
                    try:
                        with open_database.connect() as conn:
                            current_query_table = conn.execute(new_user_query)
                        show_query = 1
                        query_state = 0
                        current_screen = 2
                        database.save_query(new_user_query)
                        cursor_main = [0, 0, 1, 0]
                        cursor_sub = [0, 0, 0, 0]
                        open_window = 1
                    except:
                        query_state = 0
                else:
                    query_state = 0

        # prints the help on the bottom of the screen based on scr_dim
        if scr_dim[1] > 90:
            scr_bottom.addstr(2, 2, "[h] Help [Arrows/<>] Move [delete] Back [f] Find [n] New [u] Update [d] Delete [q] Exit", curses.A_REVERSE)
        elif scr_dim[1] > 40:
            scr_bottom.addstr(2, 2, "[h] Help [Arrows] Move [k] Queries [q] Exit", curses.A_REVERSE)
        else:
            scr_bottom.addstr(2, 2, "[h] Help [q] Exit", curses.A_REVERSE)

        refresh_windows(current_screen, scr_top, scr_front_main, scr_show_left, scr_show_main, scr_bottom, scr_query_main)

        # get the last key
        x = scr.getch()

        # following only occure if the screen shown is table view
        if current_screen != 1:
            # enter
            if x == 10 and current_screen == 2:
                if open_window == 0 and cursor_main[0] > 0:
                    open_window = 1
                    cursor_sub = [0, 0, 0, 0]
            elif x == 10 and current_screen == 3:
                new_user_query = database.run_user_query(cursor_main, open_database)
                if new_user_query != "000":
                    try:
                        with open_database.connect() as conn:
                            current_query_table = conn.execute(new_user_query)
                        show_query = 1
                        query_state = 0
                        current_screen = 2
                        cursor_main = [0, 0, 1, 0]
                        cursor_sub = [0, 0, 0, 0]
                        open_window = 1
                    except:
                        query_state = 0
                else:
                    query_state = 0

            # backspace
            if x == 263 or x == 127 or x == 27:
                if open_window == 1:
                    open_window = 0
                    show_query = 0
            if x == 102 and show_query != 1:        # f
                if open_window == 1 and cursor_main[2] == 1:
                    if len(find_list) == 0:
                        find_list = database.find_database_entry(cursor_main, cursor_sub, columns, shown_tables, current_table, scr_bottom, scr_dim)
                        cursor_sub = [0, 0, 0, 0]
                    else:
                        find_list = []
                        cursor_sub = [0, 0, 0, 0]
            if x == 100 and show_query != 1 and current_screen == 2:        # d
                if open_window == 1 and cursor_main[2] == 1 and cursor_sub[1] + cursor_sub[0] > 1 and cursor_sub[2] == 0:
                    current_table = database.get_table(shown_tables[cursor_main[0] + cursor_main[1] - 1], open_database, database_dir)
                    table_executions = database.delete_database_entry(cursor_main, cursor_sub, columns, shown_tables, current_table, open_database, scr_bottom, table_executions, current_real_database, current_database, find_list)

                    if find_list != []:
                        find_list.remove(find_list[cursor_sub[0] + cursor_sub[1] - 2])

                    if cursor_sub[1] == 0 or cursor_sub[1] == 1:
                        cursor_sub[0] = cursor_sub[0] - 1
                    else:
                        cursor_sub[0] = cursor_sub[0] - 1
                        cursor_sub[1] = cursor_sub[1] - 1
                    if find_list != []:
                        n = 0
                        for find in find_list:
                            find_list[n] = find - 1
                            if find_list[n] == -1:
                                find_list.remove(find_list[n])
                            n = n + 1
                updatedData = 1
            elif x == 100 and current_screen == 3:
                database.delete_fav_query(cursor_main)
                cursor_main[0] = cursor_main[0] - 1

                if open_window == 1 and cursor_main[2] == 1 and cursor_sub[1] + cursor_sub[0] > 1 and cursor_sub[2] + cursor_sub[3] > 0:
                    table_executions = database.delete_database_cell(cursor_main, cursor_sub, columns, shown_tables, current_table, open_database, scr_bottom, table_executions, find_list)
                if open_window == 1 and cursor_main[2] == 2 and cursor_sub[1] + cursor_sub[0] > 0 and database_type == "SQLite3":
                    table_executions = database.delete_execution(cursor_main, cursor_sub, shown_tables, current_table, table_executions, current_real_database, current_database, open_database)
                    if cursor_sub[0] > 1:
                        cursor_sub[0] = cursor_sub[0] - 1
                    elif cursor_sub[1] > 0:
                        cursor_sub[1] = cursor_sub[1] - 1
                    else:
                        cursor_sub[0] = cursor_sub[0] - 1
                updatedData = 1
            if x == 117 and show_query != 1:        # u
                if open_window == 1 and cursor_main[2] == 1 and cursor_sub[1] + cursor_sub[0] > 1 and cursor_sub[2] + cursor_sub[3] > 0:
                    table_executions = database.update_database_cell(cursor_main, cursor_sub, columns, shown_tables, current_table, open_database, scr_bottom, scr_dim, table_executions, find_list)
                    updatedData = 1

            if x == 110 and show_query != 1:        # n
                if open_window == 1 and cursor_main[2] == 2:
                    scr_show_main.addstr(1, 1, "New SQL execution:", curses.A_REVERSE)
                    scr_show_main.refresh()
                    table_executions = database.new_execution(cursor_main, cursor_sub, table_executions, scr_dim, open_database, scr_show_main, shown_tables, scr_bottom)
                    updatedData = 1
                if open_window == 1 and cursor_main[2] == 1:
                    table_executions = database.new_entry(cursor_main, cursor_sub, table_executions, scr_dim, open_database, scr_bottom, shown_tables, columns)
                    updatedData = 1
                if current_screen == 3:
                    query_state = 1
            if x == 261 and (cursor_sub != [0, 0, 0, 0] or show_query != 1):        # right
                if open_window == 1:
                    if cursor_sub[0] + cursor_sub[1] == 0:
                        cursor_main = user_input.cursor_right(cursor_main, header_list, scr_dim)
                        cursor_sub = [0, 0, 0, 0]
                    else:
                        cursor_sub = user_input.cursor_right(cursor_sub, columns, scr_dim)
                else:
                    if cursor_main[0] > 0:
                        open_window = 1
                        cursor_sub == [0, 0, 0, 0]
            if x == 62 and cursor_main != [0, 0, 1, 0]:     # >
                if open_window == 1:
                    cursor_main = user_input.cursor_right(cursor_main, header_list, scr_dim)
                    cursor_sub = [0, 0, 0, 0]
            if x == 60 and cursor_main != [0, 0, 1, 0]:     # <
                if open_window == 1:
                    cursor_main = user_input.cursor_left(cursor_main, header_list, scr_dim)
                    cursor_sub = [0, 0, 0, 0]
            elif x == 260 and (cursor_sub != [0, 0, 0, 0] or show_query != 1):     # left
                if open_window == 1:
                    if cursor_sub[0] + cursor_sub[1] == 0:
                        cursor_main = user_input.cursor_left(cursor_main, header_list, scr_dim)
                        cursor_sub = [0, 0, 0, 0]
                    else:
                        cursor_sub = user_input.cursor_left(cursor_sub, columns, scr_dim)
                else:
                    pass
            elif x == 258:      # down
                if open_window == 0 and current_screen == 2:
                    cursor_main = user_input.cursor_down(cursor_main, open_list, scr_dim, cursor_sub)
                    cursor_sub = [0, 0, 0, 0]
                elif open_window == 0 and current_screen == 3:
                    cursor_main = user_input.cursor_down(cursor_main, fav_queries, scr_dim, cursor_sub)
                    cursor_sub = [0, 0, 0, 0]
                elif open_window == 1 and cursor_main[2] == 1:
                    if find_list == []:
                        cursor_sub = user_input.cursor_down(cursor_sub, entry_list, scr_dim, cursor_main)
                    else:
                        find_list.append(-1)
                        cursor_sub = user_input.cursor_down(cursor_sub, find_list, scr_dim, cursor_main)
                        find_list.remove(-1)
                elif open_window == 1 and cursor_main[2] == 2:
                    cursor_sub = user_input.cursor_down(cursor_sub, executions_list, scr_dim, cursor_main)
                else:
                    cursor_sub = user_input.cursor_down(cursor_sub, open_list, scr_dim, cursor_main)
            elif x == 259:      # up
                if open_window == 0:
                    cursor_main = user_input.cursor_up(cursor_main, open_list, scr_dim, cursor_sub)
                    cursor_sub = [0, 0, 0, 0]
                elif open_window == 1 and cursor_main[2] == 1:
                    cursor_sub = user_input.cursor_up(cursor_sub, entry_list, scr_dim, cursor_main)
                elif open_window == 1 and cursor_main[2] == 2:
                    cursor_sub = user_input.cursor_up(cursor_sub, entry_list, scr_dim, cursor_main)
                else:
                    cursor_sub = user_input.cursor_up(cursor_sub, open_list, scr_dim, cursor_main)
            elif x == 115:         # s
                if saved_database == 0 and current_database != 0:
                    database.save_database_to_file(dbname, user, host, password, port, database_type, scr_dim, scr_bottom)
            elif x == 104:      # h
                if current_database != 0 and current_database != 1:
                    if current_screen == 1:
                        current_screen = 2
                    elif current_screen == 3:
                        current_screen = 1
                    elif current_screen == 2:
                        current_screen = 1
            elif x == 107:      # k
                if current_database != 0 and current_database != 1:
                    if current_screen == 3:
                        current_screen = 2
                        show_query = 0
                    elif current_screen == 2:
                        show_query = 0
                        current_screen = 3
                        cursor_main = [0, 0, 1, 0]
                        cursor_sub = [0, 0, 0, 0]
                        open_window = 0
                    elif current_screen == 3 and show_query == 1:
                        show_query = 0
                        current_screen = 2
                    elif query_state == 1 and show_query == 0:
                        current_screen = 2
                        query_state = 0

        elif x == 104:      # h
            if current_screen == 1 and current_database != 0 and current_database != 1:
                current_screen = 2
            elif current_screen == 2:
                current_screen = 1
    except:
        term_scr(scr)
        print("CRITICAL FAILURE...")
        time.sleep(2)
        break
    #terminating the window
try:
    term_scr(scr)
except:
    pass
