Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support sigv4 signing #169

Merged
merged 14 commits into from
Oct 28, 2024
9 changes: 9 additions & 0 deletions pkg/remote/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"context"
"crypto/tls"
"fmt"
"github.com/grafana/xk6-output-prometheus-remote/pkg/sigv4"
"io"
"math"
"net/http"
Expand All @@ -22,6 +23,7 @@ type HTTPConfig struct {
Timeout time.Duration
TLSConfig *tls.Config
BasicAuth *BasicAuth
SigV4 *sigv4.Config
Headers http.Header
}

Expand Down Expand Up @@ -60,6 +62,13 @@ func NewWriteClient(endpoint string, cfg *HTTPConfig) (*WriteClient, error) {
TLSClientConfig: cfg.TLSConfig,
}
}
if cfg.SigV4 != nil {
tripper, err := sigv4.NewRoundTripper(cfg.SigV4, wc.hc.Transport)
if err != nil {
return nil, err
}
wc.hc.Transport = tripper
}
return wc, nil
}

Expand Down
45 changes: 45 additions & 0 deletions pkg/remotewrite/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"crypto/tls"
"encoding/json"
"fmt"
"github.com/grafana/xk6-output-prometheus-remote/pkg/sigv4"
"net/http"
"strconv"
"strings"
Expand Down Expand Up @@ -68,6 +69,15 @@ type Config struct {
TrendStats []string `json:"trendStats"`

StaleMarkers null.Bool `json:"staleMarkers"`

// SigV4Region is the AWS region where the workspace is.
SigV4Region null.String `json:"sigV4Region"`

// SigV4AccessKey is the AWS access key.
SigV4AccessKey null.String `json:"sigV4AccessKey"`

// SigV4SecretKey is the AWS secret key.
SigV4SecretKey null.String `json:"sigV4SecretKey"`
}

// NewConfig creates an Output's configuration.
Expand All @@ -81,6 +91,9 @@ func NewConfig() Config {
Headers: make(map[string]string),
TrendStats: defaultTrendStats,
StaleMarkers: null.BoolFrom(false),
SigV4Region: null.NewString("", false),
SigV4AccessKey: null.NewString("", false),
SigV4SecretKey: null.NewString("", false),
}
}

Expand Down Expand Up @@ -110,6 +123,14 @@ func (conf Config) RemoteConfig() (*remote.HTTPConfig, error) {
hc.TLSConfig.Certificates = []tls.Certificate{cert}
}

if conf.SigV4Region.Valid && conf.SigV4AccessKey.Valid && conf.SigV4SecretKey.Valid {
hc.SigV4 = &sigv4.Config{
Region: conf.SigV4Region.String,
AwsAccessKeyID: conf.SigV4AccessKey.String,
AwsSecretAccessKey: conf.SigV4SecretKey.String,
}
}

if len(conf.Headers) > 0 {
hc.Headers = make(http.Header)
for k, v := range conf.Headers {
Expand Down Expand Up @@ -149,6 +170,18 @@ func (conf Config) Apply(applied Config) Config {
conf.BearerToken = applied.BearerToken
}

if applied.SigV4Region.Valid {
conf.SigV4Region = applied.SigV4Region
}

if applied.SigV4AccessKey.Valid {
conf.SigV4AccessKey = applied.SigV4AccessKey
}

if applied.SigV4SecretKey.Valid {
conf.SigV4SecretKey = applied.SigV4SecretKey
}

if applied.PushInterval.Valid {
conf.PushInterval = applied.PushInterval
}
Expand Down Expand Up @@ -299,6 +332,18 @@ func parseEnvs(env map[string]string) (Config, error) {
}
}

if sigV4Region, sigV4RegionDefined := env["K6_PROMETHEUS_RW_SIGV4_REGION"]; sigV4RegionDefined {
c.SigV4Region = null.StringFrom(sigV4Region)
}

if sigV4AccessKey, sigV4AccessKeyDefined := env["K6_PROMETHEUS_RW_SIGV4_ACCESS_KEY"]; sigV4AccessKeyDefined {
c.SigV4AccessKey = null.StringFrom(sigV4AccessKey)
}

if sigV4SecretKey, sigV4SecretKeyDefined := env["K6_PROMETHEUS_RW_SIGV4_SECRET_KEY"]; sigV4SecretKeyDefined {
c.SigV4SecretKey = null.StringFrom(sigV4SecretKey)
}

if b, err := envBool(env, "K6_PROMETHEUS_RW_TREND_AS_NATIVE_HISTOGRAM"); err != nil {
return c, err
} else if b.Valid {
Expand Down
23 changes: 23 additions & 0 deletions pkg/sigv4/const.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package sigv4

const (
// Amazon Managed Service for Prometheus
awsServiceName = "aps"

signingAlgorithm = "AWS4-HMAC-SHA256"

authorizationHeaderKey = "Authorization"
amzDateKey = "X-Amz-Date"

// emptyStringSHA256 is the hex encoded sha256 value of an empty string
emptyStringSHA256 = `e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855`

// timeFormat is the time format to be used in the X-Amz-Date header or query parameter
timeFormat = "20060102T150405Z"

// shortTimeFormat is the shorten time format used in the credential scope
shortTimeFormat = "20060102"

// contentSHAKey is the SHA256 of request body
contentSHAKey = "X-Amz-Content-Sha256"
)
244 changes: 244 additions & 0 deletions pkg/sigv4/sigv4.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
package sigv4

import (
"bytes"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"time"
)

type Signer interface {
olegbespalov marked this conversation as resolved.
Show resolved Hide resolved
Sign(req *http.Request) error
}

type DefaultSigner struct {
olegbespalov marked this conversation as resolved.
Show resolved Hide resolved
config *Config

// noEscape represents the characters that AWS doesn't escape
noEscape [256]bool
}

func NewDefaultSigner(config *Config) Signer {
ds := &DefaultSigner{
config: config,
noEscape: buildAwsNoEscape(),
}

return ds
}

func (d *DefaultSigner) Sign(req *http.Request) error {
now := time.Now().UTC()
iSO8601Date := now.Format(timeFormat)

credentialScope := buildCredentialScope(now, d.config.Region)

payloadHash, err := d.getPayloadHash(req)
if err != nil {
return err
}

req.Header.Set("Host", req.Host)
req.Header.Set(amzDateKey, iSO8601Date)
req.Header.Set(contentSHAKey, payloadHash)

signedHeadersStr, canonicalHeaderStr := buildCanonicalHeaders(req)

canonicalQueryString := getCanonicalQueryString(req.URL)
canonicalReq := buildCanonicalString(
req.Method,
getCanonicalURI(req.URL, d.noEscape),
canonicalQueryString,
canonicalHeaderStr,
signedHeadersStr,
payloadHash,
)

signature := sign(
deriveKey(d.config.AwsSecretAccessKey, d.config.Region),
buildStringToSign(iSO8601Date, credentialScope, canonicalReq),
)

authorizationHeader := fmt.Sprintf(
"%s Credential=%s/%s, SignedHeaders=%s, Signature=%s",
signingAlgorithm,
d.config.AwsAccessKeyID,
credentialScope,
signedHeadersStr,
signature,
)

req.URL.RawQuery = canonicalQueryString
req.Header.Set(authorizationHeaderKey, authorizationHeader)
return nil
}

func (d *DefaultSigner) getPayloadHash(req *http.Request) (string, error) {
if req.Body == nil {
return emptyStringSHA256, nil
}

reqBody, err := io.ReadAll(req.Body)
if err != nil {
return "", err
}
reqBodyBuffer := bytes.NewReader(reqBody)

hash := sha256.New()
if _, err := io.Copy(hash, reqBodyBuffer); err != nil {
return "", err
}

payloadHash := hex.EncodeToString(hash.Sum(nil))

// ensuring that we keep the request body intact for next tripper
req.Body = io.NopCloser(bytes.NewReader(reqBody))

return payloadHash, nil
}

func buildCredentialScope(signingTime time.Time, region string) string {
return fmt.Sprintf(
"%s/%s/%s/aws4_request",
signingTime.UTC().Format(shortTimeFormat),
region,
awsServiceName,
)
}

func buildCanonicalString(method, uri, query, canonicalHeaders, signedHeaders, payloadHash string) string {
return strings.Join([]string{
method,
uri,
query,
canonicalHeaders,
signedHeaders,
payloadHash,
}, "\n")
}

var ignoredHeaders = map[string]struct{}{
"Authorization": struct{}{},
"User-Agent": struct{}{},
"X-Amzn-Trace-Id": struct{}{},
"Expect": struct{}{},
}

// buildCanonicalHeaders is mostly ported from https://github.com/aws/aws-sdk-go-v2/aws/signer/v4 buildCanonicalHeaders
func buildCanonicalHeaders(req *http.Request) (signedHeaders, canonicalHeadersStr string) {
const hostHeader, contentLengthHeader = "host", "content-length"
host, header, length := req.Host, req.Header, req.ContentLength

signed := make(http.Header)
headers := append([]string{}, hostHeader)
signed[hostHeader] = append(signed[hostHeader], host)

if length > 0 {
headers = append(headers, contentLengthHeader)
signed[contentLengthHeader] = append(signed[contentLengthHeader], strconv.FormatInt(length, 10))
}

for k, v := range header {
if _, ok := ignoredHeaders[k]; ok {
continue
}

if strings.EqualFold(k, contentLengthHeader) {
// prevent signing already handled content-length header.
continue
}

lowerCaseKey := strings.ToLower(k)
if _, ok := signed[lowerCaseKey]; ok {
// include additional values
signed[lowerCaseKey] = append(signed[lowerCaseKey], v...)
continue
}

headers = append(headers, lowerCaseKey)
signed[lowerCaseKey] = v
}

// aws requires headers to keys to be sorted
sort.Strings(headers)
signedHeaders = strings.Join(headers, ";")

var canonicalHeaders strings.Builder
for _, h := range headers {
if h == hostHeader {
canonicalHeaders.WriteString(fmt.Sprintf("%s:%s\n", hostHeader, stripExcessSpaces(host)))
continue
}

canonicalHeaders.WriteString(fmt.Sprintf("%s:", h))
values := signed[h]
for j, v := range values {
cleanedValue := strings.TrimSpace(stripExcessSpaces(v))
canonicalHeaders.WriteString(cleanedValue)
if j < len(values)-1 {
canonicalHeaders.WriteRune(',')
}
}
canonicalHeaders.WriteRune('\n')
}
canonicalHeadersStr = canonicalHeaders.String()
return signedHeaders, canonicalHeadersStr
}

func getCanonicalURI(u *url.URL, noEscape [256]bool) string {
return escapePath(getURIPath(u), noEscape)
}

func getCanonicalQueryString(u *url.URL) string {
query := u.Query()

// Sort Each Query Key's Values
for key := range query {
sort.Strings(query[key])
}

var rawQuery strings.Builder
rawQuery.WriteString(strings.Replace(query.Encode(), "+", "%20", -1))
return rawQuery.String()
}

func buildStringToSign(amzDate, credentialScope, canonicalRequestString string) string {
hash := sha256.New()
hash.Write([]byte(canonicalRequestString))
return strings.Join([]string{
signingAlgorithm,
amzDate,
credentialScope,
hex.EncodeToString(hash.Sum(nil)),
}, "\n")
}

func deriveKey(secretKey, region string) string {
signingDate := time.Now().UTC().Format(shortTimeFormat)
hmacDate := hmacSHA256([]byte("AWS4"+secretKey), signingDate)
hmacRegion := hmacSHA256(hmacDate, region)
hmacService := hmacSHA256(hmacRegion, awsServiceName)
signingKey := hmacSHA256(hmacService, "aws4_request")
return string(signingKey)
}

func hmacSHA256(key []byte, data string) []byte {
h := hmac.New(sha256.New, key)
h.Write([]byte(data))
return h.Sum(nil)
}

func sign(signingKey string, strToSign string) string {
h := hmac.New(sha256.New, []byte(signingKey))
h.Write([]byte(strToSign))
sig := hex.EncodeToString(h.Sum(nil))
return sig
}
Loading