#!/usr/bin/env python3
import sys
import os
import shutil
import tempfile
import subprocess
import argparse

__version__ = "0.6.1"

def pfd(args, srr_id, extra_args, n_spots):
    tmp_dir = tempfile.TemporaryDirectory(prefix="pfd_",dir=args.tmpdir)
    sys.stderr.write("tempdir: {}\n".format(tmp_dir.name))

    avg = int(n_spots / args.threads)
    out = []
    last = 1
    for i in range(0,args.threads):
        out.append([last,last + avg-1])
        last += avg
        if i == args.threads-1: out[i][1] += n_spots % args.threads
    sys.stderr.write("blocks: {}\n".format(out))

    ps = []
    for i in range(0,args.threads):
        d = os.path.join(tmp_dir.name, str(i))
        os.mkdir(d)
        p = subprocess.Popen(["fastq-dump", "-N", str(out[i][0]), "-X", str(out[i][1]), "-O", d]+extra_args+[srr_id])
        ps.append(p)

    ps[0].wait()
    files = os.listdir(os.path.join(tmp_dir.name, "0"))
    wfd = {}
    for fo in files:
        wfd[fo] = open(os.path.join(args.outdir,fo), "wb")

    for i in range(0,args.threads):
        exit_code = ps[i].wait()
        if exit_code != 0:
            sys.stderr.write("exit code: {}\n".format(exit_code))
            sys.exit(1)
        for fo in files:
            with open(os.path.join(tmp_dir.name,str(i),fo), "rb") as fd:
                shutil.copyfileobj(fd, wfd[fo])

def get_spot_count(sra_id):
    p = subprocess.Popen(["sra-stat", "--meta", "--quick", sra_id], stdout=subprocess.PIPE)
    stdout, stderr = p.communicate()
    txt = stdout.decode().rstrip().split("\n")
    total = 0
    for l in txt:
        total += int(l.split("|")[2].split(":")[0])
    sys.stderr.write("{} spots: {}\n".format(sra_id,total))
    return total

def partition(f, l):
    r = ([],[])
    for i in l:
        if f(i):
            r[0].append(i)
        else:
            r[1].append(i)
    return r

def main():
    parser = argparse.ArgumentParser(description="parallel fastq-dump wrapper, extra args will be passed through")
    parser.add_argument("-s","--sra-id", help="SRA id", action="append")
    parser.add_argument("-t","--threads", help="number of threads", default=1, type=int)
    parser.add_argument("-O","--outdir", help="output directory", default=".")
    parser.add_argument("--tmpdir", help="temporary directory", default=None)
    parser.add_argument("-V", "--version", help="shows version", action="store_true")
    args, extra = parser.parse_known_args()

    if args.version:
        print("parallel-fastq-dump : {}".format(__version__))
        subprocess.Popen(["fastq-dump", "-V"]).wait()
        sys.exit(0)

    elif args.sra_id:
        extra_srrs, extra_args = partition(lambda s: s.startswith('SRR'), extra)
        args.sra_id.extend(extra_srrs)
        sys.stderr.write("SRR ids: {}\n".format(args.sra_id))
        sys.stderr.write("extra args: {}\n".format(extra_args))

        if args.outdir:
            if not os.path.isdir(args.outdir):
                os.mkdir(args.outdir)

        for si in args.sra_id:
            n_spots = get_spot_count(si)
            pfd(args, si, extra_args, n_spots)

    else:
        parser.print_help()
        sys.exit(1)

if __name__ == "__main__":
    main()

