#!/usr/bin/env python3
#
# Wrapper script to be deployed on machines whose network interfaces should be
# controllable via the RawNetworkInterfaceDriver. A /etc/labgrid/helpers.yaml
# can deny access to network interfaces. See below.
#
# This is intended to be used via sudo. For example, add via visudo:
# %developers ALL = NOPASSWD: /usr/sbin/labgrid-raw-interface

import argparse
import os
import sys

import yaml


def get_denylist():
    denylist_file = "/etc/labgrid/helpers.yaml"
    try:
        with open(denylist_file) as stream:
            data = yaml.load(stream, Loader=yaml.SafeLoader)
    except (PermissionError, FileNotFoundError, AttributeError) as e:
        raise Exception(f"No configuration file ({denylist_file}), inaccessable or invalid yaml") from e

    denylist = data.get("raw-interface", {}).get("denied-interfaces", [])

    if not isinstance(denylist, list):
        raise Exception("No explicit denied-interfaces or not a list, please check your configuration")

    denylist.append("lo")

    return denylist


def main(program, ifname, count):
    if not ifname:
        raise ValueError("Empty interface name.")
    if any((c == "/" or c.isspace()) for c in ifname):
        raise ValueError(f"Interface name '{ifname}' contains invalid characters.")
    if len(ifname) > 16:
        raise ValueError(f"Interface name '{ifname}' is too long.")

    denylist = get_denylist()

    if ifname in denylist:
        raise ValueError(f"Interface name '{ifname}' is denied in denylist.")

    programs = ["tcpreplay", "tcpdump"]
    if program not in programs:
        raise ValueError(f"Invalid program {program} called with wrapper, valid programs are: {programs}")

    args = [
        program,
    ]

    if program == "tcpreplay":
        args.append(f"--intf1={ifname}")
        args.append('-')

    if program == "tcpdump":
        args.append("-n")
        args.append(f"--interface={ifname}")
        args.append("-w")
        args.append('-')

        if count:
            args.append("-c")
            args.append(str(count))

    try:
        os.execvp(args[0], args)
    except FileNotFoundError as e:
        raise RuntimeError(f"Missing {program} binary") from e


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '-d',
        '--debug',
        action='store_true',
        default=False,
        help="enable debug mode"
    )
    parser.add_argument('program', type=str, help='program to run, either tcpreplay or tcpdump')
    parser.add_argument('interface', type=str, help='interface name')
    parser.add_argument('count', nargs="?", type=int, default=None, help='amount of frames to capture while recording')
    args = parser.parse_args()
    try:
        main(args.program, args.interface, args.count)
    except Exception as e:  # pylint: disable=broad-except
        if args.debug:
            import traceback
            traceback.print_exc(file=sys.stderr)
        print(f"ERROR: {e}", file=sys.stderr)
        exit(1)
