#!/usr/local/bin/python3

#   threefish
#   Copyright 2008, 2009 Hagen Fürstenau <hagen@zhuliguan.net>
#
#   Demonstrates Threefish encryption with PySkein.
#
#   This program is free software: you can redistribute it and/or modify
#   it under the terms of the GNU General Public License as published by
#   the Free Software Foundation, either version 3 of the License, or
#   (at your option) any later version.
#
#   This program is distributed in the hope that it will be useful,
#   but WITHOUT ANY WARRANTY; without even the implied warranty of
#   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#   GNU General Public License for more details.
#
#   You should have received a copy of the GNU General Public License
#   along with this program.  If not, see <http://www.gnu.org/licenses/>.

import sys
import os
from random import randrange
from getpass import getpass

from skein import threefish, skein512
BLOCK_SIZE = 32


def get_args(argv):
    """Check and process command line arguments.

    Returns (encrypt flag, input file name, output file name)."""

    if ((len(argv) != 3) or (argv[1] not in ("e", "d", "encrypt", "decrypt"))):
        print("Usage : {0} <mode> <file>".format(os.path.basename(sys.argv[0])))
        print("<mode>: 'encrypt' (='e') or 'decrypt' (='d')")
        sys.exit(1)
    encrypt = (sys.argv[1][0] == "e")
    infn = sys.argv[2]
    if encrypt:
        outfn = infn+".3f"
    else:
        if not infn.endswith(".3f"):
            print("Name of encrypted file has to end with '.3f'")
            sys.exit(1)
        outfn = infn[:-3]
    if not os.path.isfile(infn):
        print("'{0}' does not exist".format(infn))
        sys.exit(1)
    if os.path.isfile(outfn):
        print("'{0}' exists, refusing to overwrite".format(outfn))
        sys.exit(1)
    return (encrypt, infn, outfn)


def encrypt_file(inf, outf, key):
    """Encrypt file 'inf' to file 'outf' using 'key'."""

    # generate random tweak value and write it at the beginning of the file
    tweak = bytes(randrange(256) for _ in range(16))
    outf.write(tweak)

    # encrypt all complete blocks
    cipher = threefish(key, tweak)
    block = bytearray(BLOCK_SIZE)
    while 1:
        i = inf.readinto(block)
        if i < BLOCK_SIZE:
            break
        c = cipher.encrypt_block(block)
        outf.write(c)
        cipher.tweak = c[:16]  # cipher block chaining with first 16 bytes

    # pad and encrypt the last incomplete (possibly empty) block
    block[i:] = [randrange(256) for _ in range(BLOCK_SIZE-i)]
    block[-1] = block[-1] & ~(BLOCK_SIZE-1) | i
    outf.write(cipher.encrypt_block(block))


class DecryptError(Exception):
    pass

def decrypt_file(inf, outf, key):
    """Decrypt file 'inf' to file 'outf' using 'key'."""

    # read tweak value from the beginning of the file
    tweak = inf.read(16)
    if len(tweak) < 16:
        raise DecryptError

    # decrypt all complete blocks
    cipher = threefish(key, tweak)
    cur, next = bytearray(BLOCK_SIZE), bytearray(BLOCK_SIZE)
    if inf.readinto(cur) < BLOCK_SIZE:
        raise DecryptError
    while 1:
        i = inf.readinto(next)
        if i < BLOCK_SIZE:
            break
        outf.write(cipher.decrypt_block(cur))
        cipher.tweak = cur[:16]  # cipher block chaining with first 16 bytes
        cur, next = next, cur
    if i:
        raise DecryptError

    # decrypt and unpad last block
    block = cipher.decrypt_block(cur)
    i = block[-1] & (BLOCK_SIZE-1)
    outf.write(block[:i])


if __name__ == "__main__":
    encrypt, infn, outfn = get_args(sys.argv)
    password = getpass().encode("utf-8")
    key = skein512(password, digest_bits=BLOCK_SIZE*8).digest()
    with open(infn, "rb") as inf:
        with open(outfn, "wb") as outf:
            if encrypt:
                encrypt_file(inf, outf, key)
            else:
                try:
                    decrypt_file(inf, outf, key)
                except DecryptError:
                    print("unexpected end of encrypted data", file=sys.stderr)
                    sys.exit(1)
