Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework presign function to allow signature of clients with dedicated endpoint #28

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 24 additions & 18 deletions presigned.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/sha256"
"encoding/hex"
"net/url"
"path"
"sort"
"strconv"
"strings"
Expand All @@ -24,8 +25,6 @@ type PresignedInput struct {
Timestamp time.Time
ExtraHeaders map[string]string
ExpirySeconds int
Protocol string
Endpoint string
}

// GeneratePresignedURL creates a Presigned URL that can be used
Expand All @@ -35,8 +34,9 @@ func (s3 *S3) GeneratePresignedURL(in PresignedInput) string {
var (
nowTime = nowTime()

protocol = defaultProtocol
endpoint = defaultPresignedHost
protocol = defaultProtocol
hostname = defaultPresignedHost
path_prefix = ""
)
if !in.Timestamp.IsZero() {
nowTime = in.Timestamp.UTC()
Expand All @@ -52,28 +52,30 @@ func (s3 *S3) GeneratePresignedURL(in PresignedInput) string {
b.Reset()

// Set the protocol as default if not provided.
if in.Protocol != "" {
protocol = in.Protocol
}
if in.Endpoint != "" {
endpoint = in.Endpoint
if endpoint, _ := url.Parse(s3.Endpoint); endpoint.Host != "" {
protocol = endpoint.Scheme + "://"
hostname = endpoint.Host
path_prefix = path.Join("/", endpoint.Path, in.Bucket)
} else {
host := bytes.Buffer{}
host.WriteString(in.Bucket)
host.WriteRune('.')
host.WriteString(hostname)
hostname = host.String()
}

// Add host to Headers
signedHeaders := map[string][]byte{}
for k, v := range in.ExtraHeaders {
signedHeaders[k] = []byte(v)
}
host := bytes.Buffer{}
host.WriteString(in.Bucket)
host.WriteRune('.')
host.WriteString(endpoint)
signedHeaders["host"] = host.Bytes()
signedHeaders["host"] = []byte(hostname)

// Start Canonical Request Formation
h := sha256.New() // We write the canonical request directly to the SHA256 hash.
h.Write([]byte(in.Method)) // HTTP Verb
h.Write(newLine)
h.Write([]byte(path_prefix))
h.Write([]byte{'/'})
h.Write([]byte(in.ObjectKey)) // CanonicalURL
h.Write(newLine)
Expand Down Expand Up @@ -193,10 +195,14 @@ func (s3 *S3) GeneratePresignedURL(in PresignedInput) string {
b.Reset()

// Start Generating URL
b.WriteString(protocol)
b.WriteString(in.Bucket)
b.WriteRune('.')
b.WriteString(endpoint)
if s3.Endpoint != "" {
b.WriteString(s3.Endpoint)
b.WriteRune('/')
b.WriteString(in.Bucket)
} else {
b.WriteString(protocol)
b.WriteString(hostname)
}
b.WriteRune('/')
b.WriteString(in.ObjectKey)
b.WriteRune('?')
Expand Down
8 changes: 4 additions & 4 deletions presigned_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ func TestS3_GeneratePresignedURL_Personal(t *testing.T) {
os.Getenv("AWS_S3_ACCESS_KEY"),
os.Getenv("AWS_S3_SECRET_KEY"),
)
s.Endpoint = os.Getenv("AWS_S3_ENDPOINT")
dontwant := ""
if got := s.GeneratePresignedURL(PresignedInput{
Bucket: os.Getenv("AWS_S3_BUCKET"),
Endpoint: os.Getenv("AWS_S3_ENDPOINT"),
ObjectKey: "test1.txt",
Method: "GET",
Timestamp: nowTime(),
Expand All @@ -81,10 +81,10 @@ func TestS3_GeneratePresignedURL_ExtraHeader(t *testing.T) {
os.Getenv("AWS_S3_ACCESS_KEY"),
os.Getenv("AWS_S3_SECRET_KEY"),
)
s.Endpoint = os.Getenv("AWS_S3_ENDPOINT")
dontwant := ""
if got := s.GeneratePresignedURL(PresignedInput{
Bucket: os.Getenv("AWS_S3_BUCKET"),
Endpoint: os.Getenv("AWS_S3_ENDPOINT"),
ObjectKey: "test2.txt",
Method: "GET",
Timestamp: nowTime(),
Expand All @@ -105,10 +105,10 @@ func TestS3_GeneratePresignedURL_PUT(t *testing.T) {
os.Getenv("AWS_S3_ACCESS_KEY"),
os.Getenv("AWS_S3_SECRET_KEY"),
)
s.Endpoint = os.Getenv("AWS_S3_ENDPOINT")
dontwant := ""
if got := s.GeneratePresignedURL(PresignedInput{
Bucket: os.Getenv("AWS_S3_BUCKET"),
Endpoint: os.Getenv("AWS_S3_ENDPOINT"),
ObjectKey: "test2.txt",
Method: "PUT",
Timestamp: nowTime(),
Expand All @@ -126,13 +126,13 @@ func BenchmarkS3_GeneratePresigned(b *testing.B) {
os.Getenv("AWS_S3_ACCESS_KEY"),
os.Getenv("AWS_S3_SECRET_KEY"),
)
s.Endpoint = os.Getenv("AWS_S3_ENDPOINT")

b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
s.GeneratePresignedURL(PresignedInput{
Bucket: os.Getenv("AWS_S3_BUCKET"),
Endpoint: os.Getenv("AWS_S3_ENDPOINT"),
ObjectKey: "test.txt",
Method: "GET",
Timestamp: nowTime(),
Expand Down
Loading