#!/usr/bin/python3
import json
from redbaron import RedBaron
import traceback
import sys
from collections.abc import Iterable
import os

ABI = False

scratches_added = {"root":{}}

def iterable(obj):
    return isinstance(obj, Iterable)
    
def catchdump(x):
  try:    
    return x.dumps()      
  except:
    pass  

def hasMultiStmt(node):
  l =[]
  try:
    for nd in node:
      nodes = nd.find_all(['ifelseblock','atomtrailers', 'return', 'while'], recursive=False)
      nodes = nodes.filter(lambda x: x != None)
      l.extend(nodes)
    return len(l) > 1
  except:
    pass

def getScratch(red):
  if not hasattr(red, 'find_all'):
    return ''
  s = red.find_all('assign')
  if isinstance(s, list):
    s = s.filter(lambda x: 'ScratchVar' in x)
  s2 = red.find_all('atomtrailers')
  if s:
    return '\n'.join(s.map(lambda x: x.dumps()))
  else:
    return ''

def removeblanks(ml):
  lines = ml.splitlines()
  lines = filter(lambda l: l != '' and l.strip() != '', lines)
  return '\n'.join(lines)

def hasScratch(red, varname):
  parentdef = None
  if hasattr(red, 'parentdef'):
    parentdef = red.parentdef
  else:
    parentdef = red.parent_find('def')
    red.parentdef = parentdef
  if parentdef is None:
    parentdef = { 'name': 'root'}
    
  if parentdef.name in scratches_added and varname in scratches_added[parentdef.name]:
    return True
  else:
    if varname in scratches_added['root']:
      return True
    else:
      return False

def add_root_scratches(red):
  gs = red.find_all('g:*', lambda v: v.root is v.parent and v.type != 'def')
  for nd in gs:
    if 'ScratchVar' in nd.dumps():
      scratches_added['root'][str(nd.name)] = True
    
def findmap(red, which, how, def_=None):
  l = red.find_all(which)
  if True or def_ is None:
    l.map(how)
  else:
    l.map(lambda nd: how(nd, def_))

def decs(d):
  if d.value.dumps() == 'bytes':
    d.parent.decorators.pop()
    d.parent.decorators.append('@Subroutine(TealType.bytes)')

def all(red, x = ''):
  try:
    hasMulti = hasMultiStmt(red)
    if hasattr(red,'type'):
      if red.type == 'def' and red.name != 'app' and red.name != 'sig':
        if len(red.decorators) == 0:
          if 'return' in red.dumps() or 'Return' in red.dumps():
            red.decorators.append("@Subroutine(uint64)")
          else:
            red.decorators.append("@Subroutine(TealType.none)")
      defnode = None
      if red.type == 'def': defnode = red
      stat('assign..')
      findmap(red, 'assign', assigns, red)      
      stat('bool..')
      findmap(red, 'boolean_operator', bools)
      stat('binary_operator..')
      findmap(red, 'binary_operator', concats)
      if hasMulti or red.type == 'while':
        stat('return..')
        findmap(red, 'return..', returns)
      findmap(red, 'int', ints)
      findmap(red, 'string', strings)
      stat('ifelseblock..')
      findmap(red, 'ifelseblock', ifs)
      if red.type != 'while':      
        findmap(red, 'while', whiles)
      findmap(red, 'call', calls)
      findmap(red, 'name', names)
    strs_ = []
    #hasMulti = hasMultiStmt(red)
    scratch = getScratch(red)
    if hasattr(red,'type') and red.type == 'name':
      return
    try:    
      if not hasattr(red, 'title') and not (red.type == 'binary_operator') and hasattr(red, '__iter__'):
        for nd in red:
          if not hasattr(nd, 'find_all'): continue
          nodes = nd.find_all(['ifelseblock','atomtrailers', 'return', 'while'], recursive=False)
          if hasMulti:
            findmap(nodes, 'return', returns)
          
          strs = nodes.map(catchdump)
          strs = strs.filter(lambda x: x != None)
          strs_.extend(strs)
    except BaseException as err:
      sys.stderr.write('looperr')
      sys.stderr.write(f"Unexpected {err=}, {type(err)=}")  
      sys.stderr.write(traceback.format_exc())  
      pass
    if len(strs_) > 1:
      strs_ = filter(lambda x: x != '', strs_)
      try:
        strs_ = removeblanks(strs_)
      except:
        pass
      strlist = ',\n\t'.join(strs_)      
      retstr = ''
      if red.type == 'def':
        retstr = f"{scratch}\nreturn "
      red.value = f"{retstr} Seq([\n\t{strlist} ])\n"    
      red.value = removeblanks(red.value.dumps())
    elif len(strs_) == 1:
      #if True or red.find('ifelseblock') or red.find('while') or red.find('assert'):
      if not ('return' in red.value.dumps()) and ( red.find('ifelseblock') or red.find('while') or red.type == 'assert' or red.find('assign')  ):      
        red.value =f"{scratch}\nreturn {strs_[0]}\n".replace('return return', 'return').replace('return Return','return')
  except BaseException as err:
    sys.stderr.write(f"Unexpected {err=}, {type(err)=}")  
    sys.stderr.write(traceback.format_exc())
    pass  

def bools(boolOp):
  all(boolOp.first)
  all(boolOp.second)
  try:
    boolOp.replace(f"{boolOp.value.title()}( {boolOp.first}, {boolOp.second} )")
  except:
    pass
  #  print(boolOp.help())
  #boolOp.parent.value = f"{boolOp.value.title()}( {boolOp.first}, {boolOp.second} )"

def unitaries(un):
  all(un.value)
  try:
    un.replace(f"{un.value.title()}( {un.target.dumps()} )")
  except:
    pass
  #  print(boolOp.help())
  #boolOp.parent.value = f"{boolOp.value.title()}( {boolOp.first}, {boolOp.second} )"
  
def concats(ct):
  all(ct.first)
  all(ct.second)
  if ct.first.type == 'string' or ct.second.type == 'string':
    ct.replace(f"Concat({ct.first.dumps()},{ct.second.dumps()})")
  try:
    if (ct.first.value[0].value == 'Bytes' or ct.second.value[0].value == 'Bytes' or 
       ct.first.value[0].value == 'Concat' or ct.second.value[0].value == 'Concat'):
      ct.replace(f"Concat({ct.first.dumps()},{ct.second.dumps()})")  
  except:
    pass

def assigns(asn, parentdef_ = None):  
  dumps = asn.value.dumps()  
  if 'ScratchVar' in dumps or 'StringArray' in dumps or 'DynamicArray' in dumps:
    return
  if parentdef_ is None:
    parentdef = asn.parent_find('def')
  else:
    parentdef = parentdef_
  if parentdef is None:
    return
  parval = asn.value
  if hasattr(parval, 'filtered'):
    parval = parval.filtered()[0]
  if hasattr(parval, 'type'):  
    if parval.type == 'boolean' or parval.type == 'comparison' or parval.type in 'boolean_operator':
      return
    if hasattr(parval, 'value'):
      if parval.value in ['And', 'Or', 'Not']:
        return
  makescratch = ''
  
  is_str = "'" in dumps or '"' in dumps
  if not hasScratch(asn, asn.target.value):
    if is_str:
      typ = 'TealType.bytes'
    else:
      typ = 'TealType.uint64'
    makescratch = f"{asn.target.value} = ScratchVar({typ})\n"
    #asn.insert_before(makescratch)
    parentdef.insert(0,makescratch)
    
    if not parentdef.name in scratches_added:
      scratches_added[parentdef.name] = {}    
      
    scratches_added[parentdef.name][asn.target.value] = True
    
  asn.replace(f"{asn.target.value}.store({asn.value.dumps()})")

def calls(nd):
  #all(nd.value)
  try:
    if nd.type != 'call':
      return
    if nd.parent[0].value in ['Begin','SetField','SetFields','Submit']:          
      nd.parent[0].replace(f"InnerTxnBuilder.{nd.parent[0].dumps()}")
  
  except:
    pass

def prints(p):
  p.replace(f"Log{p.value.dumps()}")

def names(nd):
  if nd.parent.type == 'assignment' and nd.parent.target == nd:
    return
    
  f = nd.find('name', lambda n: n.value=='False')
  if f:
    f.replace('Int(0)')
    return
  t = nd.find('name', lambda n: n.value=='True')
  if t:
    t.replace('Int(1)')
    return
    
  if nd.dumps()+'.load()' in nd.parent.dumps():
    return
  if nd.dumps()+'.store(' in nd.parent.dumps():
    return

  varname = nd.dumps()

  try:
    if hasScratch(nd, varname) or varname == 'size':
      nd.replace(f"{nd.dumps()}.load()")
    else:
      try:
        if nd.value in ['fee','sender','first_valid','last_valid','receiver','note','lease', 'round', 
                              'amount', 'close_remainder_to', 'vote_pk', 'type', 'type_enum', 'xfer_asset', 'asset_amount',
                               'asset_sender', 'asset_receiver', 'asset_close_to', 'group_index', 'tx_id', 'application_id',
                               'on_completion', 'rekey_to', 'config_asset', 'config_asset_total', 'config_asset_decimals',
                               'config_asset_default_frozen', 'config_asset_unit_name','config_asset_name', 'config_asset_url',
                               'config_asset_manager', 'created_asset_id', 'created_application_id', 'current_application_address',
                               'group_size', 'application_args', 'latest_timestamp']:

          if nd.parent.value[0].value in ['Txn', 'Gtxn', 'Global']:
            try:
              if nd.parent.parent.value[2].type == 'getitem':
                return
            except:
              pass
            if not (f"{varname}()" in nd.parent.dumps()):
              nd.parent.replace(nd.parent.dumps()+'()')            
      except:
        pass  
  except:
    pass

def whiles(w):
  all(w)
  w.replace(f"While( {w.test.dumps()}).Do(\n{w.value.dumps()} )")
  
def ifs(if_):
  all(if_.value[0], 'if')
  try:
    all(if_.value[1], 'if')
  except:
    pass
  else_ = ''
  if len(if_.value) == 2:
    findmap(if_.value,'return',returns)
    else_ = ', ' + if_.value[1].value.dumps()
  try:  
    if_.value = f"If( {if_.value[0].test.dumps()}, {if_.value[0].value.dumps()}{else_} )\n\n"
  except:
    pass

def returns(ret):
  all(ret.value)
  ret.replace(f"Return( {ret.value.dumps().replace('return ','')} )")

def fixreturns(d):
  try:
    if d[0][0].value == 'Return':
      d[0] = f"return {d[0][1].dumps()}"
  except:
    pass

def addreturns(d):
  try:
    if d[0][0].value != 'return':
      d[0] = f"return {d[0].dumps()}"
  except:
    pass

def ints(i):
  try:
    if i.parent.type == 'def_argument' or i.parent.type == 'getitem':
      return
    if hasattr(i.parent.parent.parent, 'value') and hasattr(i.parent.parent.parent.value, 'pop'):
      parname = i.parent.parent.parent.value[0].value
      if parname == 'Int' or parname == 'Arg' or parname in ['application_args', 'Gtxn']:
        return
    try:      
      if hasattr(i.parent.parent, 'value') and hasattr(i.parent.parent.value, 'pop'):
        if i.parent.parent.value[1].value in ['application_args', 'Gtxn']:
          return
    except:
      pass
  except BaseException as err:
    sys.stderr.write(' int outer exception ' + i.dumps())
    sys.stderr.write(f"Unexpected {err=}, {type(err)=}")  
    sys.stderr.write(traceback.format_exc())      
    pass
    return
    
  i.replace(f"Int({i.value})")

def strings(i):
  try:
    if i.parent.type == 'def_argument':
      return
    parval = i.parent.parent.parent.value[0].value
    if parval == 'Bytes' or parval == 'Addr':
      return
  except:
    pass
  i.replace(f"Bytes({i.dumps()})")

progroot = 0

def abimethod(d):
  if hasattr(d, 'return_annotation') and d.return_annotation is not None:
    del d.decorators[0]
    d.decorators.insert(0,'@staticmethod')
    d.decorators.insert(1,'@ABIMethod')

def app_ops(d):
  if d.name in ['delete', 'create', 'update', 'optIn', 'closeOut', 'clearState']:
    del d.decorators[0]
    d.decorators.insert(0,'@staticmethod')
    d.decorators.insert(1, '@Subroutine(TealType.uint64)')

def isSigDef(nd):
  try:
    if nd.type == 'def' and nd.name=='sig':
      return True
  except:
    pass
  return False

def stat(s):
  #sys.stderr.write(f"{s} ")  
  sys.stderr.write(f".")  
  sys.stderr.flush()


opslist = ['delete', 'create', 'update', 'optIn', 'closeOut', 'clearState']

def dumpglobals(red):
  gs = red.find_all('g:*', lambda v: v.root is v.parent and v.type != 'def')
  return "\n".join(gs.map(lambda n: n.dumps()))

def dump_nonabi_defs(red):
  gs = red.find_all('def', lambda v: v.return_annotation is None and not (v.name in opslist))
  return "\n".join(gs.map(lambda n: n.dumps()))
  
def dump_abi_defs(red):
  gs = red.find_all('def', lambda v: not (v.return_annotation is None))
  return "\n".join(gs.map(lambda n: n.dumps()))

def dump_app_ops(d):
  gs = d.find_all('def', lambda v: v.name in opslist)
  return "\n".join(gs.map(lambda n: n.dumps()))

def abiapp(red):
  print("""
import json
from typing import Tuple

from pyteal import *

from pytealutils.applications import ABIMethod, DefaultApprove
from pytealutils import abi

globals().update(TealType.__members__)

""")

  print(dumpglobals(red))
  print(dump_nonabi_defs(red))
  
  print("class ABIApp(DefaultApprove):")
  red.increase_indentation(2)
  print(dump_abi_defs(red))

  print(dump_app_ops(red))
  
  print("""
if __name__ == "__main__":
  app = ABIApp()

  with open("interface.json", "w") as f:
    f.write(json.dumps(app.get_interface().dictify(), indent=4))

  print(compileTeal(app.handler(), mode=Mode.Application, version=5, assembleConstants=True))
  """)

def usesabi(red):
  defs_ = red.find_all('def')
  for nd in defs_:
    if hasattr(nd, 'return_annotation') and nd.return_annotation is not None:
      return True
  return False

def simpleReplaces(source):
  source = source.replace('len(', 'Len(')
  source = source.replace('lput(', 'App.localPut(0,')
  source = source.replace('lget(', 'App.localGet(0,')
  source = source.replace('gput(', 'App.globalPut(')
  source = source.replace('gget(', 'App.globalGet(')  
  return source

def convert(fname):
  sys.stderr.write("genpyteal version 2.2.0\n")

  mode = 'Application'
  entry = 'app'
  source = open(fname, "r")
  source = source.read()
  source = simpleReplaces(source)
  red = RedBaron(source)
  stat('Parsed file..')

  isabi = usesabi(red)
    
  sig = red.filter(isSigDef)
  if sig: 
    mode = 'Signature'
    entry = 'sig'
    stat('LogicSig')
  else:
    stat('Application')
  progroot = red

  add_root_scratches(red)

  findmap(red, 'binary_operator', concats)

  stat('prints..')
  findmap(red, 'print', prints)
  findmap(red, 'name', names)
  stat('unitary_operator..')
  findmap(red, 'unitary_operator', unitaries)
  stat('return..')
  findmap(red, 'return', returns)  
  stat('def..')
  findmap(red, 'def', all)  
  stat('assign..')
  findmap(red, 'assign', assigns)
  stat('def..')
  findmap(red, 'def', fixreturns)
  stat('def..')
  findmap(red, 'def', addreturns)
  stat('decorator..')
  findmap(red, 'decorator', decs)
  stat('strings..')
  findmap(red, 'string', strings)
  sys.stderr.write(f"\n")    

  if isabi:
    findmap(red, 'def', abimethod)
    findmap(red, 'def', app_ops)
    abiapp(red)
    return
    
  print("from pyteal import *\n")
  print("globals().update(TealType.__members__)\n")

  print(red.dumps())
  print(f"if __name__ == \"__main__\":\n    print(compileTeal({entry}(), mode=Mode.{mode}, version=5))")

convert(sys.argv[1])
