#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import atexit
import os
import platform
import re
import signal
import subprocess
import sys
import tempfile
import threading
from collections import defaultdict
from io import StringIO
from itertools import zip_longest
from pathlib import Path
from typing import List, Tuple


# --------------- distlib ---------------

# Copyright (C) 2012-2017 Vinay Sajip.
# Licensed to the Python Software Foundation under a contributor agreement.
# See LICENSE.txt and CONTRIBUTORS.txt.
#
"""
Parser for the environment markers micro-language defined in PEP 508.
"""

# Note: In PEP 345, the micro-language was Python compatible, so the ast
# module could be used to parse it. However, PEP 508 introduced operators such
# as ~= and === which aren't in Python, necessitating a different approach.


_VERSION_PATTERN = re.compile(r'((\d+(\.\d+)*\w*)|\'(\d+(\.\d+)*\w*)\'|\"(\d+(\.\d+)*\w*)\")')
_VERSION_MARKERS = {'python_version', 'python_full_version'}
IDENTIFIER = re.compile(r'^([\w\.-]+)\s*')
VERSION_IDENTIFIER = re.compile(r'^([\w\.*+-]+)\s*')
COMPARE_OP = re.compile(r'^(<=?|>=?|={2,3}|[~!]=)\s*')
MARKER_OP = re.compile(r'^((<=?)|(>=?)|={2,3}|[~!]=|in|not\s+in)\s*')
OR = re.compile(r'^or\b\s*')
AND = re.compile(r'^and\b\s*')
NON_SPACE = re.compile(r'(\S+)\s*')
STRING_CHUNK = re.compile(r'([\s\w\.{}()*+#:;,/?!~`@$%^&=|<>\[\]-]+)')
PEP440_VERSION_RE = re.compile(r'^v?(\d+!)?(\d+(\.\d+)*)((a|alpha|b|beta|c|rc|pre|preview)(\d+)?)?(\.(post|r|rev)(\d+)?)?([._-]?(dev)(\d+)?)?(\+([a-zA-Z\d]+(\.[a-zA-Z\d]+)?))?$', re.I)
_DIGITS = re.compile(r'\d+\.\d+')


class UnsupportedVersionError(ValueError):
	"""This is an unsupported version."""
	pass


class Version(object):
	def __init__(self, s):
		self._string = s = s.strip()
		self._parts = parts = self.parse(s)
		assert isinstance(parts, tuple)
		assert len(parts) > 0

	def parse(self, s):
		raise NotImplementedError('please implement in a subclass')

	def _check_compatible(self, other):
		if type(self) != type(other):
			raise TypeError('cannot compare %r and %r' % (self, other))

	def __eq__(self, other):
		self._check_compatible(other)
		return self._parts == other._parts

	def __ne__(self, other):
		return not self.__eq__(other)

	def __lt__(self, other):
		self._check_compatible(other)
		return self._parts < other._parts

	def __gt__(self, other):
		return not (self.__lt__(other) or self.__eq__(other))

	def __le__(self, other):
		return self.__lt__(other) or self.__eq__(other)

	def __ge__(self, other):
		return self.__gt__(other) or self.__eq__(other)

	# See http://docs.python.org/reference/datamodel#object.__hash__
	def __hash__(self):
		return hash(self._parts)

	def __repr__(self):
		return "%s('%s')" % (self.__class__.__name__, self._string)

	def __str__(self):
		return self._string

	@property
	def is_prerelease(self):
		raise NotImplementedError('Please implement in subclasses.')


def _pep_440_key(s):
	s = s.strip()
	m = PEP440_VERSION_RE.match(s)
	if not m:
		raise UnsupportedVersionError('Not a valid version: %s' % s)
	groups = m.groups()
	nums = tuple(int(v) for v in groups[1].split('.'))
	while len(nums) > 1 and nums[-1] == 0:
		nums = nums[:-1]

	if not groups[0]:
		epoch = 0
	else:
		epoch = int(groups[0][:-1])
	pre = groups[4:6]
	post = groups[7:9]
	dev = groups[10:12]
	local = groups[13]
	if pre == (None, None):
		pre = ()
	else:
		if pre[1] is None:
			pre = pre[0], 0
		else:
			pre = pre[0], int(pre[1])
	if post == (None, None):
		post = ()
	else:
		if post[1] is None:
			post = post[0], 0
		else:
			post = post[0], int(post[1])
	if dev == (None, None):
		dev = ()
	else:
		if dev[1] is None:
			dev = dev[0], 0
		else:
			dev = dev[0], int(dev[1])
	if local is None:
		local = ()
	else:
		parts = []
		for part in local.split('.'):
			# to ensure that numeric compares as > lexicographic, avoid
			# comparing them directly, but encode a tuple which ensures
			# correct sorting
			if part.isdigit():
				part = (1, int(part))
			else:
				part = (0, part)
			parts.append(part)
		local = tuple(parts)
	if not pre:
		# either before pre-release, or final release and after
		if not post and dev:
			# before pre-release
			pre = ('a', -1)     # to sort before a0
		else:
			pre = ('z',)        # to sort after all pre-releases
	# now look at the state of post and dev.
	if not post:
		post = ('_',)   # sort before 'a'
	if not dev:
		dev = ('final',)

	#print('%s -> %s' % (s, m.groups()))
	return epoch, nums, pre, post, dev, local


_normalized_key = _pep_440_key


class NormalizedVersion(Version):
	"""A rational version.

	Good:
		1.2         # equivalent to "1.2.0"
		1.2.0
		1.2a1
		1.2.3a2
		1.2.3b1
		1.2.3c1
		1.2.3.4
		TODO: fill this out

	Bad:
		1           # minimum two numbers
		1.2a        # release level must have a release serial
		1.2.3b
	"""
	def parse(self, s):
		result = _normalized_key(s)
		# _normalized_key loses trailing zeroes in the release
		# clause, since that's needed to ensure that X.Y == X.Y.0 == X.Y.0.0
		# However, PEP 440 prefix matching needs it: for example,
		# (~= 1.4.5.0) matches differently to (~= 1.4.5.0.0).
		m = PEP440_VERSION_RE.match(s)      # must succeed
		groups = m.groups()
		self._release_clause = tuple(int(v) for v in groups[1].split('.'))
		return result

	PREREL_TAGS = set(['a', 'b', 'c', 'rc', 'dev'])

	@property
	def is_prerelease(self):
		return any(t[0] in self.PREREL_TAGS for t in self._parts if t)


def _match_prefix(x, y):
	x = str(x)
	y = str(y)
	if x == y:
		return True
	if not x.startswith(y):
		return False
	n = len(y)
	return x[n] == '.'


NV = NormalizedVersion


def _is_version_marker(s):
	return isinstance(s, str) and s in _VERSION_MARKERS


def _is_literal(o):
	if not isinstance(o, str) or not o:
		return False
	return o[0] in '\'"'


def _get_versions(s):
	return {NV(m.groups()[0]) for m in _VERSION_PATTERN.finditer(s)}


def parse_marker(marker_string):
	"""
	Parse a marker string and return a dictionary containing a marker expression.

	The dictionary will contain keys "op", "lhs" and "rhs" for non-terminals in
	the expression grammar, or strings. A string contained in quotes is to be
	interpreted as a literal string, and a string not contained in quotes is a
	variable (such as os_name).
	"""
	def marker_var(remaining):
		# either identifier, or literal string
		m = IDENTIFIER.match(remaining)
		if m:
			result = m.groups()[0]
			remaining = remaining[m.end():]
		elif not remaining:
			raise SyntaxError('unexpected end of input')
		else:
			q = remaining[0]
			if q not in '\'"':
				raise SyntaxError('invalid expression: %s' % remaining)
			oq = '\'"'.replace(q, '')
			remaining = remaining[1:]
			parts = [q]
			while remaining:
				# either a string chunk, or oq, or q to terminate
				if remaining[0] == q:
					break
				elif remaining[0] == oq:
					parts.append(oq)
					remaining = remaining[1:]
				else:
					m = STRING_CHUNK.match(remaining)
					if not m:
						raise SyntaxError('error in string literal: %s' % remaining)
					parts.append(m.groups()[0])
					remaining = remaining[m.end():]
			else:
				s = ''.join(parts)
				raise SyntaxError('unterminated string: %s' % s)
			parts.append(q)
			result = ''.join(parts)
			remaining = remaining[1:].lstrip() # skip past closing quote
		return result, remaining

	def marker_expr(remaining):
		if remaining and remaining[0] == '(':
			result, remaining = marker(remaining[1:].lstrip())
			if remaining[0] != ')':
				raise SyntaxError('unterminated parenthesis: %s' % remaining)
			remaining = remaining[1:].lstrip()
		else:
			lhs, remaining = marker_var(remaining)
			while remaining:
				m = MARKER_OP.match(remaining)
				if not m:
					break
				op = m.groups()[0]
				remaining = remaining[m.end():]
				rhs, remaining = marker_var(remaining)
				lhs = {'op': op, 'lhs': lhs, 'rhs': rhs}
			result = lhs
		return result, remaining

	def marker_and(remaining):
		lhs, remaining = marker_expr(remaining)
		while remaining:
			m = AND.match(remaining)
			if not m:
				break
			remaining = remaining[m.end():]
			rhs, remaining = marker_expr(remaining)
			lhs = {'op': 'and', 'lhs': lhs, 'rhs': rhs}
		return lhs, remaining

	def marker(remaining):
		lhs, remaining = marker_and(remaining)
		while remaining:
			m = OR.match(remaining)
			if not m:
				break
			remaining = remaining[m.end():]
			rhs, remaining = marker_and(remaining)
			lhs = {'op': 'or', 'lhs': lhs, 'rhs': rhs}
		return lhs, remaining

	return marker(marker_string)


class Evaluator(object):
	"""
	This class is used to evaluate marker expressions.
	"""

	operations = {
		'==': lambda x, y: x == y,
		'===': lambda x, y: x == y,
		'~=': lambda x, y: x == y or x > y,
		'!=': lambda x, y: x != y,
		'<':  lambda x, y: x < y,
		'<=':  lambda x, y: x == y or x < y,
		'>':  lambda x, y: x > y,
		'>=':  lambda x, y: x == y or x > y,
		'and': lambda x, y: x and y,
		'or': lambda x, y: x or y,
		'in': lambda x, y: x in y,
		'not in': lambda x, y: x not in y,
	}

	def evaluate(self, expr, context):
		"""
		Evaluate a marker expression returned by the :func:`parse_requirement`
		function in the specified context.
		"""
		if isinstance(expr, str):
			if expr[0] in '\'"':
				result = expr[1:-1]
			else:
				if expr not in context:
					raise SyntaxError('unknown variable: %s' % expr)
				result = context[expr]
		else:
			assert isinstance(expr, dict)
			op = expr['op']
			if op not in self.operations:
				raise NotImplementedError('op not implemented: %s' % op)
			elhs = expr['lhs']
			erhs = expr['rhs']
			if _is_literal(expr['lhs']) and _is_literal(expr['rhs']):
				raise SyntaxError('invalid comparison: %s %s %s' % (elhs, op, erhs))

			lhs = self.evaluate(elhs, context)
			rhs = self.evaluate(erhs, context)
			if ((_is_version_marker(elhs) or _is_version_marker(erhs)) and
				op in ('<', '<=', '>', '>=', '===', '==', '!=', '~=')):
				lhs = NV(lhs)
				rhs = NV(rhs)
			elif _is_version_marker(elhs) and op in ('in', 'not in'):
				lhs = NV(lhs)
				rhs = _get_versions(rhs)
			result = self.operations[op](lhs, rhs)
		return result


def in_venv():
	if hasattr(sys, 'real_prefix'):
		# virtualenv venvs
		result = True
	else:
		# PEP 405 venvs
		result = sys.prefix != getattr(sys, 'base_prefix', sys.prefix)
	return result


def default_context():
	def format_full_version(info):
		version = '%s.%s.%s' % (info.major, info.minor, info.micro)
		kind = info.releaselevel
		if kind != 'final':
			version += kind[0] + str(info.serial)
		return version

	if hasattr(sys, 'implementation'):
		implementation_version = format_full_version(sys.implementation.version)
		implementation_name = sys.implementation.name
	else:
		implementation_version = '0'
		implementation_name = ''

	ppv = platform.python_version()
	m = _DIGITS.match(ppv)
	pv = m.group(0)
	result = {
		'implementation_name': implementation_name,
		'implementation_version': implementation_version,
		'os_name': os.name,
		'platform_machine': platform.machine(),
		'platform_python_implementation': platform.python_implementation(),
		'platform_release': platform.release(),
		'platform_system': platform.system(),
		'platform_version': platform.version(),
		'platform_in_venv': str(in_venv()),
		'python_full_version': ppv,
		'python_version': pv,
		'sys_platform': sys.platform,
	}
	return result


DEFAULT_CONTEXT = default_context()
del default_context


evaluator = Evaluator()


def interpret_marker(marker, execution_context=None):
	"""
	Interpret a marker and return a result depending on environment.

	:param marker: The marker to interpret.
	:type marker: str
	:param execution_context: The context used for name lookup.
	:type execution_context: mapping
	"""
	try:
		expr, rest = parse_marker(marker)
	except Exception as e:
		raise SyntaxError('Unable to interpret marker syntax: %s: %s' % (marker, e))
	if rest and rest[0] != '#':
		raise SyntaxError('unexpected trailing data in marker: %s: %s' % (marker, rest))
	context = dict(DEFAULT_CONTEXT)
	if execution_context:
		context.update(execution_context)
	return evaluator.evaluate(expr, context)


# --------------- distlib ---------------


def parse_versions(versions):
	if not versions:
		raise argparse.ArgumentTypeError("Missing versions attribute")
	parsed_versions = []
	versions = versions.split(',')
	for version in versions:
		try:
			parsed_versions.append(tuple(int(num) for num in version.split('.')))
		except ValueError:
			raise argparse.ArgumentTypeError("Bad version format: %s" % version)
	return sorted(parsed_versions)


def get_output_and_args(compile_args):
	compile_args = compile_args[:]
	output_filename = os.path.splitext(compile_args[0].name)[0] + '.txt'

	out_arg_idx = None
	try:
		out_arg_idx = compile_args.index('-o')
	except ValueError:
		pass
	if out_arg_idx:
		if out_arg_idx < len(compile_args) - 1:
			output_filename = compile_args[out_arg_idx + 1]
			compile_args[out_arg_idx + 1] = '__OUTPUT__'
		else:
			compile_args.pop()
			out_arg_idx = None

	if out_arg_idx is None:
		compile_args.extend(['-o', '__OUTPUT__'])

	return output_filename, compile_args


def check_proc_status(proc):
	if proc.returncode:
		raise subprocess.CalledProcessError(proc.returncode, ' '.join(str(a) for a in proc.args))


def ensure_venv_exists(version_name, virtualenvs_dir):
	virtualenv_dir = virtualenvs_dir / f'py{version_name}'
	if not virtualenv_dir.exists():
		process_args = [f'python{version_name}', '-m', 'venv', str(virtualenv_dir)]
		proc = subprocess.Popen(process_args, preexec_fn=os.setpgrp)
		proc.wait()
		check_proc_status(proc)


class PipCompile(threading.Thread):
	def __init__(self, virtualenv_dir, compile_args, output_path):
		super().__init__()
		compile_args = [output_path if arg == '__OUTPUT__' else arg for arg in compile_args]
		self.virtualenv_dir = virtualenv_dir
		self.compile_args = compile_args
		self.requirements = ''
		self.output_path = output_path
		self.exc = None
		self.__stopping = False
		self.__proc = None

	@property
	def python_binary(self):
		return str(self.virtualenv_dir / 'bin' / 'python')

	@property
	def pip_binary(self):
		return str(self.virtualenv_dir / 'bin' / 'pip')

	@property
	def pip_compile_binary(self):
		return str(self.virtualenv_dir / 'bin' / 'pip-compile')

	def run(self):
		try:
			if self.__stopping:
				return
			self.__proc = subprocess.Popen([self.pip_binary, 'install', '--upgrade', 'pip-tools'], preexec_fn=os.setpgrp)
			self.__proc.communicate()
			if self.__stopping:
				return
			check_proc_status(self.__proc)
			self.__proc = None

			if self.__stopping:
				return
			self.__proc = subprocess.Popen([self.pip_compile_binary] + self.compile_args, preexec_fn=os.setpgrp)
			__, __ = self.__proc.communicate()
			if self.__stopping:
				return
			check_proc_status(self.__proc)
			with self.output_path.open('r') as fp:
				self.requirements = fp.read()
			self.__proc = None
		except BaseException as e:
			self.exc = e

	def stop(self):
		self.__stopping = True
		proc = self.__proc
		if proc is not None:
			try:
				group = os.getpgid(proc.pid)
			except Exception:
				group = None
			proc.terminate()
			try:
				proc.wait(0.1)
			except:
				proc.kill()
			if group is not None:
				try:
					os.killpg(group, signal.SIGKILL)
				except Exception:
					pass


def split_requirements_and_comments(requirements_txt: str) -> List[Tuple[str, str]]:
	data = re.split(r'^((?:[ \t]*[^ \t\n#]+)+)', requirements_txt, flags=re.MULTILINE)
	return [
		(requirement.strip(), comment)
		for requirement, comment
		in zip_longest(data[1::2], data[2::2], fillvalue='')
		if requirement.strip()
	]


def versions_to_markers(versions, all_versions):
	indexes = [all_versions.index(v) for v in versions]

	ranges = []

	range_start = indexes[0]
	range_end = indexes[0]

	for index in indexes:
		if index - range_end > 1:
			ranges.append((range_start, range_end))
			range_start = index
			range_end = index
		else:
			range_end = index

	ranges.append((range_start, range_end))

	markers = []

	for range_start, range_end in ranges:
		marker = []
		if range_start != 0:
			marker.append(f'python_version >= "{all_versions[range_start]}"')
		if range_end != len(all_versions) - 1:
			marker.append(f'python_version < "{all_versions[range_end + 1]}"')
		markers.append(' and '.join(marker))

	markers = [c for c in markers if c]

	if len(markers) > 1:
		markers = [f'({c})' for c in markers if c]

	return ' or '.join(markers)


def version_specific_requirements(requirements_txt, version_name):
	ctx = DEFAULT_CONTEXT.copy()
	ctx['python_version'] = version_name

	result = StringIO()

	version_requirements = split_requirements_and_comments(requirements_txt)
	for req, comment in version_requirements:
		try:
			__, markers = req.split(';', 1)
			markers = markers.strip()
		except ValueError:
			markers = ''
		if not markers or interpret_marker(markers, ctx):
			result.write(f'{req}{comment}')

	return result.getvalue()


def pip_compile(versions, compile_args, virtualenvs_dir):
	processes = []
	first_inputs = []
	version_names = [".".join(str(v) for v in version) for version in versions]

	output_filename, compile_args = get_output_and_args(compile_args)
	input_filename = Path(compile_args[0].name)

	tmp_dir = tempfile.TemporaryDirectory()
	tmp_dir_path = Path(tmp_dir.name)

	first_input = compile_args[0]
	first_input.seek(0)
	first_input_data = first_input.read()
	compile_args[0] = Path(first_input.name)
	first_input_replacement = first_input.name
	first_input.close()

	with open(output_filename, 'a+') as fp:
		fp.seek(0)
		current_saved_requirements = fp.read()

	dependency_number = 0
	dependency_pattern = '__dependenices__%d%s'
	dependency_replacements = {}

	def copy_dependencies(file_contents, base_dir):
		nonlocal dependency_number
		lines = file_contents.splitlines()
		for line_number, line in enumerate(lines):
			line = line.strip()
			if line.startswith('-c ') or line.startswith('-r '):
				cmd = line[:2]
				filename = line[3:].split('#')[0].strip()
				with base_dir.joinpath(filename).open() as fp:
					dependency_contents = fp.read()
				__, file_extension = os.path.splitext(filename)
				dependency_contents = copy_dependencies(dependency_contents, base_dir.joinpath(filename).parent)
				dependency_filename = dependency_pattern % (dependency_number, file_extension)

				input_path = base_dir.joinpath(filename).resolve()
				relative_filename = os.path.relpath(input_path, start=os.getcwd())

				dependency_replacements[str(tmp_dir_path / dependency_filename)] = relative_filename
				dependency_number += 1
				with (tmp_dir_path / dependency_filename).open('w') as fp:
					fp.write(dependency_contents)
				lines[line_number] = f'{cmd} {dependency_filename}'
		return '\n'.join(lines)

	first_input_data = copy_dependencies(first_input_data, input_filename.parent)

	def replace_dependencies(file_contents):
		for dependency_filename, original_filename in dependency_replacements.items():
			file_contents = file_contents.replace(dependency_filename, original_filename)
		return file_contents

	for version_name in version_names:
		version_compile_args = compile_args[:] # copy
		version_compile_args[0] = tmp_dir_path /  f'py{version_name}.in'
		first_inputs.append(str(version_compile_args[0]))
		with (tmp_dir_path / f'py{version_name}.in').open('w') as fp:
			fp.write(first_input_data)
		output_path = tmp_dir_path / f'py{version_name}.txt'
		with output_path.open('w') as fp:
			fp.write(version_specific_requirements(current_saved_requirements, version_name))

		virtualenv_dir = virtualenvs_dir / f'py{version_name}'
		ensure_venv_exists(version_name, virtualenvs_dir)
		compile_proc = PipCompile(virtualenv_dir, version_compile_args, output_path)
		compile_proc.start()
		processes.append(compile_proc)

	def at_exit(*args):
		for proc in processes:
			proc.stop()

	atexit.register(at_exit)
	signal.signal(signal.SIGTERM, at_exit)
	signal.signal(signal.SIGINT, at_exit)
	signal.signal(signal.SIGHUP, at_exit)

	requirements = []
	for proc, version in zip(processes, versions):
		proc.join()
		version_name = ".".join(str(v) for v in version)
		requirements.append(proc.requirements)
		if proc.exc:
			raise proc.exc

	version_specs = defaultdict(lambda: {'versions': [], 'markers': '', 'comment': ''})

	for version_name, first_tmp_input, version_requirements in zip(version_names, first_inputs, requirements):
		version_requirements = split_requirements_and_comments(version_requirements)
		for req, comment in version_requirements:
			try:
				req, markers = req.split(';', 1)
				markers = markers.strip()
			except ValueError:
				markers = ''
			req = req.strip()
			version_specs[req]['markers'] = markers
			version_specs[req]['versions'].append(version_name)
			version_specs[req]['comment'] = comment.replace(first_tmp_input, first_input_replacement)

	with open(output_filename, 'w') as fp:
		for spec, item in version_specs.items():
			fp.write(spec)
			markers = versions_to_markers(item['versions'], version_names)
			if item['markers'] and item['markers'] != markers:
				if markers:
					markers = f'({item["markers"]}) and ({markers})'
				else:
					markers = item["markers"]
			if markers:
				fp.write(f' ; {markers}')
			fp.write(replace_dependencies(item['comment']))


def main():
	parser = argparse.ArgumentParser(description="Compile requirements.in to single requirements.txt file with multiple python versions support")
	parser.add_argument('versions', type=parse_versions, help="Comma separated list of pythoon versions")
	parser.add_argument('--virtualenvs_dir', default=str(Path.home() / '.virtualenvs'), help="Where to create virtualenvs")
	parser.add_argument('input', type=argparse.FileType(mode='r'), help="First input file")
	parser.add_argument('pip_compile_args', nargs=argparse.REMAINDER, help="Arguments forwarded to pip-tools")
	args = parser.parse_args()
	compile_args = [args.input] + args.pip_compile_args
	pip_compile(args.versions, compile_args, Path(args.virtualenvs_dir))


if __name__ == "__main__":
	main()
