diff --git a/src/server/server_impl.go b/src/server/server_impl.go index 5fa5ed1c..9a1239a9 100644 --- a/src/server/server_impl.go +++ b/src/server/server_impl.go @@ -1,6 +1,7 @@ package server import ( + "bytes" "expvar" "fmt" "io" @@ -53,15 +54,15 @@ func (server *server) AddDebugHttpEndpoint(path string, help string, handler htt server.debugListener.endpoints[path] = help } -// add an http/1 handler at the /json endpoint which allows this ratelimit service to work with +// create an http/1 handler at the /json endpoint which allows this ratelimit service to work with // clients that cannot use the gRPC interface (e.g. lua) // example usage from cURL with domain "dummy" and descriptor "perday": // echo '{"domain": "dummy", "descriptors": [{"entries": [{"key": "perday"}]}]}' | curl -vvvXPOST --data @/dev/stdin localhost:8080/json -func (server *server) AddJsonHandler(svc pb.RateLimitServiceServer) { +func NewJsonHandler(svc pb.RateLimitServiceServer) func(http.ResponseWriter, *http.Request) { // Default options include enums as strings and no identation. m := &jsonpb.Marshaler{} - handler := func(writer http.ResponseWriter, request *http.Request) { + return func(writer http.ResponseWriter, request *http.Request) { var req pb.RateLimitRequest if err := jsonpb.Unmarshal(request.Body, &req); err != nil { @@ -79,21 +80,26 @@ func (server *server) AddJsonHandler(svc pb.RateLimitServiceServer) { logger.Debugf("resp:%s", resp) - writer.Header().Set("Content-Type", "application/json") - if resp.OverallCode == pb.RateLimitResponse_OVER_LIMIT { - writer.WriteHeader(http.StatusTooManyRequests) - } else if resp.OverallCode == pb.RateLimitResponse_UNKNOWN { - writer.WriteHeader(http.StatusInternalServerError) - } - - err = m.Marshal(writer, resp) + buf := bytes.NewBuffer(nil) + err = m.Marshal(buf, resp) if err != nil { - writer.Header().Set("Content-Type", "text/plain") - fmt.Fprintf(writer, "Internal error marshaling proto3 to json: %v", err) + logger.Errorf("error marshaling proto3 to json: %s", err.Error()) + http.Error(writer, "error marshaling proto3 to json: "+err.Error(), http.StatusInternalServerError) + return } + writer.Header().Set("Content-Type", "application/json") + if resp == nil || resp.OverallCode == pb.RateLimitResponse_UNKNOWN { + writer.WriteHeader(http.StatusInternalServerError) + } else if resp.OverallCode == pb.RateLimitResponse_OVER_LIMIT { + writer.WriteHeader(http.StatusTooManyRequests) + } + writer.Write(buf.Bytes()) } - server.router.HandleFunc("/json", handler) +} + +func (server *server) AddJsonHandler(svc pb.RateLimitServiceServer) { + server.router.HandleFunc("/json", NewJsonHandler(svc)) } func (server *server) GrpcServer() *grpc.Server { diff --git a/test/mocks/mocks.go b/test/mocks/mocks.go index 490d2097..703865af 100644 --- a/test/mocks/mocks.go +++ b/test/mocks/mocks.go @@ -5,3 +5,4 @@ package mocks //go:generate go run github.com/golang/mock/mockgen -destination ./config/config.go github.com/envoyproxy/ratelimit/src/config RateLimitConfig,RateLimitConfigLoader //go:generate go run github.com/golang/mock/mockgen -destination ./redis/redis.go github.com/envoyproxy/ratelimit/src/redis Client //go:generate go run github.com/golang/mock/mockgen -destination ./limiter/limiter.go github.com/envoyproxy/ratelimit/src/limiter RateLimitCache,TimeSource,JitterRandSource +//go:generate go run github.com/golang/mock/mockgen -destination ./rls/rls.go github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v2 RateLimitServiceServer diff --git a/test/mocks/rls/rls.go b/test/mocks/rls/rls.go new file mode 100644 index 00000000..77cd49ae --- /dev/null +++ b/test/mocks/rls/rls.go @@ -0,0 +1,50 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v2 (interfaces: RateLimitServiceServer) + +// Package mock_v2 is a generated GoMock package. +package mock_v2 + +import ( + context "context" + v2 "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v2" + gomock "github.com/golang/mock/gomock" + reflect "reflect" +) + +// MockRateLimitServiceServer is a mock of RateLimitServiceServer interface +type MockRateLimitServiceServer struct { + ctrl *gomock.Controller + recorder *MockRateLimitServiceServerMockRecorder +} + +// MockRateLimitServiceServerMockRecorder is the mock recorder for MockRateLimitServiceServer +type MockRateLimitServiceServerMockRecorder struct { + mock *MockRateLimitServiceServer +} + +// NewMockRateLimitServiceServer creates a new mock instance +func NewMockRateLimitServiceServer(ctrl *gomock.Controller) *MockRateLimitServiceServer { + mock := &MockRateLimitServiceServer{ctrl: ctrl} + mock.recorder = &MockRateLimitServiceServerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockRateLimitServiceServer) EXPECT() *MockRateLimitServiceServerMockRecorder { + return m.recorder +} + +// ShouldRateLimit mocks base method +func (m *MockRateLimitServiceServer) ShouldRateLimit(arg0 context.Context, arg1 *v2.RateLimitRequest) (*v2.RateLimitResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ShouldRateLimit", arg0, arg1) + ret0, _ := ret[0].(*v2.RateLimitResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ShouldRateLimit indicates an expected call of ShouldRateLimit +func (mr *MockRateLimitServiceServerMockRecorder) ShouldRateLimit(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ShouldRateLimit", reflect.TypeOf((*MockRateLimitServiceServer)(nil).ShouldRateLimit), arg0, arg1) +} diff --git a/test/server/server_impl_test.go b/test/server/server_impl_test.go new file mode 100644 index 00000000..de058ecb --- /dev/null +++ b/test/server/server_impl_test.go @@ -0,0 +1,86 @@ +package server_test + +import ( + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "strings" + "testing" + + pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v2" + + "github.com/envoyproxy/ratelimit/src/server" + mock_v2 "github.com/envoyproxy/ratelimit/test/mocks/rls" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" +) + +func assertHttpResponse(t *testing.T, + handler http.HandlerFunc, + requestBody string, + expectedStatusCode int, + expectedContentType string, + expectedResponseBody string) { + + t.Helper() + assert := assert.New(t) + + req := httptest.NewRequest("METHOD_NOT_CHECKED", "/path_not_checked", strings.NewReader(requestBody)) + w := httptest.NewRecorder() + handler(w, req) + + resp := w.Result() + actualBody, _ := ioutil.ReadAll(resp.Body) + assert.Equal(expectedContentType, resp.Header.Get("Content-Type")) + assert.Equal(expectedStatusCode, resp.StatusCode) + assert.Equal(expectedResponseBody, string(actualBody)) +} + +func TestJsonHandler(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + + rls := mock_v2.NewMockRateLimitServiceServer(controller) + handler := server.NewJsonHandler(rls) + + // Missing request body + assertHttpResponse(t, handler, "", 400, "text/plain; charset=utf-8", "EOF\n") + + // Request body is not valid json + assertHttpResponse(t, handler, "}", 400, "text/plain; charset=utf-8", "invalid character '}' looking for beginning of value\n") + + // Unknown response code + rls.EXPECT().ShouldRateLimit(nil, &pb.RateLimitRequest{ + Domain: "foo", + }).Return(&pb.RateLimitResponse{}, nil) + assertHttpResponse(t, handler, `{"domain": "foo"}`, 500, "application/json", "{}") + + // ratelimit service error + rls.EXPECT().ShouldRateLimit(nil, &pb.RateLimitRequest{ + Domain: "foo", + }).Return(nil, fmt.Errorf("some error")) + assertHttpResponse(t, handler, `{"domain": "foo"}`, 400, "text/plain; charset=utf-8", "some error\n") + + // json unmarshaling error + rls.EXPECT().ShouldRateLimit(nil, &pb.RateLimitRequest{ + Domain: "foo", + }).Return(nil, nil) + assertHttpResponse(t, handler, `{"domain": "foo"}`, 500, "text/plain; charset=utf-8", "error marshaling proto3 to json: Marshal called with nil\n") + + // successful request, not rate limited + rls.EXPECT().ShouldRateLimit(nil, &pb.RateLimitRequest{ + Domain: "foo", + }).Return(&pb.RateLimitResponse{ + OverallCode: pb.RateLimitResponse_OK, + }, nil) + assertHttpResponse(t, handler, `{"domain": "foo"}`, 200, "application/json", `{"overallCode":"OK"}`) + + // successful request, rate limited + rls.EXPECT().ShouldRateLimit(nil, &pb.RateLimitRequest{ + Domain: "foo", + }).Return(&pb.RateLimitResponse{ + OverallCode: pb.RateLimitResponse_OVER_LIMIT, + }, nil) + assertHttpResponse(t, handler, `{"domain": "foo"}`, 429, "application/json", `{"overallCode":"OVER_LIMIT"}`) +}