#!/usr/bin/env python3
#
# Display local desktop on remove host over LAN without VGA/HDMI cable.
# Copyright (c) 2019, Hiroyuki Ohsaki.
# All rights reserved.
#

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

import os
import socket
import struct
import sys
import zlib

from Xlib import X, display
from perlcompat import die, warn, getopts
import pygame
import pygame.gfxdraw
import pygame.transform

DEBUG = False

HEADER_SIZE = 48
MAX_SEGMENT_SIZE = 50000

BLANK_COLOR = 0, 0, 255, 255
POINTER_COLOR = 255, 255, 0, 128
POINTER_SIZE = 10
FONT_NAME = '/usr/share/fonts/truetype/freefont/FreeSans.ttf'
FONT_SIZE = 20
LABEL_COLOR = 255, 255, 255, 128

def usage():
    die("""\
usage: {} [-vr] [-s host] [-p port] [-W width] [-H height]
  -v       verbose mode
  -r       receiver mode
  -s host  specify receiver hostname/address
  -p port  port number
  -W #     screen width (default: 800)
  -H #     screen height (default: 600)
""".format(sys.argv[0]))

class FrameError(Exception):
    pass

def debug(msg):
    if DEBUG:
        warn(msg)

def pointer_position():
    """Obtain the current position of the pointer (i.e., mouse) on X11 screen.
    This code is contributed by Hal."""
    xscreen = display.Display().screen()
    stat = xscreen.root.query_pointer()
    x, y = stat.root_x, stat.root_y
    return x, y

def dump_root_window(width, height):
    """Return the rectangle of the root window of WIDTH x HEIGHT in RGBX format."""
    xscreen = display.Display().screen()
    img = xscreen.root.get_image(0, 0, width, height, X.ZPixmap, 0xffffffff)
    # format conversion from BGRX to RGBX
    buf = bytearray(img.data)
    for i in range(0, width * height * 4, 4):
        buf[i], buf[i + 2] = buf[i + 2], buf[i]  # swap red and blue
    return buf

class Frame:
    def __init__(self, width=800, height=600, label='NO LABEL'):
        self.width = width
        self.height = height
        self.label = label
        self.pnt_x = None
        self.pnt_y = None
        self.size = None
        self.img = None

    def header(self, size):
        """Compose a 48-byte frame header.  The payload size must be specified
        by SIZE.

        sendscreen frame header format:
        name        size    description
        identifier  4       'SSCR'
        type        1       payload type
        reserved    3
        size        4       payload size (int)
        width       2       frame width (short)
        height      2       frame height (short)
        pnt_x       2       pointer x (short)
        pnt_y       2       pointer y (short)
        label       28
        """
        label = bytes(self.label, encoding='utf-8')
        return struct.pack('4s Bxxx I HH HH 28s', b'SSCR', 0, size, self.width,
                           self.height, self.pnt_x, self.pnt_y, label)

    def parse_header(self, buf):
        """Parse byte string BUF and store the header values in object
        attributes."""
        ident, type_, size, width, height, pnt_x, pnt_y, label = struct.unpack(
            '4s Bxxx I HH HH 28s', buf)
        if ident != b'SSCR':
            raise FrameError('invalid header.')
        self.type_ = type_
        self.size = size
        self.width = width
        self.height = height
        self.pnt_x = pnt_x
        self.pnt_y = pnt_y
        # FIXME: could be written simpler
        label = label.replace(b'\x00', b'')
        self.label = label.decode()
        return True

    def capture_image(self):
        """Capture the current screen and store the position of the current
        mouse pointer in object attributes."""
        self.img = dump_root_window(self.width, self.height)
        self.pnt_x, self.pnt_y = pointer_position()

    def store_image(self, zimg):
        """Decode zlib-compressed image data ZIMG and record the size of the
        (compressed) image in attribute SIZE."""
        try:
            self.img = zlib.decompress(zimg)
        except zlib.error:
            raise FrameError('decompression failed.')
        self.size = len(zimg)

    def message(self):
        """Return the message composed of the header and the payload
        (Zlib-compressed frame image)."""
        zimg = zlib.compress(self.img)
        header = self.header(len(zimg))
        return header + zimg

class Sender:
    def __init__(self, server='localhost', port=5000, width=800, height=600):
        self.server = server
        self.port = port
        self.width = width
        self.height = height

    def compose_label(self):
        """Generate a label displayed on the screen."""
        user = os.getenv('USER')
        hostname = socket.gethostname()
        return '{} @ {}'.format(user, hostname)

    def mainloop(self):
        """Repeatedly send the rectangle of the current desktop to remote host
        using UDP."""
        label = self.compose_label()
        self.sk = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
        while True:
            frame = Frame(width=self.width, height=self.height, label=label)
            frame.capture_image()
            msg = frame.message()
            addr = self.server, self.port
            debug('sending message ({:,} bytes) to {}:{}...'.format(
                len(msg), *addr))
            for i in range(0, len(msg), MAX_SEGMENT_SIZE):
                self.sk.sendto(msg[i:i + MAX_SEGMENT_SIZE], addr)

class Receiver:
    def __init__(self, port=5000, width=800, height=600):
        self.port = port
        self.width = width
        self.height = height
        self.ndrawn = 0

    def init_display(self):
        """Initialize display settings and create a window for display."""
        debug('initializing display...')
        # disable screen saver and display power management
        # FIXME: restore power management settings
        os.system('xset s off')
        os.system('xset -dpms')

        # initialize pygame and create a window
        pygame.display.init()
        pygame.font.init()
        self.screen = pygame.display.set_mode((self.width, self.height))
        self.font = pygame.font.Font(FONT_NAME, FONT_SIZE)

    def recv(self):
        """Receive an incoming message and decode the message.  Return the
        frame decoded"""
        frame = Frame()
        zimg = None
        addr = None, None
        while True:
            self.sk.settimeout(1.)
            try:
                debug('waiting incoming UDP packet...')
                buf, addr = self.sk.recvfrom(MAX_SEGMENT_SIZE)
            except socket.timeout:
                return False
            if b'SSCR' in buf:
                frame.parse_header(buf[:HEADER_SIZE])
                zimg = buf[HEADER_SIZE:]
            elif zimg:
                zimg += buf
            if len(zimg) >= frame.size:
                break
        debug('message received ({:,} bytes) from {}:{}...'.format(
            len(zimg), *addr))
        frame.store_image(zimg)
        return frame

    def draw_img(self, screen, img, width, height):
        """Draw an image IMG with width WIDTH and height HEIGHT at the origin
        (0, 0) on screen SCREEN."""
        try:
            # convert RGBX image to pygame.Surface object
            img = pygame.image.fromstring(img, (width, height), 'RGBX')
        except ValueError:
            return
        screen.blit(img, (0, 0))

    def draw_pointer(self, screen, x, y):
        """Draw a virtual pointer at (X, Y) on screen SCREEN."""
        pygame.gfxdraw.filled_ellipse(screen, x, y, POINTER_SIZE, POINTER_SIZE,
                                      POINTER_COLOR)

    def draw_label(self, screen, font, label):
        """Draw a text LABEL using font FONT around the upper-right corner of
        the screen SCREEN."""
        text = font.render(label, 1, LABEL_COLOR)
        x = screen.get_width() - text.get_width() - FONT_SIZE // 2
        y = FONT_SIZE // 2
        screen.blit(text, (x, y))

    def draw(self, frame):
        """Draw frame FRAME at the origin (0, 0) of the current window."""
        self.screen.fill((0, 0, 0))
        self.draw_img(self.screen, frame.img, frame.width, frame.height)
        # blink the pointer as keep-alive indicator
        if self.ndrawn % 2 == 0:
            self.draw_pointer(self.screen, frame.pnt_x, frame.pnt_y)
        self.draw_label(self.screen, self.font, frame.label)
        pygame.display.update()
        self.ndrawn += 1

    def mainloop(self):
        """Repeatedly receive frames from a client, and display the frame in
        the window."""
        self.init_display()
        self.sk = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        self.sk.bind(('', self.port))  # all available interfaces
        while True:
            frame = self.recv()
            if frame and frame.img:
                self.draw(frame)
            else:
                self.screen.fill(BLANK_COLOR)
                pygame.display.update()

def main():
    opt = getopts('vrs:p:w:h:') or usage()
    if opt.v:
        global DEBUG
        DEBUG = True
    receiver_mode = opt.r
    server = opt.s if opt.s else 'localhost'
    port = int(opt.p) if opt.p else 5000
    width = int(opt.w) if opt.w else 800
    height = int(opt.h) if opt.h else 600
    if receiver_mode:
        receiver = Receiver(port, width, height)
        receiver.mainloop()
    else:
        sender = Sender(server, port, width, height)
        sender.mainloop()

if __name__ == "__main__":
    main()
