import io
import os
import gzip
import mimetypes
from typing import List, Union

from fastapi import status, FastAPI, File, Form, Request, UploadFile, APIRouter, HTTPException
from fastapi.responses import PlainTextResponse

import json
from fastapi.responses import StreamingResponse
from starlette.datastructures import Headers
from starlette.types import Send
from base64 import b64encode
from typing import Optional, Mapping, Iterator, Tuple
import secrets

app = FastAPI()
router = APIRouter()


{% set default_response_type = optional_param_value_map.pop("response_type", None) %}
{% if default_response_type %}
def is_expected_response_type(media_type, response_type):
    if media_type == "application/json" and response_type not in [dict, list]:
        return True
    elif media_type == "text/csv" and response_type != str:
        return True
    else:
        return False
{% endif %}
{{script}}

def get_validated_mimetype(file):
    """
    Return a file's mimetype, either via the file.content_type or the mimetypes lib if that's too
    generic. If the user has set UNSTRUCTURED_ALLOWED_MIMETYPES, validate against this list and
    return HTTP 400 for an invalid type.
    """
    content_type = file.content_type
    if not content_type or content_type == "application/octet-stream":
        content_type = mimetypes.guess_type(str(file.filename))[0]

        # Some filetypes missing for this library, just hardcode them for now
        if not content_type:
            if file.filename.endswith(".md"):
                content_type = "text/markdown"
            elif file.filename.endswith(".msg"):
                content_type = "message/rfc822"

    allowed_mimetypes_str = os.environ.get("UNSTRUCTURED_ALLOWED_MIMETYPES")
    if allowed_mimetypes_str is not None:
        allowed_mimetypes = allowed_mimetypes_str.split(",")

        if content_type not in allowed_mimetypes:
            raise HTTPException(
                status_code=400,
                detail=(
                    f"Unable to process {file.filename}: "
                    f"File type {content_type} is not supported."
                )
            )

    return content_type

{% if accepts_text or accepts_file %}

class MultipartMixedResponse(StreamingResponse):
    CRLF = b"\r\n"

    def __init__(self, *args, content_type: str=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.content_type = content_type

    def init_headers(self, headers: Optional[Mapping[str, str]] = None) -> None:
        super().init_headers(headers)
        self.boundary_value = secrets.token_hex(16)
        content_type = f'multipart/mixed; boundary="{self.boundary_value}"'
        self.raw_headers.append((b"content-type", content_type.encode("latin-1")))

    @property
    def boundary(self):
        return b'--' + self.boundary_value.encode()

    def _build_part_headers(self, headers: dict) -> bytes:
        header_bytes = b''
        for header, value in headers.items():
            header_bytes += f"{header}: {value}".encode() + self.CRLF
        return header_bytes

    def build_part(self, chunk: bytes) -> bytes:
        part = self.boundary + self.CRLF
        part_headers = {
            'Content-Length': len(chunk),
            'Content-Transfer-Encoding': 'base64'
        }
        if self.content_type is not None:
            part_headers['Content-Type'] = self.content_type
        part += self._build_part_headers(part_headers)
        part += self.CRLF + chunk + self.CRLF
        return part


    async def stream_response(self, send: Send) -> None:
        await send(
            {
                "type": "http.response.start",
                "status": self.status_code,
                "headers": self.raw_headers,
            }
        )
        async for chunk in self.body_iterator:
            if not isinstance(chunk, bytes):
                chunk = chunk.encode(self.charset)
                chunk = b64encode(chunk)
            await send({"type": "http.response.body", "body": self.build_part(chunk), "more_body": True})

        await send({"type": "http.response.body", "body": b"", "more_body": False})
{% endif %}


def ungz_file(file: UploadFile, gz_uncompressed_content_type=None) -> UploadFile:
    def return_content_type(filename):
        if gz_uncompressed_content_type:
            return gz_uncompressed_content_type
        else:
            return str(mimetypes.guess_type(filename)[0])

    filename = str(file.filename) if file.filename else ""
    if filename.endswith(".gz"):
        filename = filename[:-3]

    gzip_file = gzip.open(file.file).read()
    return UploadFile(
        file=io.BytesIO(gzip_file),
        size=len(gzip_file),
        filename=filename,
        headers=Headers({"content-type": return_content_type(filename)}),
    )



{% set default_response_schema = optional_param_value_map.pop("response_schema", None) %}
@router.post("{{ short_pipeline_path }}")
@router.post("{{pipeline_path}}")
def pipeline_1(request: Request,
gz_uncompressed_content_type: Optional[str] = Form(default=None),
{% if accepts_file %}files: Union[List[UploadFile], None] = File(default=None),{% endif %}
{% if accepts_text %}text_files: Union[List[UploadFile], None] = File(default=None),{% endif %}
{% if default_response_type %}output_format: Union[str, None] = Form(default=None),{% endif %}
{% if default_response_schema %}output_schema: str = Form(default=None),{% endif %}
{% for param in multi_string_param_names %}
{{param}}: List[str] = Form(default=[]),
{% endfor %}
):
{% if accepts_file %}
    if files:
        for file_index in range(len(files)):
           if files[file_index].content_type == "application/gzip":
               files[file_index] = ungz_file(files[file_index], gz_uncompressed_content_type)
{% endif %}
{% if accepts_text %}
    if text_files:
        for file_index in range(len(text_files)):
           if text_files[file_index].content_type == "application/gzip":
               text_files[file_index] = ungz_file(text_files[file_index])
{% endif %}
    content_type = request.headers.get("Accept")
{% if default_response_type %}
    default_response_type = output_format or "{{default_response_type}}"
    if not content_type or content_type == "*/*" or content_type == "multipart/mixed":
        media_type = default_response_type
    else:
        media_type = content_type
{% endif %}
{% if default_response_schema %}
    default_response_schema = output_schema or "{{default_response_schema}}"
{% endif %}

{% if accepts_text and accepts_file %}
    has_text = isinstance(text_files, list) and len(text_files)
    has_files = isinstance(files, list) and len(files)
    if not has_text and not has_files:
        raise HTTPException(
            detail="One of the request parameters \"text_files\" or \"files\" is required.\n",
            status_code=status.HTTP_400_BAD_REQUEST
        )
    files_list: List = files or []
    text_files_list: List = text_files or []

    if len(files_list) or len(text_files_list):
        if all([content_type, content_type not in [
            "*/*",
            "multipart/mixed",
            "application/json",
        ], len(files_list) + len(text_files_list) > 1]):
            raise HTTPException(
                detail=(f"Conflict in media type {content_type}"
                            " with response type \"multipart/mixed\".\n"),
                status_code=status.HTTP_406_NOT_ACCEPTABLE,
            )
        def response_generator(is_multipart):
            for text_file in text_files_list:
                text = text_file.file.read().decode("utf-8")

                response = pipeline_api(
                    text=text,
                    file=None,
                    {% for param in multi_string_param_names %}m_{{param}}={{param}}, {% endfor %}
                    {% if default_response_type %}response_type=media_type, {% endif %}
                    {% if default_response_schema %}response_schema=default_response_schema, {% endif %}
                )
                {% if default_response_type %}
                if is_expected_response_type(media_type, type(response)):
                    raise HTTPException(
                        detail=(
                            f"Conflict in media type {media_type}"
                            f" with response type {type(response)}.\n"
                        ),
                        status_code=status.HTTP_406_NOT_ACCEPTABLE,
                    )
                valid_response_types = ["application/json", "text/csv", "*/*", "multipart/mixed"]
                if media_type in valid_response_types:
                    if is_multipart:
                        if type(response) not in [str, bytes]:
                            response = json.dumps(response)
                    yield response
                else:
                    raise HTTPException(
                        detail=f"Unsupported media type {media_type}.\n",
                        status_code=status.HTTP_406_NOT_ACCEPTABLE,
                    )
                {% else %}
                if is_multipart:
                    if type(response) not in [str, bytes]:
                        response = json.dumps(response)
                yield response
                {% endif %}

            for file in files_list:
                _file = file.file

                {% if "file_content_type" in optional_param_value_map %}
                file_content_type = get_validated_mimetype(file)
                {% else %}
                get_validated_mimetype(file)
                {% endif %}

                response = pipeline_api(
                    text=None,
                    file=_file,
                    {% for param in multi_string_param_names %}m_{{param}}={{param}}, {% endfor %}
                    {% if default_response_type %}response_type=media_type, {% endif %}
                    {% if default_response_schema %}response_schema=default_response_schema, {% endif %}
                    {% if "filename" in optional_param_value_map %}filename=file.filename, {%endif%}
                    {% if "file_content_type" in optional_param_value_map %}file_content_type=file_content_type, {%endif%}
                )
                {% if default_response_type %}
                if is_expected_response_type(media_type, type(response)):
                    raise HTTPException(
                        detail=(
                            f"Conflict in media type {media_type}"
                            f" with response type {type(response)}.\n"
                        ),
                        status_code=status.HTTP_406_NOT_ACCEPTABLE,
                    )
                valid_response_types = ["application/json", "text/csv", "*/*", "multipart/mixed"]
                if media_type in valid_response_types:
                    if is_multipart:
                        if type(response) not in [str, bytes]:
                            response = json.dumps(response)
                    yield response
                else:
                    raise HTTPException(
                        detail=f"Unsupported media type {media_type}.\n",
                        status_code=status.HTTP_406_NOT_ACCEPTABLE,
                    )
                {% else %}
                if is_multipart:
                    if type(response) not in [str, bytes]:
                        response = json.dumps(response)
                yield response
                {% endif %}
        if content_type == "multipart/mixed":
            return MultipartMixedResponse(
                response_generator(is_multipart=True),
                {% if default_response_type %}content_type=media_type{% endif %}
            )
        else:
            return list(response_generator(is_multipart=False))[0] if len(files_list + text_files_list) == 1 else response_generator(is_multipart=False)
    else:
        raise HTTPException(
            detail='Request parameters "files" or "text_files" are required.\n',
            status_code=status.HTTP_400_BAD_REQUEST,
        )
{% elif accepts_text or accepts_file %}
    {% if accepts_text%}{% set var_name = "text_files" %}{% elif accepts_file %}{% set var_name = "files" %}{% endif %}
    if isinstance({{var_name}}, list) and len({{var_name}}):
        if len({{var_name}}) > 1:
            if content_type and content_type not in ["*/*", "multipart/mixed", "application/json"]:
                raise HTTPException(
                    detail=(f"Conflict in media type {content_type}"
                                " with response type \"multipart/mixed\".\n"),
                    status_code=status.HTTP_406_NOT_ACCEPTABLE,
                )
        def response_generator(is_multipart):
            for file in {{var_name}}:

                {% if "file_content_type" in optional_param_value_map %}
                file_content_type = get_validated_mimetype(file)
                {% else %}
                get_validated_mimetype(file)
                {% endif %}

                {% if accepts_text %}
                text = file.file.read().decode("utf-8")
                {% elif accepts_file %}
                _file = file.file
                {% endif %}

                response = pipeline_api(
                    {% if accepts_text %}text, {% elif accepts_file%}_file, {% endif %}
                    {% for param in multi_string_param_names %}m_{{param}}={{param}}, {% endfor %}
                    {% if default_response_type %}response_type=media_type, {% endif %}
                    {% if default_response_schema %}response_schema=default_response_schema, {% endif %}
                    {% if accepts_file %}
                        {% if "filename" in optional_param_value_map %}filename=file.filename, {%endif%}
                        {% if "file_content_type" in optional_param_value_map %}file_content_type=file_content_type, {%endif%}
                    {% endif %}
                )
                {% if default_response_type %}
                if is_expected_response_type(media_type, type(response)):
                    raise HTTPException(
                        detail=(
                            f"Conflict in media type {media_type}"
                            f" with response type {type(response)}.\n"
                        ),
                        status_code=status.HTTP_406_NOT_ACCEPTABLE,
                    )

                valid_response_types = ["application/json", "text/csv", "*/*", "multipart/mixed"]
                if media_type in valid_response_types:
                    if is_multipart:
                        if type(response) not in [str, bytes]:
                            response = json.dumps(response)
                    yield response
                else:
                    raise HTTPException(
                        detail=f"Unsupported media type {media_type}.\n",
                        status_code=status.HTTP_406_NOT_ACCEPTABLE,
                    )
                {% else %}
                if is_multipart:
                    if type(response) not in [str, bytes]:
                        response = json.dumps(response)
                yield response
                {% endif %}


        if content_type == "multipart/mixed":
            return MultipartMixedResponse(
                response_generator(is_multipart=True),
                {% if default_response_type %}content_type=media_type{% endif %})
        else:
            return list(response_generator(is_multipart=False))[0] if len({{var_name}}) == 1 else response_generator(is_multipart=False)
    else:
        raise HTTPException(
            detail="Request parameter \"{{var_name}}\" is required.\n",
            status_code=status.HTTP_400_BAD_REQUEST
        )
{% else %}
    {% if default_response_type %}
    response = pipeline_api({% for param in multi_string_param_names %}m_{{param}}={{param}}, {% endfor %} response_type=media_type,)
    {% else %}
    response = pipeline_api({% for param in multi_string_param_names %}{{param}}, {% endfor %})
    {% endif %}

    {% if default_response_type %}
    valid_response_types = ["application/json", "text/csv", "*/*"]
    if media_type in valid_response_types:
        return response
    else:
        raise HTTPException(
            detail=f"Unsupported media type {media_type}.\n",
            status_code=status.HTTP_406_NOT_ACCEPTABLE,
        )
    {% else %}
    return response
    {% endif %}
{% endif %}


app.include_router(router)
