Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add azure-workload auth to MSSQL scaler #6161

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ Here is an overview of all new **experimental** features:
- **GCP Scalers**: Added custom time horizon in GCP scalers ([#5778](https://github.com/kedacore/keda/issues/5778))
- **GitHub Scaler**: Fixed pagination, fetching repository list ([#5738](https://github.com/kedacore/keda/issues/5738))
- **Kafka**: Fix logic to scale to zero on invalid offset even with earliest offsetResetPolicy ([#5689](https://github.com/kedacore/keda/issues/5689))
- **MSSQL Scaler**: Add azure-workload auth ([#6104](https://github.com/kedacore/keda/issues/6104))
- **RabbitMQ Scaler**: Add connection name for AMQP ([#5958](https://github.com/kedacore/keda/issues/5958))
- TODO ([#XXX](https://github.com/kedacore/keda/issues/XXX))

Expand Down
256 changes: 98 additions & 158 deletions pkg/scalers/mssql_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,74 +3,56 @@ package scalers
import (
"context"
"database/sql"
"errors"
"fmt"
"net"
"net/url"
"strconv"

// mssql driver required for this scaler
_ "github.com/denisenkom/go-mssqldb"
"github.com/go-logr/logr"
v2 "k8s.io/api/autoscaling/v2"
"k8s.io/metrics/pkg/apis/external_metrics"

"github.com/kedacore/keda/v2/apis/keda/v1alpha1"
"github.com/kedacore/keda/v2/pkg/scalers/azure"
"github.com/kedacore/keda/v2/pkg/scalers/scalersconfig"
)

var (
// ErrMsSQLNoQuery is returned when "query" is missing from the config.
ErrMsSQLNoQuery = errors.New("no query given")

// ErrMsSQLNoTargetValue is returned when "targetValue" is missing from the config.
ErrMsSQLNoTargetValue = errors.New("no targetValue given")
)

// mssqlScaler exposes a data pointer to mssqlMetadata and sql.DB connection
type mssqlScaler struct {
metricType v2.MetricTargetType
metadata *mssqlMetadata
metadata mssqlMetadata
connection *sql.DB
logger logr.Logger
azureOAuth *azure.ADWorkloadIdentityTokenProvider
}

// mssqlMetadata defines metadata used by KEDA to query a Microsoft SQL database
type mssqlMetadata struct {
// The connection string used to connect to the MSSQL database.
// Both URL syntax (sqlserver://host?database=dbName) and OLEDB syntax is supported.
// +optional
connectionString string
// The username credential for connecting to the MSSQL instance, if not specified in the connection string.
// +optional
username string
// The password credential for connecting to the MSSQL instance, if not specified in the connection string.
// +optional
password string
// The hostname of the MSSQL instance endpoint, if not specified in the connection string.
// +optional
host string
// The port number of the MSSQL instance endpoint, if not specified in the connection string.
// +optional
port int
// The name of the database to query, if not specified in the connection string.
// +optional
database string
// The T-SQL query to run against the target database - e.g. SELECT COUNT(*) FROM table.
// +required
query string
// The threshold that is used as targetAverageValue in the Horizontal Pod Autoscaler.
// +required
targetValue float64
// The threshold that is used in activation phase
// +optional
activationTargetValue float64
// The index of the scaler inside the ScaledObject
// +internal
triggerIndex int
ConnectionString string `keda:"name=connectionString,order=authParams;resolvedEnv;triggerMetadata,optional"`
Username string `keda:"name=username,order=authParams;triggerMetadata,optional"`
Password string `keda:"name=password,order=authParams;resolvedEnv;triggerMetadata,optional"`
Host string `keda:"name=host,order=authParams;triggerMetadata,optional"`
Port int `keda:"name=port,order=authParams;triggerMetadata,optional"`
Database string `keda:"name=database,order=authParams;triggerMetadata,optional"`
Query string `keda:"name=query,order=triggerMetadata"`
TargetValue float64 `keda:"name=targetValue,order=triggerMetadata"`
ActivationTargetValue float64 `keda:"name=activationTargetValue,order=triggerMetadata,optional,default=0"`

TriggerIndex int

WorkloadIdentityResource string `keda:"name=WorkloadIdentityResource,order=authParams;triggerMetadata,optional"`
WorkloadIdentityClientID string
WorkloadIdentityTenantID string
WorkloadIdentityAuthorityHost string
}

func (m *mssqlMetadata) Validate() error {
if m.ConnectionString == "" && m.Host == "" {
return fmt.Errorf("must provide either connectionstring or host")
}
return nil
}

// NewMSSQLScaler creates a new mssql scaler
func NewMSSQLScaler(config *scalersconfig.ScalerConfig) (Scaler, error) {
func NewMSSQLScaler(ctx context.Context, config *scalersconfig.ScalerConfig) (Scaler, error) {
metricType, err := GetMetricTargetType(config)
if err != nil {
return nil, fmt.Errorf("error getting scaler metric type: %w", err)
Expand All @@ -83,155 +65,115 @@ func NewMSSQLScaler(config *scalersconfig.ScalerConfig) (Scaler, error) {
return nil, fmt.Errorf("error parsing mssql metadata: %w", err)
}

conn, err := newMSSQLConnection(meta, logger)
if err != nil {
return nil, fmt.Errorf("error establishing mssql connection: %w", err)
}

return &mssqlScaler{
scaler := &mssqlScaler{
metricType: metricType,
metadata: meta,
connection: conn,
logger: logger,
}, nil
}

// parseMSSQLMetadata takes a ScalerConfig and returns a mssqlMetadata or an error if the config is invalid
func parseMSSQLMetadata(config *scalersconfig.ScalerConfig) (*mssqlMetadata, error) {
meta := mssqlMetadata{}

// Query
if val, ok := config.TriggerMetadata["query"]; ok {
meta.query = val
} else {
return nil, ErrMsSQLNoQuery
}

// Target query value
if val, ok := config.TriggerMetadata["targetValue"]; ok {
targetValue, err := strconv.ParseFloat(val, 64)
if err != nil {
return nil, fmt.Errorf("targetValue parsing error %w", err)
}
meta.targetValue = targetValue
} else {
if config.AsMetricSource {
meta.targetValue = 0
} else {
return nil, ErrMsSQLNoTargetValue
}
conn, err := newMSSQLConnection(ctx, scaler)
if err != nil {
return nil, fmt.Errorf("error establishing mssql connection: %w", err)
}

// Activation target value
meta.activationTargetValue = 0
if val, ok := config.TriggerMetadata["activationTargetValue"]; ok {
activationTargetValue, err := strconv.ParseFloat(val, 64)
if err != nil {
return nil, fmt.Errorf("activationTargetValue parsing error %w", err)
}
meta.activationTargetValue = activationTargetValue
}
scaler.connection = conn

// Connection string, which can either be provided explicitly or via the helper fields
switch {
case config.AuthParams["connectionString"] != "":
meta.connectionString = config.AuthParams["connectionString"]
case config.TriggerMetadata["connectionStringFromEnv"] != "":
meta.connectionString = config.ResolvedEnv[config.TriggerMetadata["connectionStringFromEnv"]]
default:
meta.connectionString = ""
var err error

host, err := GetFromAuthOrMeta(config, "host")
if err != nil {
return nil, err
}
meta.host = host

var paramPort string
paramPort, _ = GetFromAuthOrMeta(config, "port")
if paramPort != "" {
port, err := strconv.Atoi(paramPort)
if err != nil {
return nil, fmt.Errorf("port parsing error %w", err)
}
meta.port = port
}
return scaler, nil
}

meta.username, _ = GetFromAuthOrMeta(config, "username")
func parseMSSQLMetadata(config *scalersconfig.ScalerConfig) (mssqlMetadata, error) {
meta := mssqlMetadata{}
err := config.TypedConfig(&meta)
if err != nil {
return meta, err
}

// database is optional in SQL s
meta.database, _ = GetFromAuthOrMeta(config, "database")
meta.TriggerIndex = config.TriggerIndex

if config.AuthParams["password"] != "" {
meta.password = config.AuthParams["password"]
} else if config.TriggerMetadata["passwordFromEnv"] != "" {
meta.password = config.ResolvedEnv[config.TriggerMetadata["passwordFromEnv"]]
if config.PodIdentity.Provider == v1alpha1.PodIdentityProviderAzureWorkload {
if config.AuthParams["workloadIdentityResource"] != "" {
meta.WorkloadIdentityClientID = config.PodIdentity.GetIdentityID()
meta.WorkloadIdentityTenantID = config.PodIdentity.GetIdentityTenantID()
meta.WorkloadIdentityAuthorityHost = config.PodIdentity.GetIdentityAuthorityHost()
meta.WorkloadIdentityResource = config.AuthParams["workloadIdentityResource"]
}
}
meta.triggerIndex = config.TriggerIndex
return &meta, nil

return meta, nil
}

// newMSSQLConnection returns a new, opened SQL connection for the provided mssqlMetadata
func newMSSQLConnection(meta *mssqlMetadata, logger logr.Logger) (*sql.DB, error) {
connStr := getMSSQLConnectionString(meta)
func newMSSQLConnection(ctx context.Context, s *mssqlScaler) (*sql.DB, error) {
connStr := getMSSQLConnectionString(ctx, s)

db, err := sql.Open("sqlserver", connStr)
if err != nil {
logger.Error(err, fmt.Sprintf("Found error opening mssql: %s", err))
s.logger.Error(err, "Found error opening mssql")
return nil, err
}

err = db.Ping()
if err != nil {
logger.Error(err, fmt.Sprintf("Found error pinging mssql: %s", err))
s.logger.Error(err, "Found error pinging mssql")
return nil, err
}

return db, nil
}

// getMSSQLConnectionString returns a connection string from a mssqlMetadata
func getMSSQLConnectionString(meta *mssqlMetadata) string {
var connStr string

if meta.connectionString != "" {
connStr = meta.connectionString
} else {
query := url.Values{}
if meta.database != "" {
query.Add("database", meta.database)
}
func getMSSQLConnectionString(ctx context.Context, s *mssqlScaler) string {
meta := s.metadata
if meta.ConnectionString != "" {
return meta.ConnectionString
}

connectionURL := &url.URL{Scheme: "sqlserver", RawQuery: query.Encode()}
if meta.username != "" {
if meta.password != "" {
connectionURL.User = url.UserPassword(meta.username, meta.password)
} else {
connectionURL.User = url.User(meta.username)
}
}
query := url.Values{}
if meta.Database != "" {
query.Add("database", meta.Database)
}

if meta.port > 0 {
connectionURL.Host = net.JoinHostPort(meta.host, fmt.Sprintf("%d", meta.port))
connectionURL := &url.URL{Scheme: "sqlserver", RawQuery: query.Encode()}
if meta.Username != "" {
if meta.Password != "" {
connectionURL.User = url.UserPassword(meta.Username, meta.Password)
} else {
connectionURL.Host = meta.host
connectionURL.User = url.User(meta.Username)
}
}

if meta.Port > 0 {
connectionURL.Host = net.JoinHostPort(meta.Host, fmt.Sprintf("%d", meta.Port))
} else {
connectionURL.Host = meta.Host
}

if meta.WorkloadIdentityResource != "" {
token := s.getOAuthToken(ctx)
connectionURL.RawQuery += fmt.Sprintf("&access_token=%s", token)
}

connStr = connectionURL.String()
return connectionURL.String()
}

func (s *mssqlScaler) getOAuthToken(ctx context.Context) string {
if s.azureOAuth == nil {
s.azureOAuth = azure.NewAzureADWorkloadIdentityTokenProvider(ctx, s.metadata.WorkloadIdentityClientID, s.metadata.WorkloadIdentityTenantID, s.metadata.WorkloadIdentityAuthorityHost, s.metadata.WorkloadIdentityResource)
}

err := s.azureOAuth.Refresh()
if err != nil {
fmt.Println("Error fetching OAuth token:", err)
return ""
}

return connStr
return s.azureOAuth.OAuthToken()
}

// GetMetricSpecForScaling returns the MetricSpec for the Horizontal Pod Autoscaler
func (s *mssqlScaler) GetMetricSpecForScaling(context.Context) []v2.MetricSpec {
externalMetric := &v2.ExternalMetricSource{
Metric: v2.MetricIdentifier{
Name: GenerateMetricNameWithIndex(s.metadata.triggerIndex, "mssql"),
Name: GenerateMetricNameWithIndex(s.metadata.TriggerIndex, "mssql"),
},
Target: GetMetricTargetMili(s.metricType, s.metadata.targetValue),
Target: GetMetricTargetMili(s.metricType, s.metadata.TargetValue),
}

metricSpec := v2.MetricSpec{
Expand All @@ -241,7 +183,6 @@ func (s *mssqlScaler) GetMetricSpecForScaling(context.Context) []v2.MetricSpec {
return []v2.MetricSpec{metricSpec}
}

// GetMetricsAndActivity returns a value for a supported metric or an error if there is a problem getting the metric
func (s *mssqlScaler) GetMetricsAndActivity(ctx context.Context, metricName string) ([]external_metrics.ExternalMetricValue, bool, error) {
num, err := s.getQueryResult(ctx)
if err != nil {
Expand All @@ -250,13 +191,13 @@ func (s *mssqlScaler) GetMetricsAndActivity(ctx context.Context, metricName stri

metric := GenerateMetricInMili(metricName, num)

return []external_metrics.ExternalMetricValue{metric}, num > s.metadata.activationTargetValue, nil
return []external_metrics.ExternalMetricValue{metric}, num > s.metadata.ActivationTargetValue, nil
}

// getQueryResult returns the result of the scaler query
func (s *mssqlScaler) getQueryResult(ctx context.Context) (float64, error) {
var value float64
err := s.connection.QueryRowContext(ctx, s.metadata.query).Scan(&value)

err := s.connection.QueryRowContext(ctx, s.metadata.Query).Scan(&value)
switch {
case err == sql.ErrNoRows:
value = 0
Expand All @@ -268,7 +209,6 @@ func (s *mssqlScaler) getQueryResult(ctx context.Context) (float64, error) {
return value, nil
}

// Close closes the mssql database connections
func (s *mssqlScaler) Close(context.Context) error {
err := s.connection.Close()
if err != nil {
Expand Down
Loading
Loading