#!/usr/bin/env python3
#
# MIT License
# 
# s3-pit-restore, a point in time restore tool for Amazon S3
#
# Copyright (c) [2016] [Madisoft S.p.A.]
#
# Author: Matteo Moretti <matteo.moretti@madisoft.it>
# Author: Angelo Compagnucci <angelo.compagnucci@gmail.com>
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import os, sys, time, signal, argparse, boto3, botocore, \
        unittest, concurrent.futures, shutil, uuid, time
from datetime import datetime, timezone
from dateutil.parser import parse

args = None
executor = None
transfer = None
futures = {}

class TestS3PitRestore(unittest.TestCase):

    def generate_tree(self, path, contents):
        for i, content in enumerate(contents):
            folder_path = os.path.join(path, "folder%d" % i)
            os.makedirs(folder_path, exist_ok=True)
            file_path = os.path.join(folder_path, "file%d" % i)

            with open(file_path, 'w') as outfile:
                outfile.write(content)
            print(file_path, content)

    def check_tree(self, path, contents):
        for i, content in enumerate(contents):
            folder_path = os.path.join(path, "folder%d" % i)
            os.makedirs(folder_path, exist_ok=True)
            file_path = os.path.join(folder_path, "file%d" % i)

            in_content=""
            with open(file_path, 'r') as infile:
                in_content = infile.read()
                print(file_path, content, "<>", in_content)
                if in_content != content:
                    return False
        return True

    def upload_directory(self, resource, path, bucketname):
        for root,dirs,files in os.walk(path):
            for f in files:
                base_path = os.path.basename(os.path.normpath(path))
                local_path = os.path.join(root, f)
                relative_path = os.path.relpath(local_path, path)
                s3_path = os.path.join(base_path, relative_path)                
                resource.meta.client.upload_file(os.path.join(root,f), bucketname, s3_path)

    def test_restore(self):
        contents_before = [str(uuid.uuid4()), str(uuid.uuid4()), str(uuid.uuid4())]
        contents_after = [str(uuid.uuid4()), str(uuid.uuid4()), str(uuid.uuid4())]
        path = os.path.join(args.dest, "test-s3-pit-restore")
        s3 = boto3.resource('s3')

        print("Before ...")
        self.generate_tree(path, contents_before)
        self.upload_directory(s3, path, args.bucket)
        time.sleep(1)
        time_before = datetime.now(timezone.utc)

        print("Upload and owerwriting ...")
        self.generate_tree(path, contents_after)
        self.upload_directory(s3, path, args.bucket)
        shutil.rmtree(path)

        args.timestamp = str(time_before)
        args.prefix = os.path.basename(os.path.normpath(path))
        restore_stdout = sys.stdout
        sys.stdout = open(os.devnull, 'w')
        do_restore()
        sys.stdout = restore_stdout
        print("Restoring and checking ...")
        self.assertTrue(self.check_tree(path, contents_before))

def signal_handler(signal, frame):
    executor.shutdown(wait=False)
    for future in list(futures.keys()):
        if not future.running():
            future.cancel()
            futures.pop(future, None)
    print("Gracefully exiting ...")

def print_obj(obj, optional_message=""):
    if args.verbose:
        print('"%s" %s %s %s %s %s' % (obj["LastModified"], obj["VersionId"], obj["Size"], obj["StorageClass"], obj["Key"], optional_message))
    else:
        print(obj["Key"])

def handled_by_glacier(obj):
    if obj["StorageClass"] == "GLACIER" and not args.enable_glacier:
        print_obj(obj)
        return true
    elif obj["StorageClass"] == "GLACIER" and args.enable_glacier:
        s3 = boto3.resource('s3')
        s3_obj = s3.Object(args.bucket, obj["Key"])
        if s3_obj.restore is None:
            print_obj(obj, optional_message='requesting')
            if not args.dry_run:
                try:
                    s3_obj.restore_object(VersionId=obj["VersionId"], RestoreRequest={'Days': 3},)
                except botocore.exceptions.ClientError as ex:
                    # sometimes s3_obj.restore returns None also if restore is in progress
                    pass
            return True
        # Print out objects whose restoration is on-going
        elif 'ongoing-request="true"' in s3_obj.restore:
            print_obj(obj, optional_message='in-progress')
            return True
        # Print out objects whose restoration is complete
        elif 'ongoing-request="false"' in s3_obj.restore:
            return False
    else:
        return False

def handled_by_standard(obj):
    if obj["Key"].endswith("/"):
        if not os.path.exists(obj["Key"]):
            os.makedirs(key)
        return True
    key_path = os.path.dirname(obj["Key"])
    if key_path and not os.path.exists(key_path):
            os.makedirs(key_path)
    try:
        if not args.dry_run:
            future = executor.submit(download_file, obj)
            global futures
            futures[future] = obj
        else:
            print_obj(obj)
    except RuntimeError:
        return False
    return True

def download_file(obj):
    transfer.download_file(args.bucket, obj["Key"], obj["Key"], extra_args={"VersionId": obj["VersionId"]})
    unixtime = time.mktime(obj["LastModified"].timetuple())
    os.utime(obj["Key"],(unixtime, unixtime))

def do_restore():
    pit_date = (parse(args.timestamp) if args.timestamp else datetime.now())
    client = boto3.client('s3')
    global transfer
    transfer = boto3.s3.transfer.S3Transfer(client)
    dest = args.dest
    last_obj = {}
    objects = {}
    last_obj["Key"] = ""
    objects["IsTruncated"] = True
    objects["NextKeyMarker"] = ""

    if args.debug: boto3.set_stream_logger()

    global executor
    executor = concurrent.futures.ThreadPoolExecutor(args.max_workers)

    if not os.path.exists(dest):
        os.makedirs(dest)

    os.chdir(dest)

    # AWS gives us versions chunks of maximum 1000 element, cycling here to obtain more
    while (objects["IsTruncated"] == True):
        objects = client.list_object_versions(Bucket=args.bucket, Prefix=args.prefix, KeyMarker=objects["NextKeyMarker"])
        if not "Versions" in objects:
            print("No versions matching criteria, exiting ...")
            sys.exit(1)
        for obj in objects["Versions"]:
            if last_obj["Key"] != obj["Key"] and obj["LastModified"] <= pit_date:
                last_obj = obj

                if handled_by_glacier(obj):
                    continue

                if not handled_by_standard(obj):
                    objects["IsTruncated"] == False
                    break

        for future in concurrent.futures.as_completed(futures):
            if future in futures:
                print_obj(futures[future])
                del(futures[future])

if __name__=='__main__':
    signal.signal(signal.SIGINT, signal_handler)

    parser = argparse.ArgumentParser()
    parser.add_argument('-b', '--bucket', help='s3 bucket to restore from', required=True)
    parser.add_argument('-p', '--prefix', help='s3 path to restore from', default="")
    parser.add_argument('-t', '--timestamp', help='point in time to restore at')
    parser.add_argument('-d', '--dest', help='path where recovering to', required=True)
    parser.add_argument('-e', '--enable-glacier', help='enable recovering from glacier', action='store_true')
    parser.add_argument('-v', '--verbose', help='print verbose informations from s3 objects', action='store_true')
    parser.add_argument('--dry-run', help='execute query without transferring files', action='store_true')
    parser.add_argument('--debug', help='enable debug output', action='store_true')
    parser.add_argument('--test', help='s3 pit restore testing', action='store_true')
    parser.add_argument('--max-workers', help='max number of concurrent download requests', default=10, type=int)
    args = parser.parse_args()

    if not args.test:
        do_restore()
    else:
        runner = unittest.TextTestRunner()
        itersuite = unittest.TestLoader().loadTestsFromTestCase(TestS3PitRestore)
        runner.run(itersuite)

    sys.exit(0)
