Source code for spectree.plugins.flask_plugin

from typing import Any, Callable, Mapping, Optional, Tuple, get_type_hints

import flask
from flask import Blueprint, abort, current_app, jsonify, make_response, request
from werkzeug.routing import parse_converter_args

from spectree._pydantic import ValidationError
from spectree._types import ModelType
from spectree.plugins.base import BasePlugin, Context, validate_response
from spectree.response import Response
from spectree.utils import (
    flask_response_unpack,
    get_multidict_items,
    werkzeug_parse_rule,
)


[docs] class FlaskPlugin(BasePlugin): blueprint_state = None
[docs] def find_routes(self): for rule in current_app.url_map.iter_rules(): if any( str(rule).startswith(path) for path in (f"/{self.config.path}", "/static") ): continue if rule.endpoint.startswith("openapi"): continue if ( self.blueprint_state and self.blueprint_state.url_prefix and ( not str(rule).startswith(self.blueprint_state.url_prefix) or str(rule).startswith( "/".join([self.blueprint_state.url_prefix, self.config.path]) ) ) ): continue yield rule
[docs] def bypass(self, func, method): return method in ["HEAD", "OPTIONS"]
[docs] def parse_func(self, route: Any): if self.blueprint_state: func = self.blueprint_state.app.view_functions[route.endpoint] else: func = current_app.view_functions[route.endpoint] # view class: https://flask.palletsprojects.com/en/1.1.x/views/ view_cls = getattr(func, "view_class", None) if view_cls: for method in route.methods: view = getattr(view_cls, method.lower(), None) if view: yield method, view else: for method in route.methods: yield method, func
[docs] def parse_path( self, route: Optional[Mapping[str, str]], path_parameter_descriptions: Optional[Mapping[str, str]], ) -> Tuple[str, list]: subs = [] parameters = [] for converter, arguments, variable in werkzeug_parse_rule(str(route)): if converter is None: subs.append(variable) continue subs.append(f"{{{variable}}}") args: tuple = () kwargs: dict = {} if arguments: args, kwargs = parse_converter_args(arguments) schema = None if converter == "any": schema = { "type": "string", "enum": args, } elif converter == "int": schema = { "type": "integer", "format": "int32", } if "max" in kwargs: schema["maximum"] = kwargs["max"] if "min" in kwargs: schema["minimum"] = kwargs["min"] elif converter == "float": schema = { "type": "number", "format": "float", } elif converter == "uuid": schema = { "type": "string", "format": "uuid", } elif converter == "path": schema = { "type": "string", "format": "path", } elif converter == "string": schema = { "type": "string", } for prop in ["length", "maxLength", "minLength"]: if prop in kwargs: schema[prop] = kwargs[prop] elif converter == "default": schema = {"type": "string"} description = ( path_parameter_descriptions.get(variable, "") if path_parameter_descriptions else "" ) parameters.append( { "name": variable, "in": "path", "required": True, "schema": schema, "description": description, } ) return "".join(subs), parameters
[docs] def request_validation(self, request, query, json, form, headers, cookies): """ req_query: werkzeug.datastructures.ImmutableMultiDict req_json: dict req_headers: werkzeug.datastructures.EnvironHeaders req_cookies: werkzeug.datastructures.ImmutableMultiDict """ req_query = get_multidict_items(request.args, query) req_headers = dict(iter(request.headers)) or {} req_cookies = get_multidict_items(request.cookies) has_data = request.method not in ("GET", "DELETE") # flask Request.mimetype is already normalized use_json = json and has_data and request.mimetype not in self.FORM_MIMETYPE use_form = form and has_data and request.mimetype in self.FORM_MIMETYPE request.context = Context( query.parse_obj(req_query) if query else None, json.parse_obj(request.get_json(silent=True) or {}) if use_json else None, form.parse_obj(self._fill_form(request)) if use_form else None, headers.parse_obj(req_headers) if headers else None, cookies.parse_obj(req_cookies) if cookies else None, )
def _fill_form(self, request) -> dict: req_data = get_multidict_items(request.form) req_data.update(get_multidict_items(request.files) if request.files else {}) return req_data
[docs] def validate( self, func: Callable, query: Optional[ModelType], json: Optional[ModelType], form: Optional[ModelType], headers: Optional[ModelType], cookies: Optional[ModelType], resp: Optional[Response], before: Callable, after: Callable, validation_error_status: int, skip_validation: bool, *args: Any, **kwargs: Any, ): response, req_validation_error, resp_validation_error = None, None, None try: self.request_validation(request, query, json, form, headers, cookies) if self.config.annotations: annotations = get_type_hints(func) for name in ("query", "json", "form", "headers", "cookies"): if annotations.get(name): kwargs[name] = getattr(request.context, name) except ValidationError as err: req_validation_error = err response = make_response(jsonify(err.errors()), validation_error_status) before(request, response, req_validation_error, None) if req_validation_error: assert response # make mypy happy abort(response) result = func(*args, **kwargs) payload, status, additional_headers = flask_response_unpack(result) if isinstance(payload, flask.Response): payload, resp_status, resp_headers = ( payload.get_json(), payload.status_code, payload.headers, ) # the inner flask.Response.status_code only takes effect when there is # no other status code if status == 200: status = resp_status additional_headers.update(resp_headers) if not skip_validation and resp: try: response_validation_result = validate_response( validation_model=resp.find_model(status), response_payload=payload, ) except ValidationError as err: response = make_response(err.errors(), 500) else: response = make_response( ( response_validation_result.payload, status, additional_headers, ) ) else: response = make_response(result) after(request, response, resp_validation_error, None) return response
[docs] def register_route(self, app): app.add_url_rule( rule=self.config.spec_url, endpoint=f"openapi_{self.config.path}", view_func=lambda: jsonify(self.spectree.spec), ) if isinstance(app, Blueprint): def gen_doc_page(ui): spec_url = self.config.spec_url if self.blueprint_state.url_prefix is not None: spec_url = "/".join( ( self.blueprint_state.url_prefix.rstrip("/"), self.config.spec_url.lstrip("/"), ) ) return self.config.page_templates[ui].format( spec_url=spec_url, spec_path=self.config.path, **self.config.swagger_oauth2_config(), ) for ui in self.config.page_templates: app.add_url_rule( rule=f"/{self.config.path}/{ui}/", endpoint=f"openapi_{self.config.path}_{ui.replace('.', '_')}", view_func=lambda ui=ui: gen_doc_page(ui), ) app.record(lambda state: setattr(self, "blueprint_state", state)) else: for ui in self.config.page_templates: app.add_url_rule( rule=f"/{self.config.path}/{ui}/", endpoint=f"openapi_{self.config.path}_{ui}", view_func=lambda ui=ui: self.config.page_templates[ui].format( spec_url=self.config.spec_url, spec_path=self.config.path, **self.config.swagger_oauth2_config(), ), )