#!/usr/bin/env python

from __future__ import print_function
import subprocess as sp
import sys

shebang_prefix = '#!'
cmd_prefix = '>>>'
comment_prefix = '#'
stderr_prefix = '!'
returncode_prefix = '$?='
sep_length = 60

class Tester():

    def __init__(self, file):
        self.file = file
        self.executable = '/bin/sh'
        self.cmd = []
        self.stdout = []
        self.stderr = []
        self.returncode = []
        self.parse_file(file)

    def add_cmd(self, cmd):
        self.cmd.append(cmd)
        self.stdout.append([])
        self.stderr.append([])
        self.returncode.append(0)

    def add_stdout(self, line):
        self.stdout[len(self.cmd)-1].append(line)

    def add_stderr(self, line):
        self.stderr[len(self.cmd)-1].append(line)

    def add_returncode(self, retc):
        self.returncode[len(self.cmd)-1] = retc

    def parse_file(self, file):
        with open(args.file) as file:
            for line in file:
                line = line.strip()

                if line.startswith(shebang_prefix):
                    self.executable = line.strip(shebang_prefix)
                elif line.startswith(comment_prefix) or len(line) is 0:
                    continue
                elif line.startswith(cmd_prefix):
                    cmd = line.strip(cmd_prefix).strip()
                    self.add_cmd(cmd)
                elif line.startswith(returncode_prefix):
                    returncode = int(line.strip(returncode_prefix).strip())
                    self.add_returncode(returncode)
                elif line.startswith(stderr_prefix):
                    self.add_stderr(line.strip(stderr_prefix).strip())
                else:
                    self.add_stdout(line)

    def run(self, cmd):
        p = sp.Popen(cmd, shell=True, executable=self.executable, stdout=sp.PIPE, stderr=sp.PIPE)
        outb, errb = p.communicate()
        out = [l.strip() for l in outb.decode().splitlines() if l.strip()]
        err = [l.strip() for l in errb.decode().splitlines() if l.strip()]
        ret = p.returncode
        return (out, err, ret)

    def test(self, verbose=False, stop_on_fail=False):
        passed = 0
        failed = 0
        print('doctest-cli: {}'.format(self.file))
        print('-'*sep_length)
        for i,cmd in enumerate(self.cmd):
            exp_out = self.stdout[i]
            exp_err = self.stderr[i]
            exp_ret = self.returncode[i]

            out, err, ret = self.run(cmd)

            if out == exp_out and ret == exp_ret and err == exp_err:
                passed += 1
                if verbose:
                    print('Test passed:\n ', cmd)
                    print('With STDERR:\n', '\n '.join(err))
                    print('With STDOUT:\n', '\n '.join(out))
                    print('With RETURN:\n', ret)
                    print('-'*sep_length)
            else:
                failed += 1
                print('Error while running:\n ', cmd)

                if ret != exp_ret:
                    print('With STDERR:\n ', '\n '.join(err))
                    print('Expected RETURN:', exp_ret)
                    print('Got:', ret)
                elif err != exp_err:
                    print('Expected STDERR:\n ', '\n '.join(exp_err))
                    print('Got:\n ', '\n '.join(err))
                else:
                    print('With STDERR:\n ', '\n '.join(err))
                    print('Expected STDOUT:\n ', '\n '.join(exp_out))
                    print('Got:\n ', '\n '.join(out))
                if stop_on_fail:
                    return (passed, failed)
                print('-'*sep_length)

        return (passed, failed)

if __name__ == '__main__':
    from argparse import ArgumentParser

    parser = ArgumentParser('doctest-cli')
    parser.add_argument('file', help='input file')
    parser.add_argument('-s', '--stop-on-fail', action = 'store_true', help='Stop after first failed test')
    parser.add_argument('-v', '--verbose', action = 'store_true', help='Show all tests')
    args = parser.parse_args()

    t = Tester(args.file)
    passed, failed = t.test(args.verbose, args.stop_on_fail)
    if failed is 0:
        print('All {} tests passed'.format(passed))
    else:
        print('Tests ran: {}\nTests passed: {}\nTests failed: {}'.format(passed+failed, passed, failed))
