Source code for spectree.plugins.starlette_plugin

import inspect
from collections import namedtuple
from functools import partial
from json import JSONDecodeError
from typing import Any, Callable, Optional, get_type_hints

from starlette.convertors import CONVERTOR_TYPES
from starlette.requests import Request
from starlette.responses import HTMLResponse, JSONResponse
from starlette.routing import compile_path

from spectree._pydantic import BaseModel, ValidationError, serialize_model_instance
from spectree._types import ModelType
from spectree.plugins.base import (
    BasePlugin,
    Context,
    RawResponsePayload,
    validate_response,
)
from spectree.response import Response

METHODS = {"get", "post", "put", "patch", "delete"}
Route = namedtuple("Route", ["path", "methods", "func"])


class _PydanticResponseModel(BaseModel):
    __root__: Any


def PydanticResponse(content):
    class _PydanticResponse(JSONResponse):
        def render(self, content) -> bytes:
            self._model_class = content.__class__
            return super().render(
                serialize_model_instance(_PydanticResponseModel(__root__=content))
            )

    return _PydanticResponse(content)


[docs] class StarlettePlugin(BasePlugin): ASYNC = True
[docs] def __init__(self, spectree): super().__init__(spectree) self.conv2type = {conv: typ for typ, conv in CONVERTOR_TYPES.items()}
[docs] def register_route(self, app): self.app = app self.app.add_route( self.config.spec_url, lambda request: JSONResponse(self.spectree.spec), ) for ui in self.config.page_templates: self.app.add_route( f"/{self.config.path}/{ui}", lambda request, ui=ui: HTMLResponse( self.config.page_templates[ui].format( spec_url=self.config.spec_url, spec_path=self.config.path, **self.config.swagger_oauth2_config(), ) ), )
async def request_validation(self, request, query, json, form, headers, cookies): has_data = request.method not in ("GET", "DELETE") content_type = request.headers.get("content-type", "").lower() use_json = json and has_data and content_type == "application/json" use_form = ( form and has_data and any([x in content_type for x in self.FORM_MIMETYPE]) ) request.context = Context( query.parse_obj(request.query_params) if query else None, json.parse_obj(await request.json() or {}) if use_json else None, form.parse_obj(await request.form() or {}) if use_form else None, headers.parse_obj(request.headers) if headers else None, cookies.parse_obj(request.cookies) if cookies else None, )
[docs] async def validate( self, func: Callable, query: Optional[ModelType], json: Optional[ModelType], form: Optional[ModelType], headers: Optional[ModelType], cookies: Optional[ModelType], resp: Optional[Response], before: Callable, after: Callable, validation_error_status: int, skip_validation: bool, *args: Any, **kwargs: Any, ): if isinstance(args[0], Request): instance, request = None, args[0] else: instance, request = args[:2] response = None req_validation_error = resp_validation_error = json_decode_error = None try: await self.request_validation(request, query, json, form, headers, cookies) if self.config.annotations: annotations = get_type_hints(func) for name in ("query", "json", "form", "headers", "cookies"): if annotations.get(name): kwargs[name] = getattr(request.context, name) except ValidationError as err: req_validation_error = err response = JSONResponse(err.errors(), validation_error_status) except JSONDecodeError as err: json_decode_error = err self.logger.info( "%s Validation Error", validation_error_status, extra={"spectree_json_decode_error": str(err)}, ) response = JSONResponse({"error_msg": str(err)}, validation_error_status) before(request, response, req_validation_error, instance) if req_validation_error or json_decode_error: return response if inspect.iscoroutinefunction(func): response = await func(*args, **kwargs) else: response = func(*args, **kwargs) if ( not skip_validation and resp and response and not ( isinstance(response, JSONResponse) and hasattr(response, "_model_class") and response._model_class == resp.find_model(response.status_code) ) ): try: validate_response( validation_model=resp.find_model(response.status_code), response_payload=RawResponsePayload(payload=response.body), ) except ValidationError as err: response = JSONResponse(err.errors(), 500) after(request, response, resp_validation_error, instance) return response
[docs] def find_routes(self): routes = [] def parse_route(app, prefix=""): # :class:`starlette.staticfiles.StaticFiles` doesn't have routes if not app.routes: return for route in app.routes: if route.path.startswith(f"/{self.config.path}"): continue func = route.app if isinstance(func, partial): try: func = func.__wrapped__ except AttributeError as err: self.logger.warning( "failed to get the wrapped func %s: %s", func, err ) if inspect.isclass(func): for method in METHODS: if getattr(func, method, None): routes.append( Route( f"{prefix}{route.path}", {method.upper()}, getattr(func, method), ) ) elif inspect.isfunction(func): routes.append( Route(f"{prefix}{route.path}", route.methods, route.endpoint) ) else: parse_route(route, prefix=f"{prefix}{route.path}") parse_route(self.app) return routes
[docs] def bypass(self, func, method): if method in ["HEAD", "OPTIONS"]: return True return False
[docs] def parse_func(self, route): for method in route.methods or ["GET"]: yield method, route.func
[docs] def parse_path(self, route, path_parameter_descriptions): _, path, variables = compile_path(route.path) parameters = [] for name, conv in variables.items(): schema = None typ = self.conv2type[conv] if typ == "int": schema = {"type": "integer", "format": "int32"} elif typ == "float": schema = { "type": "number", "format": "float", } elif typ == "path": schema = { "type": "string", "format": "path", } elif typ == "str": schema = {"type": "string"} description = ( path_parameter_descriptions.get(name, "") if path_parameter_descriptions else "" ) parameters.append( { "name": name, "in": "path", "required": True, "schema": schema, "description": description, } ) return path, parameters