diff --git a/pkg/db/db_session/default.go b/pkg/db/db_session/default.go index 4967a535..85c4083e 100755 --- a/pkg/db/db_session/default.go +++ b/pkg/db/db_session/default.go @@ -62,7 +62,7 @@ func (f *Default) Init(config *config.DatabaseConfig) { )) } - dbx = stdlib.OpenDB(*connConfig, stdlib.OptionBeforeConnect(f.setPassword())) + dbx = stdlib.OpenDB(*connConfig, stdlib.OptionBeforeConnect(setPassword(config))) dbx.SetMaxOpenConns(config.MaxOpenConnections) // Connect GORM to use the same connection @@ -94,32 +94,39 @@ func (f *Default) Init(config *config.DatabaseConfig) { }) } -func (f *Default) setPassword() func(ctx context.Context, connConfig *pgx.ConnConfig) error { - +func setPassword(dbConfig *config.DatabaseConfig) func(ctx context.Context, connConfig *pgx.ConnConfig) error { return func(ctx context.Context, connConfig *pgx.ConnConfig) error { - if f.config.AuthMethod == constants.AuthMethodPassword { - connConfig.Password = f.config.Password + if dbConfig.AuthMethod == constants.AuthMethodPassword { + connConfig.Password = dbConfig.Password return nil - } else if f.config.AuthMethod == constants.AuthMethodMicrosoftEntra { - if isExpired(f.config.Token) { - // ARO-HCP environment variable configuration is set by the Azure workload identity webhook. - // Use [WorkloadIdentityCredential] directly when not using the webhook or needing more control over its configuration. - cred, err := azidentity.NewDefaultAzureCredential(nil) - if err != nil { - return err - } - token, err := cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: []string{f.config.TokenRequestScope}}) + } else if dbConfig.AuthMethod == constants.AuthMethodMicrosoftEntra { + if isExpired(dbConfig.Token) { + token, err := getAccessToken(ctx, dbConfig) if err != nil { return err } connConfig.Password = token.Token - f.config.Token = &token + dbConfig.Token = token } } return nil } } +func getAccessToken(ctx context.Context, dbConfig *config.DatabaseConfig) (*azcore.AccessToken, error) { + // ARO-HCP environment variable configuration is set by the Azure workload identity webhook. + // Use [WorkloadIdentityCredential] directly when not using the webhook or needing more control over its configuration. + cred, err := azidentity.NewDefaultAzureCredential(nil) + if err != nil { + return nil, err + } + token, err := cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: []string{dbConfig.TokenRequestScope}}) + if err != nil { + return nil, err + } + return &token, nil +} + func isExpired(accessToken *azcore.AccessToken) bool { return accessToken == nil || time.Until(accessToken.ExpiresOn).Seconds() < constants.MinTokenLifeThreshold @@ -144,13 +151,14 @@ func waitForNotification(ctx context.Context, l *pq.Listener, callback func(id s case <-time.After(10 * time.Second): logger.V(10).Infof("Received no events on channel during interval. Pinging source") go func() { + // TODO: Need to handle the error, especially in cases of network failure. l.Ping() }() } } } -func newListener(ctx context.Context, connstr, channel string, callback func(id string)) { +func newListener(ctx context.Context, dbConfig *config.DatabaseConfig, channel string, callback func(id string)) { logger := ocmlogger.NewOCMLogger(ctx) plog := func(ev pq.ListenerEventType, err error) { @@ -158,6 +166,18 @@ func newListener(ctx context.Context, connstr, channel string, callback func(id logger.Error(err.Error()) } } + connstr := dbConfig.ConnectionString(true) + // append the password to the connection string + if dbConfig.AuthMethod == constants.AuthMethodPassword { + connstr += fmt.Sprintf(" password='%s'", dbConfig.Password) + } else if dbConfig.AuthMethod == constants.AuthMethodMicrosoftEntra { + token, err := getAccessToken(ctx, dbConfig) + if err != nil { + panic(err) + } + connstr += fmt.Sprintf(" password='%s'", token.Token) + } + listener := pq.NewListener(connstr, 10*time.Second, time.Minute, plog) err := listener.Listen(channel) if err != nil { @@ -169,7 +189,7 @@ func newListener(ctx context.Context, connstr, channel string, callback func(id } func (f *Default) NewListener(ctx context.Context, channel string, callback func(id string)) { - newListener(ctx, f.config.ConnectionString(true), channel, callback) + newListener(ctx, f.config, channel, callback) } func (f *Default) New(ctx context.Context) *gorm.DB { diff --git a/pkg/db/db_session/test.go b/pkg/db/db_session/test.go index 672afe94..eac616c1 100755 --- a/pkg/db/db_session/test.go +++ b/pkg/db/db_session/test.go @@ -217,5 +217,5 @@ func (f *Test) ResetDB() { } func (f *Test) NewListener(ctx context.Context, channel string, callback func(id string)) { - newListener(ctx, f.config.ConnectionString(true), channel, callback) + newListener(ctx, f.config, channel, callback) }