import inspect
import re
from functools import partial
from typing import Any, Callable, Dict, List, Mapping, Optional, get_type_hints
from falcon import HTTP_400, HTTP_415, HTTPError
from falcon.routing.compiled import _FIELD_PATTERN as FALCON_FIELD_PATTERN
from spectree._pydantic import ValidationError
from spectree._types import ModelType
from spectree.plugins.base import BasePlugin, validate_response
from spectree.response import Response
class OpenAPI:
def __init__(self, spec: Mapping[str, str]):
self.spec = spec
def on_get(self, _: Any, resp: Any):
resp.media = self.spec
class DocPage:
def __init__(self, html: str, **kwargs: Any):
self.page = html.format(**kwargs)
def on_get(self, _: Any, resp: Any):
resp.content_type = "text/html"
resp.text = self.page
class OpenAPIAsgi(OpenAPI):
async def on_get(self, req: Any, resp: Any):
super().on_get(req, resp)
class DocPageAsgi(DocPage):
async def on_get(self, req: Any, resp: Any):
super().on_get(req, resp)
DOC_CLASS: List[str] = [
x.__name__ for x in (DocPage, OpenAPI, DocPageAsgi, OpenAPIAsgi)
]
HTTP_500: str = "500 Internal Service Response Validation Error"
[docs]
class FalconPlugin(BasePlugin):
OPEN_API_ROUTE_CLASS = OpenAPI
DOC_PAGE_ROUTE_CLASS = DocPage
[docs]
def __init__(self, spectree):
super().__init__(spectree)
self.FALCON_MEDIA_ERROR_CODE = (HTTP_400, HTTP_415)
# NOTE from `falcon.routing.compiled.CompiledRouterNode`
self.ESCAPE = r"[\.\(\)\[\]\?\$\*\+\^\|]"
self.ESCAPE_TO = r"\\\g<0>"
self.EXTRACT = r"{\2}"
# NOTE this regex is copied from werkzeug.routing._converter_args_re and
# modified to support only int args
self.INT_ARGS = re.compile(
r"""
((?P<name>\w+)\s*=\s*)?
(?P<value>\d+)\s*
""",
re.VERBOSE,
)
self.INT_ARGS_NAMES = ("num_digits", "min", "max")
[docs]
def register_route(self, app: Any):
self.app = app
self.app.add_route(
self.config.spec_url, self.OPEN_API_ROUTE_CLASS(self.spectree.spec)
)
for ui in self.config.page_templates:
self.app.add_route(
f"/{self.config.path}/{ui}",
self.DOC_PAGE_ROUTE_CLASS(
self.config.page_templates[ui],
spec_url=self.config.spec_url,
spec_path=self.config.path,
**self.config.swagger_oauth2_config(),
),
)
[docs]
def find_routes(self):
routes = []
def find_node(node):
if node.resource and node.resource.__class__.__name__ not in DOC_CLASS:
routes.append(node)
for child in node.children:
find_node(child)
for route in self.app._router._roots:
find_node(route)
return routes
[docs]
def parse_func(self, route: Any) -> Dict[str, Any]:
return route.method_map.items()
[docs]
def parse_path(self, route, path_parameter_descriptions):
subs, parameters = [], []
for segment in route.uri_template.strip("/").split("/"):
matches = FALCON_FIELD_PATTERN.finditer(segment)
if not matches:
subs.append(segment)
continue
escaped = re.sub(self.ESCAPE, self.ESCAPE_TO, segment)
subs.append(FALCON_FIELD_PATTERN.sub(self.EXTRACT, escaped))
for field in matches:
variable, converter, argstr = [
field.group(name) for name in ("fname", "cname", "argstr")
]
if converter == "int":
if argstr is None:
argstr = ""
arg_values = [None, None, None]
for i, match in enumerate(self.INT_ARGS.finditer(argstr)):
name, value = match.group("name"), match.group("value")
index = i
if name:
index = self.INT_ARGS_NAMES.index(name)
arg_values[index] = value
num_digits, minimum, maximum = arg_values
schema = {
"type": "integer",
"format": f"int{num_digits}" if num_digits else "int32",
}
if minimum:
schema["minimum"] = minimum
if maximum:
schema["maximum"] = maximum
elif converter == "uuid":
schema = {"type": "string", "format": "uuid"}
elif converter == "dt":
schema = {
"type": "string",
"format": "date-time",
}
else:
# no converter specified or customized converters
schema = {"type": "string"}
description = (
path_parameter_descriptions.get(variable, "")
if path_parameter_descriptions
else ""
)
parameters.append(
{
"name": variable,
"in": "path",
"required": True,
"schema": schema,
"description": description,
}
)
return f'/{"/".join(subs)}', parameters
def request_validation(self, req, query, json, form, headers, cookies):
if query:
req.context.query = query.parse_obj(req.params)
if headers:
req.context.headers = headers.parse_obj(req.headers)
if cookies:
req.context.cookies = cookies.parse_obj(req.cookies)
if json:
try:
media = req.media
except HTTPError as err:
if err.status not in self.FALCON_MEDIA_ERROR_CODE:
raise
media = None
req.context.json = json.parse_obj(media)
if form:
# TODO - possible to pass the BodyPart here?
# req_form = {x.name: x for x in req.get_media()}
req_form = {x.name: x.stream.read() for x in req.get_media()}
req.context.form = form.parse_obj(req_form)
[docs]
def validate(
self,
func: Callable,
query: Optional[ModelType],
json: Optional[ModelType],
form: Optional[ModelType],
headers: Optional[ModelType],
cookies: Optional[ModelType],
resp: Optional[Response],
before: Callable,
after: Callable,
validation_error_status: int,
skip_validation: bool,
*args: Any,
**kwargs: Any,
):
# falcon endpoint method arguments: (self, req, resp)
_self, _req, _resp = args[:3]
req_validation_error, resp_validation_error = None, None
try:
self.request_validation(_req, query, json, form, headers, cookies)
if self.config.annotations:
annotations = get_type_hints(func)
for name in ("query", "json", "form", "headers", "cookies"):
if annotations.get(name):
kwargs[name] = getattr(_req.context, name)
except ValidationError as err:
req_validation_error = err
_resp.status = f"{validation_error_status} Validation Error"
_resp.media = err.errors()
before(_req, _resp, req_validation_error, _self)
if req_validation_error:
return None
result = func(*args, **kwargs)
if not self._data_set_manually(_resp) and not skip_validation and resp:
try:
status = int(_resp.status[:3])
response_validation_result = validate_response(
validation_model=resp.find_model(status),
response_payload=_resp.media,
)
except ValidationError as err:
resp_validation_error = err
_resp.status = HTTP_500
_resp.media = err.errors()
else:
_resp.media = response_validation_result.payload
after(_req, _resp, resp_validation_error, _self)
# `falcon` doesn't use this return value. However, some users may have
# their own processing logics that depend on this return value.
return result
@staticmethod
def _data_set_manually(resp):
return (resp.text is not None or resp.data is not None) and resp.media is None
[docs]
def bypass(self, func, method):
if isinstance(func, partial):
return True
return inspect.isfunction(func)
[docs]
class FalconAsgiPlugin(FalconPlugin):
"""Light wrapper around default Falcon plug-in to support Falcon 3.0 ASGI apps"""
ASYNC = True
OPEN_API_ROUTE_CLASS = OpenAPIAsgi
DOC_PAGE_ROUTE_CLASS = DocPageAsgi
async def request_validation(self, req, query, json, form, headers, cookies):
if query:
req.context.query = query.parse_obj(req.params)
if headers:
req.context.headers = headers.parse_obj(req.headers)
if cookies:
req.context.cookies = cookies.parse_obj(req.cookies)
if json:
try:
media = await req.get_media()
except HTTPError as err:
if err.status not in self.FALCON_MEDIA_ERROR_CODE:
raise
media = None
req.context.json = json.parse_obj(media)
if form:
try:
form_data = await req.get_media()
except HTTPError as err:
if err.status not in self.FALCON_MEDIA_ERROR_CODE:
raise
req.context.form = None
else:
res_data = {}
async for x in form_data:
res_data[x.name] = x
await x.data # TODO - how to avoid this?
req.context.form = form.parse_obj(res_data)
[docs]
async def validate(
self,
func: Callable,
query: Optional[ModelType],
json: Optional[ModelType],
form: Optional[ModelType],
headers: Optional[ModelType],
cookies: Optional[ModelType],
resp: Optional[Response],
before: Callable,
after: Callable,
validation_error_status: int,
skip_validation: bool,
*args: Any,
**kwargs: Any,
):
# falcon endpoint method arguments: (self, req, resp)
_self, _req, _resp = args[:3]
req_validation_error, resp_validation_error = None, None
try:
await self.request_validation(_req, query, json, form, headers, cookies)
if self.config.annotations:
annotations = get_type_hints(func)
for name in ("query", "json", "form", "headers", "cookies"):
if annotations.get(name):
kwargs[name] = getattr(_req.context, name)
except ValidationError as err:
req_validation_error = err
_resp.status = f"{validation_error_status} Validation Error"
_resp.media = err.errors()
before(_req, _resp, req_validation_error, _self)
if req_validation_error:
return None
result = (
await func(*args, **kwargs)
if inspect.iscoroutinefunction(func)
else func(*args, **kwargs)
)
if not self._data_set_manually(_resp) and not skip_validation and resp:
try:
status = int(_resp.status[:3])
response_validation_result = validate_response(
validation_model=resp.find_model(status) if resp else None,
response_payload=_resp.media,
)
except ValidationError as err:
resp_validation_error = err
_resp.status = HTTP_500
_resp.media = err.errors()
else:
_resp.media = response_validation_result.payload
after(_req, _resp, resp_validation_error, _self)
return result