#!/usr/bin/env python3
# -*- coding: utf-8 -*-


# Copyright (c) 2018 Wolfgang Rohdewald <wolfgang@rohdewald.de>
# See LICENSE for details.

"""implement a server using the mapmytracks protocol.

There is one notable difference:

https://github.com/MapMyTracks/api/blob/master/services/stop_activity.md
says stop_activity has no parameter activity_id. Our server needs it,
Oruxmaps delivers it. Maybe the MMT API definition is wrong.
See https://github.com/MapMyTracks/api/issues/25

"""

# PYTHON_ARGCOMPLETE_OK
# for command line argument completion, put this into your .bashrc:
# eval "$(register-python-argcomplete gpxdo)"
# or see https://argcomplete.readthedocs.io/en/latest/


import os
import sys
import base64
import datetime
import argparse
import logging
import logging.handlers
import traceback

from http.server import HTTPServer, BaseHTTPRequestHandler
from urllib.parse import parse_qs

from gpxpy import gpx as mod_gpx

GPX = mod_gpx.GPX
GPXTrack = mod_gpx.GPXTrack
GPXTrackSegment = mod_gpx.GPXTrackSegment
GPXTrackPoint = mod_gpx.GPXTrackPoint
GPXXMLSyntaxException = mod_gpx.GPXXMLSyntaxException

# This uses not the installed copy but the development files
_ = os.path.dirname(sys.path[0] or sys.path[1])
if os.path.exists(os.path.join(_, 'gpxity', '__init__.py')):
    sys.path.insert(0, _)
# pylint: disable=wrong-import-position

from gpxity import Track, MMT, Authenticate, Backend, Lifetrack  # noqa pylint: disable=no-name-in-module

try:
    import argcomplete
    # pylint: disable=unused-import
    from argcomplete import ChoicesCompleter  # noqa
except ImportError:
    pass


class TrackingMessage:

    """Life tracking: represent a received message.

    Args:
        from_ip: the sender
        parsed: the parsed POST data
        result: the POST answer
        response: tuple(HTTP status code, message)

    """

    def __init__(self, command, from_ip, parsed, result, response):
        self.time = datetime.datetime.now()
        self.command = command
        self.from_ip = from_ip
        self.parsed = parsed
        self.result = result
        self.response = response

    def log(self, prefix=''):
        """Log the message"""
        if self.response:
            logging.error('E %s%s', prefix, self)
        else:
            logging.info('I %s%s', prefix, self)

    def __str__(self):
        response = '{} {}'.format(*self.response) if self.response else ''
        ident = self.parsed.get('activity_id')
        request = self.parsed.get('request')
        point_msg = ''
        if 'points' in self.parsed:
            points = MMTHandler.parse_points(self.parsed['points'])
            if points:
                point_msg = ' with {} points: {}'.format(len(points), ','.join(str(round(x.elevation)) for x in points))
        return '{} {} from {} id={} request={} {} --> {}  {}'.format(
            self.time, self.command, self.from_ip, ident, request, point_msg, self.result, response).strip()


class MMTHandler(BaseHTTPRequestHandler):

    """handles all HTTP requests."""

    users = None
    login_user = None

    def log_info(self, format, *args):  # pylint: disable=redefined-builtin
        """Override: Redirect into logger."""
        self.server.logger.info(format % args)

    def log_error(self, format, *args):  # pylint: disable=redefined-builtin
        """Override: redirect into logger."""
        self.server.logger.error(format % args)

    def check_basic_auth_pw(self):
        """basic http authentication."""
        if 'Authorization' not in self.headers:
            self.send_header('WWW-Authenticate', 'Basic realm=\"Test\"')
            self.return_error(401, 'Authorization required')

        if self.users is None:
            self.load_users()
        for pair in self.users.items():
            expect = b'Basic ' + base64.b64encode(':'.join(pair).encode('utf-8'))
            expect = expect.decode('utf-8')
            if expect == self.headers['Authorization']:
                return
        self.return_error(401, 'Authorization failed')

    def load_users(self):
        """load legal user auth from serverdirectory/.users."""
        self.users = dict()
        with open(os.path.join(self.server.targets[0].url, '.users')) as user_file:
            for line in user_file:
                user, password = line.strip().split(':')
                self.users[user] = password

    def return_error(self, code, reason):
        """Answer the clint with an xml formatted error message."""
        self.error_response = (code, reason)
        try:
            self.send_response(code, reason)
            xml = '<type>error</type><reason>{}</reason>'.format(reason)
            self.send_header('Content-Type', 'text/xml; charset=UTF-8')
            xml = '<?xml version="1.0" encoding="UTF-8"?><message>{}</message>'.format(xml)
            self.send_header('Content-Length', len(xml))
            self.end_headers()
            self.wfile.write(bytes(xml.encode('utf-8')))
        except Exception as exc:
            logging.error('return_error failed: %s', exc)

    def parseRequest(self):  # noqa pylint: disable=invalid-name
        """Get interesting things.

        Returns:
            A dict with the parsed results or None

        """
        if 'Content-Length' in self.headers:
            data_length = int(self.headers['Content-Length'])
            data = self.rfile.read(data_length).decode('utf-8')
            parsed = parse_qs(data)
            for key, value in parsed.items():
                if len(value) != 1:
                    self.return_error(400, '{} must appear only once'.format(key))
                parsed[key] = parsed[key][0]
            self._fix_headers(parsed)
            return parsed
        return None

    def homepage(self):
        """Return what the client needs."""
        self.load_users()
        names = list(sorted(self.users.keys()))
        return """
            <input type="hidden" value="{}" name="mid" id="mid" />
            """.format(names.index(self.login_user))

    @staticmethod
    def answer_with_categories():
        """Return all categories."""
        all_cat = MMT.legal_categories
        return ''.join('<li><input name="add-activity-x">&nbsp;{}</li>'.format(x) for x in all_cat)

    def do_GET(self):  # noqa pylint: disable=invalid-name
        """Override standard."""
        self.server.logger.info('%s GET %s %s %s', self.server.second(), self.client_address[0], self.server.server_port, self.path)
        self.parseRequest()  # side effect: may output debug info
        self.send_response(200, 'OK')
        self.send_header('WWW-Authenticate', 'Basic realm="MMTracks API"')
        if self.path == '/':
            xml = self.homepage()
        elif self.path.endswith('/explore/wall'):
            # the client wants to find out legal categories
            xml = self.answer_with_categories()
        elif self.path.startswith('//assets/php/gpx.php'):
            parameters = self.path.split('?')[1]
            request = parse_qs(parameters)
            wanted_id = request['tid'][0]
            xml = self.server.targets[0][wanted_id].to_xml()
        else:
            xml = ''
        self.send_header('Content-Type', 'text/xml; charset=UTF-8')
        self.server.logger.info('%s  returning %s', self.server.second(), xml)
        self.send_header('Content-Length', len(xml))
        self.end_headers()
        self.wfile.write(bytes(xml.encode('utf-8')))

    def do_POST(self):  # noqa pylint: disable=invalid-name
        """override standard."""
        self.error_response = None
        answer = ''
        message = None
        try:
            if True:
                parsed = self.parseRequest()
                message = TrackingMessage('POST', self.client_address[0], parsed, answer, self.error_response)
                self.check_basic_auth_pw()
                if self.error_response:
                    return
                if self.path != '/':
                    self.return_error(400, 'Url %s: Path must be empty', self.path)
                    return
                request = parsed.get('request')
                if not request:
                    self.return_error(400, 'No request given in {}'.format(parsed))
                    return
                try:
                    method = getattr(self, 'xml_{}'.format(request))
                except AttributeError:
                    self.return_error(400, 'Unknown request {}'.format(request))
                    return
                try:
                    answer = method(parsed) or ''
                except Exception as exc:
                    self.return_error(400, '{}: {}'.format(request, exc))
                    self.server.logger.debug(traceback.format_exc())
                    return
                xml = '<?xml version="1.0" encoding="UTF-8"?><message>{}</message>'.format(answer)
                self.send_response(200, 'OK')
                self.send_header('WWW-Authenticate', 'Basic realm="MMTracks API"')
                self.send_header('Content-Type', 'text/xml; charset=UTF-8')
                self.send_header('Content-Length', len(xml))
                self.end_headers()
                self.wfile.write(bytes(xml.encode('utf-8')))
            else:
                logging.debug('got error_response: %s', str(self.error_response))
        finally:
            try:
                if message is None:
                    message = TrackingMessage('POST', self.client_address[0], parsed, answer, self.error_response)
                else:
                    message.result = answer
                    message.response = self.error_response
                message.log(prefix='  ')
                self.server.history.append(message)
            except Exception as exc:
                logging.error(exc)

    @staticmethod
    def xml_get_time(_) ->str:
        """Get server time as defined by the mapmytracks API.

        Returns:
            Our answer

        """
        return '<type>time</type><server_time>{}</server_time>'.format(
            int(datetime.datetime.now().timestamp()))

    def xml_get_tracks(self, parsed) ->str:
        """List all tracks as defined by the mapmytracks API.

        Returns:
            Our answer

        """
        a_list = list()
        if parsed['offset'] == '0':
            for idx, _ in enumerate(self.server.targets[0]):
                a_list.append(
                    '<track{}><id>{}</id>'
                    '<title><![CDATA[ {} ]]></title>'
                    '<activity_type>{}</activity_type>'
                    '<date>{}</date>'
                    '</track{}>'.format(
                        idx + 1, _.id_in_backend, _.title, _.category,
                        int(_.time.timestamp()), idx + 1))
        return '<tracks>{}</tracks>'.format(''.join(a_list))

    @staticmethod
    def parse_points(raw):
        """convert raw data back into list(GPXTrackPoint).

        Returns:
            list(GPXTrackPoint)

        """
        values = raw.split()
        if len(values) % 4:
            raise Exception('Point element count {} is not a multiple of 4'.format(len(values)))
        result = list()
        for idx in range(0, len(values), 4):
            try:
                time = datetime.datetime.utcfromtimestamp(float(values[idx + 3]))
            except ValueError:
                logging.error('Point has illegal time stamp {}'.format(values))
                time=0
            point = GPXTrackPoint(
                latitude=float(values[idx]),
                longitude=float(values[idx + 1]),
                elevation=float(values[idx + 2]),
                time=time)
            result.append(point)
        return result

    def xml_upload_activity(self, parsed) ->str:
        """Upload an activity as defined by the mapmytracks API.

        Returns:
            Our answer

        """
        track = Track()
        track.parse(parsed['gpx_file'])
        self.targets[0].add(track)
        new_ident = track.id_in_backend
        for _ in self.targets[1:]:
            _.add(track)
        return '<type>success</type><id>{}</id>'.format(new_ident)

    def _fix_headers(self, parsed):
        """Fix some not so nice things in headers."""
        if 'privicity' in parsed:
            parsed['privacy'] = parsed['privicity']
            del parsed['privicity']
        if parsed.get('request') == 'start_activity' and 'title' not in parsed:
            parsed['title'] = ''
            if parsed.get('source') == 'OruxMaps':
                # shorten monster title 2018-10-03 00:0020181003_0018
                parsed['title'] = parsed['title'][:16]

    def _new_Tracker(self, parsed) ->Lifetrack:
        """Create a Lifetrack object.

        Returns: The new object

        """
        result = Lifetrack(self.client_address[0], self.server.targets, ids=parsed.get('activity_id'))
        if result is not None:
            sender_ip = self.client_address[0]
            trackers = self.server.trackers
            same_ip_done_trackers = [x for x in trackers.values() if x.sender_ip == result.sender_ip and x.done]
            for _ in same_ip_done_trackers:
                if _.sender_ip in trackers:
                    del trackers[_.sender_ip]
                if _.formatted_ids() in trackers:
                    del trackers[_.formatted_ids()]
            trackers[sender_ip] = result
        return result

    def find_tracker(self, parsed):
        """Find matching tracker.

        Returns: the wanted tracker or None.

        """
        sender_ip = self.client_address[0]
        trackers = self.server.trackers
        request = parsed['request']
        if request == 'start_activity':
            tracker = trackers.get(sender_ip)
            if tracker is not None:
                if tracker.done:
                    del trackers[sender_ip]
                    tracker = None
                else:
                    parsed['activity_id'] = tracker.formatted_ids()
        elif 'activity_id' in parsed:
            tracker = trackers.get(parsed['activity_id'])
            if tracker:
                # the sender IP may have changed, we need this for standard stop_activity
                # without activity_id
                trackers[sender_ip] = tracker
        else:
            # The standard mapmytracks protocol does not include activity_id in stop_activity, but GPS Forwarder does.
            # This branch covers the standard.
            tracker = trackers.get(sender_ip)
        if tracker is None and trackers:
            logging.error('No tracker found for IP=%s and activity_id=%s in %s', sender_ip, parsed.get('activity_id'),
            ', '.join('{}:{}'.format(k, v.formatted_ids().replace('----','->')) for k, v in trackers.items()))
        return tracker

    def xml_start_activity(self, parsed) ->str:
        """start Lifetrack server.

        Returns:
            Our answer or None if there was an error

        """
        if self.error_response is None:
            tracker = self.find_tracker(parsed)
            points = self.parse_points(parsed['points'])
            if tracker is None:
                tracker = self._new_Tracker(parsed)
                title = parsed.get('title', '')
                public = parsed.get('privacy', 'private') == 'public'
                # the MMT API example uses cycling instead of Cycling,
                # and Oruxmaps does so too.
                category = parsed.get('activity', MMT.legal_categories[0]).capitalize()
                tracker.start(points, title, public, category)
                self.server.trackers[tracker.formatted_ids()] = tracker
            else:
                tracker.update(points)
            logging.info(
                'Starting %s, remote software: %s version %s',
                tracker, parsed.get('source'), parsed.get('version'))
            return '<type>activity_started</type><activity_id>{}</activity_id>'.format(tracker.formatted_ids())

    def xml_update_activity(self, parsed) ->str:
        """Get new points.

        Returns:
            Our answer

        """
        updated = '<type>activity_updated</type>'
        tracker = self.find_tracker(parsed)
        if tracker is None:
            tracker = self._new_Tracker(parsed)
            self.server.trackers[tracker.formatted_ids()] = tracker

        tracker.update(self.parse_points(parsed['points']))
        if tracker.done:
            self.return_error(400, 'update_activity: {} was already stopped'.format(tracker))
            return updated
        return updated

    def xml_stop_activity(self, parsed) ->str:  # pylint: disable=unused-argument
        """Client says stop.

        mapmytracks.com says we do not need to get activity_id here. So we just
        stop all trackers which might be meant. See
        https://github.com/MapMyTracks/api/issues/25

        # TODO: This will fail if the sender ip changes just before stop_activity
        # and I see now way how to fix that.

        Returns:
            Our answer

        """
        trackers = [self.find_tracker(parsed)]
        if trackers == [None]:
            trackers = [x for x in self.server.trackers.values() if not x.done]
        for tracker in trackers:
            if not tracker.done:
                logging.info('Stopping %s', tracker)
                tracker.end()
        return '<type>activity_stopped</type>'


class LifeServerMMT(HTTPServer):

    """A simple MMT server for life tracking.

    Attributes:
        trackers: A dict of all Lifetrack instances, key is the id_in_backend (activity_id).
            No tracker is ever removed. At least for now. Change that some time after
            Oruxmaps got a fix for not doing update_activity after stop_activity.

    """

    def __init__(self, options):
        """See class docstring."""
        super(LifeServerMMT, self).__init__((options.servername, options.port), MMTHandler)
        first_target = options.target[0]
        if not first_target.lower().startswith('serverdirectory:'):
            first_target = 'serverdirectory:' + first_target
        self.targets = [Backend.instantiate(first_target)]
        Authenticate.path = self.targets[0].url + '/auth.cfg'
        self.logger = self.define_logger(options)
        self.targets.extend(Backend.instantiate(x, options.timeout) for x in options.target[1:])
        self.start_second = datetime.datetime.now().timestamp()
        self.trackers = dict()
        self.history = list()

    def define_logger(self, options):
        result = logging.getLogger()
        result.setLevel(options.loglevel.upper())
        logfile = logging.FileHandler(os.path.join(self.targets[0].url,'gpxity_server.log'))
        logfile.setLevel(logging.DEBUG)
        result.addHandler(logfile)
        logging.getLogger('urllib3').level = logging.FATAL
        return result

    def second(self):
        return '{:10.4f}'.format(datetime.datetime.now().timestamp() - self.start_second)

def create_parser():
        epilog = """
        The MMT server uses BASIC AUTH for login. Please define user:password in  the file .users in the first target.
        Please define authorization for the other targets in the file auth.cfg in the first target."""
        parser = argparse.ArgumentParser('gpxity_server', epilog=epilog)
        parser.add_argument('--servername', help='the name of this server', required=True)
        parser.add_argument('--port', help='listen on PORT', type=int, required=True)
        parser.add_argument('--smtp-port', help='PORT for mailer, default is 25', type=int, default=25)
        parser.add_argument('--loglevel', help='set the loglevel, default is error', choices=('debug', 'info', 'warning', 'error'), default='error')
        parser.add_argument('--timeout', help="""
            either one value in seconds or two comma separated values: The first one is the connection
            timeout, the second one is the read timeout. Default is to wait forever.""", type=str, default=None)
        parser.add_argument('target', help='backends who should receive the data. The first one must be a local directory', nargs='+')
        return parser

class Main:  # pylint: disable=too-few-public-methods

    """main."""

    def __init__(self):
        """See class docstring."""
        parser = create_parser()

        try:
            argcomplete.autocomplete(parser)
        except NameError:
            pass

        options = parser.parse_args()
        if options.timeout is not None:
            options.timeout = [int(x.strip()) for x in options.timeout.split(',')]

        LifeServerMMT(options).serve_forever()

if __name__ == '__main__':
    Main()
