diff --git a/cmd/api/src/api/middleware/middleware.go b/cmd/api/src/api/middleware/middleware.go index 2d59afb4a..1b33a910e 100644 --- a/cmd/api/src/api/middleware/middleware.go +++ b/cmd/api/src/api/middleware/middleware.go @@ -130,10 +130,6 @@ func ContextMiddleware(next http.Handler) http.Handler { requestCtx, cancel := context.WithTimeout(request.Context(), requestedWaitDuration.Value) defer cancel() // Insert the bh context - var ipAddress string - if ipAddress, err = parseUserIP(request); err != nil { - log.Errorf("requestIP not set: %v", err) - } requestCtx = ctx.Set(requestCtx, &ctx.Context{ StartTime: startTime, @@ -143,7 +139,7 @@ func ContextMiddleware(next http.Handler) http.Handler { Scheme: getScheme(request), Host: request.Host, }, - RequestIP: ipAddress, + RequestIP: parseUserIP(request), }) // Route the request with the embedded context @@ -152,16 +148,23 @@ func ContextMiddleware(next http.Handler) http.Handler { }) } -func parseUserIP(r *http.Request) (string, error) { +func parseUserIP(r *http.Request) string { + res := "" if ipAddress := r.Header.Get("X-Forwarded-For"); ipAddress != "" { - return strings.Split(ipAddress, ",")[0], nil - } else if parsedUrl, err := url.Parse(r.RemoteAddr); err != nil { - return "", fmt.Errorf("error parsing IP address from RemoteAddr: %s", err) + res += "X-Forwarded-For: " + ipAddress + "; " + } else { + log.Errorf("No data found in X-Forwarded-For") + } + + if parsedUrl, err := url.Parse(r.RemoteAddr); err != nil { + log.Errorf("Error parsing IP address from RemoteAddr: %s", err) } else if hostName := parsedUrl.Hostname(); hostName == "" { - return "", fmt.Errorf("hostname not found in URL: %s", parsedUrl.String()) + log.Errorf("Hostname not found in URL: %s", parsedUrl.String()) } else { - return parsedUrl.Hostname(), nil + res += "Remote Address: " + parsedUrl.Hostname() } + + return res } func ParseHeaderValues(values string) map[string]string { diff --git a/cmd/api/src/api/middleware/middleware_internal_test.go b/cmd/api/src/api/middleware/middleware_internal_test.go index 575752940..941ad84bd 100644 --- a/cmd/api/src/api/middleware/middleware_internal_test.go +++ b/cmd/api/src/api/middleware/middleware_internal_test.go @@ -80,35 +80,51 @@ func TestRequestWaitDuration(t *testing.T) { require.True(t, requestedWaitDuration.UserSet) } -func TestParseUserIP_XForwardedFor(t *testing.T) { +func TestParseUserIP_XForwardedForMissing(t *testing.T) { req, err := http.NewRequest("GET", "/teapot", nil) require.Nil(t, err) - ip1 := "192.168.1.1:8080" - ip2 := "192.168.1.2" - ip3 := "192.168.1.3" - req.Header.Set("X-Forwarded-For", strings.Join([]string{ip1, ip2, ip3}, ",")) + req.RemoteAddr = "http://www.google.com/0.0.0.0:3000" - ip, err := parseUserIP(req) - require.Nil(t, err) - require.Equal(t, ip1, ip) + res := parseUserIP(req) + require.NotContains(t, res, "X-Forwarded-For") + require.Contains(t, res, "Remote Address") } func TestParseUserIP_RemoteAddrError(t *testing.T) { req, err := http.NewRequest("GET", "/teapot", nil) require.Nil(t, err) + + ip1 := "192.168.1.1:8080" + ip2 := "192.168.1.2" + ip3 := "192.168.1.3" + req.Header.Set("X-Forwarded-For", strings.Join([]string{ip1, ip2, ip3}, ",")) req.RemoteAddr = "0.0.0.0:3000" - _, err = parseUserIP(req) - require.Contains(t, err.Error(), "error parsing IP address") + res := parseUserIP(req) + require.Contains(t, res, "X-Forwarded-For") + require.Contains(t, res, ip1) + require.Contains(t, res, ip2) + require.NotContains(t, res, "Remote Address") } -func TestParseUserIP_HostnameError(t *testing.T) { +func TestParseUserIP_Success(t *testing.T) { req, err := http.NewRequest("GET", "/teapot", nil) require.Nil(t, err) - _, err = parseUserIP(req) - require.Contains(t, err.Error(), "hostname") + ip1 := "192.168.1.1:8080" + ip2 := "192.168.1.2" + ip3 := "192.168.1.3" + req.Header.Set("X-Forwarded-For", strings.Join([]string{ip1, ip2, ip3}, ",")) + + req.RemoteAddr = "http://www.google.com/0.0.0.0:3000" + + res := parseUserIP(req) + require.Contains(t, res, "X-Forwarded-For") + require.Contains(t, res, ip1) + require.Contains(t, res, ip2) + require.Contains(t, res, ip3) + require.Contains(t, res, "Remote Address") } func TestParsePreferHeaderWait(t *testing.T) {