-
Notifications
You must be signed in to change notification settings - Fork 1
/
limit_operation_amount.go
77 lines (59 loc) · 1.91 KB
/
limit_operation_amount.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
package graphql
import (
"context"
gql "github.com/99designs/gqlgen/graphql"
"github.com/vektah/gqlparser/v2/ast"
)
const (
sameOperationsDefaultThreshold = 2
allOperationsDefaultThreshold = 10
)
var _ gql.OperationMiddleware = LimitOperationAmountMiddleware(nil)
func LimitOperationAmountMiddleware(
cfg *struct {
SameOperationLimit int `inject:"config:graphql.security.limitOperationAmount.sameOperationLimit,optional"`
TotalOperationLimit int `inject:"config:graphql.security.limitOperationAmount.totalOperationLimit,optional"`
},
) func(ctx context.Context, next gql.OperationHandler) gql.ResponseHandler {
return func(ctx context.Context, next gql.OperationHandler) gql.ResponseHandler {
sameOperationLimit := sameOperationsDefaultThreshold
totalOperationLimit := allOperationsDefaultThreshold
if cfg != nil {
sameOperationLimit = cfg.SameOperationLimit
totalOperationLimit = cfg.TotalOperationLimit
}
req := gql.GetOperationContext(ctx)
occurrences := countTopLevelGraphQLOperations(req.Operation.SelectionSet)
if isAboveThreshold(sameOperationLimit, totalOperationLimit, occurrences) {
return func(ctx context.Context) *gql.Response {
return gql.ErrorResponse(ctx, "request not allowed")
}
}
return next(ctx)
}
}
func countTopLevelGraphQLOperations(definition []ast.Selection) map[string]int {
mapOfOccurrences := make(map[string]int)
for _, set := range definition {
field, ok := set.(*ast.Field)
if !ok {
continue
}
if _, exists := mapOfOccurrences[field.Name]; !exists {
mapOfOccurrences[field.Name] = 0
}
mapOfOccurrences[field.Name]++
}
return mapOfOccurrences
}
func isAboveThreshold(sameOperationLimit, totalOperationLimit int, operations map[string]int) bool {
if len(operations) > totalOperationLimit {
return true
}
for _, operationsNumber := range operations {
if operationsNumber > sameOperationLimit {
return true
}
}
return false
}