Skip to content

Commit

Permalink
Support refresh access token
Browse files Browse the repository at this point in the history
Signed-off-by: clyang82 <chuyang@redhat.com>
  • Loading branch information
clyang82 committed Dec 16, 2024
1 parent 8abc9ba commit b70bfd7
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 39 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ require (
github.com/google/uuid v1.6.0
github.com/gorilla/handlers v1.5.1
github.com/gorilla/mux v1.8.1
github.com/jackc/pgx/v5 v5.3.0
github.com/jinzhu/inflection v1.0.0
github.com/lib/pq v1.10.7
github.com/mendsley/gojwk v0.0.0-20141217222730-4d5ec6e58103
Expand Down Expand Up @@ -109,7 +110,6 @@ require (
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.3.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
Expand Down
33 changes: 9 additions & 24 deletions pkg/config/db.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
package config

import (
"context"
"fmt"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/openshift-online/maestro/pkg/constants"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/spf13/pflag"

"github.com/openshift-online/maestro/pkg/constants"
)

type DatabaseConfig struct {
Expand All @@ -31,6 +30,7 @@ type DatabaseConfig struct {

AuthMethod string `json:"auth_method"`
TokenRequestScope string `json:"token_request_scope"`
Token *azcore.AccessToken
}

func NewDatabaseConfig() *DatabaseConfig {
Expand Down Expand Up @@ -83,22 +83,7 @@ func (c *DatabaseConfig) ReadFiles() error {
return err
}

if c.AuthMethod == constants.AuthMethodMicrosoftEntra {
// ARO-HCP environment variable configuration is set by the Azure workload identity webhook.
// Use [WorkloadIdentityCredential] directly when not using the webhook or needing more control over its configuration.
cred, err := azidentity.NewDefaultAzureCredential(nil)
if err != nil {
return err
}
// The access token can be expired. but the existing connections are not invalidated.
// TODO: how to reconnect due to the network is broken etc. Right now, gorm does not have this feature.
// refer to https://github.com/go-gorm/gorm/issues/5602 & https://github.com/go-gorm/gorm/pull/1721.
token, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{c.TokenRequestScope}})
if err != nil {
return err
}
c.Password = token.Token
} else {
if c.AuthMethod == constants.AuthMethodPassword {
err = readFileValueString(c.PasswordFile, &c.Password)
if err != nil {
return err
Expand All @@ -117,13 +102,13 @@ func (c *DatabaseConfig) ConnectionStringWithName(name string, withSSL bool) str
var cmd string
if withSSL {
cmd = fmt.Sprintf(
"host=%s port=%d user=%s password='%s' dbname=%s sslmode=%s sslrootcert=%s",
c.Host, c.Port, c.Username, c.Password, name, c.SSLMode, c.RootCertFile,
"host=%s port=%d user=%s dbname=%s sslmode=%s sslrootcert=%s",
c.Host, c.Port, c.Username, name, c.SSLMode, c.RootCertFile,
)
} else {
cmd = fmt.Sprintf(
"host=%s port=%d user=%s password='%s' dbname=%s sslmode=disable",
c.Host, c.Port, c.Username, c.Password, name,
"host=%s port=%d user=%s dbname=%s sslmode=disable",
c.Host, c.Port, c.Username, name,
)
}

Expand Down
4 changes: 4 additions & 0 deletions pkg/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,8 @@ const (

AuthMethodPassword = "password" // Standard postgres username/password authentication.
AuthMethodMicrosoftEntra = "az-entra" // Microsoft Entra ID-based token authentication.

// MinTokenLifeThreshold defines the minimum remaining lifetime (in seconds) of the access token before
// it should be refreshed.
MinTokenLifeThreshold = 60.0
)
60 changes: 46 additions & 14 deletions pkg/db/db_session/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@ import (
"fmt"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/stdlib"
"github.com/lib/pq"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"

"github.com/lib/pq"

"github.com/openshift-online/maestro/pkg/config"
"github.com/openshift-online/maestro/pkg/constants"
"github.com/openshift-online/maestro/pkg/db"
ocmlogger "github.com/openshift-online/maestro/pkg/logger"
)
Expand Down Expand Up @@ -48,20 +53,16 @@ func (f *Default) Init(config *config.DatabaseConfig) {
err error
)

// Open connection to DB via standard library
dbx, err = sql.Open(config.Dialect, config.ConnectionString(config.SSLMode != disable))
connConfig, err := pgx.ParseConfig(config.ConnectionString(config.SSLMode != disable))
if err != nil {
dbx, err = sql.Open(config.Dialect, config.ConnectionString(false))
if err != nil {
panic(fmt.Sprintf(
"SQL failed to connect to %s database %s with connection string: %s\nError: %s",
config.Dialect,
config.Name,
config.LogSafeConnectionString(config.SSLMode != disable),
err.Error(),
))
}
panic(fmt.Sprintf(
"GORM failed to parse the connection string: %s\nError: %s",
config.LogSafeConnectionString(config.SSLMode != disable),
err.Error(),
))
}

dbx = stdlib.OpenDB(*connConfig, stdlib.OptionBeforeConnect(f.setPassword()))
dbx.SetMaxOpenConns(config.MaxOpenConnections)

// Connect GORM to use the same connection
Expand Down Expand Up @@ -93,6 +94,37 @@ func (f *Default) Init(config *config.DatabaseConfig) {
})
}

func (f *Default) setPassword() func(ctx context.Context, connConfig *pgx.ConnConfig) error {

return func(ctx context.Context, connConfig *pgx.ConnConfig) error {
if f.config.AuthMethod == constants.AuthMethodPassword {
connConfig.Password = f.config.Password
return nil
} else if f.config.AuthMethod == constants.AuthMethodMicrosoftEntra {
if isExpired(f.config.Token) {
// ARO-HCP environment variable configuration is set by the Azure workload identity webhook.
// Use [WorkloadIdentityCredential] directly when not using the webhook or needing more control over its configuration.
cred, err := azidentity.NewDefaultAzureCredential(nil)
if err != nil {
return err
}
token, err := cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: []string{f.config.TokenRequestScope}})
if err != nil {
return err
}
connConfig.Password = token.Token
f.config.Token = &token
}
}
return nil
}
}

func isExpired(accessToken *azcore.AccessToken) bool {
return accessToken == nil ||
time.Until(accessToken.ExpiresOn).Seconds() < constants.MinTokenLifeThreshold
}

func (f *Default) DirectDB() *sql.DB {
return f.db
}
Expand Down

0 comments on commit b70bfd7

Please sign in to comment.