From 017d2c0d4e987efb5ee173e25a0d8517d5c38fa4 Mon Sep 17 00:00:00 2001 From: Ahmed ElSayed Date: Tue, 19 May 2020 20:54:37 -0700 Subject: [PATCH] Refactor Azure Storage handling. Fixes #821 Signed-off-by: Ahmed ElSayed --- .gitignore | 3 +- Makefile | 2 +- cmd/manager/main.go | 2 +- pkg/apis/keda/v1alpha1/scaledobject_types.go | 2 +- pkg/handler/scale_jobs.go | 8 +- pkg/scalers/AzureStorage.go | 44 ----- .../{ => azure}/azure_aad_podidentity.go | 6 +- pkg/scalers/azure/azure_blob.go | 28 ++++ pkg/scalers/azure/azure_blob_test.go | 36 +++++ pkg/scalers/{ => azure}/azure_eventhub.go | 71 +++----- pkg/scalers/azure/azure_eventhub_test.go | 37 +++++ pkg/scalers/{ => azure}/azure_monitor.go | 56 ++++--- pkg/scalers/azure/azure_monitor_test.go | 51 ++++++ pkg/scalers/azure/azure_queue.go | 30 ++++ pkg/scalers/azure/azure_queue_test.go | 36 +++++ pkg/scalers/azure/azure_storage.go | 152 ++++++++++++++++++ pkg/scalers/azure/azure_storage_test.go | 63 ++++++++ pkg/scalers/azure_blob.go | 60 ------- pkg/scalers/azure_blob_scaler.go | 39 ++--- ...blob_test.go => azure_blob_scaler_test.go} | 40 +---- pkg/scalers/azure_eventhub_scaler.go | 46 +++--- ..._test.go => azure_eventhub_scaler_test.go} | 74 +++------ pkg/scalers/azure_monitor_scaler.go | 46 +++--- ...r_test.go => azure_monitor_scaler_test.go} | 50 +----- pkg/scalers/azure_queue.go | 59 ------- pkg/scalers/azure_queue_scaler.go | 5 +- ...eue_test.go => azure_queue_scaler_test.go} | 77 +-------- pkg/scalers/azure_servicebus_scaler.go | 3 +- pkg/scalers/mysql_scaler_test.go | 9 +- pkg/scalers/prometheus.go | 2 +- pkg/scalers/stan_scaler.go | 2 +- 31 files changed, 600 insertions(+), 539 deletions(-) delete mode 100644 pkg/scalers/AzureStorage.go rename pkg/scalers/{ => azure}/azure_aad_podidentity.go (84%) create mode 100644 pkg/scalers/azure/azure_blob.go create mode 100644 pkg/scalers/azure/azure_blob_test.go rename pkg/scalers/{ => azure}/azure_eventhub.go (56%) create mode 100644 pkg/scalers/azure/azure_eventhub_test.go rename pkg/scalers/{ => azure}/azure_monitor.go (78%) create mode 100644 pkg/scalers/azure/azure_monitor_test.go create mode 100644 pkg/scalers/azure/azure_queue.go create mode 100644 pkg/scalers/azure/azure_queue_test.go create mode 100644 pkg/scalers/azure/azure_storage.go create mode 100644 pkg/scalers/azure/azure_storage_test.go delete mode 100644 pkg/scalers/azure_blob.go rename pkg/scalers/{azure_blob_test.go => azure_blob_scaler_test.go} (68%) rename pkg/scalers/{azure_eventhub_test.go => azure_eventhub_scaler_test.go} (81%) rename pkg/scalers/{azure_monitor_test.go => azure_monitor_scaler_test.go} (66%) delete mode 100644 pkg/scalers/azure_queue.go rename pkg/scalers/{azure_queue_test.go => azure_queue_scaler_test.go} (55%) diff --git a/.gitignore b/.gitignore index 05afb243f6c..70c6e755146 100644 --- a/.gitignore +++ b/.gitignore @@ -343,4 +343,5 @@ config .vscode # GO Vendor -vendor \ No newline at end of file +vendor +cover.out diff --git a/Makefile b/Makefile index 659d2c7fe3d..797df7355e8 100644 --- a/Makefile +++ b/Makefile @@ -27,7 +27,7 @@ all: test build ################################################## .PHONY: test test: - go test ./... + go test ./... -covermode=atomic -coverprofile cover.out .PHONY: e2e-test e2e-test: diff --git a/cmd/manager/main.go b/cmd/manager/main.go index 425f7511379..5b3222780d1 100644 --- a/cmd/manager/main.go +++ b/cmd/manager/main.go @@ -12,8 +12,8 @@ import ( "k8s.io/client-go/rest" "github.com/kedacore/keda/pkg/apis" - "github.com/kedacore/keda/version" "github.com/kedacore/keda/pkg/controller" + "github.com/kedacore/keda/version" "github.com/operator-framework/operator-sdk/pkg/k8sutil" kubemetrics "github.com/operator-framework/operator-sdk/pkg/kube-metrics" diff --git a/pkg/apis/keda/v1alpha1/scaledobject_types.go b/pkg/apis/keda/v1alpha1/scaledobject_types.go index f7cef539251..774387f4605 100644 --- a/pkg/apis/keda/v1alpha1/scaledobject_types.go +++ b/pkg/apis/keda/v1alpha1/scaledobject_types.go @@ -79,7 +79,7 @@ type ScaleTriggers struct { // +optional type ScaledObjectStatus struct { // +optional - LastActiveTime *metav1.Time `json:"lastActiveTime,omitempty"` + LastActiveTime *metav1.Time `json:"lastActiveTime,omitempty"` // +optional // +listType ExternalMetricNames []string `json:"externalMetricNames,omitempty"` diff --git a/pkg/handler/scale_jobs.go b/pkg/handler/scale_jobs.go index acef2e0979b..9089f573fca 100644 --- a/pkg/handler/scale_jobs.go +++ b/pkg/handler/scale_jobs.go @@ -61,11 +61,11 @@ func (h *ScaleHandler) createJobs(scaledObject *kedav1alpha1.ScaledObject, scale GenerateName: scaledObject.GetName() + "-", Namespace: scaledObject.GetNamespace(), Labels: map[string]string{ - "app.kubernetes.io/name": scaledObject.GetName(), - "app.kubernetes.io/version": version.Version, - "app.kubernetes.io/part-of": scaledObject.GetName(), + "app.kubernetes.io/name": scaledObject.GetName(), + "app.kubernetes.io/version": version.Version, + "app.kubernetes.io/part-of": scaledObject.GetName(), "app.kubernetes.io/managed-by": "keda-operator", - "scaledobject": scaledObject.GetName(), + "scaledobject": scaledObject.GetName(), }, }, Spec: *scaledObject.Spec.JobTargetRef.DeepCopy(), diff --git a/pkg/scalers/AzureStorage.go b/pkg/scalers/AzureStorage.go deleted file mode 100644 index 629ff7b9105..00000000000 --- a/pkg/scalers/AzureStorage.go +++ /dev/null @@ -1,44 +0,0 @@ -package scalers - -import ( - "errors" - "strings" -) - -/* ParseAzureStorageConnectionString parses a storage account connection string into (endpointProtocol, accountName, key, endpointSuffix) - Connection string should be in following format: - DefaultEndpointsProtocol=https;AccountName=yourStorageAccountName;AccountKey=yourStorageAccountKey;EndpointSuffix=core.windows.net -*/ -func ParseAzureStorageConnectionString(connectionString string) (string, string, string, string, error) { - parts := strings.Split(connectionString, ";") - - var endpointProtocol, name, key, endpointSuffix string - for _, v := range parts { - if strings.HasPrefix(v, "DefaultEndpointsProtocol") { - protocolParts := strings.SplitN(v, "=", 2) - if len(protocolParts) == 2 { - endpointProtocol = protocolParts[1] - } - } else if strings.HasPrefix(v, "AccountName") { - accountParts := strings.SplitN(v, "=", 2) - if len(accountParts) == 2 { - name = accountParts[1] - } - } else if strings.HasPrefix(v, "AccountKey") { - keyParts := strings.SplitN(v, "=", 2) - if len(keyParts) == 2 { - key = keyParts[1] - } - } else if strings.HasPrefix(v, "EndpointSuffix") { - suffixParts := strings.SplitN(v, "=", 2) - if len(suffixParts) == 2 { - endpointSuffix = suffixParts[1] - } - } - } - if name == "" || key == "" || endpointProtocol == "" || endpointSuffix == "" { - return "", "", "", "", errors.New("Can't parse storage connection string") - } - - return endpointProtocol, name, key, endpointSuffix, nil -} diff --git a/pkg/scalers/azure_aad_podidentity.go b/pkg/scalers/azure/azure_aad_podidentity.go similarity index 84% rename from pkg/scalers/azure_aad_podidentity.go rename to pkg/scalers/azure/azure_aad_podidentity.go index ea758fb9577..337c331c6e9 100644 --- a/pkg/scalers/azure_aad_podidentity.go +++ b/pkg/scalers/azure/azure_aad_podidentity.go @@ -1,4 +1,4 @@ -package scalers +package azure import ( "encoding/json" @@ -13,11 +13,11 @@ const ( msiURL = "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=%s" ) -func getAzureADPodIdentityToken(uri string) (AADToken, error) { +func GetAzureADPodIdentityToken(audience string) (AADToken, error) { var token AADToken - resp, err := http.Get(fmt.Sprintf(msiURL, url.QueryEscape(uri))) + resp, err := http.Get(fmt.Sprintf(msiURL, url.QueryEscape(audience))) if err != nil { return token, err } diff --git a/pkg/scalers/azure/azure_blob.go b/pkg/scalers/azure/azure_blob.go new file mode 100644 index 00000000000..8cedb8eea45 --- /dev/null +++ b/pkg/scalers/azure/azure_blob.go @@ -0,0 +1,28 @@ +package azure + +import ( + "context" + "github.com/Azure/azure-storage-blob-go/azblob" +) + +// GetAzureBlobListLength returns the count of the blobs in blob container in int +func GetAzureBlobListLength(ctx context.Context, podIdentity string, connectionString, blobContainerName string, accountName string, blobDelimiter string, blobPrefix string) (int, error) { + credential, endpoint, err := ParseAzureStorageBlobConnection(podIdentity, connectionString, accountName) + if err != nil { + return -1, err + } + + listBlobsSegmentOptions := azblob.ListBlobsSegmentOptions{ + Prefix: blobPrefix, + } + p := azblob.NewPipeline(credential, azblob.PipelineOptions{}) + serviceURL := azblob.NewServiceURL(*endpoint, p) + containerURL := serviceURL.NewContainerURL(blobContainerName) + + props, err := containerURL.ListBlobsHierarchySegment(ctx, azblob.Marker{}, blobDelimiter, listBlobsSegmentOptions) + if err != nil { + return -1, err + } + + return len(props.Segment.BlobItems), nil +} diff --git a/pkg/scalers/azure/azure_blob_test.go b/pkg/scalers/azure/azure_blob_test.go new file mode 100644 index 00000000000..471e1a76229 --- /dev/null +++ b/pkg/scalers/azure/azure_blob_test.go @@ -0,0 +1,36 @@ +package azure + +import ( + "context" + "strings" + "testing" +) + +func TestGetBlobLength(t *testing.T) { + length, err := GetAzureBlobListLength(context.TODO(), "", "", "blobContainerName", "", "", "") + if length != -1 { + t.Error("Expected length to be -1, but got", length) + } + + if err == nil { + t.Error("Expected error for empty connection string, but got nil") + } + + if !strings.Contains(err.Error(), "parse storage connection string") { + t.Error("Expected error to contain parsing error message, but got", err.Error()) + } + + length, err = GetAzureBlobListLength(context.TODO(), "", "DefaultEndpointsProtocol=https;AccountName=name;AccountKey=key==;EndpointSuffix=core.windows.net", "blobContainerName", "", "", "") + + if length != -1 { + t.Error("Expected length to be -1, but got", length) + } + + if err == nil { + t.Error("Expected error for empty connection string, but got nil") + } + + if !strings.Contains(err.Error(), "illegal base64") { + t.Error("Expected error to contain base64 error message, but got", err.Error()) + } +} diff --git a/pkg/scalers/azure_eventhub.go b/pkg/scalers/azure/azure_eventhub.go similarity index 56% rename from pkg/scalers/azure_eventhub.go rename to pkg/scalers/azure/azure_eventhub.go index cbcde83dbdb..03483f82abc 100644 --- a/pkg/scalers/azure_eventhub.go +++ b/pkg/scalers/azure/azure_eventhub.go @@ -1,4 +1,4 @@ -package scalers +package azure import ( "bytes" @@ -13,11 +13,6 @@ import ( eventhub "github.com/Azure/azure-event-hubs-go" "github.com/Azure/azure-storage-blob-go/azblob" - "github.com/Azure/go-autorest/autorest/azure" -) - -const ( - environmentName = "AzurePublicCloud" ) type baseCheckpoint struct { @@ -43,29 +38,16 @@ type pythonCheckpoint struct { SequenceNumber int64 `json:"sequence_number"` } -// GetStorageCredentials returns azure env and storage credentials -func GetStorageCredentials(storageConnection string) (azure.Environment, *azblob.SharedKeyCredential, error) { - _, storageAccountName, storageAccountKey, _, err := ParseAzureStorageConnectionString(storageConnection) - if err != nil { - return azure.Environment{}, &azblob.SharedKeyCredential{}, fmt.Errorf("unable to parse connection string: %s", storageConnection) - } - - azureEnv, err := azure.EnvironmentFromName(environmentName) - if err != nil { - return azureEnv, nil, fmt.Errorf("could not get azure.Environment struct: %s", err) - } - - cred, err := azblob.NewSharedKeyCredential(storageAccountName, storageAccountKey) - if err != nil { - return azureEnv, nil, fmt.Errorf("could not prepare a blob storage credential: %s", err) - } - - return azureEnv, cred, nil +type EventHubInfo struct { + EventHubConnection string + EventHubConsumerGroup string + StorageConnection string + BlobContainer string } // GetEventHubClient returns eventhub client -func GetEventHubClient(connectionString string) (*eventhub.Hub, error) { - hub, err := eventhub.NewHubFromConnectionString(connectionString) +func GetEventHubClient(info EventHubInfo) (*eventhub.Hub, error) { + hub, err := eventhub.NewHubFromConnectionString(info.EventHubConnection) if err != nil { return nil, fmt.Errorf("failed to create hub client: %s", err) } @@ -74,39 +56,34 @@ func GetEventHubClient(connectionString string) (*eventhub.Hub, error) { } // GetCheckpointFromBlobStorage accesses Blob storage and gets checkpoint information of a partition -func GetCheckpointFromBlobStorage(ctx context.Context, partitionID string, eventHubMetadata EventHubMetadata) (Checkpoint, error) { - endpointProtocol, storageAccountName, _, endpointSuffix, err := ParseAzureStorageConnectionString(eventHubMetadata.storageConnection) +func GetCheckpointFromBlobStorage(ctx context.Context, info EventHubInfo, partitionID string) (Checkpoint, error) { + + blobCreds, storageEndpoint, err := ParseAzureStorageBlobConnection("none", info.StorageConnection, "") if err != nil { - return Checkpoint{}, fmt.Errorf("unable to parse storage connection string: %s", err) + return Checkpoint{}, err } - // Remove trailing spaces from endpointSuffix - endpointSuffix = strings.TrimSpace(endpointSuffix) - - eventHubNamespace, eventHubName, err := ParseAzureEventHubConnectionString(eventHubMetadata.eventHubConnection) + eventHubNamespace, eventHubName, err := ParseAzureEventHubConnectionString(info.EventHubConnection) if err != nil { - return Checkpoint{}, fmt.Errorf("unable to parse event hub connection string: %s", err) + return Checkpoint{}, err } // TODO: add more ways to read from different types of storage and read checkpoints/leases written in different JSON formats - var u *url.URL + var baseURL *url.URL // Checking blob store for C# and Java applications - if eventHubMetadata.blobContainer != "" { - // URL format - ://.blob./// - u, _ = url.Parse(fmt.Sprintf("%s://%s.blob.%s/%s/%s/%s", endpointProtocol, storageAccountName, endpointSuffix, eventHubMetadata.blobContainer, eventHubMetadata.eventHubConsumerGroup, partitionID)) + if info.BlobContainer != "" { + // URL format - /// + path, _ := url.Parse(fmt.Sprintf("/%s/%s/%s", info.BlobContainer, info.EventHubConsumerGroup, partitionID)) + baseURL = storageEndpoint.ResolveReference(path) } else { // Checking blob store for Azure functions - // URL format - ://.blob./azure-webjobs-eventhub//// - u, _ = url.Parse(fmt.Sprintf("%s://%s.blob.%s/azure-webjobs-eventhub/%s/%s/%s/%s", endpointProtocol, storageAccountName, endpointSuffix, eventHubNamespace, eventHubName, eventHubMetadata.eventHubConsumerGroup, partitionID)) - } - - _, cred, err := GetStorageCredentials(eventHubMetadata.storageConnection) - if err != nil { - return Checkpoint{}, fmt.Errorf("unable to get storage credentials: %s", err) + // URL format - /azure-webjobs-eventhub//// + path, _ := url.Parse(fmt.Sprintf("/azure-webjobs-eventhub/%s/%s/%s/%s", eventHubNamespace, eventHubName, info.EventHubConsumerGroup, partitionID)) + baseURL = storageEndpoint.ResolveReference(path) } // Create a BlockBlobURL object to a blob in the container. - blobURL := azblob.NewBlockBlobURL(*u, azblob.NewPipeline(cred, azblob.PipelineOptions{})) + blobURL := azblob.NewBlockBlobURL(*baseURL, azblob.NewPipeline(blobCreds, azblob.PipelineOptions{})) get, err := blobURL.Download(ctx, 0, 0, azblob.BlobAccessConditions{}, false) if err != nil { @@ -164,7 +141,7 @@ func ParseAzureEventHubConnectionString(connectionString string) (string, string } if eventHubNamespace == "" || eventHubName == "" { - return "", "", errors.New("Can't parse event hub connection string") + return "", "", errors.New("can't parse event hub connection string. Missing eventHubNamespace or eventHubName") } return eventHubNamespace, eventHubName, nil diff --git a/pkg/scalers/azure/azure_eventhub_test.go b/pkg/scalers/azure/azure_eventhub_test.go new file mode 100644 index 00000000000..c7e89fe4a46 --- /dev/null +++ b/pkg/scalers/azure/azure_eventhub_test.go @@ -0,0 +1,37 @@ +package azure + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +const csharpSdkCheckpoint = `{ + "Epoch": 123456, + "Offset": "test offset", + "Owner": "test owner", + "PartitionId": "test partitionId", + "SequenceNumber": 12345 + }` + +const pythonSdkCheckpoint = `{ + "epoch": 123456, + "offset": "test offset", + "owner": "test owner", + "partition_id": "test partitionId", + "sequence_number": 12345 + }` + +func TestGetCheckpoint(t *testing.T) { + cckp, err := getCheckpoint([]byte(csharpSdkCheckpoint)) + if err != nil { + t.Error(err) + } + + pckp, err := getCheckpoint([]byte(pythonSdkCheckpoint)) + if err != nil { + t.Error(err) + } + + assert.Equal(t, cckp, pckp) +} diff --git a/pkg/scalers/azure_monitor.go b/pkg/scalers/azure/azure_monitor.go similarity index 78% rename from pkg/scalers/azure_monitor.go rename to pkg/scalers/azure/azure_monitor.go index d1f0f4f965d..8dd65f6eb35 100644 --- a/pkg/scalers/azure_monitor.go +++ b/pkg/scalers/azure/azure_monitor.go @@ -1,4 +1,4 @@ -package scalers +package azure import ( "context" @@ -28,21 +28,34 @@ type azureExternalMetricRequest struct { ResourceGroup string } +type AzureMonitorInfo struct { + ResourceURI string + TenantID string + SubscriptionID string + ResourceGroupName string + Name string + Filter string + AggregationInterval string + AggregationType string + ClientID string + ClientPassword string +} + // GetAzureMetricValue returns the value of an Azure Monitor metric, rounded to the nearest int -func GetAzureMetricValue(ctx context.Context, metricMetadata *azureMonitorMetadata) (int32, error) { - client := createMetricsClient(metricMetadata) +func GetAzureMetricValue(ctx context.Context, info AzureMonitorInfo) (int32, error) { + client := createMetricsClient(info) - requestPtr, err := createMetricsRequest(metricMetadata) + requestPtr, err := createMetricsRequest(info) if err != nil { return -1, err } - return executeRequest(client, requestPtr) + return executeRequest(ctx, client, requestPtr) } -func createMetricsClient(metadata *azureMonitorMetadata) insights.MetricsClient { - client := insights.NewMetricsClient(metadata.subscriptionID) - config := auth.NewClientCredentialsConfig(metadata.clientID, metadata.clientPassword, metadata.tenantID) +func createMetricsClient(info AzureMonitorInfo) insights.MetricsClient { + client := insights.NewMetricsClient(info.SubscriptionID) + config := auth.NewClientCredentialsConfig(info.ClientID, info.ClientPassword, info.TenantID) authorizer, _ := config.Authorizer() client.Authorizer = authorizer @@ -50,22 +63,22 @@ func createMetricsClient(metadata *azureMonitorMetadata) insights.MetricsClient return client } -func createMetricsRequest(metadata *azureMonitorMetadata) (*azureExternalMetricRequest, error) { +func createMetricsRequest(info AzureMonitorInfo) (*azureExternalMetricRequest, error) { metricRequest := azureExternalMetricRequest{ - MetricName: metadata.name, - SubscriptionID: metadata.subscriptionID, - Aggregation: metadata.aggregationType, - Filter: metadata.filter, - ResourceGroup: metadata.resourceGroupName, + MetricName: info.Name, + SubscriptionID: info.SubscriptionID, + Aggregation: info.AggregationType, + Filter: info.Filter, + ResourceGroup: info.ResourceGroupName, } - resourceInfo := strings.Split(metadata.resourceURI, "/") + resourceInfo := strings.Split(info.ResourceURI, "/") metricRequest.ResourceProviderNamespace = resourceInfo[0] metricRequest.ResourceType = resourceInfo[1] metricRequest.ResourceName = resourceInfo[2] // if no timespan is provided, defaults to 5 minutes - timespan, err := formatTimeSpan(metadata.aggregationInterval) + timespan, err := formatTimeSpan(info.AggregationInterval) if err != nil { return nil, err } @@ -75,11 +88,10 @@ func createMetricsRequest(metadata *azureMonitorMetadata) (*azureExternalMetricR return &metricRequest, nil } -func executeRequest(client insights.MetricsClient, request *azureExternalMetricRequest) (int32, error) { - metricResponse, err := getAzureMetric(client, *request) +func executeRequest(ctx context.Context, client insights.MetricsClient, request *azureExternalMetricRequest) (int32, error) { + metricResponse, err := getAzureMetric(ctx, client, *request) if err != nil { - azureMonitorLog.Error(err, "error getting azure monitor metric") - return -1, fmt.Errorf("Error getting azure monitor metric %s: %s", request.MetricName, err.Error()) + return -1, fmt.Errorf("error getting azure monitor metric %s: %w", request.MetricName, err) } // casting drops everything after decimal, so round first @@ -88,7 +100,7 @@ func executeRequest(client insights.MetricsClient, request *azureExternalMetricR return metricValue, nil } -func getAzureMetric(client insights.MetricsClient, azMetricRequest azureExternalMetricRequest) (float64, error) { +func getAzureMetric(ctx context.Context, client insights.MetricsClient, azMetricRequest azureExternalMetricRequest) (float64, error) { err := azMetricRequest.validate() if err != nil { return -1, err @@ -97,7 +109,7 @@ func getAzureMetric(client insights.MetricsClient, azMetricRequest azureExternal metricResourceURI := azMetricRequest.metricResourceURI() klog.V(2).Infof("resource uri: %s", metricResourceURI) - metricResult, err := client.List(context.Background(), metricResourceURI, + metricResult, err := client.List(ctx, metricResourceURI, azMetricRequest.Timespan, nil, azMetricRequest.MetricName, azMetricRequest.Aggregation, nil, "", azMetricRequest.Filter, "", "") diff --git a/pkg/scalers/azure/azure_monitor_test.go b/pkg/scalers/azure/azure_monitor_test.go new file mode 100644 index 00000000000..453fa372d82 --- /dev/null +++ b/pkg/scalers/azure/azure_monitor_test.go @@ -0,0 +1,51 @@ +package azure + +import ( + "testing" + + "github.com/Azure/azure-sdk-for-go/services/preview/monitor/mgmt/2018-03-01/insights" +) + +type testExtractAzMonitorTestData struct { + testName string + isError bool + expectedValue float64 + metricRequest azureExternalMetricRequest + metricResult insights.Response +} + +var testExtractAzMonitordata = []testExtractAzMonitorTestData{ + {"nothing returned", true, -1, azureExternalMetricRequest{}, insights.Response{Value: &[]insights.Metric{}}}, + {"timeseries null", true, -1, azureExternalMetricRequest{}, insights.Response{Value: &[]insights.Metric{insights.Metric{Timeseries: nil}}}}, + {"timeseries empty", true, -1, azureExternalMetricRequest{}, insights.Response{Value: &[]insights.Metric{insights.Metric{Timeseries: &[]insights.TimeSeriesElement{}}}}}, + {"data nil", true, -1, azureExternalMetricRequest{}, insights.Response{Value: &[]insights.Metric{insights.Metric{Timeseries: &[]insights.TimeSeriesElement{insights.TimeSeriesElement{Data: nil}}}}}}, + {"data empty", true, -1, azureExternalMetricRequest{}, insights.Response{Value: &[]insights.Metric{insights.Metric{Timeseries: &[]insights.TimeSeriesElement{insights.TimeSeriesElement{Data: &[]insights.MetricValue{}}}}}}}, + {"Total Aggregation requested", false, 40, azureExternalMetricRequest{Aggregation: "Total"}, insights.Response{Value: &[]insights.Metric{insights.Metric{Timeseries: &[]insights.TimeSeriesElement{insights.TimeSeriesElement{Data: &[]insights.MetricValue{insights.MetricValue{Total: returnFloat64Ptr(40)}}}}}}}}, + {"Average Aggregation requested", false, 41, azureExternalMetricRequest{Aggregation: "Average"}, insights.Response{Value: &[]insights.Metric{insights.Metric{Timeseries: &[]insights.TimeSeriesElement{insights.TimeSeriesElement{Data: &[]insights.MetricValue{insights.MetricValue{Average: returnFloat64Ptr(41)}}}}}}}}, + {"Maximum Aggregation requested", false, 42, azureExternalMetricRequest{Aggregation: "Maximum"}, insights.Response{Value: &[]insights.Metric{insights.Metric{Timeseries: &[]insights.TimeSeriesElement{insights.TimeSeriesElement{Data: &[]insights.MetricValue{insights.MetricValue{Maximum: returnFloat64Ptr(42)}}}}}}}}, + {"Minimum Aggregation requested", false, 43, azureExternalMetricRequest{Aggregation: "Minimum"}, insights.Response{Value: &[]insights.Metric{insights.Metric{Timeseries: &[]insights.TimeSeriesElement{insights.TimeSeriesElement{Data: &[]insights.MetricValue{insights.MetricValue{Minimum: returnFloat64Ptr(43)}}}}}}}}, + {"Count Aggregation requested", false, 44, azureExternalMetricRequest{Aggregation: "Count"}, insights.Response{Value: &[]insights.Metric{insights.Metric{Timeseries: &[]insights.TimeSeriesElement{insights.TimeSeriesElement{Data: &[]insights.MetricValue{insights.MetricValue{Count: returnint64Ptr(44)}}}}}}}}, +} + +func returnFloat64Ptr(x float64) *float64 { + return &x +} + +func returnint64Ptr(x int64) *int64 { + return &x +} + +func TestAzMonitorextractValue(t *testing.T) { + for _, testData := range testExtractAzMonitordata { + value, err := extractValue(testData.metricRequest, testData.metricResult) + if err != nil && !testData.isError { + t.Errorf("Test: %v; Expected success but got error: %v", testData.testName, err) + } + if testData.isError && err == nil { + t.Errorf("Test: %v; Expected error but got success. testData: %v", testData.testName, testData) + } + if err != nil && value != testData.expectedValue { + t.Errorf("Test: %v; Expected value %v but got %v testData: %v", testData.testName, testData.expectedValue, value, testData) + } + } +} diff --git a/pkg/scalers/azure/azure_queue.go b/pkg/scalers/azure/azure_queue.go new file mode 100644 index 00000000000..642d1109ba3 --- /dev/null +++ b/pkg/scalers/azure/azure_queue.go @@ -0,0 +1,30 @@ +package azure + +import ( + "context" + "github.com/Azure/azure-storage-queue-go/azqueue" +) + +// GetAzureQueueLength returns the length of a queue in int +func GetAzureQueueLength(ctx context.Context, podIdentity string, connectionString, queueName string, accountName string) (int32, error) { + + credential, endpoint, err := ParseAzureStorageQueueConnection(podIdentity, connectionString, accountName) + if err != nil { + return -1, err + } + + p := azqueue.NewPipeline(credential, azqueue.PipelineOptions{}) + serviceURL := azqueue.NewServiceURL(*endpoint, p) + queueURL := serviceURL.NewQueueURL(queueName) + _, err = queueURL.Create(ctx, azqueue.Metadata{}) + if err != nil { + return -1, err + } + + props, err := queueURL.GetProperties(ctx) + if err != nil { + return -1, err + } + + return props.ApproximateMessagesCount(), nil +} diff --git a/pkg/scalers/azure/azure_queue_test.go b/pkg/scalers/azure/azure_queue_test.go new file mode 100644 index 00000000000..15c55899ea8 --- /dev/null +++ b/pkg/scalers/azure/azure_queue_test.go @@ -0,0 +1,36 @@ +package azure + +import ( + "context" + "strings" + "testing" +) + +func TestGetQueueLength(t *testing.T) { + length, err := GetAzureQueueLength(context.TODO(), "", "", "queueName", "") + if length != -1 { + t.Error("Expected length to be -1, but got", length) + } + + if err == nil { + t.Error("Expected error for empty connection string, but got nil") + } + + if !strings.Contains(err.Error(), "parse storage connection string") { + t.Error("Expected error to contain parsing error message, but got", err.Error()) + } + + length, err = GetAzureQueueLength(context.TODO(), "", "DefaultEndpointsProtocol=https;AccountName=name;AccountKey=key==;EndpointSuffix=core.windows.net", "queueName", "") + + if length != -1 { + t.Error("Expected length to be -1, but got", length) + } + + if err == nil { + t.Error("Expected error for empty connection string, but got nil") + } + + if !strings.Contains(err.Error(), "illegal base64") { + t.Error("Expected error to contain base64 error message, but got", err.Error()) + } +} diff --git a/pkg/scalers/azure/azure_storage.go b/pkg/scalers/azure/azure_storage.go new file mode 100644 index 00000000000..8b1480f329c --- /dev/null +++ b/pkg/scalers/azure/azure_storage.go @@ -0,0 +1,152 @@ +package azure + +import ( + "errors" + "fmt" + "github.com/Azure/azure-storage-blob-go/azblob" + "github.com/Azure/azure-storage-queue-go/azqueue" + "net/url" + "strings" +) + +/* ParseAzureStorageConnectionString parses a storage account connection string into (endpointProtocol, accountName, key, endpointSuffix) + Connection string should be in following format: + DefaultEndpointsProtocol=https;AccountName=yourStorageAccountName;AccountKey=yourStorageAccountKey;EndpointSuffix=core.windows.net +*/ + +type AzureStorageEndpointType int + +const ( + BlobEndpoint AzureStorageEndpointType = iota + QueueEndpoint + TableEndpoint + FileEndpoint +) + +func (e AzureStorageEndpointType) Prefix() string { + return [...]string{"BlobEndpoint", "QueueEndpoint", "TableEndpoint", "FileEndpoint"}[e] +} + +func (e AzureStorageEndpointType) Name() string { + return [...]string{"blob", "queue", "table", "file"}[e] +} + +func ParseAzureStorageQueueConnection(podIdentity, connectionString, accountName string) (azqueue.Credential, *url.URL, error) { + switch podIdentity { + case "azure": + token, err := GetAzureADPodIdentityToken("https://storage.azure.com/") + if err != nil { + return nil, nil, err + } + + if accountName == "" { + return nil, nil, fmt.Errorf("accountName is required for podIdentity azure") + } + + credential := azqueue.NewTokenCredential(token.AccessToken, nil) + endpoint, _ := url.Parse(fmt.Sprintf("https://%s.queue.core.windows.net", accountName)) + return credential, endpoint, nil + case "", "none": + endpoint, accountName, accountKey, err := parseAzureStorageConnectionString(connectionString, QueueEndpoint) + if err != nil { + return nil, nil, err + } + + credential, err := azqueue.NewSharedKeyCredential(accountName, accountKey) + if err != nil { + return nil, nil, err + } + + return credential, endpoint, nil + default: + return nil, nil, fmt.Errorf("azure queues doesn't support %s pod identity type", podIdentity) + } +} + +func ParseAzureStorageBlobConnection(podIdentity, connectionString, accountName string) (azblob.Credential, *url.URL, error) { + switch podIdentity { + case "azure": + token, err := GetAzureADPodIdentityToken("https://storage.azure.com/") + if err != nil { + return nil, nil, err + } + + if accountName == "" { + return nil, nil, fmt.Errorf("accountName is required for podIdentity azure") + } + + credential := azblob.NewTokenCredential(token.AccessToken, nil) + endpoint, _ := url.Parse(fmt.Sprintf("https://%s.blob.core.windows.net", accountName)) + return credential, endpoint, nil + case "", "none": + endpoint, accountName, accountKey, err := parseAzureStorageConnectionString(connectionString, BlobEndpoint) + if err != nil { + return nil, nil, err + } + + credential, err := azblob.NewSharedKeyCredential(accountName, accountKey) + if err != nil { + return nil, nil, err + } + + return credential, endpoint, nil + default: + return nil, nil, fmt.Errorf("azure queues doesn't support %s pod identity type", podIdentity) + } +} + +func parseAzureStorageConnectionString(connectionString string, endpointType AzureStorageEndpointType) (*url.URL, string, string, error) { + parts := strings.Split(connectionString, ";") + + getValue := func(pair string) string { + parts := strings.SplitN(pair, "=", 2) + if len(parts) == 2 { + return parts[1] + } + return "" + } + + var endpointProtocol, name, key, endpointSuffix, endpoint string + for _, v := range parts { + if strings.HasPrefix(v, "DefaultEndpointsProtocol") { + endpointProtocol = getValue(v) + } else if strings.HasPrefix(v, "AccountName") { + name = getValue(v) + } else if strings.HasPrefix(v, "AccountKey") { + key = getValue(v) + } else if strings.HasPrefix(v, "EndpointSuffix") { + endpointSuffix = getValue(v) + } else if endpointType == BlobEndpoint && strings.HasPrefix(v, endpointType.Prefix()) { + endpoint = getValue(v) + } else if endpointType == QueueEndpoint && strings.HasPrefix(v, endpointType.Prefix()) { + endpoint = getValue(v) + } else if endpointType == TableEndpoint && strings.HasPrefix(v, endpointType.Prefix()) { + endpoint = getValue(v) + } else if endpointType == FileEndpoint && strings.HasPrefix(v, endpointType.Prefix()) { + endpoint = getValue(v) + } + } + + if name == "" || key == "" { + return nil, "", "", errors.New("can't parse storage connection string. Missing key or name") + } + + if endpoint != "" { + u, err := url.Parse(endpoint) + if err != nil { + return nil, "", "", err + } + return u, name, key, nil + } + + if endpointProtocol == "" || endpointSuffix == "" { + return nil, "", "", errors.New("can't parse storage connection string. Missing DefaultEndpointsProtocol or EndpointSuffix") + } + + u, err := url.Parse(fmt.Sprintf("%s://%s.%s.%s", endpointProtocol, name, endpointType.Name(), endpointSuffix)) + if err != nil { + return nil, "", "", err + } + + return u, name, key, nil +} diff --git a/pkg/scalers/azure/azure_storage_test.go b/pkg/scalers/azure/azure_storage_test.go new file mode 100644 index 00000000000..3a185fee6e6 --- /dev/null +++ b/pkg/scalers/azure/azure_storage_test.go @@ -0,0 +1,63 @@ +package azure + +import "testing" + +type parseConnectionStringTestData struct { + connectionString string + accountName string + accountKey string + endpoint string + endpointType AzureStorageEndpointType + isError bool +} + +var parseConnectionStringTestDataset = []parseConnectionStringTestData{ + {"DefaultEndpointsProtocol=https;AccountName=testing;AccountKey=key==;EndpointSuffix=core.windows.net", "testing", "key==", "https://testing.queue.core.windows.net", QueueEndpoint, false}, + {"DefaultEndpointsProtocol=https;AccountName=testing;AccountKey=key==;EndpointSuffix=core.windows.net", "testing", "key==", "https://testing.blob.core.windows.net", BlobEndpoint, false}, + {"DefaultEndpointsProtocol=https;AccountName=testing;AccountKey=key==;EndpointSuffix=core.windows.net", "testing", "key==", "https://testing.table.core.windows.net", TableEndpoint, false}, + {"DefaultEndpointsProtocol=https;AccountName=testing;AccountKey=key==;EndpointSuffix=core.windows.net", "testing", "key==", "https://testing.file.core.windows.net", FileEndpoint, false}, + {"AccountName=testingAccountKey=key==", "", "", "", QueueEndpoint, true}, + {"", "", "", "", QueueEndpoint, true}, + {"DefaultEndpointsProtocol=https;AccountName=testing;AccountKey=key==;EndpointSuffix=core.windows.net;QueueEndpoint=https://queue.net", "testing", "key==", "https://queue.net", QueueEndpoint, false}, + {"DefaultEndpointsProtocol=https;AccountName=testing;AccountKey=key==;EndpointSuffix=core.windows.net;BlobEndpoint=https://blob.net", "testing", "key==", "https://blob.net", BlobEndpoint, false}, + {"DefaultEndpointsProtocol=https;AccountName=testing;AccountKey=key==;EndpointSuffix=core.windows.net;TableEndpoint=https://table.net", "testing", "key==", "https://table.net", TableEndpoint, false}, + {"DefaultEndpointsProtocol=https;AccountName=testing;AccountKey=key==;EndpointSuffix=core.windows.net;FileEndpoint=https://file.net", "testing", "key==", "https://file.net", FileEndpoint, false}, +} + +func TestParseStorageConnectionString(t *testing.T) { + for _, testData := range parseConnectionStringTestDataset { + endpoint, accountName, accountKey, err := parseAzureStorageConnectionString(testData.connectionString, testData.endpointType) + + if !testData.isError && err != nil { + t.Error("Expected success but got err", err) + } + + if testData.isError && err == nil { + t.Error("Expected error but got nil") + } + + if accountName != testData.accountName { + t.Error( + "For", testData.connectionString, + "expected accountName=", testData.accountName, + "but got", accountName) + } + + if accountKey != testData.accountKey { + t.Error( + "For", testData.connectionString, + "expected accountKey=", testData.accountKey, + "but got", accountKey) + } + + if err == nil { + if endpoint.String() != testData.endpoint { + t.Error( + "For", testData.connectionString, + "expected endpoint=", testData.endpoint, + "but got", endpoint) + } + } + + } +} diff --git a/pkg/scalers/azure_blob.go b/pkg/scalers/azure_blob.go deleted file mode 100644 index 13c2d49a65b..00000000000 --- a/pkg/scalers/azure_blob.go +++ /dev/null @@ -1,60 +0,0 @@ -package scalers - -import ( - "context" - "fmt" - "net/url" - - "github.com/Azure/azure-storage-blob-go/azblob" -) - -// GetAzureBlobListLength returns the count of the blobs in blob container in int -func GetAzureBlobListLength(ctx context.Context, podIdentity string, connectionString, blobContainerName string, accountName string, blobDelimiter string, blobPrefix string) (int, error) { - - var credential azblob.Credential - var listBlobsSegmentOptions azblob.ListBlobsSegmentOptions - var err error - - if podIdentity == "" || podIdentity == "none" { - - var accountKey string - - _, accountName, accountKey, _, err = ParseAzureStorageConnectionString(connectionString) - - if err != nil { - return -1, err - } - - credential, err = azblob.NewSharedKeyCredential(accountName, accountKey) - if err != nil { - return -1, err - } - } else if podIdentity == "azure" { - token, err := getAzureADPodIdentityToken("https://storage.azure.com/") - if err != nil { - azureBlobLog.Error(err, "Error fetching token cannot determine blob list count") - return -1, nil - } - - credential = azblob.NewTokenCredential(token.AccessToken, nil) - } else { - return -1, fmt.Errorf("Azure blobs doesn't support %s pod identity type", podIdentity) - - } - - if blobPrefix != "" { - listBlobsSegmentOptions.Prefix = blobPrefix - } - - p := azblob.NewPipeline(credential, azblob.PipelineOptions{}) - u, _ := url.Parse(fmt.Sprintf("https://%s.blob.core.windows.net", accountName)) - serviceURL := azblob.NewServiceURL(*u, p) - containerURL := serviceURL.NewContainerURL(blobContainerName) - - props, err := containerURL.ListBlobsHierarchySegment(ctx, azblob.Marker{} , blobDelimiter, listBlobsSegmentOptions) - if err != nil { - return -1, err - } - - return len(props.Segment.BlobItems) , nil -} diff --git a/pkg/scalers/azure_blob_scaler.go b/pkg/scalers/azure_blob_scaler.go index f4a9bdd4382..c35c1012207 100644 --- a/pkg/scalers/azure_blob_scaler.go +++ b/pkg/scalers/azure_blob_scaler.go @@ -3,6 +3,7 @@ package scalers import ( "context" "fmt" + "github.com/kedacore/keda/pkg/scalers/azure" "strconv" v2beta1 "k8s.io/api/autoscaling/v2beta1" @@ -14,15 +15,15 @@ import ( ) const ( - blobCountMetricName = "blobCount" - defaultTargetBlobCount = 5 - defaultBlobDelimiter = "/" - defaultBlobPrefix = "" + blobCountMetricName = "blobCount" + defaultTargetBlobCount = 5 + defaultBlobDelimiter = "/" + defaultBlobPrefix = "" defaultBlobConnectionSetting = "AzureWebJobsStorage" ) type azureBlobScaler struct { - metadata *azureBlobMetadata + metadata *azureBlobMetadata podIdentity string } @@ -30,7 +31,7 @@ type azureBlobMetadata struct { targetBlobCount int blobContainerName string blobDelimiter string - blobPrefix string + blobPrefix string connection string useAAdPodIdentity bool accountName string @@ -46,7 +47,7 @@ func NewAzureBlobScaler(resolvedEnv, metadata, authParams map[string]string, pod } return &azureBlobScaler{ - metadata: meta, + metadata: meta, podIdentity: podIdentity, }, nil } @@ -101,17 +102,17 @@ func parseAzureBlobMetadata(metadata, resolvedEnv, authParams map[string]string, // Found the connection in a parameter from TriggerAuthentication meta.connection = connection } else { - connectionSetting := defaultBlobConnectionSetting - if val, ok := metadata["connection"]; ok && val != "" { - connectionSetting = val + connectionSetting := defaultBlobConnectionSetting + if val, ok := metadata["connection"]; ok && val != "" { + connectionSetting = val + } + + if val, ok := resolvedEnv[connectionSetting]; ok { + meta.connection = val + } else { + return nil, "", fmt.Errorf("no connection setting given") + } } - - if val, ok := resolvedEnv[connectionSetting]; ok { - meta.connection = val - } else { - return nil, "", fmt.Errorf("no connection setting given") - } - } } else if podAuth == "azure" { // If the Use AAD Pod Identity is present then check account name if val, ok := metadata["accountName"]; ok && val != "" { @@ -128,7 +129,7 @@ func parseAzureBlobMetadata(metadata, resolvedEnv, authParams map[string]string, // GetScaleDecision is a func func (s *azureBlobScaler) IsActive(ctx context.Context) (bool, error) { - length, err := GetAzureBlobListLength( + length, err := azure.GetAzureBlobListLength( ctx, s.podIdentity, s.metadata.connection, @@ -159,7 +160,7 @@ func (s *azureBlobScaler) GetMetricSpecForScaling() []v2beta1.MetricSpec { //GetMetrics returns value for a supported metric and an error if there is a problem getting the metric func (s *azureBlobScaler) GetMetrics(ctx context.Context, metricName string, metricSelector labels.Selector) ([]external_metrics.ExternalMetricValue, error) { - bloblen, err := GetAzureBlobListLength( + bloblen, err := azure.GetAzureBlobListLength( ctx, s.podIdentity, s.metadata.connection, diff --git a/pkg/scalers/azure_blob_test.go b/pkg/scalers/azure_blob_scaler_test.go similarity index 68% rename from pkg/scalers/azure_blob_test.go rename to pkg/scalers/azure_blob_scaler_test.go index 3d611d368dd..a69f4bc5373 100644 --- a/pkg/scalers/azure_blob_test.go +++ b/pkg/scalers/azure_blob_scaler_test.go @@ -1,47 +1,14 @@ package scalers -import ( - "context" - "strings" - "testing" -) - -func TestGetBlobLength(t *testing.T) { - length, err := GetAzureBlobListLength(context.TODO(), "", "", "blobContainerName", "", "","") - if length != -1 { - t.Error("Expected length to be -1, but got", length) - } - - if err == nil { - t.Error("Expected error for empty connection string, but got nil") - } - - if !strings.Contains(err.Error(), "parse storage connection string") { - t.Error("Expected error to contain parsing error message, but got", err.Error()) - } - - length, err = GetAzureBlobListLength(context.TODO(), "", "DefaultEndpointsProtocol=https;AccountName=name;AccountKey=key==;EndpointSuffix=core.windows.net", "blobContainerName", "", "","") - - if length != -1 { - t.Error("Expected length to be -1, but got", length) - } - - if err == nil { - t.Error("Expected error for empty connection string, but got nil") - } - - if !strings.Contains(err.Error(), "illegal base64") { - t.Error("Expected error to contain base64 error message, but got", err.Error()) - } -} +import "testing" var testAzBlobResolvedEnv = map[string]string{ "CONNECTION": "SAMPLE", } type parseAzBlobMetadataTestData struct { - metadata map[string]string - isError bool + metadata map[string]string + isError bool resolvedEnv map[string]string authParams map[string]string podIdentity string @@ -64,7 +31,6 @@ var testAzBlobMetadata = []parseAzBlobMetadataTestData{ {map[string]string{"accountName": "sample_acc", "blobContainerName": ""}, true, testAzBlobResolvedEnv, map[string]string{}, "azure"}, // connection from authParams {map[string]string{"blobContainerName": "sample_container", "blobCount": "5"}, false, testAzBlobResolvedEnv, map[string]string{"connection": "value"}, "none"}, - } func TestAzBlobParseMetadata(t *testing.T) { diff --git a/pkg/scalers/azure_eventhub_scaler.go b/pkg/scalers/azure_eventhub_scaler.go index 9e1de7fd70a..6592ff73d6b 100644 --- a/pkg/scalers/azure_eventhub_scaler.go +++ b/pkg/scalers/azure_eventhub_scaler.go @@ -3,11 +3,11 @@ package scalers import ( "context" "fmt" + "github.com/kedacore/keda/pkg/scalers/azure" "math" "strconv" eventhub "github.com/Azure/azure-event-hubs-go" - "github.com/Azure/azure-storage-blob-go/azblob" "k8s.io/api/autoscaling/v2beta1" "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -29,17 +29,13 @@ const ( var eventhubLog = logf.Log.WithName("azure_eventhub_scaler") type AzureEventHubScaler struct { - metadata *EventHubMetadata - client *eventhub.Hub - storageCredentials *azblob.SharedKeyCredential + metadata *EventHubMetadata + client *eventhub.Hub } type EventHubMetadata struct { - eventHubConnection string - eventHubConsumerGroup string - threshold int64 - storageConnection string - blobContainer string + eventHubInfo azure.EventHubInfo + threshold int64 } // NewAzureEventHubScaler creates a new scaler for eventHub @@ -49,26 +45,22 @@ func NewAzureEventHubScaler(resolvedEnv, metadata map[string]string) (Scaler, er return nil, fmt.Errorf("unable to get eventhub metadata: %s", err) } - _, cred, err := GetStorageCredentials(parsedMetadata.storageConnection) - if err != nil { - return nil, fmt.Errorf("unable to get storage credentials: %s", err) - } - - hub, err := GetEventHubClient(parsedMetadata.eventHubConnection) + hub, err := azure.GetEventHubClient(parsedMetadata.eventHubInfo) if err != nil { return nil, fmt.Errorf("unable to get eventhub client: %s", err) } return &AzureEventHubScaler{ - metadata: parsedMetadata, - storageCredentials: cred, - client: hub, + metadata: parsedMetadata, + client: hub, }, nil } // parseAzureEventHubMetadata parses metadata func parseAzureEventHubMetadata(metadata, resolvedEnv map[string]string) (*EventHubMetadata, error) { - meta := EventHubMetadata{} + meta := EventHubMetadata{ + eventHubInfo: azure.EventHubInfo{}, + } meta.threshold = defaultEventHubMessageThreshold if val, ok := metadata[thresholdMetricName]; ok { @@ -86,7 +78,7 @@ func parseAzureEventHubMetadata(metadata, resolvedEnv map[string]string) (*Event } if val, ok := resolvedEnv[storageConnectionSetting]; ok { - meta.storageConnection = val + meta.eventHubInfo.StorageConnection = val } else { return nil, fmt.Errorf("no storage connection string given") } @@ -97,19 +89,19 @@ func parseAzureEventHubMetadata(metadata, resolvedEnv map[string]string) (*Event } if val, ok := resolvedEnv[eventHubConnectionSetting]; ok { - meta.eventHubConnection = val + meta.eventHubInfo.EventHubConnection = val } else { return nil, fmt.Errorf("no event hub connection string given") } - meta.eventHubConsumerGroup = defaultEventHubConsumerGroup + meta.eventHubInfo.EventHubConsumerGroup = defaultEventHubConsumerGroup if val, ok := metadata["consumerGroup"]; ok { - meta.eventHubConsumerGroup = val + meta.eventHubInfo.EventHubConsumerGroup = val } - meta.blobContainer = defaultBlobContainer + meta.eventHubInfo.BlobContainer = defaultBlobContainer if val, ok := metadata["blobContainer"]; ok { - meta.blobContainer = val + meta.eventHubInfo.BlobContainer = val } return &meta, nil @@ -122,7 +114,7 @@ func (scaler *AzureEventHubScaler) GetUnprocessedEventCountInPartition(ctx conte return -1, fmt.Errorf("unable to get partition info: %s", err) } - checkpoint, err := GetCheckpointFromBlobStorage(ctx, partitionID, *scaler.metadata) + checkpoint, err := azure.GetCheckpointFromBlobStorage(ctx, scaler.metadata.eventHubInfo, partitionID) if err != nil { return -1, fmt.Errorf("unable to get checkpoint from storage: %s", err) } @@ -202,7 +194,7 @@ func (scaler *AzureEventHubScaler) GetMetrics(ctx context.Context, metricName st return []external_metrics.ExternalMetricValue{}, fmt.Errorf("unable to get partitionRuntimeInfo for metrics: %s", err) } - checkpoint, err := GetCheckpointFromBlobStorage(ctx, partitionID, *scaler.metadata) + checkpoint, err := azure.GetCheckpointFromBlobStorage(ctx, scaler.metadata.eventHubInfo, partitionID) if err != nil { return []external_metrics.ExternalMetricValue{}, fmt.Errorf("unable to get checkpoint from storage: %s", err) } diff --git a/pkg/scalers/azure_eventhub_test.go b/pkg/scalers/azure_eventhub_scaler_test.go similarity index 81% rename from pkg/scalers/azure_eventhub_test.go rename to pkg/scalers/azure_eventhub_scaler_test.go index 29d6ec0ddef..e2e4ac5a533 100644 --- a/pkg/scalers/azure_eventhub_test.go +++ b/pkg/scalers/azure_eventhub_scaler_test.go @@ -3,13 +3,11 @@ package scalers import ( "context" "fmt" + "github.com/kedacore/keda/pkg/scalers/azure" "net/url" "os" - "strings" "testing" - "github.com/stretchr/testify/assert" - eventhub "github.com/Azure/azure-event-hubs-go" "github.com/Azure/azure-storage-blob-go/azblob" ) @@ -54,8 +52,10 @@ var parseEventHubMetadataDataset = []parseEventHubMetadataTestData{ var testEventHubScaler = AzureEventHubScaler{ metadata: &EventHubMetadata{ - eventHubConnection: "none", - storageConnection: "none", + eventHubInfo: azure.EventHubInfo{ + EventHubConnection: "none", + StorageConnection: "none", + }, }, } @@ -83,18 +83,17 @@ func TestGetUnprocessedEventCountInPartition(t *testing.T) { if eventHubKey != "" && storageConnectionString != "" { eventHubConnectionString := fmt.Sprintf("Endpoint=sb://%s.servicebus.windows.net/;SharedAccessKeyName=RootManageSharedAccessKey;SharedAccessKey=%s;EntityPath=%s", testEventHubNamespace, eventHubKey, testEventHubName) - storageAccountName := strings.Split(strings.Split(storageConnectionString, ";")[1], "=")[1] + storageCredentials, endpoint, err := azure.ParseAzureStorageBlobConnection("none", storageConnectionString, "") + if err != nil { + t.Error(err) + t.FailNow() + } t.Log("Creating event hub client...") hubOption := eventhub.HubWithPartitionedSender("0") client, err := eventhub.NewHubFromConnectionString(eventHubConnectionString, hubOption) if err != nil { - t.Errorf("Expected to create event hub client but got error: %s", err) - } - - _, storageCredentials, err := GetStorageCredentials(storageConnectionString) - if err != nil { - t.Errorf("Expected to generate storage credentials but got error: %s", err) + t.Fatalf("Expected to create event hub client but got error: %s", err) } if eventHubConnectionString == "" { @@ -106,11 +105,10 @@ func TestGetUnprocessedEventCountInPartition(t *testing.T) { } // Can actually test that numbers return - testEventHubScaler.metadata.eventHubConnection = eventHubConnectionString - testEventHubScaler.metadata.storageConnection = storageConnectionString + testEventHubScaler.metadata.eventHubInfo.EventHubConnection = eventHubConnectionString + testEventHubScaler.metadata.eventHubInfo.StorageConnection = storageConnectionString testEventHubScaler.client = client - testEventHubScaler.storageCredentials = storageCredentials - testEventHubScaler.metadata.eventHubConsumerGroup = "$Default" + testEventHubScaler.metadata.eventHubInfo.EventHubConsumerGroup = "$Default" // Send 1 message to event hub first t.Log("Sending message to event hub") @@ -121,7 +119,7 @@ func TestGetUnprocessedEventCountInPartition(t *testing.T) { // Create fake checkpoint with path azure-webjobs-eventhub/.servicebus.windows.net//$Default t.Log("Creating container..") - ctx, err := CreateNewCheckpointInStorage(storageAccountName, storageCredentials, client) + ctx, err := CreateNewCheckpointInStorage(endpoint, storageCredentials, client) if err != nil { t.Errorf("err creating container: %s", err) } @@ -145,49 +143,20 @@ func TestGetUnprocessedEventCountInPartition(t *testing.T) { // Delete container - this will also delete checkpoint t.Log("Deleting container...") - err = DeleteContainerInStorage(ctx, storageAccountName, storageCredentials) + err = DeleteContainerInStorage(ctx, endpoint, storageCredentials) if err != nil { t.Error(err) } } } -const csharpSdkCheckpoint = `{ - "Epoch": 123456, - "Offset": "test offset", - "Owner": "test owner", - "PartitionId": "test partitionId", - "SequenceNumber": 12345 - }` - -const pythonSdkCheckpoint = `{ - "epoch": 123456, - "offset": "test offset", - "owner": "test owner", - "partition_id": "test partitionId", - "sequence_number": 12345 - }` - -func TestGetCheckpoint(t *testing.T) { - cckp, err := getCheckpoint([]byte(csharpSdkCheckpoint)) - if err != nil { - t.Error(err) - } - - pckp, err := getCheckpoint([]byte(pythonSdkCheckpoint)) - if err != nil { - t.Error(err) - } - - assert.Equal(t, cckp, pckp) -} - -func CreateNewCheckpointInStorage(storageAccountName string, credential *azblob.SharedKeyCredential, client *eventhub.Hub) (context.Context, error) { +func CreateNewCheckpointInStorage(endpoint *url.URL, credential azblob.Credential, client *eventhub.Hub) (context.Context, error) { urlPath := fmt.Sprintf("%s.servicebus.windows.net/%s/$Default/", testEventHubNamespace, testEventHubName) // Create container ctx := context.Background() - url, _ := url.Parse(fmt.Sprintf("https://%s.blob.core.windows.net/%s", storageAccountName, testContainerName)) + path, _ := url.Parse(testContainerName) + url := endpoint.ResolveReference(path) containerURL := azblob.NewContainerURL(*url, azblob.NewPipeline(credential, azblob.PipelineOptions{})) _, err := containerURL.Create(ctx, azblob.Metadata{}, azblob.PublicAccessNone) if err != nil { @@ -283,8 +252,9 @@ func SendMessageToEventHub(client *eventhub.Hub) error { return nil } -func DeleteContainerInStorage(ctx context.Context, storageAccountName string, credential *azblob.SharedKeyCredential) error { - url, _ := url.Parse(fmt.Sprintf("https://%s.blob.core.windows.net/%s", storageAccountName, testContainerName)) +func DeleteContainerInStorage(ctx context.Context, endpoint *url.URL, credential azblob.Credential) error { + path, _ := url.Parse(testContainerName) + url := endpoint.ResolveReference(path) containerURL := azblob.NewContainerURL(*url, azblob.NewPipeline(credential, azblob.PipelineOptions{})) _, err := containerURL.Delete(ctx, azblob.ContainerAccessConditions{ diff --git a/pkg/scalers/azure_monitor_scaler.go b/pkg/scalers/azure_monitor_scaler.go index a27d83a34a7..cbab9c36603 100644 --- a/pkg/scalers/azure_monitor_scaler.go +++ b/pkg/scalers/azure_monitor_scaler.go @@ -3,6 +3,7 @@ package scalers import ( "context" "fmt" + "github.com/kedacore/keda/pkg/scalers/azure" "strconv" "strings" @@ -26,17 +27,8 @@ type azureMonitorScaler struct { } type azureMonitorMetadata struct { - resourceURI string - tenantID string - subscriptionID string - resourceGroupName string - name string - filter string - aggregationInterval string - aggregationType string - clientID string - clientPassword string - targetValue int + azureMonitorInfo azure.AzureMonitorInfo + targetValue int } var azureMonitorLog = logf.Log.WithName("azure_monitor_scaler") @@ -54,7 +46,9 @@ func NewAzureMonitorScaler(resolvedEnv, metadata, authParams map[string]string) } func parseAzureMonitorMetadata(metadata, resolvedEnv, authParams map[string]string) (*azureMonitorMetadata, error) { - meta := azureMonitorMetadata{} + meta := azureMonitorMetadata{ + azureMonitorInfo: azure.AzureMonitorInfo{}, + } if val, ok := metadata[targetValueName]; ok && val != "" { targetValue, err := strconv.Atoi(val) @@ -72,31 +66,31 @@ func parseAzureMonitorMetadata(metadata, resolvedEnv, authParams map[string]stri if len(resourceURI) != 3 { return nil, fmt.Errorf("resourceURI not in the correct format. Should be namespace/resource_type/resource_name") } - meta.resourceURI = val + meta.azureMonitorInfo.ResourceURI = val } else { return nil, fmt.Errorf("no resourceURI given") } if val, ok := metadata["resourceGroupName"]; ok && val != "" { - meta.resourceGroupName = val + meta.azureMonitorInfo.ResourceGroupName = val } else { return nil, fmt.Errorf("no resourceGroupName given") } if val, ok := metadata[azureMonitorMetricName]; ok && val != "" { - meta.name = val + meta.azureMonitorInfo.Name = val } else { return nil, fmt.Errorf("no metricName given") } if val, ok := metadata["metricAggregationType"]; ok && val != "" { - meta.aggregationType = val + meta.azureMonitorInfo.AggregationType = val } else { return nil, fmt.Errorf("no metricAggregationType given") } if val, ok := metadata["metricFilter"]; ok && val != "" { - meta.filter = val + meta.azureMonitorInfo.Filter = val } if val, ok := metadata["metricAggregationInterval"]; ok && val != "" { @@ -104,25 +98,25 @@ func parseAzureMonitorMetadata(metadata, resolvedEnv, authParams map[string]stri if len(aggregationInterval) != 3 { return nil, fmt.Errorf("metricAggregationInterval not in the correct format. Should be hh:mm:ss") } - meta.aggregationInterval = val + meta.azureMonitorInfo.AggregationInterval = val } // Required authentication parameters below if val, ok := metadata["subscriptionId"]; ok && val != "" { - meta.subscriptionID = val + meta.azureMonitorInfo.SubscriptionID = val } else { return nil, fmt.Errorf("no subscriptionId given") } if val, ok := metadata["tenantId"]; ok && val != "" { - meta.tenantID = val + meta.azureMonitorInfo.TenantID = val } else { return nil, fmt.Errorf("no tenantId given") } if val, ok := authParams["activeDirectoryClientId"]; ok && val != "" { - meta.clientID = val + meta.azureMonitorInfo.ClientID = val } else { clientIDSetting := defaultClientIDSetting if val, ok := metadata["activeDirectoryClientId"]; ok && val != "" { @@ -130,14 +124,14 @@ func parseAzureMonitorMetadata(metadata, resolvedEnv, authParams map[string]stri } if val, ok := resolvedEnv[clientIDSetting]; ok { - meta.clientID = val + meta.azureMonitorInfo.ClientID = val } else { return nil, fmt.Errorf("no activeDirectoryClientId given") } } if val, ok := authParams["activeDirectoryClientPassword"]; ok && val != "" { - meta.clientPassword = val + meta.azureMonitorInfo.ClientPassword = val } else { clientPasswordSetting := defaultClientPasswordSetting if val, ok := metadata["activeDirectoryClientPassword"]; ok && val != "" { @@ -145,7 +139,7 @@ func parseAzureMonitorMetadata(metadata, resolvedEnv, authParams map[string]stri } if val, ok := resolvedEnv[clientPasswordSetting]; ok { - meta.clientPassword = val + meta.azureMonitorInfo.ClientPassword = val } else { return nil, fmt.Errorf("no activeDirectoryClientPassword given") } @@ -156,7 +150,7 @@ func parseAzureMonitorMetadata(metadata, resolvedEnv, authParams map[string]stri // Returns true if the Azure Monitor metric value is greater than zero func (s *azureMonitorScaler) IsActive(ctx context.Context) (bool, error) { - val, err := GetAzureMetricValue(ctx, s.metadata) + val, err := azure.GetAzureMetricValue(ctx, s.metadata.azureMonitorInfo) if err != nil { azureMonitorLog.Error(err, "error getting azure monitor metric") return false, err @@ -178,7 +172,7 @@ func (s *azureMonitorScaler) GetMetricSpecForScaling() []v2beta1.MetricSpec { // GetMetrics returns value for a supported metric and an error if there is a problem getting the metric func (s *azureMonitorScaler) GetMetrics(ctx context.Context, metricName string, metricSelector labels.Selector) ([]external_metrics.ExternalMetricValue, error) { - val, err := GetAzureMetricValue(ctx, s.metadata) + val, err := azure.GetAzureMetricValue(ctx, s.metadata.azureMonitorInfo) if err != nil { azureMonitorLog.Error(err, "error getting azure monitor metric") return []external_metrics.ExternalMetricValue{}, err diff --git a/pkg/scalers/azure_monitor_test.go b/pkg/scalers/azure_monitor_scaler_test.go similarity index 66% rename from pkg/scalers/azure_monitor_test.go rename to pkg/scalers/azure_monitor_scaler_test.go index 8c499944d5a..5a3c496cd87 100644 --- a/pkg/scalers/azure_monitor_test.go +++ b/pkg/scalers/azure_monitor_scaler_test.go @@ -1,10 +1,6 @@ package scalers -import ( - "testing" - - "github.com/Azure/azure-sdk-for-go/services/preview/monitor/mgmt/2018-03-01/insights" -) +import "testing" type parseAzMonitorMetadataTestData struct { metadata map[string]string @@ -64,47 +60,3 @@ func TestAzMonitorParseMetadata(t *testing.T) { } } } - -type testExtractAzMonitorTestData struct { - testName string - isError bool - expectedValue float64 - metricRequest azureExternalMetricRequest - metricResult insights.Response -} - -var testExtractAzMonitordata = []testExtractAzMonitorTestData{ - {"nothing returned", true, -1, azureExternalMetricRequest{}, insights.Response{Value: &[]insights.Metric{}}}, - {"timeseries null", true, -1, azureExternalMetricRequest{}, insights.Response{Value: &[]insights.Metric{insights.Metric{Timeseries: nil}}}}, - {"timeseries empty", true, -1, azureExternalMetricRequest{}, insights.Response{Value: &[]insights.Metric{insights.Metric{Timeseries: &[]insights.TimeSeriesElement{}}}}}, - {"data nil", true, -1, azureExternalMetricRequest{}, insights.Response{Value: &[]insights.Metric{insights.Metric{Timeseries: &[]insights.TimeSeriesElement{insights.TimeSeriesElement{Data: nil}}}}}}, - {"data empty", true, -1, azureExternalMetricRequest{}, insights.Response{Value: &[]insights.Metric{insights.Metric{Timeseries: &[]insights.TimeSeriesElement{insights.TimeSeriesElement{Data: &[]insights.MetricValue{}}}}}}}, - {"Total Aggregation requested", false, 40, azureExternalMetricRequest{Aggregation: "Total"}, insights.Response{Value: &[]insights.Metric{insights.Metric{Timeseries: &[]insights.TimeSeriesElement{insights.TimeSeriesElement{Data: &[]insights.MetricValue{insights.MetricValue{Total: returnFloat64Ptr(40)}}}}}}}}, - {"Average Aggregation requested", false, 41, azureExternalMetricRequest{Aggregation: "Average"}, insights.Response{Value: &[]insights.Metric{insights.Metric{Timeseries: &[]insights.TimeSeriesElement{insights.TimeSeriesElement{Data: &[]insights.MetricValue{insights.MetricValue{Average: returnFloat64Ptr(41)}}}}}}}}, - {"Maximum Aggregation requested", false, 42, azureExternalMetricRequest{Aggregation: "Maximum"}, insights.Response{Value: &[]insights.Metric{insights.Metric{Timeseries: &[]insights.TimeSeriesElement{insights.TimeSeriesElement{Data: &[]insights.MetricValue{insights.MetricValue{Maximum: returnFloat64Ptr(42)}}}}}}}}, - {"Minimum Aggregation requested", false, 43, azureExternalMetricRequest{Aggregation: "Minimum"}, insights.Response{Value: &[]insights.Metric{insights.Metric{Timeseries: &[]insights.TimeSeriesElement{insights.TimeSeriesElement{Data: &[]insights.MetricValue{insights.MetricValue{Minimum: returnFloat64Ptr(43)}}}}}}}}, - {"Count Aggregation requested", false, 44, azureExternalMetricRequest{Aggregation: "Count"}, insights.Response{Value: &[]insights.Metric{insights.Metric{Timeseries: &[]insights.TimeSeriesElement{insights.TimeSeriesElement{Data: &[]insights.MetricValue{insights.MetricValue{Count: returnint64Ptr(44)}}}}}}}}, -} - -func returnFloat64Ptr(x float64) *float64 { - return &x -} - -func returnint64Ptr(x int64) *int64 { - return &x -} - -func TestAzMonitorextractValue(t *testing.T) { - for _, testData := range testExtractAzMonitordata { - value, err := extractValue(testData.metricRequest, testData.metricResult) - if err != nil && !testData.isError { - t.Errorf("Test: %v; Expected success but got error: %v", testData.testName, err) - } - if testData.isError && err == nil { - t.Errorf("Test: %v; Expected error but got success. testData: %v", testData.testName, testData) - } - if err != nil && value != testData.expectedValue { - t.Errorf("Test: %v; Expected value %v but got %v testData: %v", testData.testName, testData.expectedValue, value, testData) - } - } -} diff --git a/pkg/scalers/azure_queue.go b/pkg/scalers/azure_queue.go deleted file mode 100644 index d7ef62dbf07..00000000000 --- a/pkg/scalers/azure_queue.go +++ /dev/null @@ -1,59 +0,0 @@ -package scalers - -import ( - "context" - "fmt" - "net/url" - - "github.com/Azure/azure-storage-queue-go/azqueue" -) - -// GetAzureQueueLength returns the length of a queue in int -func GetAzureQueueLength(ctx context.Context, podIdentity string, connectionString, queueName string, accountName string) (int32, error) { - - var credential azqueue.Credential - var err error - - if podIdentity == "" || podIdentity == "none" { - - var accountKey string - - _, accountName, accountKey, _, err = ParseAzureStorageConnectionString(connectionString) - - if err != nil { - return -1, err - } - - credential, err = azqueue.NewSharedKeyCredential(accountName, accountKey) - if err != nil { - return -1, err - } - } else if podIdentity == "azure" { - token, err := getAzureADPodIdentityToken("https://storage.azure.com/") - if err != nil { - azureQueueLog.Error(err, "Error fetching token cannot determine queue size") - return -1, nil - } - - credential = azqueue.NewTokenCredential(token.AccessToken, nil) - } else { - return -1, fmt.Errorf("Azure queues doesn't support %s pod identity type", podIdentity) - - } - - p := azqueue.NewPipeline(credential, azqueue.PipelineOptions{}) - u, _ := url.Parse(fmt.Sprintf("https://%s.queue.core.windows.net", accountName)) - serviceURL := azqueue.NewServiceURL(*u, p) - queueURL := serviceURL.NewQueueURL(queueName) - _, err = queueURL.Create(ctx, azqueue.Metadata{}) - if err != nil { - return -1, err - } - - props, err := queueURL.GetProperties(ctx) - if err != nil { - return -1, err - } - - return props.ApproximateMessagesCount(), nil -} diff --git a/pkg/scalers/azure_queue_scaler.go b/pkg/scalers/azure_queue_scaler.go index 4fb20b33b10..7102220a608 100644 --- a/pkg/scalers/azure_queue_scaler.go +++ b/pkg/scalers/azure_queue_scaler.go @@ -3,6 +3,7 @@ package scalers import ( "context" "fmt" + "github.com/kedacore/keda/pkg/scalers/azure" "strconv" v2beta1 "k8s.io/api/autoscaling/v2beta1" @@ -111,7 +112,7 @@ func parseAzureQueueMetadata(metadata, resolvedEnv, authParams map[string]string // GetScaleDecision is a func func (s *azureQueueScaler) IsActive(ctx context.Context) (bool, error) { - length, err := GetAzureQueueLength( + length, err := azure.GetAzureQueueLength( ctx, s.podIdentity, s.metadata.connection, @@ -140,7 +141,7 @@ func (s *azureQueueScaler) GetMetricSpecForScaling() []v2beta1.MetricSpec { //GetMetrics returns value for a supported metric and an error if there is a problem getting the metric func (s *azureQueueScaler) GetMetrics(ctx context.Context, metricName string, metricSelector labels.Selector) ([]external_metrics.ExternalMetricValue, error) { - queuelen, err := GetAzureQueueLength( + queuelen, err := azure.GetAzureQueueLength( ctx, s.podIdentity, s.metadata.connection, diff --git a/pkg/scalers/azure_queue_test.go b/pkg/scalers/azure_queue_scaler_test.go similarity index 55% rename from pkg/scalers/azure_queue_test.go rename to pkg/scalers/azure_queue_scaler_test.go index 5471ad975d7..bca9b1b0b96 100644 --- a/pkg/scalers/azure_queue_test.go +++ b/pkg/scalers/azure_queue_scaler_test.go @@ -1,81 +1,6 @@ package scalers -import ( - "context" - "strings" - "testing" -) - -type parseConnectionStringTestData struct { - connectionString string - accountName string - accountKey string - isError bool -} - -var parseConnectionStringTestDataset = []parseConnectionStringTestData{ - {"DefaultEndpointsProtocol=https;AccountName=testing;AccountKey=key==;EndpointSuffix=core.windows.net", "testing", "key==", false}, - {"DefaultEndpointsProtocol=https;AccountName=testing;AccountKey=key==;EndpointSuffix=core.windows.net", "testing", "key==", false}, - {"AccountName=testingAccountKey=key==", "", "", true}, - {"", "", "", true}, -} - -func TestParseStorageConnectionString(t *testing.T) { - for _, testData := range parseConnectionStringTestDataset { - _, accountName, accountKey, _, err := ParseAzureStorageConnectionString(testData.connectionString) - - if !testData.isError && err != nil { - t.Error("Expected success but got err", err) - } - - if testData.isError && err == nil { - t.Error("Expected error but got nil") - } - - if accountName != testData.accountName { - t.Error( - "For", testData.connectionString, - "expected accountName=", testData.accountName, - "but got", accountName) - } - - if accountKey != testData.accountKey { - t.Error( - "For", testData.connectionString, - "expected accountKey=", testData.accountKey, - "but got", accountKey) - } - } -} - -func TestGetQueueLength(t *testing.T) { - length, err := GetAzureQueueLength(context.TODO(), "", "", "queueName", "") - if length != -1 { - t.Error("Expected length to be -1, but got", length) - } - - if err == nil { - t.Error("Expected error for empty connection string, but got nil") - } - - if !strings.Contains(err.Error(), "parse storage connection string") { - t.Error("Expected error to contain parsing error message, but got", err.Error()) - } - - length, err = GetAzureQueueLength(context.TODO(), "", "DefaultEndpointsProtocol=https;AccountName=name;AccountKey=key==;EndpointSuffix=core.windows.net", "queueName", "") - - if length != -1 { - t.Error("Expected length to be -1, but got", length) - } - - if err == nil { - t.Error("Expected error for empty connection string, but got nil") - } - - if !strings.Contains(err.Error(), "illegal base64") { - t.Error("Expected error to contain base64 error message, but got", err.Error()) - } -} +import "testing" var testAzQueueResolvedEnv = map[string]string{ "CONNECTION": "SAMPLE", diff --git a/pkg/scalers/azure_servicebus_scaler.go b/pkg/scalers/azure_servicebus_scaler.go index ca079416e1e..e04a52768ed 100755 --- a/pkg/scalers/azure_servicebus_scaler.go +++ b/pkg/scalers/azure_servicebus_scaler.go @@ -3,6 +3,7 @@ package scalers import ( "context" "fmt" + "github.com/kedacore/keda/pkg/scalers/azure" "strconv" servicebus "github.com/Azure/azure-service-bus-go" @@ -173,7 +174,7 @@ type azureTokenProvider struct { // GetToken implements TokenProvider interface for azureTokenProvider func (azureTokenProvider) GetToken(uri string) (*auth.Token, error) { - token, err := getAzureADPodIdentityToken("https://servicebus.azure.net") + token, err := azure.GetAzureADPodIdentityToken("https://servicebus.azure.net") if err != nil { return nil, err } diff --git a/pkg/scalers/mysql_scaler_test.go b/pkg/scalers/mysql_scaler_test.go index 2efcb0f0f70..75fc46b494b 100644 --- a/pkg/scalers/mysql_scaler_test.go +++ b/pkg/scalers/mysql_scaler_test.go @@ -10,17 +10,17 @@ var testMySQLResolvedEnv = map[string]string{ } type parseMySQLMetadataTestData struct { - metdadata map[string] string + metdadata map[string]string raisesError bool } var testMySQLMetdata = []parseMySQLMetadataTestData{ // No metadata - {metdadata: map[string]string{}, raisesError:true}, + {metdadata: map[string]string{}, raisesError: true}, // connectionString - {metdadata: map[string]string{"query": "query", "queryValue": "12", "connectionString": "test_value"}, raisesError:false}, + {metdadata: map[string]string{"query": "query", "queryValue": "12", "connectionString": "test_value"}, raisesError: false}, // Params instead of conn str - {metdadata: map[string]string{"query": "query", "queryValue": "12", "host": "test_host", "port": "test_port", "username": "test_username", "password": "test_password", "dbName": "test_dbname"}, raisesError:false}, + {metdadata: map[string]string{"query": "query", "queryValue": "12", "host": "test_host", "port": "test_port", "username": "test_username", "password": "test_password", "dbName": "test_dbname"}, raisesError: false}, } func TestParseMySQLMetadata(t *testing.T) { @@ -55,4 +55,3 @@ func TestMetadataToConnectionStrBuildNew(t *testing.T) { t.Errorf("%s != %s", expected, connStr) } } - diff --git a/pkg/scalers/prometheus.go b/pkg/scalers/prometheus.go index ec0400a85ee..691098e4f1e 100644 --- a/pkg/scalers/prometheus.go +++ b/pkg/scalers/prometheus.go @@ -101,7 +101,7 @@ func (s *prometheusScaler) IsActive(ctx context.Context) (bool, error) { prometheusLog.Error(err, "error executing prometheus query") return false, err } - + return val > 0, nil } diff --git a/pkg/scalers/stan_scaler.go b/pkg/scalers/stan_scaler.go index ab9ea637c04..d8c983b9439 100644 --- a/pkg/scalers/stan_scaler.go +++ b/pkg/scalers/stan_scaler.go @@ -156,7 +156,7 @@ func (s *stanScaler) getMaxMsgLag() int64 { return s.channelInfo.LastSequence - maxValue } -func (s *stanScaler) hasPendingMessage() bool { +func (s *stanScaler) hasPendingMessage() bool { subscriberFound := false combinedQueueName := s.metadata.durableName + ":" + s.metadata.queueGroup