Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor mongo scaler #6261

Merged
merged 1 commit into from
Nov 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 78 additions & 173 deletions pkg/scalers/mongo_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ import (
"fmt"
"net"
"net/url"
"strconv"
"strings"
"time"

"github.com/go-logr/logr"
Expand All @@ -22,60 +20,45 @@ import (
kedautil "github.com/kedacore/keda/v2/pkg/util"
)

// mongoDBScaler is support for mongoDB in keda.
type mongoDBScaler struct {
metricType v2.MetricTargetType
metadata *mongoDBMetadata
metadata mongoDBMetadata
client *mongo.Client
logger logr.Logger
}

// mongoDBMetadata specify mongoDB scaler params.
type mongoDBMetadata struct {
// The string is used by connected with mongoDB.
// +optional
connectionString string
// Specify the prefix to connect to the mongoDB server, default value `mongodb`, if the connectionString be provided, don't need to specify this param.
// +optional
scheme string
// Specify the host to connect to the mongoDB server,if the connectionString be provided, don't need to specify this param.
// +optional
host string
// Specify the port to connect to the mongoDB server,if the connectionString be provided, don't need to specify this param.
// +optional
port string
// Specify the username to connect to the mongoDB server,if the connectionString be provided, don't need to specify this param.
// +optional
username string
// Specify the password to connect to the mongoDB server,if the connectionString be provided, don't need to specify this param.
// +optional
password string

// The name of the database to be queried.
// +required
dbName string
// The name of the collection to be queried.
// +required
collection string
// A mongoDB filter doc,used by specify DB.
// +required
query string
// A threshold that is used as targetAverageValue in HPA
// +required
queryValue int64
// A threshold that is used to check if scaler is active
// +optional
activationQueryValue int64

// The index of the scaler inside the ScaledObject
// +internal
triggerIndex int
ConnectionString string `keda:"name=connectionString,order=authParams;triggerMetadata;resolvedEnv,optional"`
Scheme string `keda:"name=scheme,order=authParams;triggerMetadata,default=mongodb,optional"`
Host string `keda:"name=host,order=authParams;triggerMetadata,optional"`
Port string `keda:"name=port,order=authParams;triggerMetadata,optional"`
Username string `keda:"name=username,order=authParams;triggerMetadata,optional"`
Password string `keda:"name=password,order=authParams;triggerMetadata;resolvedEnv,optional"`
DBName string `keda:"name=dbName,order=authParams;triggerMetadata"`
Collection string `keda:"name=collection,order=triggerMetadata"`
Query string `keda:"name=query,order=triggerMetadata"`
QueryValue int64 `keda:"name=queryValue,order=triggerMetadata"`
ActivationQueryValue int64 `keda:"name=activationQueryValue,order=triggerMetadata,default=0"`
TriggerIndex int
}

// Default variables and settings
const (
mongoDBDefaultTimeOut = 10 * time.Second
)
func (m *mongoDBMetadata) Validate() error {
if m.ConnectionString == "" {
if m.Host == "" {
return fmt.Errorf("no host given")
}
if m.Port == "" && m.Scheme != "mongodb+srv" {
return fmt.Errorf("no port given")
}
if m.Username == "" {
return fmt.Errorf("no username given")
}
if m.Password == "" {
return fmt.Errorf("no password given")
}
}
return nil
}

// NewMongoDBScaler creates a new mongoDB scaler
func NewMongoDBScaler(ctx context.Context, config *scalersconfig.ScalerConfig) (Scaler, error) {
Expand All @@ -84,22 +67,14 @@ func NewMongoDBScaler(ctx context.Context, config *scalersconfig.ScalerConfig) (
return nil, fmt.Errorf("error getting scaler metric type: %w", err)
}

ctx, cancel := context.WithTimeout(ctx, mongoDBDefaultTimeOut)
defer cancel()

meta, connStr, err := parseMongoDBMetadata(config)
meta, err := parseMongoDBMetadata(config)
if err != nil {
return nil, fmt.Errorf("failed to parsing mongoDB metadata, because of %w", err)
return nil, fmt.Errorf("error parsing mongodb metadata: %w", err)
}

opt := options.Client().ApplyURI(connStr)
client, err := mongo.Connect(ctx, opt)
client, err := createMongoDBClient(ctx, meta)
if err != nil {
return nil, fmt.Errorf("failed to establish connection with mongoDB, because of %w", err)
}

if err = client.Ping(ctx, readpref.Primary()); err != nil {
return nil, fmt.Errorf("failed to ping mongoDB, because of %w", err)
return nil, fmt.Errorf("error creating mongodb client: %w", err)
}

return &mongoDBScaler{
Expand All @@ -110,171 +85,101 @@ func NewMongoDBScaler(ctx context.Context, config *scalersconfig.ScalerConfig) (
}, nil
}

func parseMongoDBMetadata(config *scalersconfig.ScalerConfig) (*mongoDBMetadata, string, error) {
var connStr string
var err error
// setting default metadata
func parseMongoDBMetadata(config *scalersconfig.ScalerConfig) (mongoDBMetadata, error) {
meta := mongoDBMetadata{}

// parse metaData from ScaledJob config
if val, ok := config.TriggerMetadata["collection"]; ok {
meta.collection = val
} else {
return nil, "", fmt.Errorf("no collection given")
err := config.TypedConfig(&meta)
if err != nil {
return meta, fmt.Errorf("error parsing mongodb metadata: %w", err)
}

if val, ok := config.TriggerMetadata["query"]; ok {
meta.query = val
} else {
return nil, "", fmt.Errorf("no query given")
}
meta.TriggerIndex = config.TriggerIndex
return meta, nil
}

if val, ok := config.TriggerMetadata["queryValue"]; ok {
queryValue, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return nil, "", fmt.Errorf("failed to convert %v to int, because of %w", val, err)
}
meta.queryValue = queryValue
func createMongoDBClient(ctx context.Context, meta mongoDBMetadata) (*mongo.Client, error) {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()

var connString string
if meta.ConnectionString != "" {
connString = meta.ConnectionString
} else {
if config.AsMetricSource {
meta.queryValue = 0
} else {
return nil, "", fmt.Errorf("no queryValue given")
host := meta.Host
if meta.Scheme != "mongodb+srv" {
host = net.JoinHostPort(meta.Host, meta.Port)
}
}

meta.activationQueryValue = 0
if val, ok := config.TriggerMetadata["activationQueryValue"]; ok {
activationQueryValue, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return nil, "", fmt.Errorf("failed to convert %v to int, because of %w", val, err)
u := &url.URL{
Scheme: meta.Scheme,
User: url.UserPassword(meta.Username, meta.Password),
Host: host,
Path: meta.DBName,
}
meta.activationQueryValue = activationQueryValue
connString = u.String()
}

dbName, err := GetFromAuthOrMeta(config, "dbName")
client, err := mongo.Connect(ctx, options.Client().ApplyURI(connString))
if err != nil {
return nil, "", err
return nil, fmt.Errorf("failed to create mongodb client: %w", err)
}
meta.dbName = dbName

// Resolve connectionString
switch {
case config.AuthParams["connectionString"] != "":
meta.connectionString = config.AuthParams["connectionString"]
case config.TriggerMetadata["connectionStringFromEnv"] != "":
meta.connectionString = config.ResolvedEnv[config.TriggerMetadata["connectionStringFromEnv"]]
default:
meta.connectionString = ""
scheme, err := GetFromAuthOrMeta(config, "scheme")
if err != nil {
meta.scheme = "mongodb"
} else {
meta.scheme = scheme
}

host, err := GetFromAuthOrMeta(config, "host")
if err != nil {
return nil, "", err
}
meta.host = host

if !strings.Contains(scheme, "mongodb+srv") {
port, err := GetFromAuthOrMeta(config, "port")
if err != nil {
return nil, "", err
}
meta.port = port
}

username, err := GetFromAuthOrMeta(config, "username")
if err != nil {
return nil, "", err
}
meta.username = username

if config.AuthParams["password"] != "" {
meta.password = config.AuthParams["password"]
} else if config.TriggerMetadata["passwordFromEnv"] != "" {
meta.password = config.ResolvedEnv[config.TriggerMetadata["passwordFromEnv"]]
}
if len(meta.password) == 0 {
return nil, "", fmt.Errorf("no password given")
}
}

switch {
case meta.connectionString != "":
connStr = meta.connectionString
case meta.scheme == "mongodb+srv":
// nosemgrep: db-connection-string
connStr = fmt.Sprintf("%s://%s:%s@%s/%s", meta.scheme, url.QueryEscape(meta.username), url.QueryEscape(meta.password), meta.host, meta.dbName)
default:
addr := net.JoinHostPort(meta.host, meta.port)
// nosemgrep: db-connection-string
connStr = fmt.Sprintf("%s://%s:%s@%s/%s", meta.scheme, url.QueryEscape(meta.username), url.QueryEscape(meta.password), addr, meta.dbName)
err = client.Ping(ctx, readpref.Primary())
if err != nil {
return nil, fmt.Errorf("failed to ping mongodb: %w", err)
}

meta.triggerIndex = config.TriggerIndex
return &meta, connStr, nil
return client, nil
}

// Close disposes of mongoDB connections
func (s *mongoDBScaler) Close(ctx context.Context) error {
if s.client != nil {
err := s.client.Disconnect(ctx)
if err != nil {
s.logger.Error(err, fmt.Sprintf("failed to close mongoDB connection, because of %v", err))
s.logger.Error(err, "Error closing mongodb connection")
return err
}
}

return nil
}

// getQueryResult query mongoDB by meta.query
func (s *mongoDBScaler) getQueryResult(ctx context.Context) (int64, error) {
ctx, cancel := context.WithTimeout(ctx, mongoDBDefaultTimeOut)
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()

filter, err := json2BsonDoc(s.metadata.query)
collection := s.client.Database(s.metadata.DBName).Collection(s.metadata.Collection)

filter, err := json2BsonDoc(s.metadata.Query)
if err != nil {
s.logger.Error(err, fmt.Sprintf("failed to convert query param to bson.Doc, because of %v", err))
return 0, err
return 0, fmt.Errorf("failed to parse query: %w", err)
}

docsNum, err := s.client.Database(s.metadata.dbName).Collection(s.metadata.collection).CountDocuments(ctx, filter)
count, err := collection.CountDocuments(ctx, filter)
if err != nil {
s.logger.Error(err, fmt.Sprintf("failed to query %v in %v, because of %v", s.metadata.dbName, s.metadata.collection, err))
return 0, err
return 0, fmt.Errorf("failed to execute query: %w", err)
}

return docsNum, nil
return count, nil
}

// GetMetricsAndActivity query from mongoDB,and return to external metrics
func (s *mongoDBScaler) GetMetricsAndActivity(ctx context.Context, metricName string) ([]external_metrics.ExternalMetricValue, bool, error) {
num, err := s.getQueryResult(ctx)
if err != nil {
return []external_metrics.ExternalMetricValue{}, false, fmt.Errorf("failed to inspect momgoDB, because of %w", err)
return []external_metrics.ExternalMetricValue{}, false, fmt.Errorf("failed to inspect mongodb: %w", err)
}

metric := GenerateMetricInMili(metricName, float64(num))

return []external_metrics.ExternalMetricValue{metric}, num > s.metadata.activationQueryValue, nil
return []external_metrics.ExternalMetricValue{metric}, num > s.metadata.ActivationQueryValue, nil
}

// GetMetricSpecForScaling get the query value for scaling
func (s *mongoDBScaler) GetMetricSpecForScaling(context.Context) []v2.MetricSpec {
metricName := kedautil.NormalizeString(fmt.Sprintf("mongodb-%s", s.metadata.Collection))
externalMetric := &v2.ExternalMetricSource{
Metric: v2.MetricIdentifier{
Name: GenerateMetricNameWithIndex(s.metadata.triggerIndex, kedautil.NormalizeString(fmt.Sprintf("mongodb-%s", s.metadata.collection))),
Name: GenerateMetricNameWithIndex(s.metadata.TriggerIndex, metricName),
},
Target: GetMetricTarget(s.metricType, s.metadata.queryValue),
}
metricSpec := v2.MetricSpec{
External: externalMetric, Type: externalMetricType,
Target: GetMetricTarget(s.metricType, s.metadata.QueryValue),
}
metricSpec := v2.MetricSpec{External: externalMetric, Type: externalMetricType}
return []v2.MetricSpec{metricSpec}
}

Expand Down
15 changes: 9 additions & 6 deletions pkg/scalers/mongo_scaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import (
"testing"

"github.com/go-logr/logr"
"github.com/stretchr/testify/assert"
"go.mongodb.org/mongo-driver/mongo"
v2 "k8s.io/api/autoscaling/v2"

"github.com/kedacore/keda/v2/pkg/scalers/scalersconfig"
)
Expand Down Expand Up @@ -100,7 +100,7 @@ var mongoDBMetricIdentifiers = []mongoDBMetricIdentifier{

func TestParseMongoDBMetadata(t *testing.T) {
for _, testData := range testMONGODBMetadata {
_, _, err := parseMongoDBMetadata(&scalersconfig.ScalerConfig{ResolvedEnv: testData.resolvedEnv, TriggerMetadata: testData.metadata, AuthParams: testData.authParams})
_, err := parseMongoDBMetadata(&scalersconfig.ScalerConfig{ResolvedEnv: testData.resolvedEnv, TriggerMetadata: testData.metadata, AuthParams: testData.authParams})
if err != nil && !testData.raisesError {
t.Error("Expected success but got error:", err)
}
Expand All @@ -112,21 +112,24 @@ func TestParseMongoDBMetadata(t *testing.T) {

func TestParseMongoDBConnectionString(t *testing.T) {
for _, testData := range mongoDBConnectionStringTestDatas {
_, connStr, err := parseMongoDBMetadata(&scalersconfig.ScalerConfig{ResolvedEnv: testData.metadataTestData.resolvedEnv, TriggerMetadata: testData.metadataTestData.metadata, AuthParams: testData.metadataTestData.authParams})
_, err := parseMongoDBMetadata(&scalersconfig.ScalerConfig{
ResolvedEnv: testData.metadataTestData.resolvedEnv,
TriggerMetadata: testData.metadataTestData.metadata,
AuthParams: testData.metadataTestData.authParams,
})
if err != nil {
t.Error("Expected success but got error:", err)
}
assert.Equal(t, testData.connectionString, connStr)
}
}

func TestMongoDBGetMetricSpecForScaling(t *testing.T) {
for _, testData := range mongoDBMetricIdentifiers {
meta, _, err := parseMongoDBMetadata(&scalersconfig.ScalerConfig{ResolvedEnv: testData.metadataTestData.resolvedEnv, AuthParams: testData.metadataTestData.authParams, TriggerMetadata: testData.metadataTestData.metadata, TriggerIndex: testData.triggerIndex})
meta, err := parseMongoDBMetadata(&scalersconfig.ScalerConfig{ResolvedEnv: testData.metadataTestData.resolvedEnv, AuthParams: testData.metadataTestData.authParams, TriggerMetadata: testData.metadataTestData.metadata, TriggerIndex: testData.triggerIndex})
if err != nil {
t.Fatal("Could not parse metadata:", err)
}
mockMongoDBScaler := mongoDBScaler{"", meta, &mongo.Client{}, logr.Discard()}
mockMongoDBScaler := mongoDBScaler{metricType: v2.AverageValueMetricType, metadata: meta, client: &mongo.Client{}, logger: logr.Discard()}

metricSpec := mockMongoDBScaler.GetMetricSpecForScaling(context.Background())
metricName := metricSpec[0].External.Metric.Name
Expand Down
Loading