Skip to content

Commit

Permalink
- Added attributions for ported code
Browse files Browse the repository at this point in the history
- Added config validations for tripper
- Added comment for aps
- Removed noEscape to signer internal state
  • Loading branch information
obanby committed Sep 27, 2024
1 parent c25240c commit 7e10b47
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 12 deletions.
4 changes: 3 additions & 1 deletion pkg/sigv4/const.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package sigv4

const (
awsServiceName = "aps"
// Amazon Managed Service for Prometheus
awsServiceName = "aps"

signingAlgorithm = "AWS4-HMAC-SHA256"

authorizationHeaderKey = "Authorization"
Expand Down
22 changes: 15 additions & 7 deletions pkg/sigv4/sigv4.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,20 @@ type Signer interface {

type DefaultSigner struct {
config *Config

// noEscape represents the characters that AWS doesn't escape
noEscape [256]bool
}

func NewDefaultSigner(config *Config) Signer {
// initialize noEscape array. This way we can avoid using init() functions
for i := 0; i < len(noEscape); i++ {
ds := &DefaultSigner{
config: config,
noEscape: [256]bool{},
}

for i := 0; i < len(ds.noEscape); i++ {
// AWS expects every character except these to be escaped
noEscape[i] = (i >= 'A' && i <= 'Z') ||
ds.noEscape[i] = (i >= 'A' && i <= 'Z') ||
(i >= 'a' && i <= 'z') ||
(i >= '0' && i <= '9') ||
i == '-' ||
Expand All @@ -36,7 +43,7 @@ func NewDefaultSigner(config *Config) Signer {
i == '~'
}

return &DefaultSigner{config: config}
return ds
}

func (d *DefaultSigner) Sign(req *http.Request) error {
Expand All @@ -59,7 +66,7 @@ func (d *DefaultSigner) Sign(req *http.Request) error {
canonicalQueryString := getCanonicalQueryString(req.URL)
canonicalReq := buildCanonicalString(
req.Method,
getCanonicalURI(req.URL),
getCanonicalURI(req.URL, d.noEscape),
canonicalQueryString,
canonicalHeaderStr,
signedHeadersStr,
Expand Down Expand Up @@ -136,6 +143,7 @@ var ignoredHeaders = map[string]struct{}{
"Expect": struct{}{},
}

// buildCanonicalHeaders is mostly ported from https://github.com/aws/aws-sdk-go-v2/aws/signer/v4 buildCanonicalHeaders
func buildCanonicalHeaders(req *http.Request) (signed http.Header, signedHeaders, canonicalHeadersStr string) {
host, header, length := req.Host, req.Header, req.ContentLength

Expand Down Expand Up @@ -203,8 +211,8 @@ func buildCanonicalHeaders(req *http.Request) (signed http.Header, signedHeaders
return signed, signedHeaders, canonicalHeadersStr
}

func getCanonicalURI(u *url.URL) string {
return escapePath(getURIPath(u), false)
func getCanonicalURI(u *url.URL, noEscape [256]bool) string {
return escapePath(getURIPath(u), false, noEscape)
}

func getCanonicalQueryString(u *url.URL) string {
Expand Down
15 changes: 14 additions & 1 deletion pkg/sigv4/tripper.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package sigv4
import (
"errors"
"net/http"
"strings"
)

type Tripper struct {
Expand All @@ -22,10 +23,22 @@ func NewRoundTripper(config *Config, next http.RoundTripper) (*Tripper, error) {
return nil, errors.New("can't initialize a sigv4 round tripper with nil config")
}

if len(strings.TrimSpace(config.Region)) == 0 {
return nil, errors.New("sigV4 config `Region` must be set")
}

if len(strings.TrimSpace(config.AwsSecretAccessKey)) == 0 {
return nil, errors.New("sigV4 config `AwsSecretAccessKey` must be set")
}

if len(strings.TrimSpace(config.AwsAccessKeyID)) == 0 {
return nil, errors.New("sigV4 config `AwsAccessKeyID` must be set")
}

if next == nil {
next = http.DefaultTransport
}

tripper := &Tripper{
config: config,
next: next,
Expand Down
7 changes: 4 additions & 3 deletions pkg/sigv4/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@ import (
"strings"
)

var noEscape [256]bool

// escapePath escapes part of a URL path in Amazon style.
func escapePath(path string, encodeSep bool) string {
// inspired by github.com/aws/smithy-go/encoding/httpbinding EscapePath method
func escapePath(path string, encodeSep bool, noEscape [256]bool) string {
var buf bytes.Buffer
for i := 0; i < len(path); i++ {
c := path[i]
Expand All @@ -25,6 +24,7 @@ func escapePath(path string, encodeSep bool) string {

// stripExcessSpaces will rewrite the passed in slice's string values to not
// contain multiple side-by-side spaces.
// Ported from github.com/aws/aws-sdk-go-v2/aws/signer/internal/v4 StripExcessSpaces
func stripExcessSpaces(str string) string {
const doubleSpace = " "

Expand Down Expand Up @@ -65,6 +65,7 @@ func stripExcessSpaces(str string) string {
}

// getURIPath returns the escaped URI component from the provided URL.
// Ported from github.com/aws/aws-sdk-go-v2/aws/signer/internal/v4 GetURIPath
func getURIPath(u *url.URL) string {
var uriPath string

Expand Down

0 comments on commit 7e10b47

Please sign in to comment.