Source code for pjrpc.server.specs.extractors.pydantic

import inspect
from typing import Any, Dict, Iterable, List, Optional, Type

import pydantic as pd

from pjrpc.common import UNSET, MaybeSet
from pjrpc.common.exceptions import JsonRpcError
from pjrpc.common.typedefs import MethodType
from pjrpc.server.specs.extractors import BaseSchemaExtractor, Error, Schema


[docs]class PydanticSchemaExtractor(BaseSchemaExtractor): """ Pydantic method specification extractor. """ def __init__(self, ref_template: str = '#/components/schemas/{model}'): self._ref_template = ref_template
[docs] def extract_params_schema(self, method: MethodType, exclude: Iterable[str] = ()) -> Dict[str, Schema]: exclude = set(exclude) signature = inspect.signature(method) field_definitions: Dict[str, Any] = {} for param in signature.parameters.values(): if param.name in exclude: continue if param.kind in [inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY]: field_definitions[param.name] = ( param.annotation if param.annotation is not inspect.Parameter.empty else Any, param.default if param.default is not inspect.Parameter.empty else ..., ) params_model = pd.create_model('RequestModel', **field_definitions) model_schema = params_model.model_json_schema(ref_template=self._ref_template) parameters_schema = {} for param_name, param_schema in model_schema['properties'].items(): required = param_name in model_schema.get('required', []) parameters_schema[param_name] = Schema( schema=param_schema, summary=param_schema.get('title', UNSET), description=param_schema.get('description', UNSET), deprecated=param_schema.get('deprecated', UNSET), required=required, definitions=model_schema.get('$defs'), ) return parameters_schema
[docs] def extract_result_schema(self, method: MethodType) -> Schema: result = inspect.signature(method) if result.return_annotation is inspect.Parameter.empty: return_annotation = Any elif result.return_annotation is None: return_annotation = Optional[None] else: return_annotation = result.return_annotation result_model = pd.create_model('ResultModel', result=(return_annotation, ...)) model_schema = result_model.model_json_schema(ref_template=self._ref_template) result_schema = model_schema['properties']['result'] required = 'result' in model_schema.get('required', []) if not required: result_schema['nullable'] = 'true' result_schema = Schema( schema=result_schema, summary=result_schema.get('title', UNSET), description=result_schema.get('description', UNSET), deprecated=result_schema.get('deprecated', UNSET), required=required, definitions=model_schema.get('$defs', UNSET), ) return result_schema
[docs] def extract_errors_schema( self, method: MethodType, errors: Optional[Iterable[Type[JsonRpcError]]] = None, ) -> MaybeSet[List[Error]]: if errors: errors_schema = [] for error in errors: field_definitions: Dict[str, Any] = {} for field_name, annotation in self._get_annotations(error).items(): if field_name.startswith('_'): continue field_definitions[field_name] = (annotation, getattr(error, field_name, ...)) result_model = pd.create_model(error.message, **field_definitions) model_schema = result_model.model_json_schema(ref_template=self._ref_template) data_schema = model_schema['properties'].get('data', UNSET) required = 'data' in model_schema.get('required', []) errors_schema.append( Error( code=error.code, message=error.message, data=data_schema, data_required=required, title=error.message, description=inspect.cleandoc(error.__doc__) if error.__doc__ is not None else UNSET, deprecated=model_schema.get('deprecated', UNSET), definitions=model_schema.get('$defs'), ), ) return errors_schema else: return UNSET
@staticmethod def _extract_field_schema(model_schema: Dict[str, Any], field_name: str) -> Dict[str, Any]: field_schema = model_schema['properties'][field_name] if '$ref' in field_schema: field_schema = model_schema['definitions'][field_schema['$ref']] return field_schema @staticmethod def _get_annotations(cls: Type[Any]) -> Dict[str, Any]: annotations: Dict[str, Any] = {} for patent in cls.mro(): annotations.update(**getattr(patent, '__annotations__', {})) return annotations