diff --git a/config/http_config.go b/config/http_config.go index 07de306b..17a9f2eb 100644 --- a/config/http_config.go +++ b/config/http_config.go @@ -17,11 +17,13 @@ package config import ( "bytes" + "context" "crypto/sha256" "crypto/tls" "crypto/x509" "fmt" "io/ioutil" + "net" "net/http" "net/url" "strings" @@ -194,6 +196,24 @@ func (a *BasicAuth) UnmarshalYAML(unmarshal func(interface{}) error) error { return unmarshal((*plain)(a)) } +// DialContextFunc defines the signature of the DialContext() function implemented +// by net.Dialer. +type DialContextFunc func(context.Context, string, string) (net.Conn, error) + +type httpClientOptions struct { + dialContextFunc DialContextFunc +} + +// HTTPClientOption defines an option that can be applied to the HTTP client. +type HTTPClientOption func(options *httpClientOptions) + +// WithDialContextFunc allows you to override func gets used for the actual dialing. The default is `net.Dialer.DialContext`. +func WithDialContextFunc(fn DialContextFunc) HTTPClientOption { + return func(opts *httpClientOptions) { + opts.dialContextFunc = fn + } +} + // NewClient returns a http.Client using the specified http.RoundTripper. func newClient(rt http.RoundTripper) *http.Client { return &http.Client{Transport: rt} @@ -201,8 +221,8 @@ func newClient(rt http.RoundTripper) *http.Client { // NewClientFromConfig returns a new HTTP client configured for the // given config.HTTPClientConfig. The name is used as go-conntrack metric label. -func NewClientFromConfig(cfg HTTPClientConfig, name string, disableKeepAlives, enableHTTP2 bool) (*http.Client, error) { - rt, err := NewRoundTripperFromConfig(cfg, name, disableKeepAlives, enableHTTP2) +func NewClientFromConfig(cfg HTTPClientConfig, name string, disableKeepAlives, enableHTTP2 bool, optFuncs ...HTTPClientOption) (*http.Client, error) { + rt, err := NewRoundTripperFromConfig(cfg, name, disableKeepAlives, enableHTTP2, optFuncs...) if err != nil { return nil, err } @@ -217,7 +237,25 @@ func NewClientFromConfig(cfg HTTPClientConfig, name string, disableKeepAlives, e // NewRoundTripperFromConfig returns a new HTTP RoundTripper configured for the // given config.HTTPClientConfig. The name is used as go-conntrack metric label. -func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, disableKeepAlives, enableHTTP2 bool) (http.RoundTripper, error) { +func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, disableKeepAlives, enableHTTP2 bool, optFuncs ...HTTPClientOption) (http.RoundTripper, error) { + opts := &httpClientOptions{} + for _, f := range optFuncs { + f(opts) + } + + var dialContext func(ctx context.Context, network, addr string) (net.Conn, error) + + if opts.dialContextFunc != nil { + dialContext = conntrack.NewDialContextFunc( + conntrack.DialWithDialContextFunc((func(context.Context, string, string) (net.Conn, error))(opts.dialContextFunc)), + conntrack.DialWithTracing(), + conntrack.DialWithName(name)) + } else { + dialContext = conntrack.NewDialContextFunc( + conntrack.DialWithTracing(), + conntrack.DialWithName(name)) + } + newRT := func(tlsConfig *tls.Config) (http.RoundTripper, error) { // The only timeout we care about is the configured scrape timeout. // It is applied on request. So we leave out any timings here. @@ -233,10 +271,7 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, disableKeepAli IdleConnTimeout: 5 * time.Minute, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, - DialContext: conntrack.NewDialContextFunc( - conntrack.DialWithTracing(), - conntrack.DialWithName(name), - ), + DialContext: dialContext, } if enableHTTP2 { // HTTP/2 support is golang has many problematic cornercases where diff --git a/config/http_config_test.go b/config/http_config_test.go index cf8ae7db..61406c77 100644 --- a/config/http_config_test.go +++ b/config/http_config_test.go @@ -16,10 +16,13 @@ package config import ( + "context" "crypto/tls" "crypto/x509" + "errors" "fmt" "io/ioutil" + "net" "net/http" "net/http/httptest" "os" @@ -50,6 +53,7 @@ const ( MissingKey = "missing/secret.key" ExpectedMessage = "I'm here to serve you!!!" + ExpectedError = "expected error" AuthorizationCredentials = "theanswertothegreatquestionoflifetheuniverseandeverythingisfortytwo" AuthorizationCredentialsFile = "testdata/bearer.token" AuthorizationType = "APIKEY" @@ -413,6 +417,23 @@ func TestNewClientFromInvalidConfig(t *testing.T) { } } +func TestCustomDialContextFunc(t *testing.T) { + dialFn := func(_ context.Context, _, _ string) (net.Conn, error) { + return nil, errors.New(ExpectedError) + } + + cfg := HTTPClientConfig{} + client, err := NewClientFromConfig(cfg, "test", false, true, WithDialContextFunc(dialFn)) + if err != nil { + t.Fatalf("Can't create a client from this config: %+v", cfg) + } + + _, err = client.Get("http://localhost") + if err == nil || !strings.Contains(err.Error(), ExpectedError) { + t.Errorf("Expected error %q but got %q", ExpectedError, err) + } +} + func TestMissingBearerAuthFile(t *testing.T) { cfg := HTTPClientConfig{ BearerTokenFile: MissingBearerTokenFile,