Skip to content

Commit

Permalink
Use correctly allowed http methods for schema generation
Browse files Browse the repository at this point in the history
  • Loading branch information
Jekel committed Sep 17, 2023
1 parent 933f7b7 commit 5ca4fc4
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 1 deletion.
14 changes: 13 additions & 1 deletion drf_spectacular/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,19 @@ def _get_api_endpoints(self, patterns, prefix):
return api_endpoints

def get_allowed_methods(self, callback):
methods = super().get_allowed_methods(callback)
if hasattr(callback, 'actions'):
actions = set(callback.actions)
http_method_names = set(callback.cls.http_method_names)
methods = [method.upper() for method in actions & http_method_names]
else:
# pass to constructor allowed method names to get valid ones
kwargs = {}
http_method_names = callback.initkwargs.get('http_method_names', [])
if http_method_names:
kwargs['http_method_names'] = http_method_names

methods = callback.cls(**kwargs).allowed_methods

return [
method for method in methods
if method not in ('OPTIONS', 'HEAD', 'TRACE', 'CONNECT')
Expand Down
21 changes: 21 additions & 0 deletions tests/test_regressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,6 +956,27 @@ def get(self, request):
assert '#/components/schemas/X' in get_response_schema(operation)['$ref']


def test_schema_contains_only_allowed_methods(no_warnings):
class XSerializer(serializers.Serializer):
integer = serializers.IntegerField()

class X(models.Model):
integer = models.IntegerField()

class XAPIView(generics.ListCreateAPIView):
model = X
serializer_class = XSerializer

urlpatterns = [
path('api/x/', XAPIView.as_view()),
path('api/x1/', XAPIView.as_view(http_method_names=['post'])),
]
schema = generate_schema(None, patterns=urlpatterns)
assert sorted(schema['paths']['/api/x/'].keys()) == sorted(['get', 'post'])
assert list(schema['paths']['/api/x1/'].keys()) == ['post']
assert 'X' in schema['components']['schemas']


def test_auto_schema_and_extend_parameters(no_warnings):
class CustomAutoSchema(AutoSchema):
def get_override_parameters(self):
Expand Down

0 comments on commit 5ca4fc4

Please sign in to comment.