#!/usr/bin/env python3
# comp - Curses Omni Media Player
#
# comp is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# comp program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with comp.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright (C) 2017  Nguyễn Gia Phong <vn.mcsinyx@gmail.com>

import curses
import json
import re
from argparse import ArgumentParser
from collections import deque
from configparser import ConfigParser
from functools import reduce
from gettext import bindtextdomain, gettext as _, textdomain
from os import makedirs
from os.path import abspath, dirname, expanduser
from threading import Thread

from mpv import MPV
from pkg_resources import resource_filename
from youtube_dl import YoutubeDL

from omp import extract_info, Omp

# Init gettext
bindtextdomain('omp', resource_filename('omp', 'locale'))
textdomain('omp')

# Global constants
SYSTEM_CONFIG = '/etc/comp/settings.ini'
USER_CONFIG = expanduser('~/.config/comp/settings.ini')
MODES = ("play-current", "play-all", "play-selected", "repeat-current",
         "repeat-all", "repeat-selected", "shuffle-all", "shuffle-selected")
MODE_STR_LEN = max(len(_(mode)) for mode in MODES)
DURATION_COL_LEN = max(len(_("Duration")), 8)


def justified(s, width):
    """Return s left-justified of length width."""
    return s.ljust(width)[:width]


class Comp(Omp):
    """Meta object for drawing and playing.

    Attributes:
        entries (list): list of all tracks
        json_file (str): path to save JSON playlist
        mode (str): the mode to pick and play tracks
        mp (MPV): an mpv instance
        play_backward (bool): flag show if to play the previous track
        play_list (list): list of tracks according to mode
        played (list): list of previously played tracks
        playing (int): index of playing track in played
        playlist (iterator): iterator of tracks according to mode
        reading (bool): flag show if user input is being read
        search_res (iterator):  title-searched results
        scr (curses WindowObject): curses window object
        start (int): index of the first track to be printed on screen
        vid (str): flag show if video output is enabled
        y (int): the current y-coordinate
    """
    def __new__(cls, entries, json_file, mode, mpv_vid, mpv_vo, ytdlf):
        self = object.__new__(cls)
        self.play_backward, self.reading = False, False
        self.playing, self.start, self.y = -1, 0, 1
        self.json_file, self.mode, self.vid = json_file, mode, mpv_vid
        self.entries, self.played = entries, []
        self.playlist, self.search_res = iter(()), deque()
        self.mp = MPV(input_default_bindings=True, input_vo_keyboard=True,
                      ytdl=True, ytdl_format=ytdlf)
        self.scr = curses.initscr()
        return self

    def adds(self, s, y, x=0, X=-1, attr=curses.A_NORMAL, lpad=1):
        """Paint the string s, added lpad spaces to the left, from
        (y, x) to (y, X) with attributes attr, overwriting anything
        previously on the display.
        """
        if self.reading: return
        y %= curses.LINES
        x %= curses.COLS
        length = X % curses.COLS - x + (y != curses.LINES - 1)
        self.scr.addstr(y, x, (' '*lpad + s).ljust(length)[:length], attr)

    def update_status(self, message='', msgattr=curses.A_NORMAL):
        """Update the status lines at the bottom of the screen."""
        def add_status_str(s, x=0, X=-1, attr=curses.color_pair(12), lpad=1):
            self.adds(s, curses.LINES - 2, x=x, X=X, attr=attr, lpad=lpad)

        if self.mp.osd.duration is not None:
            self.played[self.playing]['duration'] = self.mp.osd.duration
        add_status_str(self.mp.osd.time_pos or '00:00:00', X=8)
        add_status_str('/', x=9, X=10)
        add_status_str(self.mp.osd.duration or '00:00:00', x=11, X=19)
        add_status_str('|' if self.mp.pause else '>', x=20, X=21)
        add_status_str((self.mp.media_title or b'').decode(), x=22,
                       attr=curses.color_pair(12)|curses.A_BOLD)
        add_status_str(_(self.mode), x=-5-len(_(self.mode)))
        if not self.mp.mute: add_status_str('A', x=-4, X=-3)
        if self.vid != 'no': add_status_str('V', x=-2, lpad=0)
        if message: self.adds(message, curses.LINES-1, attr=msgattr, lpad=0)
        self.scr.refresh()

    def setno(self, *keys):
        """Set all keys of each entry in entries to False."""
        for entry in self.entries:
            for key in keys:
                entry[key] = False

    def play(self, force=False):
        """Play the next track."""
        def mpv_play(entry, force):
            self.setno('playing')
            entry['playing'] = True
            self.mp.vid = self.vid
            try:
                self.mp.play(entry['filename'])
            except:
                entry['error'] = True
            self.print(entry)
            if force: self.mp.pause = False
            self.mp.wait_for_playback()
            self.play()
            entry['playing'] = False
            self.print(entry)

        if self.play_backward and -self.playing < len(self.played):
            self.playing -= 1
            t = self.played[self.playing], force
        elif self.playing < -1:
            self.playing += 1
            t = self.played[self.playing], force
        else:
            try:
                self.played.append(next(self.playlist))
            except StopIteration:
                return
            else:
                t = self.played[-1], force

        self.play_backward = False
        play_thread = Thread(target=mpv_play, args=t, daemon=True)
        play_thread.start()

    def _writeln(self, y, title, duration, attr):
        title_len = curses.COLS - DURATION_COL_LEN - 3
        self.adds(title, y, attr=attr)
        self.adds(duration, y, x=title_len+1, attr=attr)
        self.scr.refresh()

    def print(self, entry=None, y=None):
        """Print the entry in the line y."""
        if entry is y is None:
            entry = self.current()
            y = self.idx() - self.start + 1
        elif entry is None:
            entry = self.entries[self.start + y - 1]
        elif y is None:
            y = self.idx(entry) - self.start + 1
        if y < 1 or y > curses.LINES - 3: return

        c = {'error': 1, 'playing': 3, 'selected': 5}
        color = ((8 if entry is self.current() else 0)
            | reduce(int.__xor__, (c.get(i, 0) for i in entry if entry[i])))
        if color:
            self._writeln(y, entry['title'], entry['duration'],
                          curses.color_pair(color) | curses.A_BOLD)
        else:
            self._writeln(y, entry['title'], entry['duration'],
                          curses.A_NORMAL)

    def redraw(self):
        """Redraw the whole screen."""
        self._writeln(0, _("Title"), _("Duration"),
                      curses.color_pair(10) | curses.A_BOLD)
        for i, entry in enumerate(self.entries[self.start:][:curses.LINES-3]):
            self.print(entry, i + 1)
        self.scr.clrtobot()
        self.update_status()

    def __init__(self, entries, json_file, mode, mpv_vid, mpv_vo, ytdlf):
        Omp.__init__(self, entries, lambda name, val: self.update_status(),
                     json_file, mode, mpv_vid, mpv_vo, ytdlf)
        curses.noecho()
        curses.cbreak()
        self.scr.keypad(True)
        curses.curs_set(False)
        curses.start_color()
        curses.use_default_colors()
        for i in range(1, 8): curses.init_pair(i, i, -1)
        curses.init_pair(8, -1, 7)
        for i in range(1, 7): curses.init_pair(i + 8, -1, i)
        self.redraw()

    def __enter__(self): return self

    def idx(self, entry=None):
        """Return the index of the current entry."""
        if entry is None:
            return self.start + self.y - 1
        return self.entries.index(entry)

    def current(self):
        """Return the current entry."""
        try:
            return self.entries[self.idx()]
        except:
            return {}

    def gets(self, prompt):
        """Print the prompt string at the bottom of the screen then read
        from standard input.
        """
        self.adds(prompt, curses.LINES - 1, lpad=0)
        self.reading = True
        curses.curs_set(True)
        curses.echo()
        b = self.scr.getstr(curses.LINES - 1, len(prompt))
        self.reading = False
        curses.curs_set(False)
        curses.noecho()
        return b.decode()

    def seek(self, amount, reference='relative', precision='default-precise'):
        """Wrap mp.seek with a try clause to avoid crash when nothing is
        being played.
        """
        try:
            self.mp.seek(amount, reference, precision)
        except:
            pass

    def move(self, delta):
        """Move to the relatively next delta entry."""
        if not (self.entries and delta): return
        start, prev_entry = self.start, self.current()
        maxy = min(len(self.entries), curses.LINES - 3)

        if self.idx() + delta <= 0:
            self.start, self.y = 0, 1
        elif self.idx() + delta >= len(self.entries):
            self.start, self.y = len(self.entries) - maxy, maxy
        elif self.y + delta < 1:
            self.start += self.y + delta - 1
            self.y = 1
        elif self.y + delta > curses.LINES - 3:
            self.start += self.y + delta - maxy
            self.y = maxy
        else:
            self.y += delta

        if self.start == start:
            self.print(prev_entry)
            self.print()
        else:
            self.redraw()

    def search(self, backward=False):
        """Prompt then search for a pattern."""
        p = re.compile(self.gets('/'), re.IGNORECASE)
        entries = deque(self.entries)
        entries.rotate(-self.idx())
        self.search_res = deque(filter(
            lambda entry: p.search(entry['title']) is not None, entries))
        if backward: self.search_res.reverse()
        if self.search_res:
            self.move(self.idx(self.search_res[0]) - self.idx())
        else:
            self.update_status(_("Pattern not found"), curses.color_pair(1))

    def next_search(self, backward=False):
        """Repeat previous search."""
        if self.search_res:
            self.search_res.rotate(1 if backward else -1)
            self.move(self.idx(self.search_res[0]) - self.idx())
        else:
            self.update_status(_("Pattern not found"), curses.color_pair(1))

    def resize(self):
        curses.update_lines_cols()
        self.scr.clear()
        l = curses.LINES - 3
        if curses.COLS < MODE_STR_LEN + 42 or l < 1:    # too small
            sizeerr = _("Current size: {}x{}. Minimum size: {}x4.").format(
                curses.COLS, curses.LINES, MODE_STR_LEN + 42)
            self.scr.addstr(0, 0, sizeerr[:curses.LINES*curses.COLS-1])
            self.scr.refresh()
        elif self.y > l:    # shorter than the current entry
            self.start += self.y - l
            self.y = l
            self.redraw()
        elif 0 < self.start > len(self.entries) - l:    # longer than the list
            idx, self.start = self.idx(), min(0, len(self.entries) - l)
            self.y = idx - self.start + 1
            if self.y > l:
                self.start += self.y - l
                self.y = l
            self.redraw()
        else:
            self.redraw()

    def __exit__(self, exc_type, exc_value, traceback):
        curses.nocbreak()
        self.scr.keypad(False)
        curses.echo()
        curses.endwin()
        Omp.__exit__(self, exc_type, exc_value, traceback)


parser = ArgumentParser(description=_("Curses Omni Media Player"))
parser.add_argument('-e', '--extractor', default='youtube-dl',
                    choices=('json', 'mpv', 'youtube-dl'), required=False,
                    help=_("playlist extractor, default is youtube-dl"))
parser.add_argument('file', help=_("path or URL to the playlist to be opened"))
parser.add_argument('-c', '--config', default=USER_CONFIG, required=False,
                    help=_("path to the configuration file"))
parser.add_argument('--vid', required=False,
                    help=_("initial video channel. auto selects the default,\
                            no disables video"))
parser.add_argument('--vo', required=False, metavar='DRIVER',
                    help=_("specify the video output backend to be used. See\
                            VIDEO OUTPUT DRIVERS in mpv(1) man page for\
                            details and descriptions of available drivers"))
parser.add_argument('-f', '--format', required=False, metavar='YTDL_FORMAT',
                    help=_("video format/quality to be passed to youtube-dl"))
args = parser.parse_args()
entries = extract_info(args.file, args.extractor)
if entries is None:
    print(_("'{}': Can't extract playlist").format(args.file))
    exit()
json_file = args.file if args.extractor == 'json' else ''
config = ConfigParser()
config.read(args.config)
vid = args.vid or config.get('mpv', 'video', fallback='auto')
vo = args.vo or config.get('mpv', 'video-output', fallback=None)
mode = config.get('comp', 'play-mode', fallback='play-current')
ytdlf = args.format or config.get('youtube-dl', 'format', fallback='best')

with Comp(entries, json_file, mode, vid, vo, ytdlf) as comp:
    c = comp.scr.getch()
    while c != 113:     # letter q
        if c == 10:     # curses.KEY_ENTER doesn't work
            comp.update_playlist()
            comp.next(force=True)
        elif c == 32:   # space
            comp.current()['selected'] = not comp.current().get('selected')
            comp.move(1)
        elif c == 47:   # /
            comp.search()
        elif c == 60:   # <
            try:
                if comp.mp._get_property('time-pos', float) < 1:
                    comp.next(backward=True)
                else:
                    comp.seek(0, 'absolute')
            except:
                pass
        elif c == 62:   # >
            comp.next()
        elif c == 63:   # ?
            comp.search(backward=True)
        elif c == 65:   # letter A
            comp.mp.mute ^= True    # hack to toggle bool value
        elif c == 77:   # letter M
            comp.mode = MODES[(MODES.index(comp.mode) - 1) % 8]
            comp.update_status()
        elif c == 78:  # letter N
            comp.next_search(backward=True)
        elif c == 86:   # letter V
            comp.vid = 'auto' if comp.vid == 'no' else 'no'
            comp.mp.vid = comp.vid
            comp.update_status()
        elif c == 87:   # letter W
            s = comp.gets(_("Save playlist to [{}]: ").format(comp.json_file))
            if s: comp.json_file = s
            try:
                makedirs(dirname(abspath(comp.json_file)), exist_ok=True)
                with open(comp.json_file, 'w') as f:
                    json.dump(comp.entries, f, ensure_ascii=False,
                              indent=2, sort_keys=True)
            except:
                errmsg = _("'{}': Can't open file for writing").format(
                    comp.json_file)
                comp.update_status(errmsg, curses.color_pair(1))
            else:
                comp.update_status(_("'{}' written").format(comp.json_file))
        elif c == 100:  # letter d
            comp.entries.pop(comp.idx())
            if 1 < len(comp.entries) - curses.LINES + 4 == comp.start:
                comp.start -= 1
            elif comp.idx() == len(comp.entries):
                comp.y -= 1
            comp.redraw()
        elif c == 105:   # letter i
            extractor = comp.gets(_("Playlist extractor: "))
            filename = comp.gets(_("Insert: "))
            entries = extract_info(filename, extractor)
            if entries is None:
                comp.update_status(
                    _("'{}': Can't extract playlist").format(filename))
            else:
                bottom = comp.entries[comp.idx():]
                comp.entries = comp.entries[:comp.idx()]
                comp.entries.extend(entries)
                comp.entries.extend(bottom)
                comp.redraw()
        elif c == 109:  # letter m
            comp.mode = MODES[(MODES.index(comp.mode) + 1) % 8]
            comp.update_status()
        elif c == 110:  # letter n
            comp.next_search()
        elif c == 111:   # letter o
            extractor = comp.gets(_("Playlist extractor: "))
            filename = comp.gets(_("Open: "))
            entries = extract_info(filename, extractor)
            if entries is None:
                comp.update_status(
                    _("'{}': Can't extract playlist").format(filename))
            else:
                comp.entries, comp.start, comp.y = entries, 0, 1
                comp.redraw()
        elif c == 112:  # letter p
            comp.mp.pause ^= True
        elif c in (curses.KEY_UP, 107):     # up arrow or letter k
            comp.move(-1)
        elif c in (curses.KEY_DOWN, 106):   # down arrow or letter j
            comp.move(1)
        elif c in (curses.KEY_LEFT, 104):   # left arrow or letter h
            comp.seek(-5, precision='exact')
        elif c in (curses.KEY_RIGHT, 108):  # right arrow or letter l
            comp.seek(5, precision='exact')
        elif c == curses.KEY_HOME:  # home
            comp.move(-len(comp.entries))
        elif c == curses.KEY_END:   # end
            comp.move(len(comp.entries))
        elif c == curses.KEY_NPAGE:     # page down
            comp.move(curses.LINES - 4)
        elif c == curses.KEY_PPAGE:     # page up
            comp.move(4 - curses.LINES)
        elif c in (curses.KEY_F5, curses.KEY_RESIZE):
            comp.resize()
        c = comp.scr.getch()
