diff --git a/internal/retryafter.go b/internal/retryafter.go index 784473b..593cd2c 100644 --- a/internal/retryafter.go +++ b/internal/retryafter.go @@ -18,26 +18,32 @@ type OptionalDuration struct { } func parseDelaySeconds(s string) (time.Duration, error) { + // Verify duration parsed properly n, err := strconv.Atoi(s) + if err != nil { + return 0, errCouldNotParseRetryAfterHeader + } - // Verify duration parsed properly and bigger than 0 - if err == nil && n > 0 { + // If n > 0 return n seconds, otherwise return 0 + if n > 0 { duration := time.Duration(n) * time.Second return duration, nil } - return 0, errCouldNotParseRetryAfterHeader + return 0, nil } func parseHTTPDate(s string) (time.Duration, error) { + // Verify duration parsed properly t, err := http.ParseTime(s) + if err != nil { + return 0, errCouldNotParseRetryAfterHeader + } - // Verify duration parsed properly and bigger than 0 - if err == nil { - if duration := time.Until(t); duration > 0 { - return duration, nil - } + // If the date is in the future return that duration, otherwise return 0 + if duration := time.Until(t); duration > 0 { + return duration, nil } - return 0, errCouldNotParseRetryAfterHeader + return 0, nil } // ExtractRetryAfterHeader extracts Retry-After response header if the status diff --git a/internal/retryafter_test.go b/internal/retryafter_test.go index f2e36e9..123ada1 100644 --- a/internal/retryafter_test.go +++ b/internal/retryafter_test.go @@ -50,9 +50,10 @@ func TestExtractRetryAfterHeaderDelaySeconds(t *testing.T) { resp.StatusCode = http.StatusBadGateway assertUndefinedDuration(t, ExtractRetryAfterHeader(resp)) - // Verify no duration is created for n < 0 + // Verify a zero duration is created for n < 0 + resp.StatusCode = http.StatusTooManyRequests resp.Header.Set(retryAfterHTTPHeader, strconv.Itoa(-1)) - assertUndefinedDuration(t, ExtractRetryAfterHeader(resp)) + assertDuration(t, ExtractRetryAfterHeader(resp), 0) } func TestExtractRetryAfterHeaderHttpDate(t *testing.T) { @@ -81,7 +82,11 @@ func TestExtractRetryAfterHeaderHttpDate(t *testing.T) { resp.Header.Set(retryAfterHTTPHeader, retryAfter.Format(time.RFC1123)) assertUndefinedDuration(t, ExtractRetryAfterHeader(resp)) - // Verify no duration is created for n < 0 + // Verify a zero duration is created for n = 0 + resp.Header.Set(retryAfterHTTPHeader, now.UTC().Format(http.TimeFormat)) + assertDuration(t, ExtractRetryAfterHeader(resp), 0) + + // Verify a zero duration is created for n < 0 resp.Header.Set(retryAfterHTTPHeader, now.Add(-1*time.Second).UTC().Format(http.TimeFormat)) - assertUndefinedDuration(t, ExtractRetryAfterHeader(resp)) + assertDuration(t, ExtractRetryAfterHeader(resp), 0) }