Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support sigv4 signing #169

Merged
merged 14 commits into from
Oct 28, 2024
9 changes: 9 additions & 0 deletions pkg/remote/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"context"
"crypto/tls"
"fmt"
"github.com/grafana/xk6-output-prometheus-remote/pkg/sigv4"
"io"
"math"
"net/http"
Expand All @@ -22,6 +23,7 @@ type HTTPConfig struct {
Timeout time.Duration
TLSConfig *tls.Config
BasicAuth *BasicAuth
SigV4 *sigv4.Config
Headers http.Header
}

Expand Down Expand Up @@ -60,6 +62,13 @@ func NewWriteClient(endpoint string, cfg *HTTPConfig) (*WriteClient, error) {
TLSClientConfig: cfg.TLSConfig,
}
}
if cfg.SigV4 != nil {
tripper, err := sigv4.NewRoundTripper(cfg.SigV4, wc.hc.Transport)
if err != nil {
return nil, err
}
wc.hc.Transport = tripper
}
return wc, nil
}

Expand Down
45 changes: 45 additions & 0 deletions pkg/remotewrite/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"crypto/tls"
"encoding/json"
"fmt"
"github.com/grafana/xk6-output-prometheus-remote/pkg/sigv4"
"net/http"
"strconv"
"strings"
Expand Down Expand Up @@ -68,6 +69,15 @@ type Config struct {
TrendStats []string `json:"trendStats"`

StaleMarkers null.Bool `json:"staleMarkers"`

// Sigv4Region is the AWS region where the workspace is.
Sigv4Region null.String `json:"sigV4Region"`

// Sigv4AccessKey is the AWS access key.
Sigv4AccessKey null.String `json:"sigV4AccessKey"`

// Sigv4SecretKey is the AWS secret key.
Sigv4SecretKey null.String `json:"sigV4SecretKey"`
}

// NewConfig creates an Output's configuration.
Expand All @@ -81,6 +91,9 @@ func NewConfig() Config {
Headers: make(map[string]string),
TrendStats: defaultTrendStats,
StaleMarkers: null.BoolFrom(false),
Sigv4Region: null.NewString("", false),
olegbespalov marked this conversation as resolved.
Show resolved Hide resolved
Sigv4AccessKey: null.NewString("", false),
Sigv4SecretKey: null.NewString("", false),
}
}

Expand Down Expand Up @@ -110,6 +123,14 @@ func (conf Config) RemoteConfig() (*remote.HTTPConfig, error) {
hc.TLSConfig.Certificates = []tls.Certificate{cert}
}

if conf.Sigv4Region.Valid && conf.Sigv4AccessKey.Valid && conf.Sigv4SecretKey.Valid {
hc.SigV4 = &sigv4.Config{
Region: conf.Sigv4Region.String,
AwsAccessKeyID: conf.Sigv4AccessKey.String,
AwsSecretAccessKey: conf.Sigv4SecretKey.String,
}
}

if len(conf.Headers) > 0 {
hc.Headers = make(http.Header)
for k, v := range conf.Headers {
Expand Down Expand Up @@ -149,6 +170,18 @@ func (conf Config) Apply(applied Config) Config {
conf.BearerToken = applied.BearerToken
}

if applied.Sigv4Region.Valid {
conf.Sigv4Region = applied.Sigv4Region
}

if applied.Sigv4AccessKey.Valid {
conf.Sigv4AccessKey = applied.Sigv4AccessKey
}

if applied.Sigv4SecretKey.Valid {
conf.Sigv4SecretKey = applied.Sigv4SecretKey
}

if applied.PushInterval.Valid {
conf.PushInterval = applied.PushInterval
}
Expand Down Expand Up @@ -299,6 +332,18 @@ func parseEnvs(env map[string]string) (Config, error) {
}
}

if sigv4Region, sigv4RegionDefined := env["K6_PROMETHEUS_RW_SIGV4_REGION"]; sigv4RegionDefined {
c.Sigv4Region = null.StringFrom(sigv4Region)
}

if sigv4AccessKey, sigv4AccessKeyDefined := env["K6_PROMETHEUS_RW_SIGV4_ACCESS_KEY"]; sigv4AccessKeyDefined {
c.Sigv4AccessKey = null.StringFrom(sigv4AccessKey)
}

if sigv4SecretKey, sigv4SecretKeyDefined := env["K6_PROMETHEUS_RW_SIGV4_SECRET_KEY"]; sigv4SecretKeyDefined {
c.Sigv4SecretKey = null.StringFrom(sigv4SecretKey)
}

if b, err := envBool(env, "K6_PROMETHEUS_RW_TREND_AS_NATIVE_HISTOGRAM"); err != nil {
return c, err
} else if b.Valid {
Expand Down
21 changes: 21 additions & 0 deletions pkg/sigv4/const.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package sigv4

const (
awsServiceName = "aps"
olegbespalov marked this conversation as resolved.
Show resolved Hide resolved
signingAlgorithm = "AWS4-HMAC-SHA256"

authorizationHeaderKey = "Authorization"
amzDateKey = "X-Amz-Date"

// emptyStringSHA256 is the hex encoded sha256 value of an empty string
emptyStringSHA256 = `e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855`

// timeFormat is the time format to be used in the X-Amz-Date header or query parameter
timeFormat = "20060102T150405Z"

// shortTimeFormat is the shorten time format used in the credential scope
shortTimeFormat = "20060102"

// contentSHAKey is the SHA256 of request body
contentSHAKey = "X-Amz-Content-Sha256"
)
254 changes: 254 additions & 0 deletions pkg/sigv4/sigv4.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
package sigv4

import (
"bytes"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"time"
)

type Signer interface {
olegbespalov marked this conversation as resolved.
Show resolved Hide resolved
Sign(req *http.Request) error
}

type DefaultSigner struct {
olegbespalov marked this conversation as resolved.
Show resolved Hide resolved
config *Config
}

func NewDefaultSigner(config *Config) Signer {
// initialize noEscape array. This way we can avoid using init() functions
for i := 0; i < len(noEscape); i++ {
// AWS expects every character except these to be escaped
noEscape[i] = (i >= 'A' && i <= 'Z') ||
olegbespalov marked this conversation as resolved.
Show resolved Hide resolved
(i >= 'a' && i <= 'z') ||
(i >= '0' && i <= '9') ||
i == '-' ||
i == '.' ||
i == '_' ||
i == '~'
}

return &DefaultSigner{config: config}
}

func (d *DefaultSigner) Sign(req *http.Request) error {
now := time.Now().UTC()
iSO8601Date := now.Format(timeFormat)

credentialScope := buildCredentialScope(now, d.config.Region)

payloadHash, err := d.getPayloadHash(req)
if err != nil {
return err
}

req.Header.Set("Host", req.Host)
req.Header.Set(amzDateKey, iSO8601Date)
req.Header.Set(contentSHAKey, payloadHash)

_, signedHeadersStr, canonicalHeaderStr := buildCanonicalHeaders(req)

canonicalQueryString := getCanonicalQueryString(req.URL)
canonicalReq := buildCanonicalString(
req.Method,
getCanonicalURI(req.URL),
canonicalQueryString,
canonicalHeaderStr,
signedHeadersStr,
payloadHash,
)

signature := sign(
deriveKey(d.config.AwsSecretAccessKey, d.config.Region),
buildStringToSign(iSO8601Date, credentialScope, canonicalReq),
)

authorizationHeader := fmt.Sprintf(
"%s Credential=%s/%s, SignedHeaders=%s, Signature=%s",
signingAlgorithm,
d.config.AwsAccessKeyID,
credentialScope,
signedHeadersStr,
signature,
)

req.URL.RawQuery = canonicalQueryString
req.Header.Set(authorizationHeaderKey, authorizationHeader)
return nil
}

func (d *DefaultSigner) getPayloadHash(req *http.Request) (string, error) {
if req.Body == nil {
return emptyStringSHA256, nil
}

reqBody, err := io.ReadAll(req.Body)
if err != nil {
return "", err
}
reqBodyBuffer := bytes.NewReader(reqBody)

hash := sha256.New()
if _, err := io.Copy(hash, reqBodyBuffer); err != nil {
return "", err
}

payloadHash := hex.EncodeToString(hash.Sum(nil))

// ensuring that we keep the request body intact for next tripper
req.Body = io.NopCloser(bytes.NewReader(reqBody))

return payloadHash, nil
}

func buildCredentialScope(signingTime time.Time, region string) string {
return fmt.Sprintf(
"%s/%s/%s/aws4_request",
signingTime.UTC().Format(shortTimeFormat),
region,
awsServiceName,
)
}

func buildCanonicalString(method, uri, query, canonicalHeaders, signedHeaders, payloadHash string) string {
return strings.Join([]string{
method,
uri,
query,
canonicalHeaders,
signedHeaders,
payloadHash,
}, "\n")
}

var ignoredHeaders = map[string]struct{}{
"Authorization": struct{}{},
"User-Agent": struct{}{},
"X-Amzn-Trace-Id": struct{}{},
"Expect": struct{}{},
}

func buildCanonicalHeaders(req *http.Request) (signed http.Header, signedHeaders, canonicalHeadersStr string) {
host, header, length := req.Host, req.Header, req.ContentLength

signed = make(http.Header)

var headers []string
const hostHeader = "host"
headers = append(headers, hostHeader)
signed[hostHeader] = append(signed[hostHeader], host)

const contentLengthHeader = "content-length"
if length > 0 {
headers = append(headers, contentLengthHeader)
signed[contentLengthHeader] = append(signed[contentLengthHeader], strconv.FormatInt(length, 10))
}

for k, v := range header {
if _, ok := ignoredHeaders[k]; ok {
continue // ignored header
}
if strings.EqualFold(k, contentLengthHeader) {
// prevent signing already handled content-length header.
continue
}

lowerCaseKey := strings.ToLower(k)
if _, ok := signed[lowerCaseKey]; ok {
// include additional values
signed[lowerCaseKey] = append(signed[lowerCaseKey], v...)
continue
}

headers = append(headers, lowerCaseKey)
signed[lowerCaseKey] = v
}
sort.Strings(headers)

signedHeaders = strings.Join(headers, ";")

var canonicalHeaders strings.Builder
n := len(headers)
const colon = ':'
for i := 0; i < n; i++ {
if headers[i] == hostHeader {
canonicalHeaders.WriteString(hostHeader)
canonicalHeaders.WriteRune(colon)
canonicalHeaders.WriteString(stripExcessSpaces(host))
} else {
canonicalHeaders.WriteString(headers[i])
canonicalHeaders.WriteRune(colon)
// Trim out leading, trailing, and dedup inner spaces from signed header values.
values := signed[headers[i]]
for j, v := range values {
cleanedValue := strings.TrimSpace(stripExcessSpaces(v))
canonicalHeaders.WriteString(cleanedValue)
if j < len(values)-1 {
canonicalHeaders.WriteRune(',')
}
}
}
canonicalHeaders.WriteRune('\n')
}
canonicalHeadersStr = canonicalHeaders.String()

return signed, signedHeaders, canonicalHeadersStr
}

func getCanonicalURI(u *url.URL) string {
return escapePath(getURIPath(u), false)
}

func getCanonicalQueryString(u *url.URL) string {
query := u.Query()

// Sort Each Query Key's Values
for key := range query {
sort.Strings(query[key])
}

var rawQuery strings.Builder
rawQuery.WriteString(strings.Replace(query.Encode(), "+", "%20", -1))
return rawQuery.String()
}

func buildStringToSign(amzDate, credentialScope, canonicalRequestString string) string {
hash := sha256.New()
hash.Write([]byte(canonicalRequestString))
return strings.Join([]string{
signingAlgorithm,
amzDate,
credentialScope,
hex.EncodeToString(hash.Sum(nil)),
}, "\n")
}

func deriveKey(secretKey, region string) string {
signingDate := time.Now().UTC().Format(shortTimeFormat)
hmacDate := hmacSHA256([]byte("AWS4"+secretKey), signingDate)
hmacRegion := hmacSHA256(hmacDate, region)
hmacService := hmacSHA256(hmacRegion, awsServiceName)
signingKey := hmacSHA256(hmacService, "aws4_request")
return string(signingKey)
}

func hmacSHA256(key []byte, data string) []byte {
h := hmac.New(sha256.New, key)
h.Write([]byte(data))
return h.Sum(nil)
}

func sign(signingKey string, strToSign string) string {
h := hmac.New(sha256.New, []byte(signingKey))
h.Write([]byte(strToSign))
sig := hex.EncodeToString(h.Sum(nil))
return sig
}
Loading