#!/usr/bin/env python3
# vim: filetype=python

"""
Sender pipeline:
Unix socket/sound source
  --- samples stream ---> [Sample Reader]
  ---  sample_queue  ---> [Packetizer]
---> Uni/Multicast UDP

Receiver pipeline:
UDP packets
  ---> [Receiver.datagram_received]
  --- chunk_list ---> [Receiver.pump_audio]
---> Unix socket/sound sink
"""
from time import time
import struct
import zlib
import socket
from datetime import datetime, timedelta

import asyncio
from functools import partial

from collections import deque

VERSION=(0, 3, 1)

class TimeMachine:
    """
    Handle fast milisecond precision timestamps

    `timemark' marks a time in future - certain number of miliseconds ahead of
    current time.
    """
    def get_timemark(self, latency):
        "Get a timemark `latency' ms in future"
        now = datetime.utcnow()
        stamp = (now.second * 1000) + (now.microsecond // 1000)
        # Stamp ranges from 0 to 59999 - fits in uint16_t
        stamp = (stamp + latency) % 59999
        stamp = struct.pack('>H', stamp)
        return stamp

    def to_absolute_timestamp(self, mark):
        now = datetime.utcnow()
        mark = struct.unpack('>H', mark)[0]
        second = mark // 1000
        microsecond = (mark % 1000) * 1000

        now = datetime.utcnow()
        if now.second > second:
            # next minute
            now = now.replace(second=0, microsecond=0) + timedelta(minutes=1)
        else:
            # this minute
            pass
        absolute_mark = now.replace(second=second, microsecond=microsecond)

        return absolute_mark.timestamp()


class SampleReader(asyncio.Protocol):
    """Read samples over the network, chunk them and put into a queue"""

    # Number of empty chunks before silence is detected.
    SILENCE_TRESHOLD = 30

    HEADER_SIZE = 4

    def __init__(self):
        super().__init__()
        self.sample_queue = asyncio.Queue()

        self.silence_detect = 0

    def set_chunk_size(self, payload_size, sample_size):
        "Calculate optimal chunk size"
        # 1420 is max payload for UDP over 1500 MTU ethernet
        # 80 - max IP header (60) + UDP header.
        # 4 - our header / timestamp
        # NOTE: 60 bytes is pessimistically large IP header. Could be as
        #       small as 20 bytes.

        # Remove our header from the max payload size
        self.chunk_size = payload_size - self.HEADER_SIZE

        # To fit always the same amount of both channels (to not swap them in case
        # of a packet drop) ensure the amount of space is divisible by sample_size
        self.chunk_size -= self.chunk_size % sample_size

        # Required for MTU detection
        self.sample_size = sample_size

    def connection_made(self, transport):
        "Initialize stream buffer"
        self.buffer = bytes()

    def data_received(self, data):
        "Read fifo indefinetely and push data into queue"

        # NOTE: Buffer needs to be only twice the size of the data
        # and could be handled without allocations/deallocations.
        self.buffer += data

        while len(self.buffer) >= self.chunk_size:
            chunk = self.buffer[:self.chunk_size]
            self.buffer = self.buffer[self.chunk_size:]

            # Detect the end of current silence
            if self.silence_detect is True:
                if any(chunk):
                    self.silence_detect = 0
                    print("Silence - end")
                else:
                    # Still silence
                    continue

            # Heuristic detection of silence start
            if chunk[0] == 0 and chunk[-1] == 0:
                self.silence_detect += 1
            else:
                self.silence_detect = 0

            # Silence too long - stop transmission
            if self.silence_detect > self.SILENCE_TRESHOLD:
                if any(chunk): # Accurate check
                    self.silence_detect = 0
                else:
                    print("Silence - start")
                    self.silence_detect = True
                    continue

            self.sample_queue.put_nowait(chunk)

        # Warning - might happen on slow UDP output sink
        if self.sample_queue.qsize() > 30:
            s = "WARNING: Samples in queue: %d - slow UDP transmission!"
            s = s % self.sample_queue.qsize()
            print(s)

    def connection_lost(self, exc):
        print("The pulse was lost. I should go.")
        loop = asyncio.get_event_loop()
        loop.close()

    def decrement_chunk_size(self):
        "Decrement chunk size and flush chunks currently in queue"
        self.chunk_size -= 1
        self.chunk_size -= self.chunk_size % self.sample_size
        while True:
            try:
                self.sample_queue.get_nowait()
            except asyncio.QueueEmpty:
                break
        return self.chunk_size + self.HEADER_SIZE

    def get_next_chunk(self):
        return self.sample_queue.get()


class Packetizer:
    """Read chunks from queue, add timestamp marks and send over multicast."""

    HEADER_COMPRESSED_AUDIO = b'\x80\x00'
    HEADER_RAW_AUDIO = b'\x00\x00'
    HEADER_STATUS = b'\x40\x00'

    def __init__(self, reader, time_machine, latency_msec, compress=False):
        self.reader = reader
        self.time_machine = time_machine
        self.latency_msec = latency_msec
        self.compress = compress

    def create_socket(self, channels, ttl, multicast_loop, broadcast):
        "Create a UDP multicast socket"
        import socket
        import struct
        self.sock = socket.socket(socket.AF_INET,
                                  socket.SOCK_DGRAM,
                                  socket.IPPROTO_UDP)
        self.sock.setsockopt(socket.IPPROTO_IP,
                             socket.IP_MULTICAST_TTL,
                             ttl)

        if multicast_loop is True:
            self.sock.setsockopt(socket.IPPROTO_IP,
                                 socket.IP_MULTICAST_LOOP, 1)

        if broadcast is True:
            self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)

        self.destinations = [
            (address, port)
            for address, port in channels
        ]

        IP_PMTUDISC_DO = 2
        IP_MTU_DISCOVER = 10
        IP_MTU = 14
        # Set DF flag on IP packet (Don't Fragment) - fragmenting would be bad idea
        # it's way better to chunk the packets right.
        self.sock.setsockopt(socket.IPPROTO_IP, IP_MTU_DISCOVER, IP_PMTUDISC_DO)

    def _create_status_packet(self, chunk_no):
        "Format status packet"
        flags = Packetizer.HEADER_STATUS
        now = datetime.utcnow().timestamp()
        dgram = flags + struct.pack('dI', now, chunk_no)
        return dgram

    @asyncio.coroutine
    def packetize(self):
        "Read pre-chunked samples from queue and send them over UDP"
        start = time()
        # Numer of sent packets
        stat_pkts = 0
        # Chunk number as seen by receivers
        chunk_no = 0
        bytes_sent = 0
        bytes_raw = 0
        cancelled_compressions = 0

        # Current speed measurement
        recent = 0
        recent_bytes = 0
        recent_start = time()

        while True:
            # Block until samples are read by the reader.
            chunk = yield from self.reader.get_next_chunk()
            mark = self.time_machine.get_timemark(self.latency_msec)

            chunk_len = len(chunk)
            if self.compress is not False:
                chunk_compressed = zlib.compress(chunk, self.compress)
                if len(chunk_compressed) < chunk_len:
                    # Go with compressed
                    dgram = Packetizer.HEADER_COMPRESSED_AUDIO + mark + chunk_compressed
                else:
                    # Cancel - compressed might not fit to packet
                    dgram = Packetizer.HEADER_RAW_AUDIO + mark + chunk
                    cancelled_compressions += 1
            else:
                dgram = b'\x00\x00' + mark + chunk

            dgram_len = len(dgram)
            chunk_no += 1
            recent += 1
            for destination in self.destinations:
                try:
                    self.sock.sendto(dgram, destination)
                    bytes_sent += dgram_len
                    recent_bytes += dgram_len
                    bytes_raw += chunk_len + 4
                    stat_pkts += 1
                except OSError as e:
                    import errno
                    if e.errno == errno.EMSGSIZE:
                        s = "WARNING: UDP datagram size (%d) is too big for your network MTU"
                        s = s % len(dgram)
                        print(s)
                        new_size = self.reader.decrement_chunk_size()
                        print("Trying MTU detection. New payload size is %d" % new_size)
                        break

            # Send small status datagram every 124 chunks - ~ 1 second
            # It's used to determine if some frames were lost on the network
            # and therefore if output buffer resync is required.
            if chunk_no % 124 == 0:
                dgram = self._create_status_packet(chunk_no)
                for destination in self.destinations:
                    self.sock.sendto(dgram, destination)

            if recent >= 100:
                # Main status line
                now = time()
                took_total = now - start
                took_recent = now - recent_start
                s = ("STATE: dsts=%d total: pkts=%d kB=%d time=%d "
                     "kB/s: avg=%.3f cur=%.3f")
                s = s % (
                    len(self.destinations),
                    stat_pkts,
                    bytes_sent / 1024, took_total,
                    bytes_sent / took_total / 1024,
                    recent_bytes / took_recent / 1024,
                )
                if self.compress:
                    s += ' compress_ratio=%.3f cancelled=%d'
                    s = s % (bytes_sent / bytes_raw, cancelled_compressions)
                print(s)

                recent_start = now
                recent_bytes = 0
                recent = 0


class Receiver(asyncio.DatagramProtocol):
    """
    Packet receiver

    - Receive packets
    - decode headers
    - store in chunk list.
    """

    CMD_AUDIO = 1
    CMD_DROPS = 2

    def __init__(self, time_machine, channel):
        # Store config
        self.channel = channel

        self.time_machine = time_machine

        # NOTE: On LAN an unsorted deque works for me. Might need
        # a packet ordering based on time mark eventually.
        self.chunk_list = deque()

        self.chunk_available = asyncio.Event()

        self.stat_network_latency = 0
        self.stat_network_drops = 0

        super().__init__()

    def connection_made(self, transport):
        "Configure multicast"
        sock = transport.get_extra_info('socket')
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)

        # Check if address is multicast and join group.
        group, port = self.channel

        multicast = True
        octets = group.split('.')

        # Received audio chunk counter
        self.chunk_no = 0
        self.last_sender_chunk_no = None

        if len(octets) != 4:
            multicast = False
        else:
            try:
                o = int(octets[0])
                if not (o >= 224 and o <= 239):
                    multicast = False
            except ValueError:
                multicast = False

        # If not multicast - end
        if multicast is False:
            print("Assuming unicast reception on %s:%d" % (group, port))
            return

        # Multicast - join group
        print("Joining multicast group", group)

        group = socket.inet_aton(group)
        mreq = struct.pack('4sL', group, socket.INADDR_ANY)
        sock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq)

    def _handle_status(self, data):
        now = datetime.utcnow().timestamp()
        sender_timestamp, sender_chunk_no = struct.unpack('dI', data[2:2+8+8])

        # Handle timestamp
        self.stat_network_latency = (now - sender_timestamp)

        # Handle dropped packets

        # If this is first status packet
        # or low sender_chunk_no indicates that sender was restarted
        if self.last_sender_chunk_no is None or sender_chunk_no < 1500:
            self.last_sender_chunk_no = sender_chunk_no
            self.chunk_no = 0
            dropped = 0
            return

        # How many chunks were transmitted since previous status packet?
        chunks_sent = sender_chunk_no - self.last_sender_chunk_no
        dropped = chunks_sent - self.chunk_no

        self.last_sender_chunk_no = sender_chunk_no
        self.chunk_no = 0

        self.stat_network_drops += dropped
        if dropped < 0:
            print("WARNING: More pkts received than sent! You are receiving multiple streams.")

        #s = "TX-RX STATE: network latency: %.3f ms; network drops: %d"
        #s = s % (, dropped)
        #print(s)
        if dropped > 0:
            self.chunk_list.append((self.CMD_DROPS, dropped))
            self.chunk_available.set()


    def datagram_received(self, data, addr):
        "Handle incoming datagram - audio chunk, or status packet"

        header = data[:2]
        mark = data[2:4]
        chunk = data[4:]
        if header == Packetizer.HEADER_RAW_AUDIO:
            pass
        elif header == Packetizer.HEADER_COMPRESSED_AUDIO:
            try:
                chunk = zlib.decompress(chunk)
            except zlib.error:
                print("WARNING: Invalid compressed data - dropping")
                return
        elif header == Packetizer.HEADER_STATUS:
            # Status header!
            self._handle_status(data)
            return
        else:
            print("Invalid header!")
            return
        mark = self.time_machine.to_absolute_timestamp(mark)
        item = (mark, chunk)

        # Count received audio-chunks
        self.chunk_no += 1

        self.chunk_list.append((self.CMD_AUDIO, item))
        self.chunk_available.set()

    def error_received(self, exc):
        print('Error received:', exc)

    def connection_lost(self, exc):
        print("Socket closed, stop the event loop")
        loop = asyncio.get_event_loop()
        loop.stop()


class Player:
    "Play received audio and keep sync"

    def __init__(self, receiver, output_socket,
                 tolerance, sink_latency):

        # Our data source
        self.receiver = receiver

        # Path to the output 
        self.output_socket = output_socket

        # Configuration
        self.tolerance = tolerance
        self.sink_latency = sink_latency

        # Required for silence insertion
        self.last_chunk_size = 0

        # Stats
        self.stat_drops = 0
        self.stat_total_delay = 0

        # Generate silence frames (zeroed) of appropriate sizes for chunks
        self.silence_cache = {}

    def get_silent_chunk(self, size):
        "Generate and cache silent chunks"

        silent_chunk = self.silence_cache.get(size, None)

        if silent_chunk is not None:
            return silent_chunk

        silent_chunk = b'\x00' * size
        self.silence_cache[size] = silent_chunk
        return silent_chunk

    def flush_output(self):
        # Time we will return to playback
        back_on = now + self.tolerance + self.sink_latency / 2
        while self.receiver.chunk_list:
            mark, chunk = self.receiver.chunk_list.popleft()
            desired_time = mark - self.sink_latency
            if desired_time < back_on:
                self.stat_drops += 1
            else:
                break

    @asyncio.coroutine
    def pump_audio(self):
        "Reads asynchronously chunks from the list and feeds into Unix socket"

        # Self-tuning scheduling latency
        sched_latency = 0.001
        cnt = 0

        # Connect to unix socket
        conn = yield from asyncio.open_unix_connection(self.output_socket)
        pulse_reader, pulse_writer = conn

        while True:
            if len(self.receiver.chunk_list) == 0:
                self.receiver.chunk_available.clear()
                print("Queue empty, waiting for initial data...")
                yield from self.receiver.chunk_available.wait()
                continue

            cmd, item = self.receiver.chunk_list.popleft()

            if cmd == self.receiver.CMD_DROPS:
                silent_chunk = self.get_silent_chunk(self.last_chunk_size)
                for i in range(item):
                    pulse_writer.write(silent_chunk)
                continue

            # CMD_AUDIO

            mark, chunk = item
            desired_time = mark - self.sink_latency

            # 0) We got the next chunk to be played
            cnt += 1
            now = datetime.utcnow().timestamp()
            delay = desired_time - now

            self.stat_total_delay += delay

            # 1) We are falling behind and need to drop something to get back
            # on track. 
            if delay < -self.tolerance:
                s = "Falling behind - resync: delay=%.3fms <= %.3fms"
                s = s % (delay*1000, -self.tolerance*1000)
                print(s)
                self.stat_drops += 1
                continue

            # 2) Is it too much in the future? Wait for it!
            if delay > 3 * sched_latency:
                #print("Waiting for it, delay=%.3f sched_latency=%.6f" % (
                #    delay, sched_latency)
                #)
                # Wait until we reach tolerated range
                yield from asyncio.sleep(delay - sched_latency)
                # Update scheduling latency every X times

                if cnt % 15 == 0:
                    # Calculate the "overshoot"
                    new_now = datetime.utcnow().timestamp()
                    jump_error = new_now - desired_time

                    # jump_error > 0 - we overshoot
                    # jump_error < 0 - we are early
                    orig = sched_latency

                    # Correct overshoot more eagerly than being early
                    if jump_error > 0:
                        sched_latency += jump_error
                    else:
                        sched_latency += jump_error / 4

                    #print("sl=%.6f - je=%.6f - sl_now=%.6f" % (orig, jump_error, sched_latency))

                    if sched_latency > self.tolerance:
                        print("sched_latency too high")
                        sched_latency = self.tolerance
                    #print("new_delay=%.4f updated sched_latency=%.6f " % (
                    #    new_delay, sched_latency)
                    #)

            # Let's hear the music.
            # NOTE: Detect pulseaudio failure and stop.
            pulse_writer.write(chunk)
            self.last_chunk_size = len(chunk)
            if cnt % 5 == 0:
                start = time()
                yield from pulse_writer.drain()
                took = time() - start
                if took > self.tolerance:
                    print("WARNING: Output pipe seems slow! Drain took %.3fms" % (took * 1000))

            # Main status line
            if cnt % 50 == 0:
                s = ("STATUS: queue=%-3d latency[ms]: sched=%-4.1f net=%-5.1f "
                     "avg_delay=%-5.2f drops: time=%d net=%d")
                s = s % (
                    len(self.receiver.chunk_list),
                    1000.0 * sched_latency,
                    1000.0 * self.receiver.stat_network_latency,
                    1000.0 * self.stat_total_delay/cnt,
                    self.stat_drops,
                    self.receiver.stat_network_drops
                )
                print(s)
                if self.receiver.stat_network_latency > 4:
                    print("WARNING: Your network latency seems HUGE. Are the clocks synchronized?")
                elif self.receiver.stat_network_latency < 0:
                    print("WARNING: You either exceeded the speed of light or have unsynchronised clocks")


def _parse_arguments():
    "Parse program arguments"
    import argparse

    version = ".".join(str(p) for p in VERSION)

    dst="WaveSync %s - multi-room sound synchronization system"
    dst = dst % version
    p = argparse.ArgumentParser(description=dst)
    snd = p.add_argument_group('sender options')
    rcv = p.add_argument_group('receiver options')
    opt = p.add_argument_group('common')
    act = p.add_argument_group('actions')

    act.add_argument("--tx",
                     metavar="INPUT",
                     action="store",
                     help="transmit sound from a given unix socket")

    act.add_argument("--rx",
                     metavar="OUTPUT",
                     action="store",
                     help="receive sound and sent to a unix socket")

    snd.add_argument("--latency",
                     dest="latency_msec",
                     metavar="MSEC",
                     action="store",
                     default=1500,
                     type=int,
                     help="assumed system latency in miliseconds")

    snd.add_argument("--payload-size",
                     metavar="BYTES",
                     action="store",
                     type=int,
                     default=1500 - 80,
                     help="UDP payload size, 1420 is the safe default")

    snd.add_argument("--ttl",
                     metavar="TTL",
                     action="store",
                     type=int,
                     default=2,
                     help="multicast TTL")

    snd.add_argument("--sample-size",
                     metavar="BYTES",
                     action="store",
                     type=int,
                     default=4,
                     help="sample size in bytes (16-bit 2-channel = 4 bytes)")

    snd.add_argument("--compress",
                     action="store",
                     default=False,
                     type=int,
                     help="compression level (1-9))")

    snd.add_argument("--no-loop",
                     dest="multicast_loop",
                     action="store_false",
                     default=True,
                     help="Do not loop multicast packets back to the sender")

    snd.add_argument("--broadcast",
                     action="store_true",
                     help="Use transmission to broadcast target")

    rcv.add_argument("--tolerance",
                     dest='tolerance_msec',
                     metavar="MSEC",
                     action="store",
                     type=int,
                     default=10,
                     help="error tolerance")

    rcv.add_argument("--sink-latency",
                     dest="sink_latency_msec",
                     metavar="MSEC",
                     action="store",
                     type=int,
                     default=30,
                     help="sink latency")

    opt.add_argument("--channel",
                     dest="ip_list",
                     metavar="ADDRESS:PORT",
                     action="append",
                     default=[],
                     help="multicast group or a unicast address, "
                          "may be given multiple times with --tx")

    opt.add_argument("--debug",
                     action="store_true",
                     help="enable debugging code")

    args = p.parse_args()

    if (args.tx is None) == (args.rx is None):
        p.error('Exactly one action: --tx or --rx must be specified')

    if args.sink_latency_msec > args.latency_msec:
        p.error("Sink latency cannot exceed system latency! Leave some margin too.")

    if not args.ip_list:
        args.ip_list.append('224.0.0.57:45300')

    if args.rx is not None and len(args.ip_list) > 1:
        p.error('Receiver must have only a single channel (IP)')

    # Parse IP addresses
    parsed_ip_list = []
    for arg in args.ip_list:
        tmp = arg.split(':')
        if len(tmp) != 2:
            p.error('TX/RX channel not in format IP_ADDRESS:PORT: ' + arg)
        address, port = tmp

        try:
            port = int(port)
        except ValueError:
            p.error('Port is not a number in channel: ' + arg)

        parsed_ip_list.append((address, port))
    args.ip_list = parsed_ip_list

    return args


def start_tx(args, loop, time_machine):
    "Initialize sender"

    # Sound sample reader
    sample_reader = SampleReader()
    sample_reader.set_chunk_size(args.payload_size, args.sample_size)

    # Packet splitter / sender
    packetizer = Packetizer(sample_reader, time_machine,
                            args.latency_msec,
                            compress=args.compress)

    packetizer.create_socket(args.ip_list,
                             args.ttl,
                             args.multicast_loop,
                             args.broadcast)

    connection = loop.create_unix_connection(lambda: sample_reader, args.tx)

    # Start loop
    asyncio.async(packetizer.packetize())
    asyncio.async(connection)


def start_rx(args, loop, time_machine):
    "Initialize receiver"

    # Network receiver with it's connection
    channel = args.ip_list[0]
    receiver = Receiver(time_machine,
                        channel=channel)

    connection = loop.create_datagram_endpoint(lambda: receiver,
                                               family=socket.AF_INET,
                                               local_addr=channel)

    # Corouting pumping audio into PA
    player = Player(receiver,
                    output_socket=args.rx,
                    tolerance=args.tolerance_msec / 1000.0,
                    sink_latency=args.sink_latency_msec / 1000.0)

    pump = player.pump_audio()

    asyncio.async(connection)
    asyncio.async(pump)


def main():
    "Parse arguments and start the event loop"
    args = _parse_arguments()

    loop = asyncio.get_event_loop()

    if args.debug:
        loop.set_debug(True)

    time_machine = TimeMachine()

    if args.tx is not None:
        start_tx(args, loop, time_machine)
    elif args.rx is not None:
        start_rx(args, loop, time_machine)

    try:
        loop.run_forever()
    finally:
        print('closing event loop')
        loop.close()


if __name__ == "__main__":
    main()
