diff --git a/cmd/root.go b/cmd/root.go index 828ce981..ac01a021 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -29,15 +29,12 @@ import ( "syscall" "time" - "cloud.google.com/go/alloydbconn" "contrib.go.opencensus.io/exporter/prometheus" "contrib.go.opencensus.io/exporter/stackdriver" "github.com/GoogleCloudPlatform/alloydb-auth-proxy/alloydb" - "github.com/GoogleCloudPlatform/alloydb-auth-proxy/internal/gcloud" "github.com/GoogleCloudPlatform/alloydb-auth-proxy/internal/proxy" "github.com/spf13/cobra" "go.opencensus.io/trace" - "golang.org/x/oauth2" ) var ( @@ -171,6 +168,9 @@ func parseConfig(cmd *cobra.Command, conf *proxy.Config, args []string) error { if len(args) == 0 { return newBadCommandError("missing instance uri (e.g., /projects/$PROJECTS/locations/$LOCTION/clusters/$CLUSTER/instances/$INSTANCES)") } + + conf.UserAgent = userAgent + userHasSet := func(f string) bool { return cmd.PersistentFlags().Lookup(f).Changed } @@ -195,31 +195,6 @@ func parseConfig(cmd *cobra.Command, conf *proxy.Config, args []string) error { if conf.CredentialsFile != "" && conf.GcloudAuth { return newBadCommandError("cannot specify --credentials-file and --gcloud-auth flags at the same time") } - opts := []alloydbconn.Option{ - alloydbconn.WithUserAgent(userAgent), - } - switch { - case conf.Token != "": - cmd.Printf("Authorizing with the -token flag\n") - opts = append(opts, alloydbconn.WithTokenSource( - oauth2.StaticTokenSource(&oauth2.Token{AccessToken: conf.Token}), - )) - case conf.CredentialsFile != "": - cmd.Printf("Authorizing with the credentials file at %q\n", conf.CredentialsFile) - opts = append(opts, alloydbconn.WithCredentialsFile( - conf.CredentialsFile, - )) - case conf.GcloudAuth: - cmd.Println("Authorizing with gcloud user credentials") - ts, err := gcloud.TokenSource() - if err != nil { - return err - } - opts = append(opts, alloydbconn.WithTokenSource(ts)) - default: - cmd.Println("Authorizing with Application Default Credentials") - } - conf.DialerOpts = opts if userHasSet("http-port") && !userHasSet("prometheus-namespace") { return newBadCommandError("cannot specify --http-port without --prometheus-namespace") @@ -394,18 +369,7 @@ func runSignalWrapper(cmd *Command) error { startCh := make(chan *proxy.Client) go func() { defer close(startCh) - // Check if the caller has configured a dialer. - // Otherwise, initialize a new one. - d := cmd.conf.Dialer - if d == nil { - var err error - d, err = alloydbconn.NewDialer(ctx, cmd.conf.DialerOpts...) - if err != nil { - shutdownCh <- fmt.Errorf("error initializing dialer: %v", err) - return - } - } - p, err := proxy.NewClient(ctx, d, cmd.Command, cmd.conf) + p, err := proxy.NewClient(ctx, cmd.Command, cmd.conf) if err != nil { shutdownCh <- fmt.Errorf("unable to start: %v", err) return diff --git a/cmd/root_test.go b/cmd/root_test.go index 4701a3d1..19d30c59 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -27,12 +27,14 @@ import ( "cloud.google.com/go/alloydbconn" "github.com/GoogleCloudPlatform/alloydb-auth-proxy/internal/proxy" "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" "github.com/spf13/cobra" ) func TestNewCommandArguments(t *testing.T) { withDefaults := func(c *proxy.Config) *proxy.Config { + if c.UserAgent == "" { + c.UserAgent = userAgent + } if c.Addr == "" { c.Addr = "127.0.0.1" } @@ -180,8 +182,7 @@ func TestNewCommandArguments(t *testing.T) { t.Fatalf("want error = nil, got = %v", err) } - opts := cmpopts.IgnoreFields(proxy.Config{}, "DialerOpts") - if got := c.conf; !cmp.Equal(tc.want, got, opts) { + if got := c.conf; !cmp.Equal(tc.want, got) { t.Fatalf("want = %#v\ngot = %#v\ndiff = %v", tc.want, got, cmp.Diff(tc.want, got)) } }) diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 7da6c753..1952b9b8 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -28,7 +28,9 @@ import ( "cloud.google.com/go/alloydbconn" "github.com/GoogleCloudPlatform/alloydb-auth-proxy/alloydb" + "github.com/GoogleCloudPlatform/alloydb-auth-proxy/internal/gcloud" "github.com/spf13/cobra" + "golang.org/x/oauth2" ) // InstanceConnConfig holds the configuration for an individual instance @@ -48,6 +50,10 @@ type InstanceConnConfig struct { // Config contains all the configuration provided by the caller. type Config struct { + // UserAgent is the user agent to use when sending requests to the Admin + // API. + UserAgent string + // Token is the Bearer token used for authorization. Token string @@ -76,10 +82,33 @@ type Config struct { // Dialer specifies the dialer to use when connecting to AlloyDB // instances. Dialer alloydb.Dialer +} + +// DialerOptions builds appropriate list of options from the Config +// values for use by alloydbconn.NewClient() +func (c *Config) DialerOptions() ([]alloydbconn.Option, error) { + opts := []alloydbconn.Option{ + alloydbconn.WithUserAgent(c.UserAgent), + } + switch { + case c.Token != "": + opts = append(opts, alloydbconn.WithTokenSource( + oauth2.StaticTokenSource(&oauth2.Token{AccessToken: c.Token}), + )) + case c.CredentialsFile != "": + opts = append(opts, alloydbconn.WithCredentialsFile( + c.CredentialsFile, + )) + case c.GcloudAuth: + ts, err := gcloud.TokenSource() + if err != nil { + return nil, err + } + opts = append(opts, alloydbconn.WithTokenSource(ts)) + default: + } - // DialerOpts specifies the opts to use when creating a new dialer. This - // value is ignored when a Dialer has been set. - DialerOpts []alloydbconn.Option + return opts, nil } type portConfig struct { @@ -131,7 +160,22 @@ type Client struct { } // NewClient completes the initial setup required to get the proxy to a "steady" state. -func NewClient(ctx context.Context, d alloydb.Dialer, cmd *cobra.Command, conf *Config) (*Client, error) { +func NewClient(ctx context.Context, cmd *cobra.Command, conf *Config) (*Client, error) { + // Check if the caller has configured a dialer. + // Otherwise, initialize a new one. + d := conf.Dialer + if d == nil { + var err error + dialerOpts, err := conf.DialerOptions() + if err != nil { + return nil, fmt.Errorf("error initializing dialer: %v", err) + } + d, err = alloydbconn.NewDialer(ctx, dialerOpts...) + if err != nil { + return nil, fmt.Errorf("error initializing dialer: %v", err) + } + } + pc := newPortConfig(conf.Port) var mnts []*socketMount for _, inst := range conf.Instances { diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index d850be76..60e046da 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -185,7 +185,8 @@ func TestClientInitialization(t *testing.T) { for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { - c, err := proxy.NewClient(ctx, fakeDialer{}, &cobra.Command{}, tc.in) + tc.in.Dialer = fakeDialer{} + c, err := proxy.NewClient(ctx, &cobra.Command{}, tc.in) if err != nil { t.Fatalf("want error = nil, got = %v", err) } @@ -228,14 +229,15 @@ func TestClientInitializationWorksRepeatedly(t *testing.T) { Instances: []proxy.InstanceConnConfig{ {Name: "/projects/proj/locations/region/clusters/clust/instances/inst1"}, }, + Dialer: fakeDialer{}, } - c, err := proxy.NewClient(ctx, fakeDialer{}, &cobra.Command{}, in) + c, err := proxy.NewClient(ctx, &cobra.Command{}, in) if err != nil { t.Fatalf("want error = nil, got = %v", err) } c.Close() - c, err = proxy.NewClient(ctx, fakeDialer{}, &cobra.Command{}, in) + c, err = proxy.NewClient(ctx, &cobra.Command{}, in) if err != nil { t.Fatalf("want error = nil, got = %v", err) }