Skip to content

Commit

Permalink
- Ported code from aws-sdk-go for buildingCanonicalHeaders
Browse files Browse the repository at this point in the history
- Ported code from aws-sdk-go for proper query string handling
- Refactored code to reduce methods
- Added validations for tripper
  • Loading branch information
obanby committed Sep 26, 2024
1 parent 5f2f038 commit 4ac35b2
Show file tree
Hide file tree
Showing 6 changed files with 299 additions and 109 deletions.
6 changes: 5 additions & 1 deletion pkg/remote/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,11 @@ func NewWriteClient(endpoint string, cfg *HTTPConfig) (*WriteClient, error) {
}
}
if cfg.SigV4 != nil {
wc.hc.Transport = sigv4.NewRoundTripper(cfg.SigV4, wc.hc.Transport)
tripper, err := sigv4.NewRoundTripper(cfg.SigV4, wc.hc.Transport)
if err != nil {
return nil, err
}
wc.hc.Transport = tripper
}
return wc, nil
}
Expand Down
21 changes: 21 additions & 0 deletions pkg/sigv4/const.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package sigv4

const (
awsServiceName = "aps"
signingAlgorithm = "AWS4-HMAC-SHA256"

authorizationHeaderKey = "Authorization"
amzDateKey = "X-Amz-Date"

// emptyStringSHA256 is the hex encoded sha256 value of an empty string
emptyStringSHA256 = `e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855`

// timeFormat is the time format to be used in the X-Amz-Date header or query parameter
timeFormat = "20060102T150405Z"

// shortTimeFormat is the shorten time format used in the credential scope
shortTimeFormat = "20060102"

// contentSHAKey is the SHA256 of request body
contentSHAKey = "X-Amz-Content-Sha256"
)
248 changes: 151 additions & 97 deletions pkg/sigv4/sigv4.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,73 +10,84 @@ import (
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"time"
)

const signingAlgo = "AWS4-HMAC-SHA256"
const awsServiceName = "aps"

type Signer interface {
Sign(req *http.Request) error
}

type DefaultSigner struct {
iSO8601Date string
canonicalHeaders string
signedHeaders string
credentialScope string
config *Config
payloadHash string
config *Config
}

func NewDefaultSigner(config *Config) Signer {
return &DefaultSigner{
config: config,
// initialize noEscape array. This way we can avoid using init() functions
for i := 0; i < len(noEscape); i++ {
// AWS expects every character except these to be escaped
noEscape[i] = (i >= 'A' && i <= 'Z') ||
(i >= 'a' && i <= 'z') ||
(i >= '0' && i <= '9') ||
i == '-' ||
i == '.' ||
i == '_' ||
i == '~'
}

return &DefaultSigner{config: config}
}

func (d *DefaultSigner) Sign(req *http.Request) error {
now := time.Now().UTC()
d.iSO8601Date = now.Format("20060102T150405Z")
d.credentialScope = fmt.Sprintf(
"%s/%s/%s/aws4_request",
now.UTC().Format("20060102"),
d.config.Region,
awsServiceName,
)
iSO8601Date := now.Format(timeFormat)

credentialScope := buildCredentialScope(now, d.config.Region)

payloadHash, err := d.getPayloadHash(req)
if err != nil {
return err
}

d.payloadHash = payloadHash
d.addRequiredHeaders(req)
d.canonicalHeaders, d.signedHeaders = d.getCanonicalAndSignedHeaders(req)
req.Header.Set("Host", req.Host)
req.Header.Set(amzDateKey, iSO8601Date)
req.Header.Set(contentSHAKey, payloadHash)

canonicalReq := d.createCanonicalRequest(req)
stringToSign, err := d.createStringToSign(canonicalReq)
if err != nil {
return err
}
_, signedHeadersStr, canonicalHeaderStr := buildCanonicalHeaders(req)

canonicalQueryString := getCanonicalQueryString(req.URL)
canonicalReq := buildCanonicalString(
req.Method,
getCanonicalURI(req.URL),
canonicalQueryString,
canonicalHeaderStr,
signedHeadersStr,
payloadHash,
)

signature := sign(
deriveKey(d.config.AwsSecretAccessKey, d.config.Region),
buildStringToSign(iSO8601Date, credentialScope, canonicalReq),
)

signature := d.sign(d.createSigningKey(), stringToSign)
authorizationHeader := fmt.Sprintf(
"%s Credential=%s/%s, SignedHeaders=%s, Signature=%s",
signingAlgo,
signingAlgorithm,
d.config.AwsAccessKeyID,
d.credentialScope,
d.signedHeaders,
credentialScope,
signedHeadersStr,
signature,
)
req.Header.Set("Authorization", authorizationHeader)

req.URL.RawQuery = canonicalQueryString
req.Header.Set(authorizationHeaderKey, authorizationHeader)
return nil
}

func (d *DefaultSigner) getPayloadHash(req *http.Request) (string, error) {
if req.Body == nil {
return hex.EncodeToString(sha256.New().Sum(nil)), nil
return emptyStringSHA256, nil
}

reqBody, err := io.ReadAll(req.Body)
Expand All @@ -98,101 +109,144 @@ func (d *DefaultSigner) getPayloadHash(req *http.Request) (string, error) {
return payloadHash, nil
}

func (d *DefaultSigner) addRequiredHeaders(req *http.Request) {
req.Header.Set("Host", req.Host)
req.Header.Set("x-amz-date", d.iSO8601Date)
req.Header.Set("x-amz-content-sha256", d.payloadHash)
func buildCredentialScope(signingTime time.Time, region string) string {
return fmt.Sprintf(
"%s/%s/%s/aws4_request",
signingTime.UTC().Format(shortTimeFormat),
region,
awsServiceName,
)
}

func (d *DefaultSigner) getCanonicalAndSignedHeaders(req *http.Request) (string, string) {
var headers []string
var signedHeaders []string
func buildCanonicalString(method, uri, query, canonicalHeaders, signedHeaders, payloadHash string) string {
return strings.Join([]string{
method,
uri,
query,
canonicalHeaders,
signedHeaders,
payloadHash,
}, "\n")
}

for key, value := range req.Header {
lowercaseKey := strings.ToLower(key)
encodedValue := strings.TrimSpace(strings.Join(value, ","))
headers = append(headers, lowercaseKey+":"+encodedValue)
signedHeaders = append(signedHeaders, lowercaseKey)
}
var ignoredHeaders = map[string]struct{}{
"Authorization": struct{}{},
"User-Agent": struct{}{},
"X-Amzn-Trace-Id": struct{}{},
"Expect": struct{}{},
}

sort.Strings(headers)
sort.Strings(signedHeaders)
func buildCanonicalHeaders(req *http.Request) (signed http.Header, signedHeaders, canonicalHeadersStr string) {
host, header, length := req.Host, req.Header, req.ContentLength

canonicalHeaders := strings.Join(headers, "\n") + "\n"
canonicalSignedHeaders := strings.Join(signedHeaders, ";")
return canonicalHeaders, canonicalSignedHeaders
}
signed = make(http.Header)

func (d *DefaultSigner) createCanonicalRequest(req *http.Request) string {
return strings.Join([]string{
req.Method,
d.getCanonicalURI(req.URL),
d.getCanonicalQueryString(req.URL),
d.canonicalHeaders,
d.signedHeaders,
d.payloadHash,
}, "\n")
}
var headers []string
const hostHeader = "host"
headers = append(headers, hostHeader)
signed[hostHeader] = append(signed[hostHeader], host)

const contentLengthHeader = "content-length"
if length > 0 {
headers = append(headers, contentLengthHeader)
signed[contentLengthHeader] = append(signed[contentLengthHeader], strconv.FormatInt(length, 10))
}

for k, v := range header {
if _, ok := ignoredHeaders[k]; ok {
continue // ignored header
}
if strings.EqualFold(k, contentLengthHeader) {
// prevent signing already handled content-length header.
continue
}

func (d *DefaultSigner) getCanonicalURI(u *url.URL) string {
if u.Path == "" {
return "/"
lowerCaseKey := strings.ToLower(k)
if _, ok := signed[lowerCaseKey]; ok {
// include additional values
signed[lowerCaseKey] = append(signed[lowerCaseKey], v...)
continue
}

headers = append(headers, lowerCaseKey)
signed[lowerCaseKey] = v
}
sort.Strings(headers)

// The spec requires not to encode `/`
segments := strings.Split(u.Path, "/")
for i, segment := range segments {
segments[i] = url.PathEscape(segment)
signedHeaders = strings.Join(headers, ";")

var canonicalHeaders strings.Builder
n := len(headers)
const colon = ':'
for i := 0; i < n; i++ {
if headers[i] == hostHeader {
canonicalHeaders.WriteString(hostHeader)
canonicalHeaders.WriteRune(colon)
canonicalHeaders.WriteString(stripExcessSpaces(host))
} else {
canonicalHeaders.WriteString(headers[i])
canonicalHeaders.WriteRune(colon)
// Trim out leading, trailing, and dedup inner spaces from signed header values.
values := signed[headers[i]]
for j, v := range values {
cleanedValue := strings.TrimSpace(stripExcessSpaces(v))
canonicalHeaders.WriteString(cleanedValue)
if j < len(values)-1 {
canonicalHeaders.WriteRune(',')
}
}
}
canonicalHeaders.WriteRune('\n')
}
canonicalHeadersStr = canonicalHeaders.String()

return strings.Join(segments, "/")
return signed, signedHeaders, canonicalHeadersStr
}

func (d *DefaultSigner) getCanonicalQueryString(u *url.URL) string {
queryParams := u.Query()
var queryPairs []string
func getCanonicalURI(u *url.URL) string {
return escapePath(getURIPath(u), false)
}

for key, values := range queryParams {
for _, value := range values {
queryPairs = append(queryPairs, url.QueryEscape(key)+"="+url.QueryEscape(value))
}
}
func getCanonicalQueryString(u *url.URL) string {
query := u.Query()

sort.Strings(queryPairs)
// Sort Each Query Key's Values
for key := range query {
sort.Strings(query[key])
}

return strings.Join(queryPairs, "&")
var rawQuery strings.Builder
rawQuery.WriteString(strings.Replace(query.Encode(), "+", "%20", -1))
return rawQuery.String()
}

func (d *DefaultSigner) createStringToSign(canonicalRequest string) (string, error) {
func buildStringToSign(amzDate, credentialScope, canonicalRequestString string) string {
hash := sha256.New()
if _, err := hash.Write([]byte(canonicalRequest)); err != nil {
return "", err
}
return fmt.Sprintf(
"%s\n%s\n%s\n%s",
signingAlgo,
d.iSO8601Date,
d.credentialScope,
hash.Write([]byte(canonicalRequestString))
return strings.Join([]string{
signingAlgorithm,
amzDate,
credentialScope,
hex.EncodeToString(hash.Sum(nil)),
), nil
}, "\n")
}

func (d *DefaultSigner) createSigningKey() string {
signingDate := time.Now().UTC().Format("20060102")
dateKey := d.hmacSHA256([]byte("AWS4"+d.config.AwsSecretAccessKey), signingDate)
dateRegionKey := d.hmacSHA256(dateKey, d.config.Region)
dateRegionServiceKey := d.hmacSHA256(dateRegionKey, awsServiceName)
signingKey := d.hmacSHA256(dateRegionServiceKey, "aws4_request")
func deriveKey(secretKey, region string) string {
signingDate := time.Now().UTC().Format(shortTimeFormat)
hmacDate := hmacSHA256([]byte("AWS4"+secretKey), signingDate)
hmacRegion := hmacSHA256(hmacDate, region)
hmacService := hmacSHA256(hmacRegion, awsServiceName)
signingKey := hmacSHA256(hmacService, "aws4_request")
return string(signingKey)
}

func (d *DefaultSigner) hmacSHA256(key []byte, data string) []byte {
func hmacSHA256(key []byte, data string) []byte {
h := hmac.New(sha256.New, key)
h.Write([]byte(data))
return h.Sum(nil)
}

func (d *DefaultSigner) sign(signingKey string, strToSign string) string {
func sign(signingKey string, strToSign string) string {
h := hmac.New(sha256.New, []byte(signingKey))
h.Write([]byte(strToSign))
sig := hex.EncodeToString(h.Sum(nil))
Expand Down
11 changes: 9 additions & 2 deletions pkg/sigv4/tripper.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sigv4

import (
"errors"
"net/http"
)

Expand All @@ -16,15 +17,21 @@ type Config struct {
AwsAccessKeyID string
}

func NewRoundTripper(config *Config, next http.RoundTripper) *Tripper {
func NewRoundTripper(config *Config, next http.RoundTripper) (*Tripper, error) {
if config == nil {
return nil, errors.New("can't initialize a sigv4 round tripper with nil config")
}

if next == nil {
next = http.DefaultTransport
}
return &Tripper{

tripper := &Tripper{
config: config,
next: next,
signer: NewDefaultSigner(config),
}
return tripper, nil
}

func (c *Tripper) RoundTrip(req *http.Request) (*http.Response, error) {
Expand Down
Loading

0 comments on commit 4ac35b2

Please sign in to comment.