#!/usr/bin/env python3
#
#  Encrypt/decrypt a file given a password
#
#  Written by Sean Reifschneider, 2023-04

import os
import base64
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from typing import IO
import argparse
from getpass import getpass
import sys
import typer
from typing import Optional
import tempfile
import subprocess
from hashlib import sha1


app = typer.Typer(
    help="Encrypt or decrypt a file based on a password.",
    epilog="Fernet is an encryption that uses existing tools (AES, PKCS7, HMAC, "
    "SHA256) to implement a 'best practices' for encrypting a file with a "
    "password.  It's primary benefit is that it is easily availabile for Python "
    "programs, simple, and secure.  See for more information: "
    "https://github.com/linsomniac/fernetcrypt",
)

encrypt_blocksize = 40960
decrypt_blocksize = 54712
salt_file_size = 20
magic = "#UF1#"


def read_fernet_header(fp: IO[bytes], raw: bool = False) -> bytes:
    "Read a fernet header, return the salt"
    length = salt_file_size + len(magic)
    if raw:
        length = 16

    data = fp.read(length)
    if raw:
        return data

    if not raw:
        if data[-len(magic) :].decode("ascii") != magic:
            raise ValueError("This does not look like a fernet file")

    return base64.b85decode(data[:salt_file_size])


def fernet_encrypt(
    input_file: str, output_file: str, password: str, raw: bool = False
) -> None:
    salt = os.urandom(16)

    kdf = PBKDF2HMAC(
        algorithm=hashes.SHA256(),
        length=32,
        salt=salt,
        iterations=960000,
    )
    key = base64.urlsafe_b64encode(kdf.derive(password.encode("ascii")))
    f = Fernet(key)

    with open(input_file, "rb") as infp, open(output_file, "wb") as outfp:
        if raw:
            outfp.write(salt)
        else:
            outfp.write(base64.b85encode(salt))
            outfp.write(magic.encode("ascii"))

        while True:
            data = infp.read(encrypt_blocksize)
            if not data:
                break
            outfp.write(f.encrypt(data))


def fernet_decrypt(
    input_file: str, output_file: str, password: str, raw: bool = False
) -> None:
    with open(input_file, "rb") as infp, open(output_file, "wb") as outfp:
        salt = read_fernet_header(infp, raw)

        kdf = PBKDF2HMAC(
            algorithm=hashes.SHA256(),
            length=32,
            salt=salt,
            iterations=960000,
        )
        key = base64.urlsafe_b64encode(kdf.derive(password.encode("ascii")))
        f = Fernet(key)

        while True:
            data = infp.read(decrypt_blocksize)
            if not data:
                break
            outfp.write(f.decrypt(data))


def parse_arguments():
    parser = argparse.ArgumentParser(
        description="Encrypt or decrypt a file based on a password.",
        epilog="Fernet is an encryption that uses existing tools (AES, PKCS7, HMAC, "
        "SHA256) to implement a 'best practices' for encrypting a file with a "
        "password.  It's primary benefit is that it is easily availabile for Python "
        "programs, simple, and secure.  See for more information: "
        "https://github.com/linsomniac/fernetcrypt",
    )
    parser.add_argument(
        "mode",
        choices=["encrypt", "decrypt", "edit"],
        help="Mode of operation: encrypt, decrypt or edit",
    )
    parser.add_argument(
        "-p",
        "--password",
        help="Password for encryption/decryption (optional).  Can also be specified in the "
        "'FERNET_PASSWORD' environment variable.  Otherwise, it will be read "
        "from the terminal.",
    )
    parser.add_argument(
        "-r",
        "--raw",
        action="store_true",
        default=False,
        help="Use 'raw' Fernet encrypted format rather than the default.",
    )
    parser.add_argument("input_file", help="Input file to be encrypted/decrypted")
    parser.add_argument(
        "output_file", required=False, help="Output file after encryption/decryption"
    )

    # Parse the command-line arguments
    args = parser.parse_args()

    #  get the password
    password = args.password
    if not password:
        password = os.environ.get("FERNET_PASSWORD")
    if not password:
        password = getpass("Please enter the password: ")
    args.password = password

    return args


def main():
    args = parse_arguments()

    if args.mode != "edit" and args.output_file is None:
        print("ERROR: Must specify `output_file` unless mode is `edit`")
        sys.exit(1)

    if args.mode == "encrypt":
        fernet_encrypt(args.input_file, args.output_file, args.password, args.raw)
    elif args.mode == "decrypt":
        fernet_decrypt(args.input_file, args.output_file, args.password, args.raw)
    else:
        print(f"ERROR: Unknown argument '{args.mode}'")
        sys.exit(1)


@app.command()
def encrypt(
    password: str = typer.Option(
        None,
        prompt="Enter password for encryption",
        confirmation_prompt=False,
        hide_input=True,
        envvar="FERNET_PASSWORD",
        help="Password for encryption.  Can also be specified in the "
        "'FERNET_PASSWORD' environment variable.  Otherwise, it will be read "
        "from the terminal.",
    ),
    raw: bool = typer.Option(
        False, help="Use 'raw' Fernet encrypted format rather than the default."
    ),
    input_file: str = typer.Argument(..., help="Input file to encrypt"),
    output_file: str = typer.Argument(None, help="Output file for the encrypted data"),
):
    """
    Encrypt a file.
    """
    if output_file is None:
        output_file = input_file + ".fernet"

    fernet_encrypt(input_file, output_file, password, raw)


@app.command()
def decrypt(
    password: str = typer.Option(
        None,
        prompt="Enter password for decryption",
        confirmation_prompt=False,
        hide_input=True,
        envvar="FERNET_PASSWORD",
        help="Password for decryption.  Can also be specified in the "
        "'FERNET_PASSWORD' environment variable.  Otherwise, it will be read "
        "from the terminal.",
    ),
    raw: bool = typer.Option(
        False, help="Use 'raw' Fernet encrypted format rather than the default."
    ),
    input_file: str = typer.Argument(..., help="Input file to decrypt"),
    output_file: str = typer.Argument(None, help="Output file for the plain-text data"),
):
    """
    Decrypt a file.
    """
    if output_file is None:
        if not "." in input_file:
            print(
                "ERROR: Input file does not have a '.' in it to strip for default output file name"
            )
            sys.exit(1)
        output_file = input_file.rsplit(".", 1)[0]

    fernet_decrypt(input_file, output_file, password, raw)


@app.command()
def edit(
    password: str = typer.Option(
        None,
        prompt="Enter password for editing",
        confirmation_prompt=False,
        hide_input=True,
        envvar="FERNET_PASSWORD",
        help="Password for decryption.  Can also be specified in the "
        "'FERNET_PASSWORD' environment variable.  Otherwise, it will be read "
        "from the terminal.",
    ),
    raw: bool = typer.Option(
        False, help="Use 'raw' Fernet encrypted format rather than the default."
    ),
    filename: str = typer.Argument(..., help="Encrypted file to edit"),
):
    """
    Edit an encrypted file in place.
    """
    print(f"Editing {filename} with password {password}")
    import tempfile

    # Creating a temporary file
    with tempfile.NamedTemporaryFile(mode='w+') as tmp:
        fernet_decrypt(filename, tmp.name, password, raw)
        tmp.flush()
        tmp.seek(0)

        before = sha1(open(tmp.name, 'rb').read()).digest()
        subprocess.run([os.environ.get("EDITOR", "vi"), tmp.name])
        after = sha1(open(tmp.name, 'rb').read()).digest()

        if after == before:
            print("No changes detected, not updating.")
        else:
            fernet_encrypt(tmp.name, filename, password, raw)

if __name__ == "__main__":
    app()
