Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HeaderFunc to allow modifying headers before every request #298

Merged
merged 2 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions client/clientimpl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,79 @@ func TestConnectWithHeader(t *testing.T) {
})
}

func TestConnectWithHeaderFunc(t *testing.T) {
testClients(t, func(t *testing.T, client OpAMPClient) {
// Start a server.
srv := internal.StartMockServer(t)
var conn atomic.Value
srv.OnConnect = func(r *http.Request) {
authHdr := r.Header.Get("Authorization")
assert.EqualValues(t, "Bearer 12345678", authHdr)
userAgentHdr := r.Header.Get("User-Agent")
assert.EqualValues(t, "custom-agent/1.0", userAgentHdr)
conn.Store(true)
}

hf := func(header http.Header) http.Header {
header.Set("Authorization", "Bearer 12345678")
header.Set("User-Agent", "custom-agent/1.0")
return header
}

// Start a client.
settings := types.StartSettings{
OpAMPServerURL: "ws://" + srv.Endpoint,
HeaderFunc: hf,
}
startClient(t, settings, client)

// Wait for connection to be established.
eventually(t, func() bool { return conn.Load() != nil })

// Shutdown the Server and the client.
srv.Close()
_ = client.Stop(context.Background())
})
}

func TestConnectWithHeaderAndHeaderFunc(t *testing.T) {
testClients(t, func(t *testing.T, client OpAMPClient) {
// Start a server.
srv := internal.StartMockServer(t)
var conn atomic.Value
srv.OnConnect = func(r *http.Request) {
authHdr := r.Header.Get("Authorization")
assert.EqualValues(t, "Bearer 12345678", authHdr)
userAgentHdr := r.Header.Get("User-Agent")
assert.EqualValues(t, "custom-agent/1.0", userAgentHdr)
conn.Store(true)
}

baseHeader := http.Header{}
baseHeader.Set("User-Agent", "custom-agent/1.0")

hf := func(header http.Header) http.Header {
header.Set("Authorization", "Bearer 12345678")
return header
}

// Start a client.
settings := types.StartSettings{
OpAMPServerURL: "ws://" + srv.Endpoint,
Header: baseHeader,
HeaderFunc: hf,
}
startClient(t, settings, client)

// Wait for connection to be established.
eventually(t, func() bool { return conn.Load() != nil })

// Shutdown the Server and the client.
srv.Close()
_ = client.Stop(context.Background())
})
}

func TestConnectWithTLS(t *testing.T) {
testClients(t, func(t *testing.T, client OpAMPClient) {
// Start a server.
Expand Down
2 changes: 1 addition & 1 deletion client/httpclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (c *httpClient) Start(ctx context.Context, settings types.StartSettings) er
c.opAMPServerURL = settings.OpAMPServerURL

// Prepare Server connection settings.
c.sender.SetRequestHeader(settings.Header)
c.sender.SetRequestHeader(settings.Header, settings.HeaderFunc)

// Add TLS configuration into httpClient
c.sender.AddTLSConfig(settings.TLSConfig)
Expand Down
33 changes: 24 additions & 9 deletions client/internal/httpsender.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ type HTTPSender struct {
compressionEnabled bool

// Headers to send with all requests.
requestHeader http.Header
getHeader func() http.Header

// Processor to handle received messages.
receiveProcessor receivedProcessor
Expand All @@ -75,7 +75,7 @@ func NewHTTPSender(logger types.Logger) *HTTPSender {
pollingIntervalMs: defaultPollingIntervalMs,
}
// initialize the headers with no additional headers
h.SetRequestHeader(nil)
h.SetRequestHeader(nil, nil)
return h
}

Expand Down Expand Up @@ -121,12 +121,26 @@ func (h *HTTPSender) Run(

// SetRequestHeader sets additional HTTP headers to send with all future requests.
// Should not be called concurrently with any other method.
func (h *HTTPSender) SetRequestHeader(header http.Header) {
if header == nil {
header = http.Header{}
func (h *HTTPSender) SetRequestHeader(baseHeaders http.Header, headerFunc func(http.Header) http.Header) {
if baseHeaders == nil {
baseHeaders = http.Header{}
}

if headerFunc == nil {
headerFunc = func(h http.Header) http.Header {
return h
}
}

h.getHeader = func() http.Header {
requestHeader := headerFunc(baseHeaders.Clone())
requestHeader.Set(headerContentType, contentTypeProtobuf)
if h.compressionEnabled {
requestHeader.Set(headerContentEncoding, encodingTypeGZip)
}

return requestHeader
}
h.requestHeader = header
h.requestHeader.Set(headerContentType, contentTypeProtobuf)
}

// makeOneRequestRoundtrip sends a request and receives a response.
Expand Down Expand Up @@ -255,7 +269,7 @@ func (h *HTTPSender) prepareRequest(ctx context.Context) (*requestWrapper, error
return nil, err
}

req.Header = h.requestHeader
req.Header = h.getHeader()
return &req, nil
}

Expand Down Expand Up @@ -295,9 +309,10 @@ func (h *HTTPSender) SetPollingInterval(duration time.Duration) {
atomic.StoreInt64(&h.pollingIntervalMs, duration.Milliseconds())
}

// EnableCompression enables compression for the sender.
// Should not be called concurrently with Run.
func (h *HTTPSender) EnableCompression() {
h.compressionEnabled = true
h.requestHeader.Set(headerContentEncoding, encodingTypeGZip)
}

func (h *HTTPSender) AddTLSConfig(config *tls.Config) {
Expand Down
5 changes: 5 additions & 0 deletions client/types/startsettings.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ type StartSettings struct {
// Optional additional HTTP headers to send with all HTTP requests.
Header http.Header

// Optional function that can be used to modify the HTTP headers
// before each HTTP request.
// Can modify and return the argument or return the argument without modifying.
HeaderFunc func(http.Header) http.Header
BinaryFissionGames marked this conversation as resolved.
Show resolved Hide resolved

// Optional TLS config for HTTP connection.
TLSConfig *tls.Config

Expand Down
20 changes: 17 additions & 3 deletions client/wsclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ type wsClient struct {
url *url.URL

// HTTP request headers to use when connecting to OpAMP Server.
requestHeader http.Header
getHeader func() http.Header

// Websocket dialer and connection.
dialer websocket.Dialer
Expand Down Expand Up @@ -86,7 +86,21 @@ func (c *wsClient) Start(ctx context.Context, settings types.StartSettings) erro
}
c.dialer.TLSClientConfig = settings.TLSConfig

c.requestHeader = settings.Header
headerFunc := settings.HeaderFunc
if headerFunc == nil {
headerFunc = func(h http.Header) http.Header {
return h
}
}

baseHeader := settings.Header
if baseHeader == nil {
baseHeader = http.Header{}
}

c.getHeader = func() http.Header {
return headerFunc(baseHeader.Clone())
}

c.common.StartConnectAndRun(c.runUntilStopped)

Expand Down Expand Up @@ -142,7 +156,7 @@ func (c *wsClient) SendCustomMessage(message *protobufs.CustomMessage) (messageS
// by the Server.
func (c *wsClient) tryConnectOnce(ctx context.Context) (retryAfter sharedinternal.OptionalDuration, err error) {
var resp *http.Response
conn, resp, err := c.dialer.DialContext(ctx, c.url.String(), c.requestHeader)
conn, resp, err := c.dialer.DialContext(ctx, c.url.String(), c.getHeader())
if err != nil {
if c.common.Callbacks != nil && !c.common.IsStopping() {
c.common.Callbacks.OnConnectFailed(ctx, err)
Expand Down
Loading