diff --git a/client/influxdb.go b/client/influxdb.go index fe66e934818..7740fb972ec 100644 --- a/client/influxdb.go +++ b/client/influxdb.go @@ -23,11 +23,13 @@ type Query struct { // URL: The URL of the server connecting to. // Username/Password are optional. They will be passed via basic auth if provided. // UserAgent: If not provided, will default "InfluxDBClient", +//Timeout: If not provided, will default to 0 (no timeout) type Config struct { URL url.URL Username string Password string UserAgent string + Timeout time.Duration } // Client is used to make calls to the server. @@ -45,7 +47,7 @@ func NewClient(c Config) (*Client, error) { url: c.URL, username: c.Username, password: c.Password, - httpClient: http.DefaultClient, + httpClient: &http.Client{Timeout: c.Timeout}, userAgent: c.UserAgent, } if client.userAgent == "" { diff --git a/client/influxdb_test.go b/client/influxdb_test.go index 857bb1dee96..901f0f44be4 100644 --- a/client/influxdb_test.go +++ b/client/influxdb_test.go @@ -427,3 +427,35 @@ func TestBatchPoints_Normal(t *testing.T) { t.Errorf("failed to unmarshal nanosecond data: %s", err.Error()) } } + +func TestClient_Timeout(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(1 * time.Second) + var data influxdb.Response + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(data) + })) + defer ts.Close() + + u, _ := url.Parse(ts.URL) + config := client.Config{URL: *u, Timeout: 500 * time.Millisecond} + c, err := client.NewClient(config) + if err != nil { + t.Fatalf("unexpected error. expected %v, actual %v", nil, err) + } + + query := client.Query{} + _, err = c.Query(query) + if err == nil { + t.Fatalf("unexpected success. expected timeout error") + } else if !strings.Contains(err.Error(), "use of closed network connection") { + t.Fatalf("unexpected error. expected 'use of closed network connection' error, got %v", err) + } + + confignotimeout := client.Config{URL: *u} + cnotimeout, err := client.NewClient(confignotimeout) + _, err = cnotimeout.Query(query) + if err != nil { + t.Fatalf("unexpected error. expected %v, actual %v", nil, err) + } +}