Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support sigv4 signing #169

Merged
merged 14 commits into from
Oct 28, 2024
5 changes: 5 additions & 0 deletions pkg/remote/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"context"
"crypto/tls"
"fmt"
"github.com/grafana/xk6-output-prometheus-remote/pkg/sigv4"
"io"
"math"
"net/http"
Expand All @@ -22,6 +23,7 @@ type HTTPConfig struct {
Timeout time.Duration
TLSConfig *tls.Config
BasicAuth *BasicAuth
SigV4 *sigv4.Config
Headers http.Header
}

Expand Down Expand Up @@ -60,6 +62,9 @@ func NewWriteClient(endpoint string, cfg *HTTPConfig) (*WriteClient, error) {
TLSClientConfig: cfg.TLSConfig,
}
}
if cfg.SigV4 != nil {
wc.hc.Transport = sigv4.NewRoundTripper(cfg.SigV4, wc.hc.Transport)
}
return wc, nil
}

Expand Down
45 changes: 45 additions & 0 deletions pkg/remotewrite/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"crypto/tls"
"encoding/json"
"fmt"
"github.com/grafana/xk6-output-prometheus-remote/pkg/sigv4"
"net/http"
"strconv"
"strings"
Expand Down Expand Up @@ -68,6 +69,15 @@ type Config struct {
TrendStats []string `json:"trendStats"`

StaleMarkers null.Bool `json:"staleMarkers"`

// Sigv4Region is the AWS region where the workspace is.
Sigv4Region null.String `json:"sigV4Region"`

// Sigv4AccessKey is the AWS access key.
Sigv4AccessKey null.String `json:"sigV4AccessKey"`

// Sigv4SecretKey is the AWS secret key.
Sigv4SecretKey null.String `json:"sigV4SecretKey"`
}

// NewConfig creates an Output's configuration.
Expand All @@ -81,6 +91,9 @@ func NewConfig() Config {
Headers: make(map[string]string),
TrendStats: defaultTrendStats,
StaleMarkers: null.BoolFrom(false),
Sigv4Region: null.NewString("", false),
olegbespalov marked this conversation as resolved.
Show resolved Hide resolved
Sigv4AccessKey: null.NewString("", false),
Sigv4SecretKey: null.NewString("", false),
}
}

Expand Down Expand Up @@ -110,6 +123,14 @@ func (conf Config) RemoteConfig() (*remote.HTTPConfig, error) {
hc.TLSConfig.Certificates = []tls.Certificate{cert}
}

if conf.Sigv4Region.Valid && conf.Sigv4AccessKey.Valid && conf.Sigv4SecretKey.Valid {
hc.SigV4 = &sigv4.Config{
Region: conf.Sigv4Region.String,
AwsAccessKeyID: conf.Sigv4AccessKey.String,
AwsSecretAccessKey: conf.Sigv4SecretKey.String,
}
}

if len(conf.Headers) > 0 {
hc.Headers = make(http.Header)
for k, v := range conf.Headers {
Expand Down Expand Up @@ -149,6 +170,18 @@ func (conf Config) Apply(applied Config) Config {
conf.BearerToken = applied.BearerToken
}

if applied.Sigv4Region.Valid {
conf.Sigv4Region = applied.Sigv4Region
}

if applied.Sigv4AccessKey.Valid {
conf.Sigv4AccessKey = applied.Sigv4AccessKey
}

if applied.Sigv4SecretKey.Valid {
conf.Sigv4SecretKey = applied.Sigv4SecretKey
}

if applied.PushInterval.Valid {
conf.PushInterval = applied.PushInterval
}
Expand Down Expand Up @@ -299,6 +332,18 @@ func parseEnvs(env map[string]string) (Config, error) {
}
}

if sigv4Region, sigv4RegionDefined := env["K6_PROMETHEUS_RW_SIGV4_REGION"]; sigv4RegionDefined {
c.Sigv4Region = null.StringFrom(sigv4Region)
}

if sigv4AccessKey, sigv4AccessKeyDefined := env["K6_PROMETHEUS_RW_SIGV4_ACCESS_KEY"]; sigv4AccessKeyDefined {
c.Sigv4AccessKey = null.StringFrom(sigv4AccessKey)
}

if sigv4SecretKey, sigv4SecretKeyDefined := env["K6_PROMETHEUS_RW_SIGV4_SECRET_KEY"]; sigv4SecretKeyDefined {
c.Sigv4SecretKey = null.StringFrom(sigv4SecretKey)
}

if b, err := envBool(env, "K6_PROMETHEUS_RW_TREND_AS_NATIVE_HISTOGRAM"); err != nil {
return c, err
} else if b.Valid {
Expand Down
200 changes: 200 additions & 0 deletions pkg/sigv4/sigv4.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
package sigv4

import (
"bytes"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"strings"
"time"
)

const signingAlgo = "AWS4-HMAC-SHA256"
const awsServiceName = "aps"

type Signer interface {
olegbespalov marked this conversation as resolved.
Show resolved Hide resolved
Sign(req *http.Request) error
}

type DefaultSigner struct {
olegbespalov marked this conversation as resolved.
Show resolved Hide resolved
iSO8601Date string
canonicalHeaders string
signedHeaders string
credentialScope string
config *Config
payloadHash string
}

func NewDefaultSigner(config *Config) Signer {
return &DefaultSigner{
config: config,
}
}

func (d *DefaultSigner) Sign(req *http.Request) error {
now := time.Now().UTC()
d.iSO8601Date = now.Format("20060102T150405Z")
d.credentialScope = fmt.Sprintf(
"%s/%s/%s/aws4_request",
now.UTC().Format("20060102"),
d.config.Region,
awsServiceName,
)

payloadHash, err := d.getPayloadHash(req)
if err != nil {
return err
}

d.payloadHash = payloadHash
d.addRequiredHeaders(req)
d.canonicalHeaders, d.signedHeaders = d.getCanonicalAndSignedHeaders(req)
olegbespalov marked this conversation as resolved.
Show resolved Hide resolved

canonicalReq := d.createCanonicalRequest(req)
stringToSign, err := d.createStringToSign(canonicalReq)
if err != nil {
return err
}

signature := d.sign(d.createSigningKey(), stringToSign)
authorizationHeader := fmt.Sprintf(
"%s Credential=%s/%s, SignedHeaders=%s, Signature=%s",
signingAlgo,
d.config.AwsAccessKeyID,
d.credentialScope,
d.signedHeaders,
signature,
)
req.Header.Set("Authorization", authorizationHeader)
return nil
}

func (d *DefaultSigner) getPayloadHash(req *http.Request) (string, error) {
if req.Body == nil {
return hex.EncodeToString(sha256.New().Sum(nil)), nil
olegbespalov marked this conversation as resolved.
Show resolved Hide resolved
}

reqBody, err := io.ReadAll(req.Body)
if err != nil {
return "", err
}
reqBodyBuffer := bytes.NewReader(reqBody)

hash := sha256.New()
if _, err := io.Copy(hash, reqBodyBuffer); err != nil {
return "", err
}

payloadHash := hex.EncodeToString(hash.Sum(nil))

// ensuring that we keep the request body intact for next tripper
req.Body = io.NopCloser(bytes.NewReader(reqBody))

return payloadHash, nil
}

func (d *DefaultSigner) addRequiredHeaders(req *http.Request) {
req.Header.Set("Host", req.Host)
req.Header.Set("x-amz-date", d.iSO8601Date)
req.Header.Set("x-amz-content-sha256", d.payloadHash)
}

func (d *DefaultSigner) getCanonicalAndSignedHeaders(req *http.Request) (string, string) {
var headers []string
var signedHeaders []string

for key, value := range req.Header {
olegbespalov marked this conversation as resolved.
Show resolved Hide resolved
lowercaseKey := strings.ToLower(key)
encodedValue := strings.TrimSpace(strings.Join(value, ","))
headers = append(headers, lowercaseKey+":"+encodedValue)
signedHeaders = append(signedHeaders, lowercaseKey)
}

sort.Strings(headers)
sort.Strings(signedHeaders)

canonicalHeaders := strings.Join(headers, "\n") + "\n"
canonicalSignedHeaders := strings.Join(signedHeaders, ";")
return canonicalHeaders, canonicalSignedHeaders
}

func (d *DefaultSigner) createCanonicalRequest(req *http.Request) string {
return strings.Join([]string{
req.Method,
d.getCanonicalURI(req.URL),
d.getCanonicalQueryString(req.URL),
d.canonicalHeaders,
d.signedHeaders,
d.payloadHash,
}, "\n")
}

func (d *DefaultSigner) getCanonicalURI(u *url.URL) string {
if u.Path == "" {
return "/"
}

// The spec requires not to encode `/`
segments := strings.Split(u.Path, "/")
for i, segment := range segments {
segments[i] = url.PathEscape(segment)
}

return strings.Join(segments, "/")
}

func (d *DefaultSigner) getCanonicalQueryString(u *url.URL) string {
queryParams := u.Query()
var queryPairs []string

for key, values := range queryParams {
for _, value := range values {
queryPairs = append(queryPairs, url.QueryEscape(key)+"="+url.QueryEscape(value))
}
}

sort.Strings(queryPairs)

return strings.Join(queryPairs, "&")
}

func (d *DefaultSigner) createStringToSign(canonicalRequest string) (string, error) {
hash := sha256.New()
if _, err := hash.Write([]byte(canonicalRequest)); err != nil {
return "", err
}
return fmt.Sprintf(
"%s\n%s\n%s\n%s",
signingAlgo,
d.iSO8601Date,
d.credentialScope,
hex.EncodeToString(hash.Sum(nil)),
), nil
}

func (d *DefaultSigner) createSigningKey() string {
signingDate := time.Now().UTC().Format("20060102")
dateKey := d.hmacSHA256([]byte("AWS4"+d.config.AwsSecretAccessKey), signingDate)
dateRegionKey := d.hmacSHA256(dateKey, d.config.Region)
dateRegionServiceKey := d.hmacSHA256(dateRegionKey, awsServiceName)
signingKey := d.hmacSHA256(dateRegionServiceKey, "aws4_request")
return string(signingKey)
}

func (d *DefaultSigner) hmacSHA256(key []byte, data string) []byte {
h := hmac.New(sha256.New, key)
h.Write([]byte(data))
return h.Sum(nil)
}

func (d *DefaultSigner) sign(signingKey string, strToSign string) string {
h := hmac.New(sha256.New, []byte(signingKey))
h.Write([]byte(strToSign))
sig := hex.EncodeToString(h.Sum(nil))
return sig
}
35 changes: 35 additions & 0 deletions pkg/sigv4/tripper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package sigv4

import (
"net/http"
)

type Tripper struct {
config *Config
signer Signer
next http.RoundTripper
}

type Config struct {
Region string
AwsSecretAccessKey string
AwsAccessKeyID string
}

func NewRoundTripper(config *Config, next http.RoundTripper) *Tripper {
olegbespalov marked this conversation as resolved.
Show resolved Hide resolved
if next == nil {
next = http.DefaultTransport
}
return &Tripper{
config: config,
next: next,
signer: NewDefaultSigner(config),
}
}

func (c *Tripper) RoundTrip(req *http.Request) (*http.Response, error) {
if err := c.signer.Sign(req); err != nil {
return nil, err
}
return c.next.RoundTrip(req)
}
35 changes: 35 additions & 0 deletions pkg/sigv4/tripper_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package sigv4_test

import (
"github.com/grafana/xk6-output-prometheus-remote/pkg/sigv4"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"testing"
)

func TestTripper_request_includes_required_headers(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
olegbespalov marked this conversation as resolved.
Show resolved Hide resolved
// Check if required headers are present
authorization := r.Header.Get("Authorization")
amzDate := r.Header.Get("x-amz-date")
contentSHA256 := r.Header.Get("x-amz-content-sha256")

// Respond to the request
w.WriteHeader(http.StatusOK)

assert.NotEmpty(t, authorization, "Authorization header should be present")
assert.NotEmpty(t, amzDate, "x-amz-date header should be present")
assert.NotEmpty(t, contentSHA256, "x-amz-content-sha256 header should be present")
}))
defer server.Close()
client := http.Client{}
client.Transport = sigv4.NewRoundTripper(&sigv4.Config{}, http.DefaultTransport)

req, err := http.NewRequest("POST", server.URL, nil)
if err != nil {
t.Fatal(err)
}

client.Do(req)
}