diff --git a/pkg/remote/client.go b/pkg/remote/client.go index eca4d8a..2cf4902 100644 --- a/pkg/remote/client.go +++ b/pkg/remote/client.go @@ -12,6 +12,8 @@ import ( "net/url" "time" + "github.com/grafana/xk6-output-prometheus-remote/pkg/sigv4" + prompb "buf.build/gen/go/prometheus/prometheus/protocolbuffers/go" "github.com/klauspost/compress/snappy" "google.golang.org/protobuf/proto" @@ -22,6 +24,7 @@ type HTTPConfig struct { Timeout time.Duration TLSConfig *tls.Config BasicAuth *BasicAuth + SigV4 *sigv4.Config Headers http.Header } @@ -60,6 +63,13 @@ func NewWriteClient(endpoint string, cfg *HTTPConfig) (*WriteClient, error) { TLSClientConfig: cfg.TLSConfig, } } + if cfg.SigV4 != nil { + tripper, err := sigv4.NewRoundTripper(cfg.SigV4, wc.hc.Transport) + if err != nil { + return nil, err + } + wc.hc.Transport = tripper + } return wc, nil } diff --git a/pkg/remotewrite/config.go b/pkg/remotewrite/config.go index 03837f9..72a03aa 100644 --- a/pkg/remotewrite/config.go +++ b/pkg/remotewrite/config.go @@ -3,12 +3,15 @@ package remotewrite import ( "crypto/tls" "encoding/json" + "errors" "fmt" "net/http" "strconv" "strings" "time" + "github.com/grafana/xk6-output-prometheus-remote/pkg/sigv4" + "github.com/grafana/xk6-output-prometheus-remote/pkg/remote" "go.k6.io/k6/lib/types" "gopkg.in/guregu/null.v3" @@ -68,6 +71,15 @@ type Config struct { TrendStats []string `json:"trendStats"` StaleMarkers null.Bool `json:"staleMarkers"` + + // SigV4Region is the AWS region where the workspace is. + SigV4Region null.String `json:"sigV4Region"` + + // SigV4AccessKey is the AWS access key. + SigV4AccessKey null.String `json:"sigV4AccessKey"` + + // SigV4SecretKey is the AWS secret key. + SigV4SecretKey null.String `json:"sigV4SecretKey"` } // NewConfig creates an Output's configuration. @@ -81,6 +93,9 @@ func NewConfig() Config { Headers: make(map[string]string), TrendStats: defaultTrendStats, StaleMarkers: null.BoolFrom(false), + SigV4Region: null.NewString("", false), + SigV4AccessKey: null.NewString("", false), + SigV4SecretKey: null.NewString("", false), } } @@ -110,6 +125,22 @@ func (conf Config) RemoteConfig() (*remote.HTTPConfig, error) { hc.TLSConfig.Certificates = []tls.Certificate{cert} } + if isSigV4PartiallyConfigured(conf.SigV4Region, conf.SigV4AccessKey, conf.SigV4SecretKey) { + return nil, errors.New( + "sigv4 seems to be partially configured. All of " + + "K6_PROMETHEUS_RW_SIGV4_REGION, K6_PROMETHEUS_RW_SIGV4_ACCESS_KEY, K6_PROMETHEUS_RW_SIGV4_SECRET_KEY " + + "must all be set. Unset all to bypass sigv4", + ) + } + + if conf.SigV4Region.Valid && conf.SigV4AccessKey.Valid && conf.SigV4SecretKey.Valid { + hc.SigV4 = &sigv4.Config{ + Region: conf.SigV4Region.String, + AwsAccessKeyID: conf.SigV4AccessKey.String, + AwsSecretAccessKey: conf.SigV4SecretKey.String, + } + } + if len(conf.Headers) > 0 { hc.Headers = make(http.Header) for k, v := range conf.Headers { @@ -149,6 +180,18 @@ func (conf Config) Apply(applied Config) Config { conf.BearerToken = applied.BearerToken } + if applied.SigV4Region.Valid { + conf.SigV4Region = applied.SigV4Region + } + + if applied.SigV4AccessKey.Valid { + conf.SigV4AccessKey = applied.SigV4AccessKey + } + + if applied.SigV4SecretKey.Valid { + conf.SigV4SecretKey = applied.SigV4SecretKey + } + if applied.PushInterval.Valid { conf.PushInterval = applied.PushInterval } @@ -299,6 +342,18 @@ func parseEnvs(env map[string]string) (Config, error) { } } + if sigV4Region, sigV4RegionDefined := env["K6_PROMETHEUS_RW_SIGV4_REGION"]; sigV4RegionDefined { + c.SigV4Region = null.StringFrom(sigV4Region) + } + + if sigV4AccessKey, sigV4AccessKeyDefined := env["K6_PROMETHEUS_RW_SIGV4_ACCESS_KEY"]; sigV4AccessKeyDefined { + c.SigV4AccessKey = null.StringFrom(sigV4AccessKey) + } + + if sigV4SecretKey, sigV4SecretKeyDefined := env["K6_PROMETHEUS_RW_SIGV4_SECRET_KEY"]; sigV4SecretKeyDefined { + c.SigV4SecretKey = null.StringFrom(sigV4SecretKey) + } + if b, err := envBool(env, "K6_PROMETHEUS_RW_TREND_AS_NATIVE_HISTOGRAM"); err != nil { return c, err } else if b.Valid { @@ -384,3 +439,12 @@ func parseArg(text string) (Config, error) { return c, nil } + +func isSigV4PartiallyConfigured(region, accessKey, secretKey null.String) bool { + hasRegion := region.Valid && len(strings.TrimSpace(region.String)) != 0 + hasAccessID := accessKey.Valid && len(strings.TrimSpace(accessKey.String)) != 0 + hasSecretAccessKey := secretKey.Valid && len(strings.TrimSpace(secretKey.String)) != 0 + // either they are all set, or all not set. False if partial + isComplete := (hasRegion && hasAccessID && hasSecretAccessKey) || (!hasRegion && !hasAccessID && !hasSecretAccessKey) + return !isComplete +} diff --git a/pkg/sigv4/const.go b/pkg/sigv4/const.go new file mode 100644 index 0000000..c8cc0dc --- /dev/null +++ b/pkg/sigv4/const.go @@ -0,0 +1,23 @@ +package sigv4 + +const ( + // Amazon Managed Service for Prometheus + awsServiceName = "aps" + + signingAlgorithm = "AWS4-HMAC-SHA256" + + authorizationHeaderKey = "Authorization" + amzDateKey = "X-Amz-Date" + + // emptyStringSHA256 is the hex encoded sha256 value of an empty string + emptyStringSHA256 = `e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855` + + // timeFormat is the time format to be used in the X-Amz-Date header or query parameter + timeFormat = "20060102T150405Z" + + // shortTimeFormat is the shorten time format used in the credential scope + shortTimeFormat = "20060102" + + // contentSHAKey is the SHA256 of request body + contentSHAKey = "X-Amz-Content-Sha256" +) diff --git a/pkg/sigv4/sigv4.go b/pkg/sigv4/sigv4.go new file mode 100644 index 0000000..24ec4ea --- /dev/null +++ b/pkg/sigv4/sigv4.go @@ -0,0 +1,249 @@ +// Package sigv4 is responsible to for aws sigv4 signing of requests +package sigv4 + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "net/http" + "net/url" + "sort" + "strconv" + "strings" + "time" +) + +type signer interface { + sign(req *http.Request) error +} + +type defaultSigner struct { + config *Config + + // noEscape represents the characters that AWS doesn't escape + noEscape [256]bool + + ignoredHeaders map[string]struct{} +} + +func newDefaultSigner(config *Config) signer { + ds := &defaultSigner{ + config: config, + noEscape: buildAwsNoEscape(), + ignoredHeaders: map[string]struct{}{ + "Authorization": {}, + "User-Agent": {}, + "X-Amzn-Trace-Id": {}, + "Expect": {}, + }, + } + + return ds +} + +func (d *defaultSigner) sign(req *http.Request) error { + now := time.Now().UTC() + iSO8601Date := now.Format(timeFormat) + + credentialScope := buildCredentialScope(now, d.config.Region) + + payloadHash, err := d.getPayloadHash(req) + if err != nil { + return err + } + + req.Header.Set("Host", req.Host) + req.Header.Set(amzDateKey, iSO8601Date) + req.Header.Set(contentSHAKey, payloadHash) + + signedHeadersStr, canonicalHeaderStr := buildCanonicalHeaders(req, d.ignoredHeaders) + + canonicalQueryString := getCanonicalQueryString(req.URL) + canonicalReq := buildCanonicalString( + req.Method, + getCanonicalURI(req.URL, d.noEscape), + canonicalQueryString, + canonicalHeaderStr, + signedHeadersStr, + payloadHash, + ) + + signature := sign( + deriveKey(d.config.AwsSecretAccessKey, d.config.Region), + buildStringToSign(iSO8601Date, credentialScope, canonicalReq), + ) + + authorizationHeader := fmt.Sprintf( + "%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", + signingAlgorithm, + d.config.AwsAccessKeyID, + credentialScope, + signedHeadersStr, + signature, + ) + + req.URL.RawQuery = canonicalQueryString + req.Header.Set(authorizationHeaderKey, authorizationHeader) + return nil +} + +func (d *defaultSigner) getPayloadHash(req *http.Request) (string, error) { + if req.Body == nil { + return emptyStringSHA256, nil + } + + reqBody, err := io.ReadAll(req.Body) + if err != nil { + return "", err + } + reqBodyBuffer := bytes.NewReader(reqBody) + + hash := sha256.New() + if _, err := io.Copy(hash, reqBodyBuffer); err != nil { + return "", err + } + + payloadHash := hex.EncodeToString(hash.Sum(nil)) + + // ensuring that we keep the request body intact for next tripper + req.Body = io.NopCloser(bytes.NewReader(reqBody)) + + return payloadHash, nil +} + +func buildCredentialScope(signingTime time.Time, region string) string { + return fmt.Sprintf( + "%s/%s/%s/aws4_request", + signingTime.UTC().Format(shortTimeFormat), + region, + awsServiceName, + ) +} + +func buildCanonicalString(method, uri, query, canonicalHeaders, signedHeaders, payloadHash string) string { + return strings.Join([]string{ + method, + uri, + query, + canonicalHeaders, + signedHeaders, + payloadHash, + }, "\n") +} + +// buildCanonicalHeaders is mostly ported from https://github.com/aws/aws-sdk-go-v2/aws/signer/v4 buildCanonicalHeaders +func buildCanonicalHeaders( + req *http.Request, + ignoredHeaders map[string]struct{}, +) (signedHeaders, canonicalHeadersStr string) { + const hostHeader, contentLengthHeader = "host", "content-length" + host, header, length := req.Host, req.Header, req.ContentLength + + signed := make(http.Header) + headers := append([]string{}, hostHeader) + signed[hostHeader] = append(signed[hostHeader], host) + + if length > 0 { + headers = append(headers, contentLengthHeader) + signed[contentLengthHeader] = append(signed[contentLengthHeader], strconv.FormatInt(length, 10)) + } + + for k, v := range header { + if _, ok := ignoredHeaders[k]; ok { + continue + } + + if strings.EqualFold(k, contentLengthHeader) { + // prevent signing already handled content-length header. + continue + } + + lowerCaseKey := strings.ToLower(k) + if _, ok := signed[lowerCaseKey]; ok { + // include additional values + signed[lowerCaseKey] = append(signed[lowerCaseKey], v...) + continue + } + + headers = append(headers, lowerCaseKey) + signed[lowerCaseKey] = v + } + + // aws requires headers to keys to be sorted + sort.Strings(headers) + signedHeaders = strings.Join(headers, ";") + + var canonicalHeaders strings.Builder + for _, h := range headers { + if h == hostHeader { + canonicalHeaders.WriteString(fmt.Sprintf("%s:%s\n", hostHeader, stripExcessSpaces(host))) + continue + } + + canonicalHeaders.WriteString(fmt.Sprintf("%s:", h)) + values := signed[h] + for j, v := range values { + cleanedValue := strings.TrimSpace(stripExcessSpaces(v)) + canonicalHeaders.WriteString(cleanedValue) + if j < len(values)-1 { + canonicalHeaders.WriteRune(',') + } + } + canonicalHeaders.WriteRune('\n') + } + canonicalHeadersStr = canonicalHeaders.String() + return signedHeaders, canonicalHeadersStr +} + +func getCanonicalURI(u *url.URL, noEscape [256]bool) string { + return escapePath(getURIPath(u), noEscape) +} + +func getCanonicalQueryString(u *url.URL) string { + query := u.Query() + + // Sort Each Query Key's Values + for key := range query { + sort.Strings(query[key]) + } + + var rawQuery strings.Builder + rawQuery.WriteString(strings.ReplaceAll(query.Encode(), "+", "%20")) + return rawQuery.String() +} + +func buildStringToSign(amzDate, credentialScope, canonicalRequestString string) string { + hash := sha256.New() + hash.Write([]byte(canonicalRequestString)) + return strings.Join([]string{ + signingAlgorithm, + amzDate, + credentialScope, + hex.EncodeToString(hash.Sum(nil)), + }, "\n") +} + +func deriveKey(secretKey, region string) string { + signingDate := time.Now().UTC().Format(shortTimeFormat) + hmacDate := hmacSHA256([]byte("AWS4"+secretKey), signingDate) + hmacRegion := hmacSHA256(hmacDate, region) + hmacService := hmacSHA256(hmacRegion, awsServiceName) + signingKey := hmacSHA256(hmacService, "aws4_request") + return string(signingKey) +} + +func hmacSHA256(key []byte, data string) []byte { + h := hmac.New(sha256.New, key) + h.Write([]byte(data)) + return h.Sum(nil) +} + +func sign(signingKey string, strToSign string) string { + h := hmac.New(sha256.New, []byte(signingKey)) + h.Write([]byte(strToSign)) + sig := hex.EncodeToString(h.Sum(nil)) + return sig +} diff --git a/pkg/sigv4/sigv4_test.go b/pkg/sigv4/sigv4_test.go new file mode 100644 index 0000000..38938e7 --- /dev/null +++ b/pkg/sigv4/sigv4_test.go @@ -0,0 +1,57 @@ +package sigv4 + +import ( + "context" + "net/http" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestBuildCanonicalHeaders(t *testing.T) { + t.Parallel() + + serviceName := "mockAPI" + region := "mock-region" + endpoint := "https://" + serviceName + "." + region + ".example.com" + + now := time.Now().UTC() + iSO8601Date := now.Format(timeFormat) + + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, endpoint, nil) + if err != nil { + t.Fatalf("failed to create request, %v", err) + } + + req.Header.Set("Host", req.Host) + req.Header.Set(amzDateKey, iSO8601Date) + req.Header.Set("InnerSpace", " inner space ") + req.Header.Set("LeadingSpace", " leading-space") + req.Header.Add("MultipleSpace", "no-space") + req.Header.Add("MultipleSpace", "\ttab-space") + req.Header.Add("MultipleSpace", "trailing-space ") + req.Header.Set("NoSpace", "no-space") + req.Header.Set("TabSpace", "\ttab-space\t") + req.Header.Set("TrailingSpace", "trailing-space ") + req.Header.Set("WrappedSpace", " wrapped-space ") + + wantSignedHeader := "host;innerspace;leadingspace;multiplespace;nospace;tabspace;trailingspace;wrappedspace;x-amz-date" + wantCanonicalHeader := strings.Join([]string{ + "host:mockAPI.mock-region.example.com", + "innerspace:inner space", + "leadingspace:leading-space", + "multiplespace:no-space,tab-space,trailing-space", + "nospace:no-space", + "tabspace:tab-space", + "trailingspace:trailing-space", + "wrappedspace:wrapped-space", + "x-amz-date:" + iSO8601Date, + "", + }, "\n") + + gotSignedHeaders, gotCanonicalHeader := buildCanonicalHeaders(req, nil) + assert.Equal(t, wantSignedHeader, gotSignedHeaders) + assert.Equal(t, wantCanonicalHeader, gotCanonicalHeader) +} diff --git a/pkg/sigv4/tripper.go b/pkg/sigv4/tripper.go new file mode 100644 index 0000000..2b75747 --- /dev/null +++ b/pkg/sigv4/tripper.go @@ -0,0 +1,60 @@ +package sigv4 + +import ( + "errors" + "net/http" + "strings" +) + +// Tripper signs each request with sigv4 +type Tripper struct { + config *Config + signer signer + next http.RoundTripper +} + +// Config holds aws access configurations +type Config struct { + Region string + AwsAccessKeyID string + AwsSecretAccessKey string +} + +func (c *Config) validate() error { + if c == nil { + return errors.New("config should not be nil") + } + hasRegion := len(strings.TrimSpace(c.Region)) != 0 + hasAccessID := len(strings.TrimSpace(c.AwsAccessKeyID)) != 0 + hasSecretAccessKey := len(strings.TrimSpace(c.AwsSecretAccessKey)) != 0 + if !hasRegion || !hasAccessID || !hasSecretAccessKey { + return errors.New("sigV4 config `Region`, `AwsAccessKeyID`, `AwsSecretAccessKey` must all be set") + } + return nil +} + +// NewRoundTripper creates a new sigv4 round tripper +func NewRoundTripper(config *Config, next http.RoundTripper) (*Tripper, error) { + if err := config.validate(); err != nil { + return nil, err + } + + if next == nil { + next = http.DefaultTransport + } + + tripper := &Tripper{ + config: config, + next: next, + signer: newDefaultSigner(config), + } + return tripper, nil +} + +// RoundTrip implements the tripper interface for sigv4 signing of requests +func (c *Tripper) RoundTrip(req *http.Request) (*http.Response, error) { + if err := c.signer.sign(req); err != nil { + return nil, err + } + return c.next.RoundTrip(req) +} diff --git a/pkg/sigv4/tripper_test.go b/pkg/sigv4/tripper_test.go new file mode 100644 index 0000000..a6e5c03 --- /dev/null +++ b/pkg/sigv4/tripper_test.go @@ -0,0 +1,99 @@ +package sigv4 + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTripper_request_includes_required_headers(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check if required headers are present + authorization := r.Header.Get(authorizationHeaderKey) + amzDate := r.Header.Get(amzDateKey) + contentSHA256 := r.Header.Get(contentSHAKey) + + // Respond to the request + w.WriteHeader(http.StatusOK) + + assert.NotEmptyf(t, authorization, "%s header should be present", authorizationHeaderKey) + assert.NotEmptyf(t, amzDate, "%s header should be present", amzDateKey) + assert.NotEmpty(t, contentSHA256, "%s header should be present", contentSHAKey) + })) + defer server.Close() + + client := http.Client{} + tripper, err := NewRoundTripper(&Config{ + Region: "us-east1", + AwsSecretAccessKey: "xyz", + AwsAccessKeyID: "abc", + }, http.DefaultTransport) + if err != nil { + t.Fatal(err) + } + client.Transport = tripper + + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, server.URL, nil) + if err != nil { + t.Fatal(err) + } + + response, _ := client.Do(req) + _ = response.Body.Close() +} + +func TestConfig_Validation(t *testing.T) { + t.Parallel() + + testCases := []struct { + shouldError bool + arg *Config + }{ + { + shouldError: false, + arg: &Config{ + Region: "us-east1", + AwsAccessKeyID: "someAccessKey", + AwsSecretAccessKey: "someSecretKey", + }, + }, + { + shouldError: true, + arg: nil, + }, + { + shouldError: true, + arg: &Config{ + Region: "us-east1", + }, + }, + { + shouldError: true, + arg: &Config{ + Region: "us-east1", + AwsAccessKeyID: "someAccessKeyId", + }, + }, + { + shouldError: true, + arg: &Config{ + AwsAccessKeyID: "SomeAccessKey", + AwsSecretAccessKey: "SomeSecretKey", + }, + }, + } + + for _, tc := range testCases { + got := tc.arg.validate() + if tc.shouldError { + assert.Error(t, got) + continue + } + assert.NoError(t, got) + } +} diff --git a/pkg/sigv4/util_test.go b/pkg/sigv4/util_test.go new file mode 100644 index 0000000..ceb19f8 --- /dev/null +++ b/pkg/sigv4/util_test.go @@ -0,0 +1,159 @@ +package sigv4 + +import ( + "net/url" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestStripExcessSpaces(t *testing.T) { + t.Parallel() + + testcases := []struct { + arg string + want string + }{ + { + arg: `AWS4-HMAC-SHA256 Credential=AKIDFAKEIDFAKEID/20160628/us-west-2/s3/aws4_request, SignedHeaders=host;x-amz-date, Signature=1234567890abcdef1234567890abcdef1234567890abcdef`, + want: `AWS4-HMAC-SHA256 Credential=AKIDFAKEIDFAKEID/20160628/us-west-2/s3/aws4_request, SignedHeaders=host;x-amz-date, Signature=1234567890abcdef1234567890abcdef1234567890abcdef`, + }, + { + arg: "a b c d", + want: "a b c d", + }, + { + arg: " abc def ghi jk ", + want: "abc def ghi jk", + }, + { + arg: " 123 456 789 101112 ", + want: "123 456 789 101112", + }, + { + arg: "12 3 1abc123", + want: "12 3 1abc123", + }, + { + arg: "aaa \t bb", + want: "aaa bb", + }, + } + + for _, tc := range testcases { + assert.Equal(t, tc.want, stripExcessSpaces(tc.arg)) + } +} + +func TestGetUriPath(t *testing.T) { + t.Parallel() + + testcases := map[string]struct { + arg string + want string + }{ + "schema and port": { + arg: "https://localhost:9000", + want: "/", + }, + "schema and no port": { + arg: "https://localhost", + want: "/", + }, + "no schema": { + arg: "localhost:9000", + want: "/", + }, + "no schema + path": { + arg: "localhost:9000/abc123", + want: "/abc123", + }, + "no schema, with separator": { + arg: "//localhost:9000", + want: "/", + }, + "no scheme, no port, with separator": { + arg: "//localhost", + want: "/", + }, + "no scheme, with separator, with path": { + arg: "//localhost:9000/abc123", + want: "/abc123", + }, + "no scheme, no port, with separator, with path": { + arg: "//localhost/abc123", + want: "/abc123", + }, + "no schema, query string": { + arg: "localhost:9000/abc123?efg=456", + want: "/abc123", + }, + } + for name, tc := range testcases { + u, err := url.Parse(tc.arg) + if err != nil { + t.Fatal(err) + } + + got := getURIPath(u) + if tc.want != got { + t.Fatalf("test %v failed, want %v got %v \n", name, tc.want, got) + } + } +} + +func TestGetUriPath_invalid_url_noescape(t *testing.T) { + t.Parallel() + + arg := &url.URL{ + Opaque: "//example.org/bucket/key-._~,!@#$%^&*()", + } + + want := "/bucket/key-._~,!@#$%^&*()" + got := getURIPath(arg) + assert.Equal(t, want, got) +} + +func TestEscapePath(t *testing.T) { + t.Parallel() + + testcases := []struct { + arg string + want string + }{ + { + arg: "/", + want: "/", + }, + { + arg: "/abc", + want: "/abc", + }, + { + arg: "/abc129", + want: "/abc129", + }, + { + arg: "/abc-def", + want: "/abc-def", + }, + { + arg: "/abc.xyz~123-456", + want: "/abc.xyz~123-456", + }, + { + arg: "/abc def-ghi", + want: "/abc%20def-ghi", + }, + { + arg: "abc!def ghi", + want: "abc%21def%20ghi", + }, + } + + noEscape := buildAwsNoEscape() + + for _, tc := range testcases { + assert.Equal(t, tc.want, escapePath(tc.arg, noEscape)) + } +} diff --git a/pkg/sigv4/utils.go b/pkg/sigv4/utils.go new file mode 100644 index 0000000..123a5e0 --- /dev/null +++ b/pkg/sigv4/utils.go @@ -0,0 +1,105 @@ +package sigv4 + +import ( + "bytes" + "fmt" + "net/url" + "strings" +) + +func buildAwsNoEscape() [256]bool { + var noEscape [256]bool + + for i := 0; i < len(noEscape); i++ { + // AWS expects every character except these to be escaped + noEscape[i] = (i >= 'A' && i <= 'Z') || + (i >= 'a' && i <= 'z') || + (i >= '0' && i <= '9') || + i == '-' || + i == '.' || + i == '_' || + i == '~' || + i == '/' + } + return noEscape +} + +// escapePath escapes part of a URL path in Amazon style. +// except for the noEscape provided. +// inspired by github.com/aws/smithy-go/encoding/httpbinding EscapePath method +func escapePath(path string, noEscape [256]bool) string { + var buf bytes.Buffer + for i := 0; i < len(path); i++ { + c := path[i] + if noEscape[c] { + buf.WriteByte(c) + continue + } + fmt.Fprintf(&buf, "%%%02X", c) + } + return buf.String() +} + +// stripExcessSpaces will remove the leading and trailing spaces, and side-by-side spaces are converted +// into a single space. +func stripExcessSpaces(str string) string { + if !strings.Contains(str, " ") && !strings.Contains(str, "\t") { + return str + } + + builder := strings.Builder{} + lastFoundSpace := -1 + const space = ' ' + str = strings.TrimSpace(str) + for i := 0; i < len(str); i++ { + if str[i] == space || str[i] == '\t' { + lastFoundSpace = i + continue + } + + if lastFoundSpace > 0 && builder.Len() != 0 { + builder.WriteByte(space) + } + builder.WriteByte(str[i]) + lastFoundSpace = -1 + } + return builder.String() +} + +// getURIPath returns the escaped URI component from the provided URL. +// Ported from inspired by github.com/aws/aws-sdk-go-v2/aws/signer/internal/v4 GetURIPath +func getURIPath(u *url.URL) string { + var uriPath string + + opaque := u.Opaque + if len(opaque) == 0 { + uriPath = u.EscapedPath() + } + + if len(opaque) == 0 && len(uriPath) == 0 { + return "/" + } + + const schemeSep, pathSep, queryStart = "//", "/", "?" + + // Cutout the scheme separator if present. + if strings.Index(opaque, schemeSep) == 0 { + opaque = opaque[len(schemeSep):] + } + + // Cut off the query string if present. + if idx := strings.Index(opaque, queryStart); idx >= 0 { + opaque = opaque[:idx] + } + + // capture URI path starting with first path separator. + if idx := strings.Index(opaque, pathSep); idx >= 0 { + uriPath = opaque[idx:] + } + + if len(uriPath) == 0 { + uriPath = "/" + } + + return uriPath +}