#!/usr/bin/env python3
# vim: filetype=python

"""
Sender pipeline:
Unix socket/sound source
  --- samples stream ---> [Sample Reader]
  ---  sample_queue  ---> [Packetizer]
---> Uni/Multicast UDP 

Receiver pipeline:
UDP packets
  ---> [Receiver.datagram_received]
  --- chunk_list ---> [Receiver.pump_audio]
---> Unix socket/sound sink
"""
from time import time
import struct
import zlib
import socket
from datetime import datetime, timedelta

import asyncio
from functools import partial

from collections import deque

VERSION=(0, 1, 1)

class TimeMachine:
    """
    Handle fast milisecond precision timestamps

    `timemark' marks a time in future - certain number of miliseconds ahead of
    current time.
    """
    def get_timemark(self, latency):
        "Get a timemark `latency' ms in future"
        now = datetime.utcnow()
        stamp = (now.second * 1000) + (now.microsecond // 1000)
        # Stamp ranges from 0 to 59999 - fits in uint16_t
        stamp = (stamp + latency) % 59999
        stamp = struct.pack('>H', stamp)
        return stamp

    def to_absolute_timestamp(self, mark):
        now = datetime.utcnow()
        mark = struct.unpack('>H', mark)[0]
        second = mark // 1000
        microsecond = (mark % 1000) * 1000

        now = datetime.utcnow()
        if now.second > second:
            # next minute
            now = now.replace(second=0, microsecond=0) + timedelta(minutes=1)
        else:
            # this minute
            pass
        absolute_mark = now.replace(second=second, microsecond=microsecond)

        return absolute_mark.timestamp()


class SampleReader(asyncio.Protocol):
    """Read samples over the network, chunk them and put into a queue"""

    # Number of empty chunks before silence is detected.
    SILENCE_TRESHOLD = 30

    def __init__(self, sample_queue, chunk_size):
        super().__init__()
        self.chunk_size = chunk_size
        self.sample_queue = sample_queue
        self.silence_detect = 0

    def connection_made(self, transport):
        "Initialize stream buffer"
        self.buffer = bytes()

    def data_received(self, data):
        "Read fifo indefinetely and push data into queue"

        # NOTE: Buffer needs to be only twice the size of the data
        # and could be handled without allocations/deallocations.
        self.buffer += data

        while len(self.buffer) >= self.chunk_size:
            chunk = self.buffer[:self.chunk_size]
            self.buffer = self.buffer[self.chunk_size:]

            # Detect the end of current silence
            if self.silence_detect is True:
                if any(chunk):
                    self.silence_detect = 0
                    print("Silence - end")
                else:
                    # Still silence
                    continue

            # Heuristic detection of silence start
            if chunk[0] == 0 and chunk[-1] == 0:
                self.silence_detect += 1
            else:
                self.silence_detect = 0

            # Silence too long - stop transmission
            if self.silence_detect > self.SILENCE_TRESHOLD:
                if any(chunk): # Accurate check
                    self.silence_detect = 0
                else:
                    print("Silence - start")
                    self.silence_detect = True
                    continue

            self.sample_queue.put_nowait(chunk)


        # Warning - might happen on slow UDP output sink
        if self.sample_queue.qsize() > 30:
            print("System too slow; unable to send UDP packets fast enough!")

    def connection_lost(self, exc):
        print("The pulse was lost. I should go.")
        loop = asyncio.get_event_loop()
        loop.close()


class Packetizer:
    """Read chunks from queue, add timestamp marks and send over multicast."""

    def __init__(self, latency_msec, time_machine, sample_queue, compress=False):
        self.sample_queue = sample_queue
        self.time_machine = time_machine
        self.latency_msec = latency_msec
        self.compress = compress

    def create_socket(self, channels, ttl, multicast_loop):
        "Create a UDP multicast socket"
        import socket
        import struct
        self.sock = socket.socket(socket.AF_INET,
                                  socket.SOCK_DGRAM,
                                  socket.IPPROTO_UDP)
        self.sock.setsockopt(socket.IPPROTO_IP,
                             socket.IP_MULTICAST_TTL,
                             ttl)
        if multicast_loop is True:
            self.sock.setsockopt(socket.IPPROTO_IP,
                                 socket.IP_MULTICAST_LOOP, 1)

        self.destinations = [
            (address, port)
            for address, port in channels
        ]

    def get_chunk_size(self, mtu, sample_size):
        "Calculate optimal chunk size"
        # 1420 is max payload for UDP over 1500 MTU ethernet
        # 80 - IP, UDP headers
        # 2 - our header / timestamp
        # NOTE: 60 bytes is pessimistically large IP header. Could be as
        #       small as 20 bytes.

        chunk_size = mtu - 80 - 4 # IP, UDP headers

        # To fit always the same amount of both channels (to not swap them in case
        # of a packet drop) ensure the amount of space is divisible by sample_size
        chunk_size -= chunk_size % sample_size
        return chunk_size

    @asyncio.coroutine
    def packetize(self):
        start = time()
        pkts = 0
        bytes_sent = 0
        bytes_raw = 0
        cancelled_compressions = 0

        # Current speed measurement
        recent = 0
        recent_bytes = 0
        recent_start = time()

        while True:
            chunk = yield from self.sample_queue.get()
            mark = self.time_machine.get_timemark(self.latency_msec)

            chunk_len = len(chunk)
            if self.compress is not False:
                chunk_compressed = zlib.compress(chunk, self.compress)
                if len(chunk_compressed) < chunk_len:
                    # Go with compressed
                    dgram = b'\x80\x00' + mark + chunk_compressed
                else:
                    # Cancel - compressed might not fit to packet
                    dgram = b'\x00\x00' + mark + chunk
                    cancelled_compressions += 1
            else:
                dgram = b'\x00\x00' + mark + chunk

            dgram_len = len(dgram)
            for destination in self.destinations:
                self.sock.sendto(dgram, destination)
                bytes_sent += dgram_len
                recent_bytes += dgram_len
                bytes_raw += chunk_len + 4

                pkts += 1

            recent += 1

            if recent >= 100:
                # Main status line
                now = time()
                took_total = now - start
                took_recent = now - recent_start
                s = ("STATE: dsts=%d total: pkts=%d kB=%d time=%d "
                     "kB/s: avg=%.3f cur=%.3f")
                s = s % (
                    len(self.destinations),
                    pkts,
                    bytes_sent / 1024, took_total,
                    bytes_sent / took_total / 1024,
                    recent_bytes / took_recent / 1024,
                )
                if self.compress:
                    s += ' compress_ratio=%.3f cancelled=%d'
                    s = s % (bytes_sent / bytes_raw, cancelled_compressions)
                print(s)

                recent_start = now
                recent_bytes = 0
                recent = 0


class Receiver(asyncio.DatagramProtocol):
    """
    Packet receiver

    - Receive packets
    - decode headers
    - store in chunk list.
    - pump_audio coroutine  reads chunk from the list and feeds into Unix
      socket into PulseAudio.
    """
    def __init__(self, time_machine, channel, tolerance, sink_latency):
        # Store config
        self.channel = channel

        self.tolerance = tolerance
        self.sink_latency = sink_latency

        self.time_machine = time_machine

        # NOTE: On LAN an unsorted deque works for me. Might need
        # a packet ordering based on time mark eventually.
        self.chunk_list = deque()

        self.chunk_available = asyncio.Event()

        super().__init__()

    def connection_made(self, transport):
        "Configure multicast"
        sock = transport.get_extra_info('socket')
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)

        # Check if address is multicast and join group.
        group, port = self.channel

        multicast = True
        octets = group.split('.')

        if len(octets) != 4:
            multicast = False
        else:
            try:
                o = int(octets[0])
                if not (o >= 224 and o <= 239):
                    multicast = False
            except ValueError:
                multicast = False

        # If not multicast - end
        if multicast is False:
            print("Assuming unicast reception on %s:%d" % (group, port))
            return

        # Multicast - join group
        print("Joining multicast group", group)

        group = socket.inet_aton(group)
        mreq = struct.pack('4sL', group, socket.INADDR_ANY)
        sock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq)

    def datagram_received(self, data, addr):
        header = data[:2]
        mark = data[2:4]
        chunk = data[4:]
        if header == b'\x80\x00':
            chunk = zlib.decompress(chunk)
        elif header == b'\x00\x00':
            pass
        else:
            print("Invalid header!")
        mark = self.time_machine.to_absolute_timestamp(mark)
        item = (mark, chunk)

        self.chunk_list.append(item)
        self.chunk_available.set()

    def error_received(self, exc):
        print('Error received:', exc)

    def connection_lost(self, exc):
        print("Socket closed, stop the event loop")
        loop = asyncio.get_event_loop()
        loop.stop()

    @asyncio.coroutine
    def pump_audio(self, pulse_socket):

        # Self-tuning scheduling latency
        sched_latency = 0.001
        cnt = 0
        drops = 0
        total_delay = 0

        conn = yield from asyncio.open_unix_connection(pulse_socket)
        pulse_reader, pulse_writer = conn
        while True:
            if len(self.chunk_list) == 0:
                self.chunk_available.clear()
                print("Queue empty, waiting for initial data...")
                yield from self.chunk_available.wait()
                continue

            item = self.chunk_list.popleft()
            mark, chunk = item

            desired_time = mark - self.sink_latency

            # 0) We got the next chunk to be played
            cnt += 1
            now = datetime.utcnow().timestamp()
            delay = desired_time - now

            total_delay += delay

            # 1) Is it too much in the past? Drop it!
            if delay < -self.tolerance:
                s = "Drop - falling behind: delay=%.3fms <= %.3fms"
                s = s % (delay*1000, -self.tolerance*1000)
                print(s)
                drops += 1
                continue

            # 2) Is it too much in the future? Wait for it!
            if delay > 3 * sched_latency:
                #print("Waiting for it, delay=%.3f sched_latency=%.6f" % (
                #    delay, sched_latency)
                #)
                # Wait until we reach tolerated range
                yield from asyncio.sleep(delay - sched_latency)
                # Update scheduling latency every X times

                if cnt % 15 == 0:
                    # Calculate the "overshoot"
                    new_now = datetime.utcnow().timestamp()
                    jump_error = new_now - desired_time

                    # jump_error > 0 - we overshoot
                    # jump_error < 0 - we are early
                    orig = sched_latency

                    # Correct overshoot more eagerly than being early
                    if jump_error > 0:
                        sched_latency += jump_error
                    else:
                        sched_latency += jump_error / 4

                    #print("sl=%.6f - je=%.6f - sl_now=%.6f" % (orig, jump_error, sched_latency))

                    if sched_latency > self.tolerance:
                        print("sched_latency too high")
                        sched_latency = self.tolerance
                    #print("new_delay=%.4f updated sched_latency=%.6f " % (
                    #    new_delay, sched_latency)
                    #)

            # Let's hear the music.
            # NOTE: Detect pulseaudio failure and stop.
            pulse_writer.write(chunk)
            if cnt % 5 == 0:
                start = time()
                yield from pulse_writer.drain()
                took = time() - start
                if took > self.tolerance:
                    print("Output pipe seems slow!")

            # Main status line
            if cnt % 50 == 0:
                s = "STATE: queue=%-3d sched_latency=%5.2fms avg_delay=%6.2fms drops=%d" % (
                    len(self.chunk_list),
                    1000.0 * sched_latency,
                    1000.0 * total_delay/cnt,
                    drops
                )
                print(s)


def _parse_arguments():
    "Parse program arguments"
    import argparse

    version = ".".join(str(p) for p in VERSION)
    dst="WaveSync %s - multi-room sound synchronization system"
    dst = dst % version
    p = argparse.ArgumentParser(description=dst)
    snd = p.add_argument_group('sender options')
    rcv = p.add_argument_group('receiver options')
    opt = p.add_argument_group('common')
    act = p.add_argument_group('actions')

    act.add_argument("--tx",
                     metavar="INPUT",
                     action="store",
                     help="transmit sound from a given unix socket")

    act.add_argument("--rx",
                     metavar="OUTPUT",
                     action="store",
                     help="receive sound and sent to a unix socket")

    snd.add_argument("--latency",
                     dest="latency_msec",
                     metavar="MSEC",
                     action="store",
                     default=1500,
                     type=int,
                     help="assumed system latency in miliseconds")

    snd.add_argument("--mtu",
             	     metavar="MTU",
                     action="store",
                     type=int,
                     default=1500,
                     help="network MTU")

    snd.add_argument("--ttl",
                     metavar="TTL",
                     action="store",
                     type=int,
                     default=2,
                     help="multicast TTL")

    snd.add_argument("--sample-size",
                     metavar="BYTES",
                     action="store",
                     type=int,
                     default=4,
                     help="sample size in bytes (16-bit 2-channel = 4 bytes)")

    snd.add_argument("--compress",
                     action="store",
                     default=False,
                     help="compression level (1-9))")

    snd.add_argument("--no-loop",
                     dest="multicast_loop",
                     action="store_false",
                     default=True,
                     help="Do not loop multicast packets back to the sender")

    rcv.add_argument("--tolerance",
                     dest='tolerance_msec',
                     metavar="MSEC",
                     action="store",
                     type=int,
                     default=5,
                     help="error tolerance")

    rcv.add_argument("--sink-latency",
                     dest="sink_latency_msec",
                     metavar="MSEC",
                     action="store",
                     type=int,
                     default=30,
                     help="sink latency")

    opt.add_argument("--channel",
                     dest="ip_list",
                     metavar="ADDRESS:PORT",
                     action="append",
                     default=[],
                     help="multicast group or a unicast address, "
                          "may be given multiple times with --tx")

    opt.add_argument("--debug",
                     action="store_true",
                     help="enable debugging code")

    args = p.parse_args()

    if (args.tx is None) == (args.rx is None):
        p.error('Exactly one action: --tx or --rx must be specified')

    if args.sink_latency_msec > args.latency_msec:
        p.error("Sink latency cannot exceed system latency! Leave some margin too.")

    if not args.ip_list:
        args.ip_list.append('224.0.0.57:45300')

    if args.rx is not None and len(args.ip_list) > 1:
        p.error('Receiver must have only a single channel (IP)')

    # Parse IP addresses
    parsed_ip_list = []
    for arg in args.ip_list:
        tmp = arg.split(':')
        if len(tmp) != 2:
            p.error('TX/RX channel not in format IP_ADDRESS:PORT: ' + arg)
        address, port = tmp

        try:
            port = int(port)
        except ValueError:
            p.error('Port is not a number in channel: ' + arg)

        parsed_ip_list.append((address, port))
    args.ip_list = parsed_ip_list

    return args


def start_tx(args, loop, time_machine):
    "Initialize sender"
    sample_queue = asyncio.Queue()

    # Packet splitter / sender
    packetizer = Packetizer(args.latency_msec, time_machine,
                            sample_queue, compress=args.compress)

    packetizer.create_socket(args.ip_list,
                             args.ttl,
                             args.multicast_loop)

    # Sound sample reader
    data_size = packetizer.get_chunk_size(args.mtu, args.sample_size)
    sample_reader = SampleReader(sample_queue, data_size)

    connection = loop.create_unix_connection(lambda: sample_reader, args.tx)

    # Start loop
    asyncio.async(packetizer.packetize())
    asyncio.async(connection)


def start_rx(args, loop, time_machine):
    "Initialize receiver"

    # Network receiver with it's connection
    channel = args.ip_list[0]
    receiver = Receiver(time_machine,
                        channel=channel,
                        tolerance=args.tolerance_msec / 1000.0,
                        sink_latency=args.sink_latency_msec / 1000.0)

    connection = loop.create_datagram_endpoint(lambda: receiver,
                                               family=socket.AF_INET,
                                               local_addr=channel)

    # Corouting pumping audio into PA
    pump = receiver.pump_audio(pulse_socket=args.rx)

    asyncio.async(connection)
    asyncio.async(pump)


def main():
    "Parse arguments and start the event loop"
    args = _parse_arguments()

    loop = asyncio.get_event_loop()

    if args.debug:
        loop.set_debug(True)

    time_machine = TimeMachine()

    if args.tx is not None:
        start_tx(args, loop, time_machine)
    elif args.rx is not None:
        start_rx(args, loop, time_machine)

    try:
        loop.run_forever()
    finally:
        print('closing event loop')
        loop.close()


if __name__ == "__main__":
    main()
