Skip to content

Commit 60e09cd

Browse files
committed
Dialect-based schema validators
1 parent 496bcee commit 60e09cd

File tree

16 files changed

+547
-465
lines changed

16 files changed

+547
-465
lines changed

openapi_core/casting/schemas/factories.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@ def __init__(
1919

2020
def create(
2121
self,
22+
spec: SchemaPath,
2223
schema: SchemaPath,
2324
format_validators: Optional[FormatValidatorsDict] = None,
2425
extra_format_validators: Optional[FormatValidatorsDict] = None,
2526
) -> SchemaCaster:
2627
schema_validator = self.schema_validators_factory.create(
28+
spec,
2729
schema,
2830
format_validators=format_validators,
2931
extra_format_validators=extra_format_validators,

openapi_core/deserializing/media_types/deserializers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def get_deserializer_callable(
6666
class MediaTypeDeserializer:
6767
def __init__(
6868
self,
69+
spec: SchemaPath,
6970
style_deserializers_factory: StyleDeserializersFactory,
7071
media_types_deserializer: MediaTypesDeserializer,
7172
mimetype: str,
@@ -75,6 +76,7 @@ def __init__(
7576
encoding: Optional[SchemaPath] = None,
7677
**parameters: str,
7778
):
79+
self.spec = spec
7880
self.style_deserializers_factory = style_deserializers_factory
7981
self.media_types_deserializer = media_types_deserializer
8082
self.mimetype = mimetype
@@ -117,6 +119,7 @@ def evolve(
117119
schema_caster = self.schema_caster.evolve(schema)
118120

119121
return cls(
122+
self.spec,
120123
self.style_deserializers_factory,
121124
self.media_types_deserializer,
122125
mimetype=mimetype or self.mimetype,
@@ -221,7 +224,7 @@ def decode_property_style(
221224
prep_encoding, default_location="query"
222225
)
223226
prop_deserializer = self.style_deserializers_factory.create(
224-
prop_style, prop_explode, prop_schema, name=prop_name
227+
self.spec, prop_schema, prop_style, prop_explode, name=prop_name
225228
)
226229
return prop_deserializer.deserialize(location)
227230

openapi_core/deserializing/media_types/factories.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def from_schema_casters_factory(
5858

5959
def create(
6060
self,
61+
spec: SchemaPath,
6162
mimetype: str,
6263
schema: Optional[SchemaPath] = None,
6364
schema_validator: Optional[SchemaValidator] = None,
@@ -89,11 +90,13 @@ def create(
8990
):
9091
schema_caster = (
9192
self.style_deserializers_factory.schema_casters_factory.create(
93+
spec,
9294
schema
9395
)
9496
)
9597

9698
return MediaTypeDeserializer(
99+
spec,
97100
self.style_deserializers_factory,
98101
media_types_deserializer,
99102
mimetype,

openapi_core/deserializing/styles/deserializers.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,14 @@ def __init__(
1717
style: str,
1818
explode: bool,
1919
name: str,
20-
schema: SchemaPath,
20+
schema_type: str,
2121
caster: SchemaCaster,
2222
deserializer_callable: Optional[DeserializerCallable] = None,
2323
):
2424
self.style = style
2525
self.explode = explode
2626
self.name = name
27-
self.schema = schema
28-
self.schema_type = (schema / "type").read_str("")
27+
self.schema_type = schema_type
2928
self.caster = caster
3029
self.deserializer_callable = deserializer_callable
3130

openapi_core/deserializing/styles/factories.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@ def __init__(
2020

2121
def create(
2222
self,
23+
spec: SchemaPath,
24+
schema: SchemaPath,
2325
style: str,
2426
explode: bool,
25-
schema: SchemaPath,
2627
name: str,
2728
) -> StyleDeserializer:
2829
deserialize_callable = self.style_deserializers.get(style)
29-
caster = self.schema_casters_factory.create(schema)
30+
caster = self.schema_casters_factory.create(spec, schema)
31+
schema_type = (schema / "type").read_str("")
3032
return StyleDeserializer(
31-
style, explode, name, schema, caster, deserialize_callable
33+
style, explode, name, schema_type, caster, deserialize_callable
3234
)

openapi_core/unmarshalling/schemas/factories.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __init__(
3333

3434
def create(
3535
self,
36+
spec: SchemaPath,
3637
schema: SchemaPath,
3738
format_validators: Optional[FormatValidatorsDict] = None,
3839
format_unmarshallers: Optional[FormatUnmarshallersDict] = None,
@@ -51,6 +52,7 @@ def create(
5152
if extra_format_validators is None:
5253
extra_format_validators = {}
5354
schema_validator = self.schema_validators_factory.create(
55+
spec,
5456
schema,
5557
format_validators=format_validators,
5658
extra_format_validators=extra_format_validators,

openapi_core/unmarshalling/unmarshallers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def __init__(
9090

9191
def _unmarshal_schema(self, schema: SchemaPath, value: Any) -> Any:
9292
unmarshaller = self.schema_unmarshallers_factory.create(
93+
self.spec,
9394
schema,
9495
format_validators=self.format_validators,
9596
extra_format_validators=self.extra_format_validators,
Lines changed: 7 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
1-
from functools import partial
2-
3-
from lazy_object_proxy import Proxy
1+
from openapi_schema_validator import OAS31_BASE_DIALECT_ID
2+
from openapi_schema_validator import OAS32_BASE_DIALECT_ID
43
from openapi_schema_validator import OAS30ReadValidator
54
from openapi_schema_validator import OAS30WriteValidator
65
from openapi_schema_validator import OAS31Validator
76
from openapi_schema_validator import OAS32Validator
87

9-
from openapi_core.validation.schemas._validators import (
10-
build_forbid_unspecified_additional_properties_validator,
11-
)
8+
from openapi_core.validation.schemas.factories import DialectSchemaValidatorsFactory
129
from openapi_core.validation.schemas.factories import SchemaValidatorsFactory
1310

1411
__all__ = [
@@ -20,44 +17,22 @@
2017

2118
oas30_write_schema_validators_factory = SchemaValidatorsFactory(
2219
OAS30WriteValidator,
23-
Proxy(
24-
partial(
25-
build_forbid_unspecified_additional_properties_validator,
26-
OAS30WriteValidator,
27-
)
28-
),
2920
)
3021

3122
oas30_read_schema_validators_factory = SchemaValidatorsFactory(
3223
OAS30ReadValidator,
33-
Proxy(
34-
partial(
35-
build_forbid_unspecified_additional_properties_validator,
36-
OAS30ReadValidator,
37-
)
38-
),
3924
)
4025

41-
oas31_schema_validators_factory = SchemaValidatorsFactory(
26+
oas31_schema_validators_factory = DialectSchemaValidatorsFactory(
4227
OAS31Validator,
43-
Proxy(
44-
partial(
45-
build_forbid_unspecified_additional_properties_validator,
46-
OAS31Validator,
47-
)
48-
),
28+
OAS31_BASE_DIALECT_ID,
4929
# NOTE: Intentionally use OAS 3.0 format checker for OAS 3.1 to preserve
5030
# backward compatibility for `byte`/`binary` formats.
5131
# See https://github.com/python-openapi/openapi-core/issues/506
5232
format_checker=OAS30ReadValidator.FORMAT_CHECKER,
5333
)
5434

55-
oas32_schema_validators_factory = SchemaValidatorsFactory(
35+
oas32_schema_validators_factory = DialectSchemaValidatorsFactory(
5636
OAS32Validator,
57-
Proxy(
58-
partial(
59-
build_forbid_unspecified_additional_properties_validator,
60-
OAS32Validator,
61-
)
62-
),
37+
OAS32_BASE_DIALECT_ID,
6338
)

openapi_core/validation/schemas/factories.py

Lines changed: 71 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,38 @@
11
from copy import deepcopy
2+
from typing import Any
23
from typing import Optional
34
from typing import cast
45

56
from jsonschema._format import FormatChecker
67
from jsonschema.protocols import Validator
8+
from jsonschema.validators import validator_for
79
from jsonschema_path import SchemaPath
810

911
from openapi_core.validation.schemas._validators import (
1012
build_enforce_properties_required_validator,
1113
)
14+
from openapi_core.validation.schemas._validators import (
15+
build_forbid_unspecified_additional_properties_validator,
16+
)
1217
from openapi_core.validation.schemas.datatypes import FormatValidatorsDict
1318
from openapi_core.validation.schemas.validators import SchemaValidator
1419

1520

1621
class SchemaValidatorsFactory:
1722
def __init__(
1823
self,
19-
schema_validator_class: type[Validator],
20-
strict_schema_validator_class: Optional[type[Validator]] = None,
24+
schema_validator_cls: type[Validator],
2125
format_checker: Optional[FormatChecker] = None,
2226
):
23-
self.schema_validator_class = schema_validator_class
24-
self.strict_schema_validator_class = strict_schema_validator_class
27+
self.schema_validator_cls = schema_validator_cls
2528
if format_checker is None:
26-
format_checker = self.schema_validator_class.FORMAT_CHECKER
29+
format_checker = self.schema_validator_cls.FORMAT_CHECKER
2730
assert format_checker is not None
2831
self.format_checker = format_checker
2932

33+
def get_validator_cls(self, spec: SchemaPath, schema: SchemaPath) -> type[Validator]:
34+
return self.schema_validator_cls
35+
3036
def get_format_checker(
3137
self,
3238
format_validators: Optional[FormatValidatorsDict] = None,
@@ -57,34 +63,82 @@ def _add_validators(
5763

5864
def create(
5965
self,
66+
spec: SchemaPath,
6067
schema: SchemaPath,
6168
format_validators: Optional[FormatValidatorsDict] = None,
6269
extra_format_validators: Optional[FormatValidatorsDict] = None,
6370
forbid_unspecified_additional_properties: bool = False,
6471
enforce_properties_required: bool = False,
6572
) -> SchemaValidator:
66-
validator_class: type[Validator] = self.schema_validator_class
67-
if forbid_unspecified_additional_properties:
68-
if self.strict_schema_validator_class is None:
69-
raise ValueError(
70-
"Strict additional properties validation is not supported "
71-
"by this factory."
72-
)
73-
validator_class = self.strict_schema_validator_class
74-
73+
validator_cls: type[Validator] = self.get_validator_cls(spec, schema)
7574
if enforce_properties_required:
76-
validator_class = build_enforce_properties_required_validator(
77-
validator_class
75+
validator_cls = build_enforce_properties_required_validator(
76+
validator_cls
77+
)
78+
if forbid_unspecified_additional_properties:
79+
validator_cls = build_forbid_unspecified_additional_properties_validator(
80+
validator_cls
7881
)
7982

8083
format_checker = self.get_format_checker(
8184
format_validators, extra_format_validators
8285
)
8386
with schema.resolve() as resolved:
84-
jsonschema_validator = validator_class(
87+
jsonschema_validator = validator_cls(
8588
resolved.contents,
8689
_resolver=resolved.resolver,
8790
format_checker=format_checker,
8891
)
8992

9093
return SchemaValidator(schema, jsonschema_validator)
94+
95+
96+
class DialectSchemaValidatorsFactory(SchemaValidatorsFactory):
97+
def __init__(
98+
self,
99+
schema_validator_cls: type[Validator],
100+
default_jsonschema_dialect_id: str,
101+
format_checker: Optional[FormatChecker] = None,
102+
):
103+
super().__init__(schema_validator_cls, format_checker)
104+
self.default_jsonschema_dialect_id = default_jsonschema_dialect_id
105+
106+
self._validator_classes_by_dialect: dict[
107+
str, type[Validator] | None
108+
] = {}
109+
110+
def get_validator_cls(self, spec: SchemaPath, schema: SchemaPath) -> type[Validator]:
111+
dialect_id = self._get_dialect_id(spec, schema)
112+
113+
validator_cls = self._get_validator_class_for_dialect(dialect_id)
114+
if validator_cls is None:
115+
raise ValueError(f"Unknown JSON Schema dialect: {dialect_id!r}")
116+
117+
return validator_cls
118+
119+
def _get_dialect_id(self, spec: SchemaPath, schema: SchemaPath,) -> str:
120+
try:
121+
return (schema / "$schema").read_str()
122+
except KeyError:
123+
return self._get_default_jsonschema_dialect_id(spec)
124+
125+
def _get_default_jsonschema_dialect_id(self, spec: SchemaPath) -> str:
126+
return (spec / "jsonSchemaDialect").read_str(
127+
default=self.default_jsonschema_dialect_id
128+
)
129+
130+
def _get_validator_class_for_dialect(
131+
self, dialect_id: str
132+
) -> type[Validator] | None:
133+
if dialect_id in self._validator_classes_by_dialect:
134+
return self._validator_classes_by_dialect[dialect_id]
135+
136+
validator_cls = cast(
137+
type[Validator] | None,
138+
validator_for(
139+
{"$schema": dialect_id},
140+
default=cast(Any, None),
141+
),
142+
)
143+
self._validator_classes_by_dialect[dialect_id] = validator_cls
144+
return validator_cls

openapi_core/validation/validators.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,15 @@ def _deserialise_media_type(
143143
schema_validator = None
144144
if schema is not None:
145145
schema_validator = self.schema_validators_factory.create(
146+
self.spec,
146147
schema,
147148
format_validators=self.format_validators,
148149
extra_format_validators=self.extra_format_validators,
149150
forbid_unspecified_additional_properties=self.forbid_unspecified_additional_properties,
150151
enforce_properties_required=self.enforce_properties_required,
151152
)
152153
deserializer = self.media_type_deserializers_factory.create(
154+
self.spec,
153155
mimetype,
154156
schema=schema,
155157
schema_validator=schema_validator,
@@ -169,12 +171,13 @@ def _deserialise_style(
169171
style, explode = get_style_and_explode(param_or_header)
170172
schema = param_or_header / "schema"
171173
deserializer = self.style_deserializers_factory.create(
172-
style, explode, schema, name=name
174+
self.spec, schema, style, explode, name=name
173175
)
174176
return deserializer.deserialize(location)
175177

176178
def _validate_schema(self, schema: SchemaPath, value: Any) -> None:
177179
validator = self.schema_validators_factory.create(
180+
self.spec,
178181
schema,
179182
format_validators=self.format_validators,
180183
extra_format_validators=self.extra_format_validators,

0 commit comments

Comments
 (0)