diff --git a/http/client.go b/http/client.go index d19c9fb9..27c2bf03 100644 --- a/http/client.go +++ b/http/client.go @@ -27,6 +27,7 @@ type client struct { httpClient *http.Client ua string apiPrefix string + headers map[string]string fallback cmds.Executor } @@ -40,6 +41,16 @@ func ClientWithUserAgent(ua string) ClientOpt { } } +// ClientWithHeader adds an HTTP header to the client. +func ClientWithHeader(key, value string) ClientOpt { + return func(c *client) { + if c.headers == nil { + c.headers = map[string]string{} + } + c.headers[key] = value + } +} + // ClientWithHTTPClient specifies a custom http.Client. Defaults to // http.DefaultClient. func ClientWithHTTPClient(hc *http.Client) ClientOpt { @@ -173,6 +184,10 @@ func (c *client) toHTTPRequest(req *cmds.Request) (*http.Request, error) { } httpReq.Header.Set(uaHeader, c.ua) + for key, val := range c.headers { + httpReq.Header.Set(key, val) + } + httpReq = httpReq.WithContext(req.Context) httpReq.Close = true diff --git a/http/client_test.go b/http/client_test.go index ef714d04..45d14cf3 100644 --- a/http/client_test.go +++ b/http/client_test.go @@ -91,3 +91,47 @@ func TestClientAPIPrefix(t *testing.T) { } } } + +func TestClientHeader(t *testing.T) { + type testcase struct { + host string + header string + value string + path []string + } + + tcs := []testcase{ + {header: "Authorization", value: "Bearer sdneijfnejvzfregfwe", path: []string{"version"}}, + {header: "Content-Type", value: "text/plain", path: []string{"version"}}, + } + + for _, tc := range tcs { + var called bool + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + t.Log(r) + + if token := r.Header.Get(tc.header); token != tc.value { + t.Errorf("expected authorization %q, got %q", tc.value, token) + } + + expPath := "/" + strings.Join(tc.path, "/") + if path := r.URL.Path; path != expPath { + t.Errorf("expected path %q, got %q", expPath, path) + } + + w.WriteHeader(http.StatusOK) + })) + testClient := s.Client() + tc.host = s.URL + r := &cmds.Request{Path: tc.path, Command: &cmds.Command{}, Root: &cmds.Command{}} + c := NewClient(tc.host, ClientWithHeader(tc.header, tc.value)).(*client) + c.httpClient = testClient + c.send(r) + + if !called { + t.Error("handler has not been called") + } + } +} diff --git a/http/config.go b/http/config.go index 69a02636..77d5aa65 100644 --- a/http/config.go +++ b/http/config.go @@ -85,6 +85,12 @@ func (cfg *ServerConfig) SetAllowCredentials(flag bool) { cfg.corsOpts.AllowCredentials = flag } +func (cfg *ServerConfig) AddAllowedHeaders(headers ...string) { + cfg.corsOptsRWMutex.Lock() + defer cfg.corsOptsRWMutex.Unlock() + cfg.corsOpts.AllowedHeaders = append(cfg.corsOpts.AllowedHeaders, headers...) +} + // allowOrigin just stops the request if the origin is not allowed. // the CORS middleware apparently does not do this for us... func allowOrigin(r *http.Request, cfg *ServerConfig) bool {