From 7e10b479978cbd988e3f526118c1c18e17469dbb Mon Sep 17 00:00:00 2001 From: obanby Date: Fri, 27 Sep 2024 11:05:57 -0400 Subject: [PATCH] - Added attributions for ported code - Added config validations for tripper - Added comment for aps - Removed noEscape to signer internal state --- pkg/sigv4/const.go | 4 +++- pkg/sigv4/sigv4.go | 22 +++++++++++++++------- pkg/sigv4/tripper.go | 15 ++++++++++++++- pkg/sigv4/utils.go | 7 ++++--- 4 files changed, 36 insertions(+), 12 deletions(-) diff --git a/pkg/sigv4/const.go b/pkg/sigv4/const.go index 4afd291..c8cc0dc 100644 --- a/pkg/sigv4/const.go +++ b/pkg/sigv4/const.go @@ -1,7 +1,9 @@ package sigv4 const ( - awsServiceName = "aps" + // Amazon Managed Service for Prometheus + awsServiceName = "aps" + signingAlgorithm = "AWS4-HMAC-SHA256" authorizationHeaderKey = "Authorization" diff --git a/pkg/sigv4/sigv4.go b/pkg/sigv4/sigv4.go index ce5cd84..c7623ed 100644 --- a/pkg/sigv4/sigv4.go +++ b/pkg/sigv4/sigv4.go @@ -21,13 +21,20 @@ type Signer interface { type DefaultSigner struct { config *Config + + // noEscape represents the characters that AWS doesn't escape + noEscape [256]bool } func NewDefaultSigner(config *Config) Signer { - // initialize noEscape array. This way we can avoid using init() functions - for i := 0; i < len(noEscape); i++ { + ds := &DefaultSigner{ + config: config, + noEscape: [256]bool{}, + } + + for i := 0; i < len(ds.noEscape); i++ { // AWS expects every character except these to be escaped - noEscape[i] = (i >= 'A' && i <= 'Z') || + ds.noEscape[i] = (i >= 'A' && i <= 'Z') || (i >= 'a' && i <= 'z') || (i >= '0' && i <= '9') || i == '-' || @@ -36,7 +43,7 @@ func NewDefaultSigner(config *Config) Signer { i == '~' } - return &DefaultSigner{config: config} + return ds } func (d *DefaultSigner) Sign(req *http.Request) error { @@ -59,7 +66,7 @@ func (d *DefaultSigner) Sign(req *http.Request) error { canonicalQueryString := getCanonicalQueryString(req.URL) canonicalReq := buildCanonicalString( req.Method, - getCanonicalURI(req.URL), + getCanonicalURI(req.URL, d.noEscape), canonicalQueryString, canonicalHeaderStr, signedHeadersStr, @@ -136,6 +143,7 @@ var ignoredHeaders = map[string]struct{}{ "Expect": struct{}{}, } +// buildCanonicalHeaders is mostly ported from https://github.com/aws/aws-sdk-go-v2/aws/signer/v4 buildCanonicalHeaders func buildCanonicalHeaders(req *http.Request) (signed http.Header, signedHeaders, canonicalHeadersStr string) { host, header, length := req.Host, req.Header, req.ContentLength @@ -203,8 +211,8 @@ func buildCanonicalHeaders(req *http.Request) (signed http.Header, signedHeaders return signed, signedHeaders, canonicalHeadersStr } -func getCanonicalURI(u *url.URL) string { - return escapePath(getURIPath(u), false) +func getCanonicalURI(u *url.URL, noEscape [256]bool) string { + return escapePath(getURIPath(u), false, noEscape) } func getCanonicalQueryString(u *url.URL) string { diff --git a/pkg/sigv4/tripper.go b/pkg/sigv4/tripper.go index 31fd6a3..ddc6e75 100644 --- a/pkg/sigv4/tripper.go +++ b/pkg/sigv4/tripper.go @@ -3,6 +3,7 @@ package sigv4 import ( "errors" "net/http" + "strings" ) type Tripper struct { @@ -22,10 +23,22 @@ func NewRoundTripper(config *Config, next http.RoundTripper) (*Tripper, error) { return nil, errors.New("can't initialize a sigv4 round tripper with nil config") } + if len(strings.TrimSpace(config.Region)) == 0 { + return nil, errors.New("sigV4 config `Region` must be set") + } + + if len(strings.TrimSpace(config.AwsSecretAccessKey)) == 0 { + return nil, errors.New("sigV4 config `AwsSecretAccessKey` must be set") + } + + if len(strings.TrimSpace(config.AwsAccessKeyID)) == 0 { + return nil, errors.New("sigV4 config `AwsAccessKeyID` must be set") + } + if next == nil { next = http.DefaultTransport } - + tripper := &Tripper{ config: config, next: next, diff --git a/pkg/sigv4/utils.go b/pkg/sigv4/utils.go index db3bf71..6d149ca 100644 --- a/pkg/sigv4/utils.go +++ b/pkg/sigv4/utils.go @@ -7,10 +7,9 @@ import ( "strings" ) -var noEscape [256]bool - // escapePath escapes part of a URL path in Amazon style. -func escapePath(path string, encodeSep bool) string { +// inspired by github.com/aws/smithy-go/encoding/httpbinding EscapePath method +func escapePath(path string, encodeSep bool, noEscape [256]bool) string { var buf bytes.Buffer for i := 0; i < len(path); i++ { c := path[i] @@ -25,6 +24,7 @@ func escapePath(path string, encodeSep bool) string { // stripExcessSpaces will rewrite the passed in slice's string values to not // contain multiple side-by-side spaces. +// Ported from github.com/aws/aws-sdk-go-v2/aws/signer/internal/v4 StripExcessSpaces func stripExcessSpaces(str string) string { const doubleSpace = " " @@ -65,6 +65,7 @@ func stripExcessSpaces(str string) string { } // getURIPath returns the escaped URI component from the provided URL. +// Ported from github.com/aws/aws-sdk-go-v2/aws/signer/internal/v4 GetURIPath func getURIPath(u *url.URL) string { var uriPath string