Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add retryable http client and use in FGA module #356

Merged
merged 4 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pkg/fga/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/google/go-querystring/query"
"github.com/workos/workos-go/v4/internal/workos"
"github.com/workos/workos-go/v4/pkg/common"
"github.com/workos/workos-go/v4/pkg/retryablehttp"
"github.com/workos/workos-go/v4/pkg/workos_errors"
)

Expand Down Expand Up @@ -42,7 +43,7 @@ type Client struct {

// The http.Client that is used to get FGA records from WorkOS.
// Defaults to http.Client.
HTTPClient *http.Client
HTTPClient *retryablehttp.HttpClient

// The endpoint to WorkOS API. Defaults to https://api.workos.com.
Endpoint string
Expand All @@ -55,7 +56,7 @@ type Client struct {

func (c *Client) init() {
if c.HTTPClient == nil {
c.HTTPClient = &http.Client{Timeout: 10 * time.Second}
c.HTTPClient = &retryablehttp.HttpClient{Client: http.Client{Timeout: 10 * time.Second}}
}

if c.Endpoint == "" {
Expand Down
27 changes: 14 additions & 13 deletions pkg/fga/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/stretchr/testify/require"
"github.com/workos/workos-go/v4/pkg/common"
"github.com/workos/workos-go/v4/pkg/retryablehttp"
)

func TestGetResource(t *testing.T) {
Expand Down Expand Up @@ -49,7 +50,7 @@ func TestGetResource(t *testing.T) {

client := test.client
client.Endpoint = server.URL
client.HTTPClient = server.Client()
client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()}

resource, err := client.GetResource(context.Background(), test.options)
if test.err {
Expand Down Expand Up @@ -130,7 +131,7 @@ func TestListResources(t *testing.T) {

client := test.client
client.Endpoint = server.URL
client.HTTPClient = server.Client()
client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()}

resources, err := client.ListResources(context.Background(), test.options)
if test.err {
Expand Down Expand Up @@ -240,7 +241,7 @@ func TestListResourceTypes(t *testing.T) {

client := test.client
client.Endpoint = server.URL
client.HTTPClient = server.Client()
client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()}

resourceTypes, err := client.ListResourceTypes(context.Background(), test.options)
if test.err {
Expand Down Expand Up @@ -367,7 +368,7 @@ func TestBatchUpdateResourceTypes(t *testing.T) {

client := test.client
client.Endpoint = server.URL
client.HTTPClient = server.Client()
client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()}

resourceTypes, err := client.BatchUpdateResourceTypes(context.Background(), test.options)
if test.err {
Expand Down Expand Up @@ -488,7 +489,7 @@ func TestCreateResource(t *testing.T) {

client := test.client
client.Endpoint = server.URL
client.HTTPClient = server.Client()
client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()}

resource, err := client.CreateResource(context.Background(), test.options)
if test.err {
Expand Down Expand Up @@ -578,7 +579,7 @@ func TestUpdateResource(t *testing.T) {

client := test.client
client.Endpoint = server.URL
client.HTTPClient = server.Client()
client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()}

resource, err := client.UpdateResource(context.Background(), test.options)
if test.err {
Expand Down Expand Up @@ -665,7 +666,7 @@ func TestDeleteResource(t *testing.T) {

client := test.client
client.Endpoint = server.URL
client.HTTPClient = server.Client()
client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()}

err := client.DeleteResource(context.Background(), test.options)
if test.err {
Expand Down Expand Up @@ -765,7 +766,7 @@ func TestListWarrants(t *testing.T) {

client := test.client
client.Endpoint = server.URL
client.HTTPClient = server.Client()
client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()}

resources, err := client.ListWarrants(context.Background(), test.options)
if test.err {
Expand Down Expand Up @@ -907,7 +908,7 @@ func TestWriteWarrant(t *testing.T) {

client := test.client
client.Endpoint = server.URL
client.HTTPClient = server.Client()
client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()}

warrantResponse, err := client.WriteWarrant(context.Background(), test.options)
if test.err {
Expand Down Expand Up @@ -973,7 +974,7 @@ func TestBatchWriteWarrants(t *testing.T) {

client := test.client
client.Endpoint = server.URL
client.HTTPClient = server.Client()
client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()}

warrantResponse, err := client.BatchWriteWarrants(context.Background(), test.options)
if test.err {
Expand Down Expand Up @@ -1057,7 +1058,7 @@ func TestCheck(t *testing.T) {

client := test.client
client.Endpoint = server.URL
client.HTTPClient = server.Client()
client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()}

checkResult, err := client.Check(context.Background(), test.options)
if test.err {
Expand Down Expand Up @@ -1157,7 +1158,7 @@ func TestCheckBatch(t *testing.T) {

client := test.client
client.Endpoint = server.URL
client.HTTPClient = server.Client()
client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()}

checkResults, err := client.CheckBatch(context.Background(), test.options)
if test.err {
Expand Down Expand Up @@ -1256,7 +1257,7 @@ func TestQuery(t *testing.T) {

client := test.client
client.Endpoint = server.URL
client.HTTPClient = server.Client()
client.HTTPClient = &retryablehttp.HttpClient{Client: *server.Client()}

queryResults, err := client.Query(context.Background(), test.options)
if test.err {
Expand Down
27 changes: 14 additions & 13 deletions pkg/fga/fga_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@ import (

"github.com/stretchr/testify/require"
"github.com/workos/workos-go/v4/pkg/common"
"github.com/workos/workos-go/v4/pkg/retryablehttp"
)

func TestFGAGetResource(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(getResourceTestHandler))
defer server.Close()

DefaultClient = &Client{
HTTPClient: server.Client(),
HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()},
Endpoint: server.URL,
}
SetAPIKey("test")
Expand All @@ -38,7 +39,7 @@ func TestFGAListResources(t *testing.T) {
defer server.Close()

DefaultClient = &Client{
HTTPClient: server.Client(),
HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()},
Endpoint: server.URL,
}
SetAPIKey("test")
Expand Down Expand Up @@ -72,7 +73,7 @@ func TestFGACreateResource(t *testing.T) {
defer server.Close()

DefaultClient = &Client{
HTTPClient: server.Client(),
HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()},
Endpoint: server.URL,
}
SetAPIKey("test")
Expand All @@ -95,7 +96,7 @@ func TestFGAUpdateResource(t *testing.T) {
defer server.Close()

DefaultClient = &Client{
HTTPClient: server.Client(),
HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()},
Endpoint: server.URL,
}
SetAPIKey("test")
Expand Down Expand Up @@ -124,7 +125,7 @@ func TestFGADeleteResource(t *testing.T) {
defer server.Close()

DefaultClient = &Client{
HTTPClient: server.Client(),
HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()},
Endpoint: server.URL,
}
SetAPIKey("test")
Expand All @@ -142,7 +143,7 @@ func TestFGAListResourceTypes(t *testing.T) {
defer server.Close()

DefaultClient = &Client{
HTTPClient: server.Client(),
HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()},
Endpoint: server.URL,
}
SetAPIKey("test")
Expand Down Expand Up @@ -184,7 +185,7 @@ func TestFGABatchUpdateResourceTypes(t *testing.T) {
defer server.Close()

DefaultClient = &Client{
HTTPClient: server.Client(),
HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()},
Endpoint: server.URL,
}
SetAPIKey("test")
Expand Down Expand Up @@ -235,7 +236,7 @@ func TestFGAListWarrants(t *testing.T) {
defer server.Close()

DefaultClient = &Client{
HTTPClient: server.Client(),
HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()},
Endpoint: server.URL,
}
SetAPIKey("test")
Expand Down Expand Up @@ -279,7 +280,7 @@ func TestFGAWriteWarrant(t *testing.T) {
defer server.Close()

DefaultClient = &Client{
HTTPClient: server.Client(),
HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()},
Endpoint: server.URL,
}
SetAPIKey("test")
Expand Down Expand Up @@ -307,7 +308,7 @@ func TestFGABatchWriteWarrants(t *testing.T) {
defer server.Close()

DefaultClient = &Client{
HTTPClient: server.Client(),
HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()},
Endpoint: server.URL,
}
SetAPIKey("test")
Expand Down Expand Up @@ -347,7 +348,7 @@ func TestFGACheck(t *testing.T) {
defer server.Close()

DefaultClient = &Client{
HTTPClient: server.Client(),
HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()},
Endpoint: server.URL,
}
SetAPIKey("test")
Expand Down Expand Up @@ -375,7 +376,7 @@ func TestFGACheckBatch(t *testing.T) {
defer server.Close()

DefaultClient = &Client{
HTTPClient: server.Client(),
HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()},
Endpoint: server.URL,
}
SetAPIKey("test")
Expand Down Expand Up @@ -405,7 +406,7 @@ func TestFGAQuery(t *testing.T) {
defer server.Close()

DefaultClient = &Client{
HTTPClient: server.Client(),
HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()},
Endpoint: server.URL,
}
SetAPIKey("test")
Expand Down
110 changes: 110 additions & 0 deletions pkg/retryablehttp/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package retryablehttp

import (
"io"
"math"
"math/rand"
"net/http"
"time"
)

const MaxRetryAttempts = 3
const MinimumDelay = 500
const MinimumDelayDuration = 250 * time.Millisecond
const MaximumDelayDuration = 5 * time.Second
const RandomizationFactor = 0.5
const BackoffMultiplier = 1.5

type HttpClient struct {
http.Client
}

func (client *HttpClient) Do(req *http.Request) (*http.Response, error) {
var res *http.Response
var err error
for retry := 0; ; {
// Reset the request body for each retry
if req.Body != nil {
body, err := req.GetBody()
if err != nil {
client.CloseIdleConnections()
return res, err
}
if c, ok := body.(io.ReadCloser); ok {
req.Body = c
} else {
req.Body = io.NopCloser(body)
}
}

res, err = client.Client.Do(req)
if err != nil {
break
}

shouldRetry := client.shouldRetry(req, res, err, retry)

if !shouldRetry {
break
}

sleepTime := client.sleepTime(retry)
retry++

timer := time.NewTimer(sleepTime)
select {
case <-req.Context().Done():
timer.Stop()
client.CloseIdleConnections()
return nil, req.Context().Err()
case <-timer.C:
}
}

if err != nil {
return nil, err
}

return res, nil
}

func (client *HttpClient) shouldRetry(req *http.Request, resp *http.Response, err error, retryAttempt int) bool {
if retryAttempt >= MaxRetryAttempts {
return false
}

if err != nil {
return true
}

if resp.StatusCode >= http.StatusInternalServerError {
return true
}

return false
}

// Calculates backoff time using exponential backoff with 50% jitter.
//
// Backoff times
// Retry Attempt | Sleep Time
// 1 | 500ms +/- 250ms
// 2 | 750ms +/- 375ms
// 3 | 1.125s +/- 562ms
func (client *HttpClient) sleepTime(retryAttempt int) time.Duration {
sleepTime := time.Duration(MinimumDelay*int64(math.Pow(BackoffMultiplier, float64(retryAttempt)))) * time.Millisecond

delta := RandomizationFactor * float64(sleepTime)
minSleep := float64(sleepTime) - delta
maxSleep := float64(sleepTime) + delta

sleepTime = time.Duration(minSleep + (rand.Float64() * (maxSleep - minSleep + 1)))

if sleepTime < MinimumDelayDuration {
sleepTime = MinimumDelayDuration
} else if sleepTime > MaximumDelayDuration {
sleepTime = MaximumDelayDuration
}

return sleepTime
}
Loading
Loading