Skip to content

Commit

Permalink
Prometheus Scaler Add custom headers and custom auth support (#4208)
Browse files Browse the repository at this point in the history
  • Loading branch information
prashant-shahi authored Feb 24, 2023
1 parent 2a773a2 commit 38a0e1c
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 19 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ Here is an overview of all new **experimental** features:
- **Azure Pipelines Scaler**: New configuration parameter `requireAllDemands` to scale only if jobs request all demands provided by the scaling definition ([#4138](https://github.com/kedacore/keda/issues/4138))
- **Hashicorp Vault**: Add support to secrets backend version 1 ([#2645](https://github.com/kedacore/keda/issues/2645))
- **Kafka Scaler**: Improve error logging for `GetBlock` method ([#4232](https://github.com/kedacore/keda/issues/4232))
- **Prometheus Scaler**: Add custom headers and custom auth support ([#4208](https://github.com/kedacore/keda/issues/4208))
- **RabbitMQ Scaler**: Add TLS support ([#967](https://github.com/kedacore/keda/issues/967))
- **Redis Scalers**: Add support to Redis 7 ([#4052](https://github.com/kedacore/keda/issues/4052))
- **Selenium Grid Scaler**: Add 'platformName' to selenium-grid scaler metadata structure ([#4038](https://github.com/kedacore/keda/issues/4038))
Expand All @@ -93,7 +94,7 @@ You can find all deprecations in [this overview](https://github.com/kedacore/ked

New deprecation(s):

- TODO
- **Prometheus Scaler**: `cortexOrgId` metadata deprecated in favor of custom headers ([#4208](https://github.com/kedacore/keda/issues/4208))

### Other

Expand Down
11 changes: 11 additions & 0 deletions pkg/scalers/authentication/authentication_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,17 @@ func GetAuthConfigs(triggerMetadata, authParams map[string]string) (out *AuthMet

out.Key = authParams["key"]
out.EnableTLS = true
case CustomAuthType:
if len(authParams["customAuthHeader"]) == 0 {
return nil, errors.New("no custom auth header given")
}
out.CustomAuthHeader = authParams["customAuthHeader"]

if len(authParams["customAuthValue"]) == 0 {
return nil, errors.New("no custom auth value given")
}
out.CustomAuthValue = authParams["customAuthValue"]
out.EnableCustomAuth = true
default:
return nil, fmt.Errorf("incorrect value for authMode is given: %s", t)
}
Expand Down
7 changes: 7 additions & 0 deletions pkg/scalers/authentication/authentication_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ const (
TLSAuthType Type = "tls"
// BearerAuthType is a auth type using a bearer token
BearerAuthType Type = "bearer"
// CustomAuthType is a auth type using a custom header
CustomAuthType Type = "custom"
)

// TransportType is type of http transport
Expand All @@ -39,6 +41,11 @@ type AuthMeta struct {
Cert string
Key string
CA string

// custom auth header
EnableCustomAuth bool
CustomAuthHeader string
CustomAuthValue string
}

type HTTPTransport struct {
Expand Down
4 changes: 2 additions & 2 deletions pkg/scalers/loki_scaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ func TestLokiScalerExecuteLogQLQuery(t *testing.T) {
}
}

func TestLokiScalerCortexHeader(t *testing.T) {
func TestLokiScalerTenantHeader(t *testing.T) {
testData := lokiQromQueryResultTestData{
name: "no values",
bodyStr: `{"data":{"result":[]}}`,
Expand All @@ -227,7 +227,7 @@ func TestLokiScalerCortexHeader(t *testing.T) {
}
tenantName := "Tenant1"
server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
reqHeader := request.Header.Get(promCortexHeaderKey)
reqHeader := request.Header.Get(tenantNameHeaderKey)
assert.Equal(t, reqHeader, tenantName)
writer.WriteHeader(testData.responseStatus)
if _, err := writer.Write([]byte(testData.bodyStr)); err != nil {
Expand Down
32 changes: 23 additions & 9 deletions pkg/scalers/prometheus_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ const (
promActivationThreshold = "activationThreshold"
promNamespace = "namespace"
promCortexScopeOrgID = "cortexOrgID"
promCortexHeaderKey = "X-Scope-OrgID"
promCustomHeaders = "customHeaders"
ignoreNullValues = "ignoreNullValues"
unsafeSsl = "unsafeSsl"
)
Expand All @@ -52,7 +52,7 @@ type prometheusMetadata struct {
prometheusAuth *authentication.AuthMeta
namespace string
scalerIndex int
cortexOrgID string
customHeaders map[string]string
// sometimes should consider there is an error we can accept
// default value is true/t, to ignore the null value return from prometheus
// change to false/f if can not accept prometheus return null values
Expand Down Expand Up @@ -157,7 +157,16 @@ func parsePrometheusMetadata(config *ScalerConfig) (meta *prometheusMetadata, er
}

if val, ok := config.TriggerMetadata[promCortexScopeOrgID]; ok && val != "" {
meta.cortexOrgID = val
return nil, fmt.Errorf("cortexOrgID is deprecated, please use customHeaders instead")
}

if val, ok := config.TriggerMetadata[promCustomHeaders]; ok && val != "" {
customHeaders, err := kedautil.ParseStringList(val)
if err != nil {
return nil, fmt.Errorf("error parsing %s: %w", promCustomHeaders, err)
}

meta.customHeaders = customHeaders
}

meta.ignoreNullValues = defaultIgnoreNullValues
Expand Down Expand Up @@ -225,14 +234,19 @@ func (s *prometheusScaler) ExecutePromQuery(ctx context.Context) (float64, error
return -1, err
}

if s.metadata.prometheusAuth != nil && s.metadata.prometheusAuth.EnableBearerAuth {
req.Header.Add("Authorization", authentication.GetBearerToken(s.metadata.prometheusAuth))
} else if s.metadata.prometheusAuth != nil && s.metadata.prometheusAuth.EnableBasicAuth {
req.SetBasicAuth(s.metadata.prometheusAuth.Username, s.metadata.prometheusAuth.Password)
for headerName, headerValue := range s.metadata.customHeaders {
req.Header.Add(headerName, headerValue)
}

if s.metadata.cortexOrgID != "" {
req.Header.Add(promCortexHeaderKey, s.metadata.cortexOrgID)
switch {
case s.metadata.prometheusAuth == nil:
break
case s.metadata.prometheusAuth.EnableBearerAuth:
req.Header.Set("Authorization", authentication.GetBearerToken(s.metadata.prometheusAuth))
case s.metadata.prometheusAuth.EnableBasicAuth:
req.SetBasicAuth(s.metadata.prometheusAuth.Username, s.metadata.prometheusAuth.Password)
case s.metadata.prometheusAuth.EnableCustomAuth:
req.Header.Set(s.metadata.prometheusAuth.CustomAuthHeader, s.metadata.prometheusAuth.CustomAuthValue)
}

r, err := s.httpClient.Do(req)
Expand Down
34 changes: 27 additions & 7 deletions pkg/scalers/prometheus_scaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,14 @@ var testPromMetadata = []parsePrometheusMetadataTestData{
{map[string]string{"serverAddress": "http://localhost:9090", "metricName": "http_requests_total", "threshold": "100", "query": ""}, true},
// ignoreNullValues with wrong value
{map[string]string{"serverAddress": "http://localhost:9090", "metricName": "http_requests_total", "threshold": "100", "query": "up", "ignoreNullValues": "xxxx"}, true},

// unsafeSsl
{map[string]string{"serverAddress": "https://localhost:9090", "metricName": "http_requests_total", "threshold": "100", "query": "up", "unsafeSsl": "true"}, false},
// customHeaders
{map[string]string{"serverAddress": "http://localhost:9090", "metricName": "http_requests_total", "threshold": "100", "query": "up", "customHeaders": "key1=value1,key2=value2"}, false},
// customHeaders with wrong format
{map[string]string{"serverAddress": "http://localhost:9090", "metricName": "http_requests_total", "threshold": "100", "query": "up", "customHeaders": "key1=value1,key2"}, true},
// deprecated cortexOrgID
{map[string]string{"serverAddress": "http://localhost:9090", "metricName": "http_requests_total", "threshold": "100", "query": "up", "cortexOrgID": "my-org"}, true},
}

var prometheusMetricIdentifiers = []prometheusMetricIdentifier{
Expand Down Expand Up @@ -82,6 +88,12 @@ var testPrometheusAuthMetadata = []prometheusAuthMetadataTestData{
{map[string]string{"serverAddress": "http://localhost:9090", "metricName": "http_requests_total", "threshold": "100", "query": "up", "authModes": "tls, basic"}, map[string]string{"ca": "caaa", "cert": "ceert", "key": "keey", "username": "user", "password": "pass"}, false},

{map[string]string{"serverAddress": "http://localhost:9090", "metricName": "http_requests_total", "threshold": "100", "query": "up", "authModes": "tls,basic"}, map[string]string{"username": "user", "password": "pass"}, true},
// success custom auth
{map[string]string{"serverAddress": "http://localhost:9090", "metricName": "http_requests_total", "threshold": "100", "query": "up", "authModes": "custom"}, map[string]string{"customAuthHeader": "header", "customAuthValue": "value"}, false},
// fail custom auth with no customAuthHeader
{map[string]string{"serverAddress": "http://localhost:9090", "metricName": "http_requests_total", "threshold": "100", "query": "up", "authModes": "custom"}, map[string]string{"customAuthHeader": ""}, true},
// fail custom auth with no customAuthValue
{map[string]string{"serverAddress": "http://localhost:9090", "metricName": "http_requests_total", "threshold": "100", "query": "up", "authModes": "custom"}, map[string]string{"customAuthValue": ""}, true},
}

func TestPrometheusParseMetadata(t *testing.T) {
Expand Down Expand Up @@ -129,7 +141,8 @@ func TestPrometheusScalerAuthParams(t *testing.T) {
if err == nil {
if (meta.prometheusAuth.EnableBearerAuth && !strings.Contains(testData.metadata["authModes"], "bearer")) ||
(meta.prometheusAuth.EnableBasicAuth && !strings.Contains(testData.metadata["authModes"], "basic")) ||
(meta.prometheusAuth.EnableTLS && !strings.Contains(testData.metadata["authModes"], "tls")) {
(meta.prometheusAuth.EnableTLS && !strings.Contains(testData.metadata["authModes"], "tls")) ||
(meta.prometheusAuth.EnableCustomAuth && !strings.Contains(testData.metadata["authModes"], "custom")) {
t.Error("wrong auth mode detected")
}
}
Expand Down Expand Up @@ -300,7 +313,7 @@ func TestPrometheusScalerExecutePromQuery(t *testing.T) {
}
}

func TestPrometheusScalerCortexHeader(t *testing.T) {
func TestPrometheusScalerCustomHeaders(t *testing.T) {
testData := prometheusQromQueryResultTestData{
name: "no values",
bodyStr: `{"data":{"result":[]}}`,
Expand All @@ -309,10 +322,17 @@ func TestPrometheusScalerCortexHeader(t *testing.T) {
isError: false,
ignoreNullValues: true,
}
cortexOrgValue := "my-org"
customHeadersValue := map[string]string{
"X-Client-Id": "cid",
"X-Tenant-Id": "tid",
"X-Organization-Token": "oid",
}
server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
reqHeader := request.Header.Get(promCortexHeaderKey)
assert.Equal(t, reqHeader, cortexOrgValue)
for headerName, headerValue := range customHeadersValue {
reqHeader := request.Header.Get(headerName)
assert.Equal(t, reqHeader, headerValue)
}

writer.WriteHeader(testData.responseStatus)
if _, err := writer.Write([]byte(testData.bodyStr)); err != nil {
t.Fatal(err)
Expand All @@ -322,7 +342,7 @@ func TestPrometheusScalerCortexHeader(t *testing.T) {
scaler := prometheusScaler{
metadata: &prometheusMetadata{
serverAddress: server.URL,
cortexOrgID: cortexOrgValue,
customHeaders: customHeadersValue,
ignoreNullValues: testData.ignoreNullValues,
},
httpClient: http.DefaultClient,
Expand Down
22 changes: 22 additions & 0 deletions pkg/util/parse_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,25 @@ func ParseInt32List(pattern string) ([]int32, error) {
}
return parsed, nil
}

func ParseStringList(pattern string) (map[string]string, error) {
parsed := make(map[string]string)
pattern = strings.TrimSpace(pattern)
if pattern == "" {
return parsed, nil
}
pairs := strings.Split(pattern, ",")
for _, pair := range pairs {
keyvalue := strings.Split(pair, "=")
if len(keyvalue) != 2 {
return nil, fmt.Errorf("error in key-value syntax, got '%s'", pair)
}
key := strings.TrimSpace(keyvalue[0])
value := strings.TrimSpace(keyvalue[1])
if _, ok := parsed[key]; ok {
return nil, fmt.Errorf("duplicate key found: %s", key)
}
parsed[key] = value
}
return parsed, nil
}
43 changes: 43 additions & 0 deletions pkg/util/parse_string_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,46 @@ func TestParseint32List(t *testing.T) {
})
}
}

func TestParseStringList(t *testing.T) {
testData := []struct {
name string
pattern string
exp map[string]string
isError bool
}{
{"success, no key-value", "", map[string]string{}, false},
{"success, one key, no value", "key1=", map[string]string{"key1": ""}, false},
{"success, one key, no value, with spaces", "key1 = ", map[string]string{"key1": ""}, false},
{"success, one pair", "key1=value1", map[string]string{"key1": "value1"}, false},
{"success, one pair with spaces", "key1 = value1", map[string]string{"key1": "value1"}, false},
{"success, one pair with spaces and no value", "key1 = ", map[string]string{"key1": ""}, false},
{"success, two keys, no value", "key1=,key2=", map[string]string{"key1": "", "key2": ""}, false},
{"success, two keys, no value, with spaces", "key1 = , key2 = ", map[string]string{"key1": "", "key2": ""}, false},
{"success, two pairs", "key1=value1,key2=value2", map[string]string{"key1": "value1", "key2": "value2"}, false},
{"success, two pairs with spaces", "key1 = value1, key2 = value2", map[string]string{"key1": "value1", "key2": "value2"}, false},
{"failure, one key", "key1", nil, true},
{"failure, duplicate keys", "key1=value1,key1=value2", nil, true},
{"failure, one key ending with two successive equals to", "key1==", nil, true},
{"failure, one valid pair and invalid one key", "key1=value1,key2", nil, true},
{"failure, two valid pairs and invalid two keys", "key1=value1,key2=value2,key3,key4", nil, true},
}

for _, tt := range testData {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseStringList(tt.pattern)

if err != nil && !tt.isError {
t.Errorf("Expected no error but got %s\n", err)
}

if err == nil && tt.isError {
t.Errorf("Expected error but got %s\n", err)
}

if !reflect.DeepEqual(tt.exp, got) {
t.Errorf("Expected %v but got %v\n", tt.exp, got)
}
})
}
}

0 comments on commit 38a0e1c

Please sign in to comment.