Skip to content

Commit

Permalink
fix: warehouse endpoints through gateway using reverse proxy
Browse files Browse the repository at this point in the history
  • Loading branch information
achettyiitr committed Jun 9, 2023
1 parent 0b5935b commit 4431a41
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 72 deletions.
27 changes: 9 additions & 18 deletions gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -901,21 +901,6 @@ func (gateway *HandleT) beaconBatchHandler(w http.ResponseWriter, r *http.Reques
gateway.beaconHandler(w, r, "batch")
}

func warehouseHandler(w http.ResponseWriter, r *http.Request) {
origin, err := url.Parse(misc.GetWarehouseURL())
if err != nil {
http.Error(w, err.Error(), 404)
return
}
// gateway.logger.LogRequest(r)
director := func(req *http.Request) {
req.URL.Scheme = "http"
req.URL.Host = origin.Host
}
proxy := &httputil.ReverseProxy{Director: director}
proxy.ServeHTTP(w, r)
}

// ProcessRequest throws a webRequest into the queue and waits for the response before returning
func (*RegularRequestHandler) ProcessRequest(gateway *HandleT, w *http.ResponseWriter, r *http.Request, reqType string, payload []byte, writeKey string) string {
done := make(chan string, 1)
Expand Down Expand Up @@ -1272,11 +1257,17 @@ func (gateway *HandleT) StartWebHandler(ctx context.Context) error {
r.Post("/screen", gateway.webScreenHandler)
r.Post("/track", gateway.webTrackHandler)
r.Post("/webhook", gateway.webhookHandler.RequestHandler)
r.Post("/warehouse", warehouseHandler)
r.Post("/warehouse/pending-events", gateway.whProxy.ServeHTTP)

r.Get("/webhook", gateway.webhookHandler.RequestHandler)
r.Get("/warehouse", warehouseHandler)

r.Route("/warehouse", func(r chi.Router) {
r.Post("/pending-events", gateway.whProxy.ServeHTTP)
r.Post("/trigger-upload", gateway.whProxy.ServeHTTP)
r.Post("/jobs", gateway.whProxy.ServeHTTP)
r.Post("/fetch-tables", gateway.whProxy.ServeHTTP)

r.Get("/jobs/status", gateway.whProxy.ServeHTTP)
})
})

srvMux.Get("/health", WithContentType("application/json; charset=utf-8", app.LivenessHandler(gateway.jobsDB)))
Expand Down
93 changes: 39 additions & 54 deletions gateway/gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"os"
"reflect"
"strconv"
"testing"
"time"

kitHelper "github.com/rudderlabs/rudder-go-kit/testhelper"

"github.com/golang/mock/gomock"
"github.com/google/uuid"
. "github.com/onsi/ginkgo/v2"
Expand Down Expand Up @@ -341,16 +345,26 @@ var _ = Describe("Gateway", func() {
gateway *HandleT
statsStore *memstats.Store
whServer *httptest.Server
serverURL string
)

BeforeEach(func() {
whServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = io.WriteString(w, "OK")
}))
url, err := url.Parse(whServer.URL)
WHURL := whServer.URL
parsedURL, err := url.Parse(WHURL)
Expect(err).To(BeNil())
whPort := parsedURL.Port()
os.Setenv("RSERVER_WAREHOUSE_WEB_PORT", whPort)

serverPort, err := kitHelper.GetFreePort()
Expect(err).To(BeNil())
config.Set("Warehouse.webPort", url.Port())
serverURL = fmt.Sprintf("http://localhost:%d", serverPort)
os.Setenv("RSERVER_GATEWAY_WEB_PORT", strconv.Itoa(serverPort))

loadConfig()

gateway = &HandleT{}
err = gateway.Setup(context.Background(), c.mockApp, c.mockBackendConfig, c.mockJobsDB, nil, c.mockVersionHandler, rsources.NewNoOpService(), sourcedebugger.NewNoOpService())
Expand All @@ -376,11 +390,10 @@ var _ = Describe("Gateway", func() {

return validDataWithProperty
}
verifyEndpoing := func(endpoints []string, method string) {
verifyEndpoint := func(endpoints []string, method string) {
client := &http.Client{}
baseURL := "http://localhost:8080"
for _, ep := range endpoints {
url := baseURL + ep
url := fmt.Sprintf("%s%s", serverURL, ep)
var req *http.Request
var err error
if ep == "/beacon/v1/batch" {
Expand Down Expand Up @@ -409,17 +422,24 @@ var _ = Describe("Gateway", func() {
})
c.mockBackendConfig.EXPECT().WaitForConfig(gomock.Any()).AnyTimes()
var err error
wait := make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
go func() {
err = gateway.StartWebHandler(ctx)
Expect(err).To(BeNil())
close(wait)
}()
getEndpoing, postEndpoints, deleteEndpoints := getEndpointMethodMap()

verifyEndpoing(getEndpoing, http.MethodGet)
verifyEndpoing(postEndpoints, http.MethodPost)
verifyEndpoing(deleteEndpoints, http.MethodDelete)
Eventually(func() bool {
resp, _ := http.Get(fmt.Sprintf("%s/version", serverURL))
return resp.StatusCode == http.StatusOK
}, time.Second*10, time.Second).Should(BeTrue())

getEndpoint, postEndpoints, deleteEndpoints := getEndpointMethodMap()
verifyEndpoint(getEndpoint, http.MethodGet)
verifyEndpoint(postEndpoints, http.MethodPost)
verifyEndpoint(deleteEndpoints, http.MethodDelete)
cancel()
<-wait
})
})

Expand Down Expand Up @@ -976,43 +996,6 @@ var _ = Describe("Gateway", func() {
})
})

Context("Warehouse proxy", func() {
DescribeTable("forwarding requests to warehouse with different response codes",
func(url string, code int, payload string) {
gateway := &HandleT{}
whMock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expect(r.URL.String()).To(Equal(url))
Expect(r.Body)
Expect(r.Body).To(Not(BeNil()))
defer func() { _ = r.Body.Close() }()
reqBody, err := io.ReadAll(r.Body)
Expect(err).To(BeNil())
Expect(string(reqBody)).To(Equal(payload))
w.WriteHeader(code)
}))
GinkgoT().Setenv("WAREHOUSE_URL", whMock.URL)
GinkgoT().Setenv("RSERVER_WAREHOUSE_MODE", config.OffMode)
err := gateway.Setup(context.Background(), c.mockApp, c.mockBackendConfig, c.mockJobsDB, nil, c.mockVersionHandler, rsources.NewNoOpService(), sourcedebugger.NewNoOpService())
Expect(err).To(BeNil())

defer func() {
err := gateway.Shutdown()
Expect(err).To(BeNil())
whMock.Close()
}()

req := httptest.NewRequest("POST", "http://rudder-server"+url, bytes.NewBufferString(payload))
w := httptest.NewRecorder()
gateway.whProxy.ServeHTTP(w, req)
resp := w.Result()
Expect(resp.StatusCode).To(Equal(code))
},
Entry("successful request", "/v1/warehouse/pending-events", http.StatusOK, `{"source_id": "1", "task_run_id":"2"}`),
Entry("failed request", "/v1/warehouse/pending-events", http.StatusBadRequest, `{"source_id": "3", "task_run_id":"4"}`),
Entry("request with query parameters", "/v1/warehouse/pending-events?triggerUpload=true", http.StatusOK, `{"source_id": "5", "task_run_id":"6"}`),
)
})

Context("jobDataFromRequest", func() {
var (
gateway *HandleT
Expand Down Expand Up @@ -1150,19 +1133,19 @@ func expectHandlerResponse(handler http.HandlerFunc, req *http.Request, response
}

// return all endpoints as key and method as value
func getEndpointMethodMap() (getEndpoints, postEndpoints, deleteEndpoints []string) {
getEndpoints = []string{
func endpointsToVerify() ([]string, []string, []string) {
getEndpoints := []string{
"/version",
"/robots.txt",
"/pixel/v1/track",
"/pixel/v1/page",
"/v1/warehouse",
"/v1/webhook",
"/v1/job-status/123",
"/v1/job-status/123/failed-records",
"/v1/warehouse/jobs/status",
}

postEndpoints = []string{
postEndpoints := []string{
"/v1/batch",
"/v1/identify",
"/v1/track",
Expand All @@ -1174,16 +1157,18 @@ func getEndpointMethodMap() (getEndpoints, postEndpoints, deleteEndpoints []stri
"/v1/import",
"/v1/audiencelist",
"/v1/webhook",
"/v1/warehouse",
"/beacon/v1/batch",
"/internal/v1/extract",
"/v1/warehouse/pending-events",
"/v1/warehouse/trigger-upload",
"/v1/warehouse/jobs",
"/v1/warehouse/fetch-tables",
}

deleteEndpoints = []string{
deleteEndpoints := []string{
"/v1/job-status/1234",
}
return
return getEndpoints, postEndpoints, deleteEndpoints
}

func allHandlers(gateway *HandleT) map[string]http.HandlerFunc {
Expand Down

0 comments on commit 4431a41

Please sign in to comment.