Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added context timeout configuration for API requests #331

Merged
merged 2 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions client/nginx.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ var (
defaultBackup = false
defaultDown = false
defaultWeight = 1
defaultTimeout = 10 * time.Second
)

// ErrUnsupportedVer means that client's API version is not supported by NGINX plus API.
Expand All @@ -46,6 +47,7 @@ type NginxClient struct {
apiEndpoint string
apiVersion int
checkAPI bool
ctxTimeout time.Duration
}

type Option func(*NginxClient)
Expand Down Expand Up @@ -546,13 +548,21 @@ func WithCheckAPI() Option {
}
}

// WithTimeout sets the timeout per request for the client.
func WithTimeout(duration time.Duration) Option {
return func(o *NginxClient) {
o.ctxTimeout = duration
}
}

// NewNginxClient creates a new NginxClient.
func NewNginxClient(apiEndpoint string, opts ...Option) (*NginxClient, error) {
c := &NginxClient{
httpClient: http.DefaultClient,
apiEndpoint: apiEndpoint,
apiVersion: APIVersion,
checkAPI: false,
ctxTimeout: defaultTimeout,
}

for _, opt := range opts {
Expand All @@ -567,8 +577,12 @@ func NewNginxClient(apiEndpoint string, opts ...Option) (*NginxClient, error) {
return nil, fmt.Errorf("API version %v is not supported by the client", c.apiVersion)
}

if c.ctxTimeout <= 0 {
return nil, fmt.Errorf("timeout has to be greater than 0 %v", c.ctxTimeout)
}

if c.checkAPI {
versions, err := getAPIVersions(c.httpClient, apiEndpoint)
versions, err := c.getAPIVersions(c.httpClient, apiEndpoint)
if err != nil {
return nil, fmt.Errorf("error accessing the API: %w", err)
}
Expand Down Expand Up @@ -596,8 +610,8 @@ func versionSupported(n int) bool {
return false
}

func getAPIVersions(httpClient *http.Client, endpoint string) (*versions, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
func (client *NginxClient) getAPIVersions(httpClient *http.Client, endpoint string) (*versions, error) {
ctx, cancel := context.WithTimeout(context.Background(), client.ctxTimeout)
defer cancel()

req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
Expand Down Expand Up @@ -852,7 +866,7 @@ func (client *NginxClient) getIDOfHTTPServer(upstream string, name string) (int,
}

func (client *NginxClient) get(path string, data interface{}) error {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), client.ctxTimeout)
defer cancel()

url := fmt.Sprintf("%v/%v/%v", client.apiEndpoint, client.apiVersion, path)
Expand Down Expand Up @@ -886,7 +900,7 @@ func (client *NginxClient) get(path string, data interface{}) error {
}

func (client *NginxClient) post(path string, input interface{}) error {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), client.ctxTimeout)
defer cancel()

url := fmt.Sprintf("%v/%v/%v", client.apiEndpoint, client.apiVersion, path)
Expand Down Expand Up @@ -918,7 +932,7 @@ func (client *NginxClient) post(path string, input interface{}) error {
}

func (client *NginxClient) delete(path string, expectedStatusCode int) error {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), client.ctxTimeout)
defer cancel()

path = fmt.Sprintf("%v/%v/%v/", client.apiEndpoint, client.apiVersion, path)
Expand All @@ -943,7 +957,7 @@ func (client *NginxClient) delete(path string, expectedStatusCode int) error {
}

func (client *NginxClient) patch(path string, input interface{}, expectedStatusCode int) error {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), client.ctxTimeout)
defer cancel()

path = fmt.Sprintf("%v/%v/%v/", client.apiEndpoint, client.apiVersion, path)
Expand Down
22 changes: 22 additions & 0 deletions client/nginx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"reflect"
"strings"
"testing"
"time"
)

func TestDetermineUpdates(t *testing.T) {
Expand Down Expand Up @@ -578,6 +579,27 @@ func TestClientWithAPIVersion(t *testing.T) {
}
}

func TestClientWithTimeout(t *testing.T) {
t.Parallel()
// Test creating a new client with a supported API version on the client
client, err := NewNginxClient("http://api-url", WithTimeout(1*time.Second))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if client == nil {
t.Fatalf("client is nil")
}

// Test creating a new client with an invalid duration
client, err = NewNginxClient("http://api-url", WithTimeout(-1*time.Second))
if err == nil {
t.Fatalf("expected error, but got nil")
}
if client != nil {
t.Fatalf("expected client to be nil, but got %v", client)
}
}

func TestClientWithHTTPClient(t *testing.T) {
t.Parallel()
// Test creating a new client passing a custom HTTP client
Expand Down
Loading