Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
rhnvrm committed Jul 9, 2024
2 parents 046c5f5 + a9c81a0 commit 4257c79
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 38 deletions.
103 changes: 82 additions & 21 deletions simples3.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"encoding/hex"
"encoding/json"
"encoding/xml"
"errors"
"fmt"
"io"
"io/ioutil"
Expand All @@ -23,7 +22,12 @@ import (
)

const (
securityCredentialsURL = "http://169.254.169.254/latest/meta-data/iam/security-credentials/"
imdsTokenHeader = "X-aws-ec2-metadata-token"
imdsTokenTtlHeader = "X-aws-ec2-metadata-token-ttl-seconds"
metadataBaseURL = "http://169.254.169.254/latest"
securityCredentialsURI = "/meta-data/iam/security-credentials/"
imdsTokenURI = "/api/token"
defaultIMDSTokenTTL = "60"

// AMZMetaPrefix to prefix metadata key.
AMZMetaPrefix = "x-amz-meta-"
Expand Down Expand Up @@ -89,12 +93,13 @@ type UploadInput struct {
// UploadResponse receives the following XML
// in case of success, since we set a 201 response from S3.
// Sample response:
// <PostResponse>
// <Location>https://s3.amazonaws.com/link-to-the-file</Location>
// <Bucket>s3-bucket</Bucket>
// <Key>development/8614bd40-691b-4668-9241-3b342c6cf429/image.jpg</Key>
// <ETag>"32-bit-tag"</ETag>
// </PostResponse>
//
// <PostResponse>
// <Location>https://s3.amazonaws.com/link-to-the-file</Location>
// <Bucket>s3-bucket</Bucket>
// <Key>development/8614bd40-691b-4668-9241-3b342c6cf429/image.jpg</Key>
// <ETag>"32-bit-tag"</ETag>
// </PostResponse>
type UploadResponse struct {
Location string `xml:"Location"`
Bucket string `xml:"Bucket"`
Expand Down Expand Up @@ -146,56 +151,112 @@ func NewUsingIAM(region string) (*S3, error) {
&http.Client{
// Set a timeout of 3 seconds for AWS IAM Calls.
Timeout: time.Second * 3, //nolint:gomnd
}, securityCredentialsURL, region)
}, metadataBaseURL, region)
}

// fetchIMDSToken retrieves an IMDSv2 token from the
// EC2 instance metadata service. It returns a token and boolean,
// only if IMDSv2 is enabled in the EC2 instance metadata
// configuration, otherwise returns an error.
func fetchIMDSToken(cl *http.Client, baseURL string) (string, bool, error) {
req, err := http.NewRequest(http.MethodPut, baseURL+imdsTokenURI, nil)
if err != nil {
return "", false, err
}

// Set the token TTL to 60 seconds.
req.Header.Set(imdsTokenTtlHeader, defaultIMDSTokenTTL)

resp, err := cl.Do(req)
if err != nil {
return "", false, err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return "", false, fmt.Errorf("failed to request IMDSv2 token: %s", resp.Status)
}

token, err := ioutil.ReadAll(resp.Body)
if err != nil {
return "", false, err
}

return string(token), true, nil
}

// fetchIAMData fetches the IAM data from the given URL.
// In case of a normal AWS setup, baseURL would be securityCredentialsURL.
// In case of a normal AWS setup, baseURL would be metadataBaseURL.
// You can use this method, to manually fetch IAM data from a custom
// endpoint and pass it to SetIAMData.
func fetchIAMData(cl *http.Client, baseURL string) (IAMResponse, error) {
resp, err := cl.Get(baseURL)
token, useIMDSv2, err := fetchIMDSToken(cl, baseURL)
if err != nil {
return IAMResponse{}, fmt.Errorf("error fetching IMDSv2 token: %w", err)
}

url := baseURL + securityCredentialsURI

req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return IAMResponse{}, fmt.Errorf("error creating imdsv2 token request: %w", err)
}

if useIMDSv2 {
req.Header.Set(imdsTokenHeader, token)
}

resp, err := cl.Do(req)
if err != nil {
return IAMResponse{}, err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return IAMResponse{}, fmt.Errorf("Error fetching IAM data: %s", resp.Status)
return IAMResponse{}, fmt.Errorf("error fetching IAM data: %s", resp.Status)
}

role, err := ioutil.ReadAll(resp.Body)
if err != nil {
return IAMResponse{}, err
}

resp, err = http.Get(baseURL + "/" + string(role))
req, err = http.NewRequest(http.MethodGet, url+string(role), nil)
if err != nil {
return IAMResponse{}, err
return IAMResponse{}, fmt.Errorf("error creating role request: %w", err)
}
if useIMDSv2 {
req.Header.Set(imdsTokenHeader, token)
}

resp, err = cl.Do(req)
if err != nil {
return IAMResponse{}, fmt.Errorf("error fetching role data: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return IAMResponse{}, errors.New(http.StatusText(resp.StatusCode))
return IAMResponse{}, fmt.Errorf("error fetching role data, got non 200 code: %s", resp.Status)
}

var jResp IAMResponse
jsonString, err := ioutil.ReadAll(resp.Body)
if err != nil {
return IAMResponse{}, err
return IAMResponse{}, fmt.Errorf("error reading role data: %w", err)
}

if err := json.Unmarshal(jsonString, &jResp); err != nil {
return IAMResponse{}, err
return IAMResponse{}, fmt.Errorf("error unmarshalling role data: %w (%s)", err, jsonString)
}

return jResp, nil
}

func newUsingIAM(cl *http.Client, credUrl, region string) (*S3, error) {
func newUsingIAM(cl *http.Client, baseURL, region string) (*S3, error) {
// Get the IAM role
iamResp, err := fetchIAMData(cl, credUrl)
iamResp, err := fetchIAMData(cl, baseURL)
if err != nil {
return nil, fmt.Errorf("Error fetching IAM data: %w", err)
return nil, fmt.Errorf("error fetching IAM data: %w", err)
}

return &S3{
Expand All @@ -209,7 +270,7 @@ func newUsingIAM(cl *http.Client, credUrl, region string) (*S3, error) {
}

// setIAMData sets the IAM data on the S3 instance.
func (s3 *S3) setIAMData(iamResp IAMResponse) {
func (s3 *S3) SetIAMData(iamResp IAMResponse) {
s3.AccessKey = iamResp.AccessKeyID
s3.SecretKey = iamResp.SecretAccessKey
s3.Token = iamResp.Token
Expand Down
81 changes: 64 additions & 17 deletions simples3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -400,49 +400,96 @@ func TestS3_NewUsingIAM(t *testing.T) {
"Type" : "AWS-HMAC","AccessKeyId" : "abc",
"SecretAccessKey" : "abc","Token" : "abc",
"Expiration" : "2018-12-24T16:24:59Z"}`
respIMDSToken = `AQAEAJWopi8yvjKYXyWJbzESE0cms-OoTnptJzS3M9g5iNcl06UEkQ==`
)

tsFail := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(2 * time.Second)
}))
defer tsFail.Close()

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
t.Errorf("Expected 'GET' request, got '%s'", r.Method)
}
if r.URL.EscapedPath() == "/" {
w.WriteHeader(http.StatusOK)
io.WriteString(w, iam)
}
if r.URL.EscapedPath() == "/"+iam {
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json")
io.WriteString(w, resp)
genServerHandlerFunc := func(failIMDS bool) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case "GET":
if !failIMDS {
// check if token is present
if r.Header.Get(imdsTokenHeader) == "" {
w.WriteHeader(http.StatusUnauthorized)
return
}
}

url := securityCredentialsURI
if r.URL.EscapedPath() == url {
w.WriteHeader(http.StatusOK)
io.WriteString(w, iam)
}
if r.URL.EscapedPath() == url+iam {
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json")
io.WriteString(w, resp)
}
case "PUT":
if failIMDS {
w.WriteHeader(http.StatusNotFound)
return
}

if r.URL.EscapedPath() == imdsTokenURI {
if r.Header.Get(imdsTokenTtlHeader) != "60" {
w.WriteHeader(http.StatusBadRequest)
return
}

w.WriteHeader(http.StatusOK)
io.WriteString(w, respIMDSToken)
}
default:
t.Errorf("Expected 'GET' or 'PUT' request, got '%s'", r.Method)
}
}
}))
}

ts := httptest.NewServer(http.HandlerFunc(genServerHandlerFunc(false)))
defer ts.Close()

tsFailIMDS := httptest.NewServer(http.HandlerFunc(genServerHandlerFunc(true)))
defer tsFailIMDS.Close()

cl := &http.Client{Timeout: 1 * time.Second}

// Test for timeout.
_, err := newUsingIAM(&http.Client{Timeout: 1 * time.Second}, tsFail.URL, "abc")
_, err := newUsingIAM(cl, tsFail.URL, "abc")
if err == nil {
t.Errorf("Expected error, got nil")
} else {
var timeoutError net.Error
if errors.As(err, &timeoutError); !timeoutError.Timeout() {

if errors.As(err, &timeoutError) && !timeoutError.Timeout() {
t.Errorf("newUsingIAM() timeout check. got error = %v", err)
}
}

// Test for successful IAM fetch.
s3, err := newUsingIAM(http.DefaultClient, ts.URL, "abc")
s3, err := newUsingIAM(cl, ts.URL, "abc")
if err != nil {
t.Errorf("newUsingIAM() error = %v", err)
}

if s3.AccessKey != "abc" && s3.SecretKey != "abc" && s3.Region != "abc" {
if s3 == nil {
t.Errorf("newUsingIAM() got = %v", s3)
}

if s3.AccessKey != "abc" || s3.SecretKey != "abc" || s3.Region != "abc" {
t.Errorf("S3.FileDelete() got = %v", s3)
}

// Test for failed IMDS token fetch.
_, err = newUsingIAM(cl, tsFailIMDS.URL, "abc")
if err == nil {
t.Errorf("Expected error, got nil")
}
}

func TestCustomEndpoint(t *testing.T) {
Expand Down

0 comments on commit 4257c79

Please sign in to comment.