Skip to content

Commit

Permalink
bug(aws): handle ECR repositories in different regions
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Conner <kev.conner@getupcloud.com>
  • Loading branch information
knrc committed Apr 19, 2024
1 parent 9873cf3 commit dff6701
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 47 deletions.
15 changes: 9 additions & 6 deletions pkg/fanal/image/registry/azure/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,31 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"golang.org/x/xerrors"

"github.com/aquasecurity/trivy/pkg/fanal/image/registry/intf"
"github.com/aquasecurity/trivy/pkg/fanal/types"
)

type Registry struct {
type RegistryClient struct {
domain string
}

type Registry struct {
}

const (
azureURL = "azurecr.io"
scope = "https://management.azure.com/.default"
scheme = "https"
)

func (r *Registry) CheckOptions(domain string, _ types.RegistryOptions) error {
func (r *Registry) CheckOptions(domain string, _ types.RegistryOptions) (intf.RegistryClient, error) {
if !strings.HasSuffix(domain, azureURL) {
return xerrors.Errorf("Azure registry: %w", types.InvalidURLPattern)
return nil, xerrors.Errorf("Azure registry: %w", types.InvalidURLPattern)
}
r.domain = domain
return nil
return &RegistryClient{domain: domain}, nil
}

func (r *Registry) GetCredential(ctx context.Context) (string, string, error) {
func (r *RegistryClient) GetCredential(ctx context.Context) (string, string, error) {
cred, err := azidentity.NewDefaultAzureCredential(nil)
if err != nil {
return "", "", xerrors.Errorf("unable to generate acr credential error: %w", err)
Expand Down
2 changes: 1 addition & 1 deletion pkg/fanal/image/registry/azure/azure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func TestRegistry_CheckOptions(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := azure.Registry{}
err := r.CheckOptions(tt.domain, types.RegistryOptions{})
_, err := r.CheckOptions(tt.domain, types.RegistryOptions{})
if tt.wantErr != "" {
assert.EqualError(t, err, tt.wantErr)
} else {
Expand Down
53 changes: 40 additions & 13 deletions pkg/fanal/image/registry/ecr/ecr.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ecr
import (
"context"
"encoding/base64"
"regexp"
"strings"

"github.com/aws/aws-sdk-go-v2/aws"
Expand All @@ -11,47 +12,73 @@ import (
"github.com/aws/aws-sdk-go-v2/service/ecr"
"golang.org/x/xerrors"

"github.com/aquasecurity/trivy/pkg/fanal/image/registry/intf"
"github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/log"
)

const ecrURL = "amazonaws.com"

type ecrAPI interface {
GetAuthorizationToken(ctx context.Context, params *ecr.GetAuthorizationTokenInput, optFns ...func(*ecr.Options)) (*ecr.GetAuthorizationTokenOutput, error)
}

type ECR struct {
}

type ECRClient struct {
Client ecrAPI
}

func getSession(option types.RegistryOptions) (aws.Config, error) {
func getSession(domain, region string, option types.RegistryOptions) (aws.Config, error) {
// create custom credential information if option is valid
if option.AWSSecretKey != "" && option.AWSAccessKey != "" && option.AWSRegion != "" {
if region != option.AWSRegion {
log.Warnf("The region from AWS_REGION (%s) is being overridden. The region from domain (%s) was used.", option.AWSRegion, domain)
}
return config.LoadDefaultConfig(
context.TODO(),
config.WithRegion(option.AWSRegion),
config.WithRegion(region),
config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(option.AWSAccessKey, option.AWSSecretKey, option.AWSSessionToken)),
)
}
return config.LoadDefaultConfig(context.TODO())
return config.LoadDefaultConfig(context.TODO(), config.WithRegion(region))
}

func (e *ECR) CheckOptions(domain string, option types.RegistryOptions) error {
if !strings.HasSuffix(domain, ecrURL) {
return xerrors.Errorf("ECR : %w", types.InvalidURLPattern)
func (e *ECR) CheckOptions(domain string, option types.RegistryOptions) (intf.RegistryClient, error) {
region := determineRegion(domain)
if region == "" {
return nil, xerrors.Errorf("ECR : %w", types.InvalidURLPattern)
}

cfg, err := getSession(option)
cfg, err := getSession(domain, region, option)
if err != nil {
return err
return nil, err
}

svc := ecr.NewFromConfig(cfg)
e.Client = svc
return nil
return &ECRClient{Client: svc}, nil
}

// Endpoints take the form
// <registry-id>.dkr.ecr.<region>.amazonaws.com
// <registry-id>.dkr.ecr-fips.<region>.amazonaws.com
// <registry-id>.dkr.ecr.<region>.amazonaws.com.cn
// <registry-id>.dkr.ecr.<region>.sc2s.sgov.gov
// <registry-id>.dkr.ecr.<region>.c2s.ic.gov
// see
// - https://docs.aws.amazon.com/general/latest/gr/ecr.html
// - https://docs.amazonaws.cn/en_us/aws/latest/userguide/endpoints-arns.html
// - https://github.com/boto/botocore/blob/1.34.51/botocore/data/endpoints.json
var ecrEndpointMatch = regexp.MustCompile(`^[^.]+\.dkr\.ecr(?:-fips)?\.([^.]+)\.(?:amazonaws\.com(?:\.cn)?|sc2s\.sgov\.gov|c2s\.ic\.gov)$`)

func determineRegion(domain string) string {
matches := ecrEndpointMatch.FindStringSubmatch(domain)
if matches != nil {
return matches[1]
}
return ""
}

func (e *ECR) GetCredential(ctx context.Context) (username, password string, err error) {
func (e *ECRClient) GetCredential(ctx context.Context) (username, password string, err error) {
input := &ecr.GetAuthorizationTokenInput{}
result, err := e.Client.GetAuthorizationToken(ctx, input)
if err != nil {
Expand Down
68 changes: 63 additions & 5 deletions pkg/fanal/image/registry/ecr/ecr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,91 @@ import (
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/ecr"
awstypes "github.com/aws/aws-sdk-go-v2/service/ecr/types"
"github.com/stretchr/testify/require"

"github.com/aquasecurity/trivy/pkg/fanal/types"
)

type testECRClient interface {
Options() ecr.Options
}

func TestCheckOptions(t *testing.T) {
var tests = map[string]struct {
domain string
wantErr error
domain string
expectedRegion string
wantErr error
}{
"InvalidURL": {
domain: "alpine:3.9",
wantErr: types.InvalidURLPattern,
},
"NoOption": {
domain: "xxx.ecr.ap-northeast-1.amazonaws.com",
domain: "xxx.dkr.ecr.ap-northeast-1.amazonaws.com",
expectedRegion: "ap-northeast-1",
},
"region-1": {
domain: "xxx.dkr.ecr.region-1.amazonaws.com",
expectedRegion: "region-1",
},
"region-2": {
domain: "xxx.dkr.ecr.region-2.amazonaws.com",
expectedRegion: "region-2",
},
"fips-region-1": {
domain: "xxx.dkr.ecr-fips.fips-region.amazonaws.com",
expectedRegion: "fips-region",
},
"cn-region-1": {
domain: "xxx.dkr.ecr.region-1.amazonaws.com.cn",
expectedRegion: "region-1",
},
"cn-region-2": {
domain: "xxx.dkr.ecr.region-2.amazonaws.com.cn",
expectedRegion: "region-2",
},
"sc2s-region-1": {
domain: "xxx.dkr.ecr.sc2s-region.sc2s.sgov.gov",
expectedRegion: "sc2s-region",
},
"c2s-region-1": {
domain: "xxx.dkr.ecr.c2s-region.c2s.ic.gov",
expectedRegion: "c2s-region",
},
"invalid-ecr": {
domain: "xxx.dkrecr.region-1.amazonaws.com",
wantErr: types.InvalidURLPattern,
},
"invalid-fips": {
domain: "xxx.dkr.ecrfips.fips-region.amazonaws.com",
wantErr: types.InvalidURLPattern,
},
"invalid-cn": {
domain: "xxx.dkr.ecr.region-2.amazonaws.cn",
wantErr: types.InvalidURLPattern,
},
"invalid-sc2s": {
domain: "xxx.dkr.ecr.sc2s-region.sc2s.sgov",
wantErr: types.InvalidURLPattern,
},
"invalid-cs2": {
domain: "xxx.dkr.ecr.c2s-region.c2s.ic",
wantErr: types.InvalidURLPattern,
},
}

for testname, v := range tests {
a := &ECR{}
err := a.CheckOptions(v.domain, types.RegistryOptions{})
ecrClient, err := a.CheckOptions(v.domain, types.RegistryOptions{})
if err != nil {
if !errors.Is(err, v.wantErr) {
t.Errorf("[%s]\nexpected error based on %v\nactual : %v", testname, v.wantErr, err)
}
continue
}

client := (ecrClient.(*ECRClient)).Client.(testECRClient)
require.Equal(t, v.expectedRegion, client.Options().Region)
}
}

Expand Down Expand Up @@ -82,7 +140,7 @@ func TestECRGetCredential(t *testing.T) {
}

for i, c := range cases {
e := ECR{
e := ECRClient{
Client: mockedECR{Resp: c.Resp},
}
username, password, err := e.GetCredential(context.Background())
Expand Down
18 changes: 11 additions & 7 deletions pkg/fanal/image/registry/google/google.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,36 @@ import (
"github.com/GoogleCloudPlatform/docker-credential-gcr/store"
"golang.org/x/xerrors"

"github.com/aquasecurity/trivy/pkg/fanal/image/registry/intf"
"github.com/aquasecurity/trivy/pkg/fanal/types"
)

type Registry struct {
type GoogleRegistryClient struct {
Store store.GCRCredStore
domain string
}

type Registry struct {
}

// Google container registry
const gcrURL = "gcr.io"

// Google artifact registry
const garURL = "docker.pkg.dev"

func (g *Registry) CheckOptions(domain string, option types.RegistryOptions) error {
func (g *Registry) CheckOptions(domain string, option types.RegistryOptions) (intf.RegistryClient, error) {
if !strings.HasSuffix(domain, gcrURL) && !strings.HasSuffix(domain, garURL) {
return xerrors.Errorf("Google registry: %w", types.InvalidURLPattern)
return nil, xerrors.Errorf("Google registry: %w", types.InvalidURLPattern)
}
g.domain = domain
client := GoogleRegistryClient{domain: domain}
if option.GCPCredPath != "" {
g.Store = store.NewGCRCredStore(option.GCPCredPath)
client.Store = store.NewGCRCredStore(option.GCPCredPath)
}
return nil
return &client, nil
}

func (g *Registry) GetCredential(_ context.Context) (username, password string, err error) {
func (g *GoogleRegistryClient) GetCredential(_ context.Context) (username, password string, err error) {
var credStore store.GCRCredStore
if g.Store == nil {
credStore, err = store.DefaultGCRCredStore()
Expand Down
12 changes: 6 additions & 6 deletions pkg/fanal/image/registry/google/google_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ func TestCheckOptions(t *testing.T) {
var tests = map[string]struct {
domain string
opt types.RegistryOptions
gcr *Registry
grc *GoogleRegistryClient
wantErr error
}{
"InvalidURL": {
Expand All @@ -23,12 +23,12 @@ func TestCheckOptions(t *testing.T) {
},
"NoOption": {
domain: "gcr.io",
gcr: &Registry{domain: "gcr.io"},
grc: &GoogleRegistryClient{domain: "gcr.io"},
},
"CredOption": {
domain: "gcr.io",
opt: types.RegistryOptions{GCPCredPath: "/path/to/file.json"},
gcr: &Registry{
grc: &GoogleRegistryClient{
domain: "gcr.io",
Store: store.NewGCRCredStore("/path/to/file.json"),
},
Expand All @@ -37,7 +37,7 @@ func TestCheckOptions(t *testing.T) {

for testname, v := range tests {
g := &Registry{}
err := g.CheckOptions(v.domain, v.opt)
grc, err := g.CheckOptions(v.domain, v.opt)
if v.wantErr != nil {
if err == nil {
t.Errorf("%s : expected error but no error", testname)
Expand All @@ -48,8 +48,8 @@ func TestCheckOptions(t *testing.T) {
}
continue
}
if !reflect.DeepEqual(v.gcr, g) {
t.Errorf("[%s]\nexpected : %v\nactual : %v", testname, v.gcr, g)
if !reflect.DeepEqual(v.grc, grc) {
t.Errorf("[%s]\nexpected : %v\nactual : %v", testname, v.grc, grc)
}
}
}
15 changes: 15 additions & 0 deletions pkg/fanal/image/registry/intf/registry.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package intf

import (
"context"

"github.com/aquasecurity/trivy/pkg/fanal/types"
)

type RegistryClient interface {
GetCredential(ctx context.Context) (string, string, error)
}

type Registry interface {
CheckOptions(domain string, option types.RegistryOptions) (RegistryClient, error)
}
14 changes: 5 additions & 9 deletions pkg/fanal/image/registry/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ import (
"github.com/aquasecurity/trivy/pkg/fanal/image/registry/azure"
"github.com/aquasecurity/trivy/pkg/fanal/image/registry/ecr"
"github.com/aquasecurity/trivy/pkg/fanal/image/registry/google"
"github.com/aquasecurity/trivy/pkg/fanal/image/registry/intf"
"github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/log"
)

var (
registries []Registry
registries []intf.Registry
)

func init() {
Expand All @@ -22,23 +23,18 @@ func init() {
RegisterRegistry(&azure.Registry{})
}

type Registry interface {
CheckOptions(domain string, option types.RegistryOptions) error
GetCredential(ctx context.Context) (string, string, error)
}

func RegisterRegistry(registry Registry) {
func RegisterRegistry(registry intf.Registry) {
registries = append(registries, registry)
}

func GetToken(ctx context.Context, domain string, opt types.RegistryOptions) (auth authn.Basic) {
// check registry which particular to get credential
for _, registry := range registries {
err := registry.CheckOptions(domain, opt)
client, err := registry.CheckOptions(domain, opt)
if err != nil {
continue
}
username, password, err := registry.GetCredential(ctx)
username, password, err := client.GetCredential(ctx)
if err != nil {
// only skip check registry if error occurred
log.Debug("Credential error", log.Err(err))
Expand Down

0 comments on commit dff6701

Please sign in to comment.