Skip to content

Commit

Permalink
Add possibility to use proxy also for upstream, not just providers
Browse files Browse the repository at this point in the history
  • Loading branch information
p53 authored Apr 5, 2024
1 parent 77a991f commit 442a95a
Show file tree
Hide file tree
Showing 5 changed files with 235 additions and 14 deletions.
2 changes: 2 additions & 0 deletions pkg/keycloak/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ type Config struct {
OpenIDProviderRetryCount int `env:"OPENID_PROVIDER_RETRY_COUNT" json:"openid-provider-retry-count" usage:"number of retries for retrieving openid configuration" yaml:"openid-provider-retry-count"`
// OpenIDProviderHeaders
OpenIDProviderHeaders map[string]string `json:"openid-provider-headers" usage:"http headers sent to idp provider" yaml:"openid-provider-headers"`
// UpstreamProxy proxy for upstream communication
UpstreamProxy string `env:"UPSTREAM_PROXY" json:"upstream-proxy" usage:"proxy for communication with upstream" yaml:"upstream-proxy"`
// BaseURI is prepended to all the generated URIs
BaseURI string `env:"BASE_URI" json:"base-uri" usage:"common prefix for all URIs" yaml:"base-uri"`
// OAuthURI is the uri for the oauth endpoints for the proxy
Expand Down
8 changes: 7 additions & 1 deletion pkg/keycloak/proxy/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1259,7 +1259,6 @@ func (r *OauthProxy) createUpstreamProxy(upstream *url.URL) error {
// and for refreshed cookies (htts://github.com/louketo/louketo-proxy/pulls/456])
proxy.KeepDestinationHeaders = true
proxy.Logger = httplog.New(io.Discard, "", 0)
proxy.KeepDestinationHeaders = true
r.Upstream = proxy

// update the tls configuration of the reverse proxy
Expand All @@ -1269,8 +1268,15 @@ func (r *OauthProxy) createUpstreamProxy(upstream *url.URL) error {
return apperrors.ErrAssertionFailed
}

var upstreamProxyFunc func(*http.Request) (*url.URL, error)
if r.Config.UpstreamProxy != "" {
upstreamProxyFunc = func(req *http.Request) (*url.URL, error) {
return url.Parse(r.Config.UpstreamProxy)
}
}
upstreamProxy.Tr = &http.Transport{
Dial: dialer,
Proxy: upstreamProxyFunc,
DisableKeepAlives: !r.Config.UpstreamKeepalives,
ExpectContinueTimeout: r.Config.UpstreamExpectContinueTimeout,
ResponseHeaderTimeout: r.Config.UpstreamResponseHeaderTimeout,
Expand Down
4 changes: 4 additions & 0 deletions pkg/testsuite/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ const (
FakeCertFilePrefix = "/gateadmin_crt_"
FakePrivFilePrefix = "/gateadmin_priv_"
FakeCaFilePrefix = "/gateadmin_ca_"
TestProxyHeaderKey = "X-GoProxy"
TestProxyHeaderVal = "yxorPoG-X"
)

var ErrCreateFakeProxy = errors.New("failed to create fake proxy service")
var ErrRunHTTPServer = errors.New("failed to run http server")
var ErrShutHTTPServer = errors.New("failed to shutdown http server")
20 changes: 20 additions & 0 deletions pkg/testsuite/fake_upstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ package testsuite
import (
"encoding/json"
"io"
"net"
"net/http"
"strings"

"github.com/elazarl/goproxy"
"golang.org/x/net/websocket"
)

Expand Down Expand Up @@ -69,3 +71,21 @@ func (f *FakeUpstreamService) ServeHTTP(wrt http.ResponseWriter, req *http.Reque
_, _ = wrt.Write(content)
}
}

func createTestProxy() (*http.Server, net.Listener, error) {
proxy := goproxy.NewProxyHttpServer()
proxy.OnRequest().DoFunc(
func(r *http.Request, ctx *goproxy.ProxyCtx) (*http.Request, *http.Response) {
r.Header.Set(TestProxyHeaderKey, TestProxyHeaderVal)
return r, nil
},
)
proxyHTTPServer := &http.Server{
Handler: proxy,
}
ln, err := net.Listen("tcp", randomLocalHost)
if err != nil {
return nil, nil, err
}
return proxyHTTPServer, ln, nil
}
215 changes: 202 additions & 13 deletions pkg/testsuite/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ limitations under the License.
package testsuite

import (
"context"
"crypto/tls"
"errors"
"fmt"
"math/rand"
"net/http"
Expand Down Expand Up @@ -211,7 +213,37 @@ func TestAuthTokenHeader(t *testing.T) {
}

func TestForwardingProxy(t *testing.T) {
server := httptest.NewServer(&FakeUpstreamService{})
errChan := make(chan error)
upProxy, lstn, err := createTestProxy()
upstreamProxyURL := fmt.Sprintf("http://%s", lstn.Addr().String())
if err != nil {
t.Fatal(err)
}

go func() {
errChan <- upProxy.Serve(lstn)
}()

fakeUpstream := httptest.NewServer(&FakeUpstreamService{})
upstreamConfig := newFakeKeycloakConfig()
upstreamConfig.EnableUma = true
upstreamConfig.NoRedirects = true
upstreamConfig.EnableDefaultDeny = true
upstreamConfig.ClientID = ValidUsername
upstreamConfig.ClientSecret = ValidPassword
upstreamConfig.PatRetryCount = 5
upstreamConfig.PatRetryInterval = 2 * time.Second
upstreamConfig.Upstream = fakeUpstream.URL
// in newFakeProxy we are creating fakeauth server so, we will
// have two different fakeauth servers for upstream and forwarding,
// so we need to skip issuer check, but responses will be same
// so it is ok for this testing
upstreamConfig.SkipAccessTokenIssuerCheck = true

upstreamProxy := newFakeProxy(
upstreamConfig,
&fakeAuthConfig{Expiration: 900 * time.Millisecond},
)

testCases := []struct {
Name string
Expand All @@ -232,7 +264,7 @@ func TestForwardingProxy(t *testing.T) {
},
ExecutionSettings: []fakeRequest{
{
URL: server.URL + FakeTestURL,
URL: upstreamProxy.getServiceURL() + FakeTestURL,
ProxyRequest: true,
ExpectedProxy: true,
ExpectedCode: http.StatusOK,
Expand All @@ -253,15 +285,15 @@ func TestForwardingProxy(t *testing.T) {
},
ExecutionSettings: []fakeRequest{
{
URL: server.URL + FakeTestURL,
URL: upstreamProxy.getServiceURL() + FakeTestURL,
ProxyRequest: true,
ExpectedProxy: true,
ExpectedCode: http.StatusOK,
OnResponse: delay,
ExpectedContentContains: "Bearer ey",
},
{
URL: server.URL + FakeTestURL,
URL: upstreamProxy.getServiceURL() + FakeTestURL,
ProxyRequest: true,
ExpectedProxy: true,
ExpectedCode: http.StatusOK,
Expand All @@ -282,11 +314,21 @@ func TestForwardingProxy(t *testing.T) {
},
ExecutionSettings: []fakeRequest{
{
URL: server.URL + FakeTestURL,
URL: upstreamProxy.getServiceURL() + FakeTestURL,
ProxyRequest: true,
ExpectedProxy: true,
ExpectedCode: http.StatusOK,
ExpectedContentContains: "Bearer ey",
Method: "POST",
FormValues: map[string]string{
"Name": "Whatever",
},
ExpectedContent: func(body string, testNum int) {
assert.Contains(t, body, FakeTestURL)
assert.Contains(t, body, "method")
assert.Contains(t, body, "Whatever")
assert.NotContains(t, body, TestProxyHeaderVal)
},
},
},
},
Expand All @@ -303,38 +345,89 @@ func TestForwardingProxy(t *testing.T) {
},
ExecutionSettings: []fakeRequest{
{
URL: server.URL + FakeTestURL,
URL: upstreamProxy.getServiceURL() + FakeTestURL,
ProxyRequest: true,
ExpectedProxy: true,
ExpectedCode: http.StatusOK,
OnResponse: delay,
ExpectedContentContains: "Bearer ey",
},
{
URL: server.URL + FakeTestURL,
URL: upstreamProxy.getServiceURL() + FakeTestURL,
ProxyRequest: true,
ExpectedProxy: true,
ExpectedCode: http.StatusOK,
ExpectedContentContains: "Bearer ey",
},
},
},
{
// forwardingProxy -> middleProxy -> our backend upstreamProxy
Name: "TestClientCredentialsGrantWithMiddleProxy",
ProxySettings: func(conf *config.Config) {
conf.EnableForwarding = true
conf.ForwardingDomains = []string{}
conf.ClientID = ValidUsername
conf.ClientSecret = ValidPassword
conf.ForwardingGrantType = configcore.GrantTypeClientCreds
conf.PatRetryCount = 5
conf.PatRetryInterval = 2 * time.Second
conf.UpstreamProxy = upstreamProxyURL
},
ExecutionSettings: []fakeRequest{
{
URL: upstreamProxy.getServiceURL() + FakeTestURL,
ProxyRequest: true,
ExpectedProxy: true,
ExpectedCode: http.StatusOK,
ExpectedContentContains: "Bearer ey",
Method: "POST",
FormValues: map[string]string{
"Name": "Whatever",
},
ExpectedContent: func(body string, testNum int) {
assert.Contains(t, body, FakeTestURL)
assert.Contains(t, body, "method")
assert.Contains(t, body, "Whatever")
assert.Contains(t, body, TestProxyHeaderVal)
},
},
},
},
}

for _, testCase := range testCases {
testCase := testCase
t.Run(
testCase.Name,
func(t *testing.T) {
c := newFakeKeycloakConfig()
c.Upstream = server.URL
testCase.ProxySettings(c)
p := newFakeProxy(c, &fakeAuthConfig{Expiration: 900 * time.Millisecond})
forwardingConfig := newFakeKeycloakConfig()

testCase.ProxySettings(forwardingConfig)
forwardingProxy := newFakeProxy(
forwardingConfig,
&fakeAuthConfig{},
)

<-time.After(time.Duration(100) * time.Millisecond)
p.RunTests(t, testCase.ExecutionSettings)
forwardingProxy.RunTests(t, testCase.ExecutionSettings)
},
)
}

select {
case err = <-errChan:
if err != nil && !errors.Is(err, http.ErrServerClosed) {
t.Fatal(errors.Join(ErrRunHTTPServer, err))
}
default:
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
err = upProxy.Shutdown(ctx)
if err != nil {
t.Fatal(errors.Join(ErrShutHTTPServer, err))
}
}
}

func TestUmaForwardingProxy(t *testing.T) {
Expand Down Expand Up @@ -447,7 +540,6 @@ func TestUmaForwardingProxy(t *testing.T) {
testCase.Name,
func(t *testing.T) {
forwardingConfig := newFakeKeycloakConfig()
forwardingConfig.Upstream = upstreamProxy.getServiceURL()

testCase.ProxySettings(forwardingConfig)
forwardingProxy := newFakeProxy(
Expand Down Expand Up @@ -2055,3 +2147,100 @@ func TestCustomHTTPMethod(t *testing.T) {
)
}
}

func TestUpstreamProxy(t *testing.T) {
errChan := make(chan error)
upstream := httptest.NewServer(&FakeUpstreamService{})
upstreamProxy, lstn, err := createTestProxy()
upstreamProxyURL := fmt.Sprintf("http://%s", lstn.Addr().String())
if err != nil {
t.Fatal(err)
}

go func() {
errChan <- upstreamProxy.Serve(lstn)
}()

testCases := []struct {
Name string
ProxySettings func(c *config.Config)
ExecutionSettings []fakeRequest
}{
{
Name: "TestUpstreamProxy",
ProxySettings: func(c *config.Config) {
c.UpstreamProxy = upstreamProxyURL
c.Upstream = upstream.URL
},
ExecutionSettings: []fakeRequest{
{
URI: "/test",
Method: "POST",
FormValues: map[string]string{
"Name": "Whatever",
},
ExpectedProxy: true,
ExpectedCode: http.StatusOK,
ExpectedContentContains: "gzip",
ExpectedContent: func(body string, testNum int) {
assert.Contains(t, body, FakeTestURL)
assert.Contains(t, body, "method")
assert.Contains(t, body, "Whatever")
assert.Contains(t, body, TestProxyHeaderVal)
},
},
},
},
{
Name: "TestNoUpstreamProxy",
ProxySettings: func(c *config.Config) {
c.Upstream = upstream.URL
},
ExecutionSettings: []fakeRequest{
{
URI: FakeTestURL,
Method: "POST",
FormValues: map[string]string{
"Name": "Whatever",
},
ExpectedProxy: true,
ExpectedCode: http.StatusOK,
ExpectedContentContains: "gzip",
ExpectedContent: func(body string, testNum int) {
assert.Contains(t, body, FakeTestURL)
assert.Contains(t, body, "method")
assert.Contains(t, body, "Whatever")
assert.NotContains(t, body, TestProxyHeaderVal)
},
},
},
},
}

for _, testCase := range testCases {
testCase := testCase
t.Run(
testCase.Name,
func(t *testing.T) {
c := newFakeKeycloakConfig()
testCase.ProxySettings(c)
p := newFakeProxy(c, &fakeAuthConfig{})
p.RunTests(t, testCase.ExecutionSettings)
},
)
}

select {
case err = <-errChan:
if err != nil && !errors.Is(err, http.ErrServerClosed) {
t.Fatal(errors.Join(ErrRunHTTPServer, err))
}
default:
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
err = upstreamProxy.Shutdown(ctx)
if err != nil {
t.Fatal(errors.Join(ErrShutHTTPServer, err))
}
}
}

0 comments on commit 442a95a

Please sign in to comment.