#!/usr/bin/env python3
#
#  Encrypt/decrypt a file given a password
#
#  Written by Sean Reifschneider, 2023-04

import os
import base64
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from typing import IO
import argparse
from getpass import getpass

encrypt_blocksize = 40960
decrypt_blocksize = 54712
salt_file_size = 20
magic = "#UF1#"


def read_fernet_header(fp: IO[bytes], raw: bool = False) -> bytes:
    "Read a fernet header, return the salt"
    length = salt_file_size + len(magic)
    if raw:
        length = 16

    data = fp.read(length)
    if raw:
        return data

    if not raw:
        if data[-len(magic) :].decode("ascii") != magic:
            raise ValueError("This does not look like a fernet file")

    return base64.b85decode(data[:salt_file_size])


def fernet_encrypt(
    input_file: str, output_file: str, password: str, raw: bool = False
) -> None:
    salt = os.urandom(16)

    kdf = PBKDF2HMAC(
        algorithm=hashes.SHA256(),
        length=32,
        salt=salt,
        iterations=960000,
    )
    key = base64.urlsafe_b64encode(kdf.derive(password.encode("ascii")))
    f = Fernet(key)

    with open(input_file, "rb") as infp, open(output_file, "wb") as outfp:
        if raw:
            outfp.write(salt)
        else:
            outfp.write(base64.b85encode(salt))
            outfp.write(magic.encode("ascii"))

        while True:
            data = infp.read(encrypt_blocksize)
            if not data:
                break
            outfp.write(f.encrypt(data))


def fernet_decrypt(
    input_file: str, output_file: str, password: str, raw: bool = False
) -> None:
    with open(input_file, "rb") as infp, open(output_file, "wb") as outfp:
        salt = read_fernet_header(infp, raw)

        kdf = PBKDF2HMAC(
            algorithm=hashes.SHA256(),
            length=32,
            salt=salt,
            iterations=960000,
        )
        key = base64.urlsafe_b64encode(kdf.derive(password.encode("ascii")))
        f = Fernet(key)

        while True:
            data = infp.read(decrypt_blocksize)
            if not data:
                break
            outfp.write(f.decrypt(data))


def parse_arguments():
    parser = argparse.ArgumentParser(
        description="Encrypt or decrypt a file based on a password.",
        epilog="Fernet is an encryption that uses existing tools (AES, PKCS7, HMAC, "
        "SHA256) to implement a 'best practices' for encrypting a file with a "
        "password.  It's primary benefit is that it is easily availabile for Python "
        "programs, simple, and secure.  See for more information: "
        "https://github.com/linsomniac/fernetcrypt",
    )
    parser.add_argument(
        "mode",
        choices=["encrypt", "decrypt"],
        help="Mode of operation: encrypt or decrypt",
    )
    parser.add_argument(
        "-p",
        "--password",
        help="Password for encryption/decryption (optional).  Can also be specified in the "
        "'FERNET_PASSWORD' environment variable.  Otherwise, it will be read "
        "from the terminal.",
    )
    parser.add_argument(
        "-r",
        "--raw",
        action="store_true",
        default=False,
        help="Use 'raw' Fernet encrypted format rather than the default.",
    )
    parser.add_argument("input_file", help="Input file to be encrypted/decrypted")
    parser.add_argument("output_file", help="Output file after encryption/decryption")

    # Parse the command-line arguments
    args = parser.parse_args()

    #  get the password
    password = args.password
    if not password:
        password = os.environ.get("FERNET_PASSWORD")
    if not password:
        password = getpass("Please enter the password: ")
    args.password = password

    return args


def main():
    args = parse_arguments()

    if args.mode == "encrypt":
        fernet_encrypt(args.input_file, args.output_file, args.password, args.raw)
    else:
        fernet_decrypt(args.input_file, args.output_file, args.password, args.raw)


if __name__ == "__main__":
    main()
