From 06c1100bd39e022e9467110e57b981d88eea342e Mon Sep 17 00:00:00 2001 From: Yongming Ding Date: Thu, 27 Oct 2022 10:09:42 -0700 Subject: [PATCH 1/5] Add NetworkPolicy Recommendation on Snowflake Backend In this commit, we add the NetworkPolicy Recommendation Application on the Snowflake backend. NetworkPolicy Recommendation is implemented as Snowflake UDFs and could be running on Snowflake warehouses. Users could start it using command policy-recommendation in theia-sf CLI tool. NetworkPolicy Recommendation UDFs are written in Python and stored under udf/ directory. Signed-off-by: Yongming Ding --- .github/workflows/python.yml | 39 ++ snowflake/Makefile | 1 + snowflake/README.md | 27 +- snowflake/cmd/policyRecommendation.go | 285 ++++++++ snowflake/main.go | 13 +- snowflake/pkg/infra/constants.go | 5 + snowflake/pkg/infra/manager.go | 176 +++++ snowflake/pkg/snowflake/snowflake.go | 37 ++ snowflake/pkg/udfs/udfs.go | 26 + snowflake/pkg/utils/timestamps/timestamps.go | 52 ++ snowflake/pkg/utils/utils.go | 40 ++ snowflake/udf/Makefile | 13 + .../udf/policy_recommendation/__init__.py | 0 .../udf/policy_recommendation/antrea_crd.py | 625 ++++++++++++++++++ .../policy_recommendation/create_function.sql | 57 ++ .../policy_recommendation_udf.py | 339 ++++++++++ .../policy_recommendation_udf_test.py | 362 ++++++++++ .../policy_recommendation_utils.py | 48 ++ .../preprocessing_udf.py | 86 +++ .../preprocessing_udf_test.py | 104 +++ .../static_policy_recommendation_udf.py | 107 +++ .../static_policy_recommendation_udf_test.py | 123 ++++ .../udf/policy_recommendation/version.txt | 1 + 23 files changed, 2563 insertions(+), 3 deletions(-) create mode 100644 snowflake/cmd/policyRecommendation.go create mode 100644 snowflake/pkg/udfs/udfs.go create mode 100644 snowflake/pkg/utils/timestamps/timestamps.go create mode 100644 snowflake/pkg/utils/utils.go create mode 100644 snowflake/udf/Makefile create mode 100644 snowflake/udf/policy_recommendation/__init__.py create mode 100644 snowflake/udf/policy_recommendation/antrea_crd.py create mode 100644 snowflake/udf/policy_recommendation/create_function.sql create mode 100644 snowflake/udf/policy_recommendation/policy_recommendation_udf.py create mode 100644 snowflake/udf/policy_recommendation/policy_recommendation_udf_test.py create mode 100644 snowflake/udf/policy_recommendation/policy_recommendation_utils.py create mode 100644 snowflake/udf/policy_recommendation/preprocessing_udf.py create mode 100644 snowflake/udf/policy_recommendation/preprocessing_udf_test.py create mode 100644 snowflake/udf/policy_recommendation/static_policy_recommendation_udf.py create mode 100644 snowflake/udf/policy_recommendation/static_policy_recommendation_udf_test.py create mode 100644 snowflake/udf/policy_recommendation/version.txt diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index d33ac2320..c8131b4e8 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -102,3 +102,42 @@ jobs: uses: codecov/codecov-action@v3 with: token: ${{ secrets.CODECOV_TOKEN }} + + check-udf-changes: + name: Check whether udf tests need to be run based on diff + runs-on: [ubuntu-latest] + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + - uses: antrea-io/has-changes@v2 + id: check_diff + with: + paths: snowflake/udf/* + outputs: + has_changes: ${{ steps.check_diff.outputs.has_changes }} + + test-udf: + needs: check-udf-changes + if: ${{ needs.check-udf-changes.outputs.has_changes == 'yes' }} + name: Udf test + strategy: + matrix: + os: [ubuntu-latest] + python-version: ["3.7"] + runs-on: ${{ matrix.os }} + steps: + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Check-out code + uses: actions/checkout@v3 + - name: Install dependencies + run: | + python -m pip install six psutil python-dateutil urllib3 requests pyyaml + wget https://downloads.antrea.io/artifacts/snowflake-udf/k8s-client-python-v24.2.0.zip + unzip k8s-client-python-v24.2.0.zip -d snowflake/udf/ + - name: Run udf tests + run: | + make -C snowflake/udf check diff --git a/snowflake/Makefile b/snowflake/Makefile index 4aa401731..4655e28a4 100644 --- a/snowflake/Makefile +++ b/snowflake/Makefile @@ -5,6 +5,7 @@ all: bin .PHONY: bin bin: + make -C udf/ $(GO) build -o $(BINDIR)/theia-sf antrea.io/theia/snowflake .PHONY: .coverage diff --git a/snowflake/README.md b/snowflake/README.md index 36f1b7465..add8cdcb7 100644 --- a/snowflake/README.md +++ b/snowflake/README.md @@ -15,6 +15,7 @@ - [Configure the Flow Aggregator in your cluster(s)](#configure-the-flow-aggregator-in-your-clusters) - [Clean up](#clean-up) - [Running applications](#running-applications) + - [NetworkPolicy Recommendation](#networkpolicy-recommendation) - [Network flow visibility with Grafana](#network-flow-visibility-with-grafana) - [Configure datasource](#configure-datasource) - [Deployments](#deployments) @@ -139,8 +140,30 @@ Snowflake credentials are required. ## Running applications -We are in the process of adding support for applications to Snowflake-powered -Theia, starting with NetworkPolicy recommendation. +### NetworkPolicy Recommendation + +NetworkPolicy Recommendation recommends the NetworkPolicy configuration +to secure Kubernetes network and applications. It analyzes the network flows +stored in the Snowflake database to generate +[Kubernetes NetworkPolicies]( +https://kubernetes.io/docs/concepts/services-networking/network-policies/) +or [Antrea NetworkPolicies]( +https://github.com/antrea-io/antrea/blob/main/docs/antrea-network-policy.md). + +```bash +# make sure you have called onboard before running policy-recommendation +./bin/theia-sf policy-recommendation --database-name > recommended_policies.yml +``` + +Database name can be found in the output of the [onboard](#getting-started) +command. + +NetworkPolicy Recommendation requires a Snowflake warehouse to execute and may +take seconds to minutes depending on the number of flows. We recommend using a +[Medium size warehouse](https://docs.snowflake.com/en/user-guide/warehouses-overview.html) +if you are working on a big dataset. If no warehouse is provided by the +`--warehouse-name` option, we will create a temporary X-Small size warehouse by +default. ## Network flow visibility with Grafana diff --git a/snowflake/cmd/policyRecommendation.go b/snowflake/cmd/policyRecommendation.go new file mode 100644 index 000000000..10ea8a80d --- /dev/null +++ b/snowflake/cmd/policyRecommendation.go @@ -0,0 +1,285 @@ +// Copyright 2022 Antrea Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cmd + +import ( + "context" + "fmt" + "time" + + "github.com/google/uuid" + "github.com/spf13/cobra" + + "antrea.io/theia/snowflake/pkg/infra" + "antrea.io/theia/snowflake/pkg/udfs" + "antrea.io/theia/snowflake/pkg/utils/timestamps" +) + +const ( + staticPolicyRecommendationFunctionName = "static_policy_recommendation" + preprocessingFunctionName = "preprocessing" + policyRecommendationFunctionName = "policy_recommendation" + defaultFunctionVersion = "v0.1.0" + defaultWaitTimeout = "10m" + // Limit the number of rows per partition to avoid hitting the 5 minutes end_partition() timeout. + partitionSizeLimit = 30000 +) + +func buildPolicyRecommendationUdfQuery(jobType string, limit uint, isolationMethod int, start string, end string, startTs string, endTs string, nsAllowList string, labelIgnoreList string, clusterUUID string, databaseName string, functionVersion string) (string, error) { + now := time.Now() + recommendationID := uuid.New().String() + functionName := udfs.GetFunctionName(staticPolicyRecommendationFunctionName, functionVersion) + query := fmt.Sprintf(`SELECT r.jobType, r.recommendationId, r.timeCreated, r.yamls FROM + TABLE(%s( + '%s', + '%s', + %d, + '%s' + ) over (partition by 1)) as r; +`, functionName, jobType, recommendationID, isolationMethod, nsAllowList) + + query += `WITH filtered_flows AS ( +SELECT + sourcePodNamespace, + sourcePodLabels, + destinationIP, + destinationPodNamespace, + destinationPodLabels, + destinationServicePortName, + destinationTransportPort, + protocolIdentifier, + flowType +FROM + flows +` + + query += `WHERE + ingressNetworkPolicyName IS NULL +AND + egressNetworkPolicyName IS NULL +` + + var startTime string + if startTs != "" { + startTime = startTs + } else if start != "" { + var err error + startTime, err = timestamps.ParseTimestamp(start, now) + if err != nil { + return "", err + } + } + if startTime != "" { + query += fmt.Sprintf(`AND + flowStartSeconds >= '%s' +`, startTime) + } + + var endTime string + if endTs != "" { + endTime = endTs + } else if end != "" { + var err error + endTime, err = timestamps.ParseTimestamp(end, now) + if err != nil { + return "", err + } + } + if endTime != "" { + query += fmt.Sprintf(`AND + flowEndSeconds >= '%s' +`, endTime) + } + + if clusterUUID != "" { + _, err := uuid.Parse(clusterUUID) + if err != nil { + return "", err + } + query += fmt.Sprintf(`AND + clusterUUID = '%s' +`, clusterUUID) + } else { + logger.Info("No clusterUUID input, all flows will be considered during policy recommendation.") + } + + query += `GROUP BY +sourcePodNamespace, +sourcePodLabels, +destinationIP, +destinationPodNamespace, +destinationPodLabels, +destinationServicePortName, +destinationTransportPort, +protocolIdentifier, +flowType + ` + + if limit > 0 { + query += fmt.Sprintf(` +LIMIT %d`, limit) + } else { + // limit the number unique flow records to 500k to avoid udf timeout + query += ` +LIMIT 500000` + } + + // Choose the destinationIP as the partition field for the preprocessing + // UDTF because flow rows could be divided into the most subsets + functionName = udfs.GetFunctionName(preprocessingFunctionName, functionVersion) + query += fmt.Sprintf(`), processed_flows AS (SELECT r.appliedTo, r.ingress, r.egress FROM filtered_flows AS f, +TABLE(%s( + '%s', + %d, + '%s', + '%s', + f.sourcePodNamespace, + f.sourcePodLabels, + f.destinationIP, + f.destinationPodNamespace, + f.destinationPodLabels, + f.destinationServicePortName, + f.destinationTransportPort, + f.protocolIdentifier, + f.flowType +) over (partition by f.destinationIP)) as r +`, functionName, jobType, isolationMethod, nsAllowList, labelIgnoreList) + + // Scan the row number for each appliedTo group and divide the partitions + // larger than partitionSizeLimit. + query += fmt.Sprintf(`), pf_with_index AS ( +SELECT + pf.appliedTo, + pf.ingress, + pf.egress, + floor((Row_number() over (partition by pf.appliedTo order by egress))/%d) as row_index +FROM processed_flows as pf +`, partitionSizeLimit) + + // Choose the appliedTo as the partition field for the policyRecommendation + // UDTF because each network policy is recommended based on all ingress and + // egress traffic related to an appliedTo group. + functionName = udfs.GetFunctionName(policyRecommendationFunctionName, functionVersion) + query += fmt.Sprintf(`) SELECT r.jobType, r.recommendationId, r.timeCreated, r.yamls FROM pf_with_index, +TABLE(%s( + '%s', + '%s', + %d, + '%s', + pf_with_index.appliedTo, + pf_with_index.ingress, + pf_with_index.egress +) over (partition by pf_with_index.appliedTo, pf_with_index.row_index)) as r +`, functionName, jobType, recommendationID, isolationMethod, nsAllowList) + + return query, nil +} + +// policyRecommendationCmd represents the policy-recommendation command +var policyRecommendationCmd = &cobra.Command{ + Use: "policy-recommendation", + Short: "Run the policy recommendation UDF in Snowflake", + Long: `This command runs the policy recommendation UDF in Snowflake. +You need to bring your own Snowflake account and created the policy +recommendation UDF using the create-udfs command first. + +Run policy recommendation with default configuration on database ANTREA_C9JR8KUKUIV4R72S: +"theia-sf policy-recommendation --database-name ANTREA_C9JR8KUKUIV4R72S" + +The "policy-recommendation" command requires a Snowflake warehouse to run policy +recommendation UDFs in Snowflake. By default, it will create a temporary one. +You can also bring your own by using the "--warehouse-name" parameter. +`, + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + jobType, _ := cmd.Flags().GetString("type") + if jobType != "initial" && jobType != "subsequent" { + return fmt.Errorf("invalid --type argument") + } + limit, _ := cmd.Flags().GetUint("limit") + isolationMethod, _ := cmd.Flags().GetInt("isolationMethod") + if isolationMethod < 1 && isolationMethod > 3 { + return fmt.Errorf("invalid -isolationMethod argument") + } + start, _ := cmd.Flags().GetString("start") + end, _ := cmd.Flags().GetString("end") + startTs, _ := cmd.Flags().GetString("start-ts") + endTs, _ := cmd.Flags().GetString("end-ts") + nsAllowList, _ := cmd.Flags().GetString("ns-allow") + labelIgnoreList, _ := cmd.Flags().GetString("label-ignore") + clusterUUID, _ := cmd.Flags().GetString("cluster-uuid") + databaseName, _ := cmd.Flags().GetString("database-name") + warehouseName, _ := cmd.Flags().GetString("warehouse-name") + functionVersion, _ := cmd.Flags().GetString("udf-version") + waitTimeout, _ := cmd.Flags().GetString("wait-timeout") + waitDuration, err := time.ParseDuration(waitTimeout) + if err != nil { + return fmt.Errorf("invalid --wait-timeout argument, err when parsing it as a duration: %v", err) + } + verbose := verbosity >= 2 + query, err := buildPolicyRecommendationUdfQuery(jobType, limit, isolationMethod, start, end, startTs, endTs, nsAllowList, labelIgnoreList, clusterUUID, databaseName, functionVersion) + if err != nil { + return err + } + ctx, cancel := context.WithTimeout(context.Background(), waitDuration) + defer cancel() + // stackName, stateBackendURL, secretsProviderURL, region, workdir are not provided here + // because we only uses snowflake client in this command. + mgr := infra.NewManager(logger, "", "", "", "", warehouseName, "", verbose) + rows, err := mgr.RunUdf(ctx, query, databaseName) + if err != nil { + return fmt.Errorf("error when running policy recommendation UDF: %w", err) + } + defer rows.Close() + + var recommendationID string + var timeCreated string + var yamls string + for cont := true; cont; cont = rows.NextResultSet() { + for rows.Next() { + if err := rows.Scan(&jobType, &recommendationID, &timeCreated, &yamls); err != nil { + return fmt.Errorf("invalid row: %w", err) + } + fmt.Printf("%s---\n", yamls) + } + } + return nil + }, +} + +func init() { + rootCmd.AddCommand(policyRecommendationCmd) + + policyRecommendationCmd.Flags().String("type", "initial", "Type of recommendation job (initial|subsequent), we only support initial jobType for now") + policyRecommendationCmd.Flags().Uint("limit", 0, "Limit on the number of flows to read, default it 0 (no limit)") + policyRecommendationCmd.Flags().Int("isolationMethod", 1, `Network isolation preference. Currently we have 3 options: +1: Recommending allow ANP/ACNP policies, with default deny rules only on Pods which have an allow rule applied +2: Recommending allow ANP/ACNP policies, with default deny rules for whole cluster +3: Recommending allow K8s NetworkPolicies only`) + policyRecommendationCmd.Flags().String("start", "", "Start time for flows, with reference to the current time (e.g., now-1h)") + policyRecommendationCmd.Flags().String("end", "", "End time for flows, with reference to the current timr (e.g., now)") + policyRecommendationCmd.Flags().String("start-ts", "", "Start time for flows, as a RFC3339 UTC timestamp (e.g., 2022-07-01T19:35:31Z)") + policyRecommendationCmd.Flags().String("end-ts", "", "End time for flows, as a RFC3339 UTC timestamp (e.g., 2022-07-01T19:35:31Z)") + policyRecommendationCmd.Flags().String("ns-allow", "kube-system,flow-aggregator,flow-visibility", "Namespaces with no restrictions") + policyRecommendationCmd.Flags().String("label-ignore", "pod-template-hash,controller-revision-hash,pod-template-generation", "Pod labels to be ignored when recommending NetworkPolicy") + policyRecommendationCmd.Flags().String("cluster-uuid", "", `UUID of the cluster for which policy recommendations will be generated +If no UUID is provided, all flows will be considered during policy recommendation`) + policyRecommendationCmd.Flags().String("database-name", "", "Snowflake database name to run policy recommendation, it can be found in the output of the onboard command") + policyRecommendationCmd.MarkFlagRequired("database-name") + policyRecommendationCmd.Flags().String("warehouse-name", "", "Snowflake Virtual Warehouse to use for running policy recommendation, by default we will use a temporary one") + policyRecommendationCmd.Flags().String("udf-version", defaultFunctionVersion, "Version of the UDF function to use") + policyRecommendationCmd.Flags().String("wait-timeout", defaultWaitTimeout, "Wait timeout of the recommendation job (e.g., 5m, 100s)") + +} diff --git a/snowflake/main.go b/snowflake/main.go index 337c4ced7..fc53a50ac 100644 --- a/snowflake/main.go +++ b/snowflake/main.go @@ -14,8 +14,19 @@ package main -import "antrea.io/theia/snowflake/cmd" +import ( + "embed" + + "antrea.io/theia/snowflake/cmd" + "antrea.io/theia/snowflake/pkg/infra" +) + +// Embed the udfs directory here because go:embed doesn't support embeding in subpackages + +//go:embed udf/* +var udfFs embed.FS func main() { + infra.UdfFs = udfFs cmd.Execute() } diff --git a/snowflake/pkg/infra/constants.go b/snowflake/pkg/infra/constants.go index 76038c577..d21c09f4f 100644 --- a/snowflake/pkg/infra/constants.go +++ b/snowflake/pkg/infra/constants.go @@ -54,4 +54,9 @@ const ( flowsTableName = "FLOWS" migrationsDir = "migrations" + + udfVersionPlaceholder = "%VERSION%" + udfCreateFunctionSQLFilename = "create_function.sql" + k8sPythonClientUrl = "https://downloads.antrea.io/artifacts/snowflake-udf/k8s-client-python-v24.2.0.zip" + k8sPythonClientFileName = "kubernetes.zip" ) diff --git a/snowflake/pkg/infra/manager.go b/snowflake/pkg/infra/manager.go index 2c6cbf33d..4022f40e8 100644 --- a/snowflake/pkg/infra/manager.go +++ b/snowflake/pkg/infra/manager.go @@ -19,6 +19,8 @@ import ( "compress/gzip" "context" "database/sql" + "embed" + "errors" "fmt" "io" "io/fs" @@ -26,6 +28,7 @@ import ( "os" "path/filepath" "runtime" + "strings" "github.com/go-logr/logr" "github.com/pulumi/pulumi/sdk/v3/go/auto" @@ -37,8 +40,11 @@ import ( "antrea.io/theia/snowflake/database" sf "antrea.io/theia/snowflake/pkg/snowflake" + utils "antrea.io/theia/snowflake/pkg/utils" ) +var UdfFs embed.FS + type pulumiPlugin struct { name string version string @@ -214,6 +220,14 @@ func installMigrateSnowflakeCLI(ctx context.Context, logger logr.Logger, dir str return nil } +func readVersionFromFile(path string) (string, error) { + b, err := os.ReadFile(path) + if err != nil { + return "", err + } + return strings.TrimSpace(string(b)), nil +} + type Manager struct { logger logr.Logger stackName string @@ -468,6 +482,11 @@ func (m *Manager) run(ctx context.Context, destroy bool) (*Result, error) { return nil, err } + err = createUdfs(ctx, logger, outs["databaseName"], warehouseName) + if err != nil { + return nil, err + } + return &Result{ Region: m.region, BucketName: outs["bucketID"], @@ -488,3 +507,160 @@ func (m *Manager) Offboard(ctx context.Context) error { _, err := m.run(ctx, true) return err } + +func createUdfs(ctx context.Context, logger logr.Logger, databaseName string, warehouseName string) error { + logger.Info("creating UDFs") + dsn, _, err := sf.GetDSN() + if err != nil { + return fmt.Errorf("failed to create DSN: %w", err) + } + + db, err := sql.Open("snowflake", dsn) + if err != nil { + return fmt.Errorf("failed to connect to Snowflake: %w", err) + } + defer db.Close() + + sfClient := sf.NewClient(db, logger) + + if err := sfClient.UseDatabase(ctx, databaseName); err != nil { + return err + } + + if err := sfClient.UseSchema(ctx, schemaName); err != nil { + return err + } + + if err := sfClient.UseWarehouse(ctx, warehouseName); err != nil { + return err + } + + // Download and stage Kubernetes python client for policy recommendation udf + err = utils.DownloadFile(k8sPythonClientUrl, k8sPythonClientFileName) + if err != nil { + return err + } + k8sPythonClientFilePath, _ := filepath.Abs(k8sPythonClientFileName) + err = sfClient.StageFile(ctx, k8sPythonClientFilePath, udfStageName) + if err != nil { + return err + } + defer func() { + err = os.Remove(k8sPythonClientFilePath) + if err != nil { + logger.Error(err, "Failed to delete Kubernetes python client zip file, please do it manually", "filepath", k8sPythonClientFilePath) + } + }() + + if err := fs.WalkDir(UdfFs, ".", func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if filepath.Ext(path) != ".zip" { + return nil + } + logger.Info("staging", "path", path) + directoryPath := path[:len(path)-4] + functionVersionPath := filepath.Join(directoryPath, "version.txt") + var version string + if _, err := os.Stat(functionVersionPath); errors.Is(err, os.ErrNotExist) { + logger.Info("did not find version.txt file for function") + version = "" + } else { + version, err = readVersionFromFile(functionVersionPath) + if err != nil { + return err + } + } + version = strings.ReplaceAll(version, ".", "_") + version = strings.ReplaceAll(version, "-", "_") + absPath, _ := filepath.Abs(path) + var pathWithVersion string + if version != "" { + pathWithVersion = fmt.Sprintf("%s_%s.zip", absPath[:len(absPath)-4], version) + } else { + // Don't add a version suffix if there is no version information + pathWithVersion = absPath + } + err = os.Rename(absPath, pathWithVersion) + if err != nil { + return err + } + err = sfClient.StageFile(ctx, pathWithVersion, udfStageName) + if err != nil { + return err + } + createFunctionSQLPath := filepath.Join(directoryPath, udfCreateFunctionSQLFilename) + if _, err := fs.Stat(UdfFs, createFunctionSQLPath); errors.Is(err, os.ErrNotExist) { + logger.Info("did not find SQL file to create function, skipping") + return nil + } + logger.Info("creating UDF", "from", createFunctionSQLPath, "version", version) + b, err := fs.ReadFile(UdfFs, createFunctionSQLPath) + if err != nil { + return err + } + query := string(b) + if !strings.Contains(query, udfVersionPlaceholder) { + return fmt.Errorf("version placeholder '%s' not found in SQL file", udfVersionPlaceholder) + } + query = strings.ReplaceAll(query, udfVersionPlaceholder, version) + _, err = sfClient.ExecMultiStatementQuery(ctx, query, false) + if err != nil { + return fmt.Errorf("error when creating UDF: %w", err) + } + return nil + }); err != nil { + return fmt.Errorf("creating failed: %w", err) + } + return nil +} + +func (m *Manager) RunUdf(ctx context.Context, query string, databaseName string) (*sql.Rows, error) { + logger := m.logger + logger.Info("Running UDF") + dsn, _, err := sf.GetDSN() + if err != nil { + return nil, fmt.Errorf("failed to create DSN: %w", err) + } + + db, err := sql.Open("snowflake", dsn) + if err != nil { + return nil, fmt.Errorf("failed to connect to Snowflake: %w", err) + } + defer db.Close() + + sfClient := sf.NewClient(db, logger) + + if err := sfClient.UseDatabase(ctx, databaseName); err != nil { + return nil, err + } + + if err := sfClient.UseSchema(ctx, schemaName); err != nil { + return nil, err + } + + warehouseName := m.warehouseName + if warehouseName == "" { + temporaryWarehouse := newTemporaryWarehouse(sfClient, logger) + warehouseName = temporaryWarehouse.Name() + if err := temporaryWarehouse.Create(ctx); err != nil { + return nil, err + } + defer func() { + if err := temporaryWarehouse.Delete(ctx); err != nil { + logger.Error(err, "Failed to delete temporary warehouse, please do it manually", "name", warehouseName) + } + }() + } + + if err := sfClient.UseWarehouse(ctx, warehouseName); err != nil { + return nil, err + } + + rows, err := sfClient.ExecMultiStatementQuery(ctx, query, true) + if err != nil { + return nil, fmt.Errorf("error when running UDF: %w", err) + } + return rows, nil +} diff --git a/snowflake/pkg/snowflake/snowflake.go b/snowflake/pkg/snowflake/snowflake.go index 014c74c2b..bbc31c214 100644 --- a/snowflake/pkg/snowflake/snowflake.go +++ b/snowflake/pkg/snowflake/snowflake.go @@ -21,6 +21,7 @@ import ( "strings" "github.com/go-logr/logr" + "github.com/snowflakedb/gosnowflake" ) type WarehouseSizeType string @@ -45,6 +46,9 @@ type Client interface { CreateWarehouse(ctx context.Context, name string, config WarehouseConfig) error UseWarehouse(ctx context.Context, name string) error DropWarehouse(ctx context.Context, name string) error + UseDatabase(ctx context.Context, name string) error + UseSchema(ctx context.Context, name string) error + StageFile(ctx context.Context, path string, stage string) error } type client struct { @@ -101,3 +105,36 @@ func (c *client) DropWarehouse(ctx context.Context, name string) error { _, err := c.db.ExecContext(ctx, query) return err } + +func (c *client) UseDatabase(ctx context.Context, name string) error { + query := fmt.Sprintf("USE DATABASE %s", name) + c.logger.V(2).Info("Snowflake query", "query", query) + _, err := c.db.ExecContext(ctx, query) + return err +} + +func (c *client) UseSchema(ctx context.Context, name string) error { + query := fmt.Sprintf("USE SCHEMA %s", name) + c.logger.Info("Snowflake query", "query", query) + _, err := c.db.ExecContext(ctx, query) + return err +} + +func (c *client) StageFile(ctx context.Context, path string, stage string) error { + query := fmt.Sprintf("PUT file://%s @%s AUTO_COMPRESS = FALSE OVERWRITE = TRUE", path, stage) + c.logger.Info("Snowflake query", "query", query) + _, err := c.db.ExecContext(ctx, query) + return err +} + +func (c *client) ExecMultiStatementQuery(ctx context.Context, query string, result bool) (*sql.Rows, error) { + multi_statement_context, _ := gosnowflake.WithMultiStatement(ctx, 0) + c.logger.Info("Snowflake query", "query", query) + if !result { + _, err := c.db.ExecContext(multi_statement_context, query) + return nil, err + } else { + rows, err := c.db.QueryContext(multi_statement_context, query) + return rows, err + } +} diff --git a/snowflake/pkg/udfs/udfs.go b/snowflake/pkg/udfs/udfs.go new file mode 100644 index 000000000..b644b2663 --- /dev/null +++ b/snowflake/pkg/udfs/udfs.go @@ -0,0 +1,26 @@ +// Copyright 2022 Antrea Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package udfs + +import ( + "fmt" + "strings" +) + +func GetFunctionName(baseName string, version string) string { + version = strings.ReplaceAll(version, ".", "_") + version = strings.ReplaceAll(version, "-", "_") + return fmt.Sprintf("%s_%s", baseName, version) +} diff --git a/snowflake/pkg/utils/timestamps/timestamps.go b/snowflake/pkg/utils/timestamps/timestamps.go new file mode 100644 index 000000000..1ebae5f1b --- /dev/null +++ b/snowflake/pkg/utils/timestamps/timestamps.go @@ -0,0 +1,52 @@ +// Copyright 2022 Antrea Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package timestamps + +import ( + "fmt" + "strings" + "time" +) + +func ParseTimestamp(t string, now time.Time, defaultT ...time.Time) (string, error) { + defaultTimestamp := now + if len(defaultT) > 0 { + defaultTimestamp = defaultT[0] + } + ts, err := func() (time.Time, error) { + fields := strings.Split(t, "-") + if len(fields) == 0 { + return defaultTimestamp, nil + } + if len(fields) > 1 && fields[0] != "now" { + return defaultTimestamp, fmt.Errorf("bad timestamp: %s", t) + } + if len(fields) == 1 { + return now, nil + } + if len(fields) == 2 { + d, err := time.ParseDuration(fields[1]) + if err != nil { + return defaultTimestamp, fmt.Errorf("bad timestamp: %s", t) + } + return now.Add(-d), nil + } + return defaultTimestamp, fmt.Errorf("bad timestamp: %s", t) + }() + if err != nil { + return "", nil + } + return ts.UTC().Format(time.RFC3339), nil +} diff --git a/snowflake/pkg/utils/utils.go b/snowflake/pkg/utils/utils.go new file mode 100644 index 000000000..15a74ebda --- /dev/null +++ b/snowflake/pkg/utils/utils.go @@ -0,0 +1,40 @@ +// Copyright 2022 Antrea Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "io" + "net/http" + "os" +) + +// Download a file from the given url to the current directory +func DownloadFile(url string, filename string) error { + resp, err := http.Get(url) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + return nil + } + file, err := os.Create(filename) + if err != nil { + return err + } + defer file.Close() + _, err = io.Copy(file, resp.Body) + return err +} diff --git a/snowflake/udf/Makefile b/snowflake/udf/Makefile new file mode 100644 index 000000000..8d5a80907 --- /dev/null +++ b/snowflake/udf/Makefile @@ -0,0 +1,13 @@ +.PHONY: all +all: policy_recommendation.zip + +policy_recommendation.zip: policy_recommendation/*.py + @zip $@ $^ + +.PHONY: clean +clean: + rm -f *.zip + +.PHONY: check +check: + PYTHONPATH="${PYTHONPATH}:$(CURDIR)" python3 -m unittest discover policy_recommendation "*_test.py" diff --git a/snowflake/udf/policy_recommendation/__init__.py b/snowflake/udf/policy_recommendation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/snowflake/udf/policy_recommendation/antrea_crd.py b/snowflake/udf/policy_recommendation/antrea_crd.py new file mode 100644 index 000000000..6a305510a --- /dev/null +++ b/snowflake/udf/policy_recommendation/antrea_crd.py @@ -0,0 +1,625 @@ +# Copyright 2022 Antrea Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http:#www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This library is used to define Antrea Network Policy related CRDs in Python. +# Code structure is following the Kubernetes Python Client library (https://github.com/kubernetes-client/python). +# This file could be changed to auto-generated by using openAPI generator in the future like the K8s python lib. + +class NetworkPolicy(object): + attribute_types = { + "kind": "string", + "api_version": "string", + "metadata": "kubernetes.client.V1ObjectMeta", + "spec": "NetworkPolicySpec", + "status": "NetworkPolicyStatus" + } + + def __init__(self, kind=None, api_version=None, metadata=None, spec=None, status=None): + self.kind = kind + self.api_version = api_version + self.metadata = metadata + self.spec = spec + self.status = status + + def to_dict(self): + """Returns the model properties as a dict""" + result = {} + + for attr, _ in self.attribute_types.items(): + value = getattr(self, attr) + if isinstance(value, list): + result[attr] = list(map( + lambda x: x.to_dict() if hasattr(x, "to_dict") else x, + value + )) + elif hasattr(value, "to_dict"): + result[attr] = value.to_dict() + elif isinstance(value, dict): + result[attr] = dict(map( + lambda item: (item[0], item[1].to_dict()) + if hasattr(item[1], "to_dict") else item, + value.items() + )) + else: + result[attr] = value + + return result + + +class NetworkPolicySpec(object): + attribute_types = { + "tier": "string", + "priority": "float", + "applied_to": "list[NetworkPolicyPeer]", + "ingress": "list[Rule]", + "egress": "list[Rule]" + } + + def __init__(self, tier=None, priority=None, applied_to=None, ingress=None, egress=None): + self.tier = tier + self.priority = priority + self.applied_to = applied_to + self.ingress = ingress + self.egress = egress + + def to_dict(self): + """Returns the model properties as a dict""" + result = {} + + for attr, _ in self.attribute_types.items(): + value = getattr(self, attr) + if isinstance(value, list): + result[attr] = list(map( + lambda x: x.to_dict() if hasattr(x, "to_dict") else x, + value + )) + elif hasattr(value, "to_dict"): + result[attr] = value.to_dict() + elif isinstance(value, dict): + result[attr] = dict(map( + lambda item: (item[0], item[1].to_dict()) + if hasattr(item[1], "to_dict") else item, + value.items() + )) + else: + result[attr] = value + + return result + +class NetworkPolicyPeer(object): + attribute_types = { + "ip_block": "IPBlock", + "pod_selector": "kubernetes.client.V1LabelSelector", + "namespace_selector": "kubernetes.client.V1LabelSelector", + "namespaces": "PeerNamespaces", + "external_entity_selector": "kubernetes.client.V1LabelSelector", + "group": "string", + "FQDN": "string" + } + + def __init__(self, ip_block=None, pod_selector=None, namespace_selector=None, namespaces=None, external_entity_selector=None, group=None, FQDN=None): + self.ip_block = ip_block + self.pod_selector = pod_selector + self.namespace_selector = namespace_selector + self.namespaces = namespaces + self.external_entity_selector = external_entity_selector + self.group = group + self.FQDN = FQDN + + def to_dict(self): + """Returns the model properties as a dict""" + result = {} + + for attr, _ in self.attribute_types.items(): + value = getattr(self, attr) + if isinstance(value, list): + result[attr] = list(map( + lambda x: x.to_dict() if hasattr(x, "to_dict") else x, + value + )) + elif hasattr(value, "to_dict"): + result[attr] = value.to_dict() + elif isinstance(value, dict): + result[attr] = dict(map( + lambda item: (item[0], item[1].to_dict()) + if hasattr(item[1], "to_dict") else item, + value.items() + )) + else: + result[attr] = value + + return result + + +class IPBlock(object): + attribute_types = { + "CIDR": "string" + } + + def __init__(self, CIDR=None): + self.CIDR = CIDR + + def to_dict(self): + """Returns the model properties as a dict""" + result = {} + + for attr, _ in self.attribute_types.items(): + value = getattr(self, attr) + if isinstance(value, list): + result[attr] = list(map( + lambda x: x.to_dict() if hasattr(x, "to_dict") else x, + value + )) + elif hasattr(value, "to_dict"): + result[attr] = value.to_dict() + elif isinstance(value, dict): + result[attr] = dict(map( + lambda item: (item[0], item[1].to_dict()) + if hasattr(item[1], "to_dict") else item, + value.items() + )) + else: + result[attr] = value + + return result + +class PeerNamespaces(object): + attribute_types = { + "Match": "string" + } + + def __init__(self, Match=None): + self.Match = Match + + def to_dict(self): + """Returns the model properties as a dict""" + result = {} + + for attr, _ in self.attribute_types.items(): + value = getattr(self, attr) + if isinstance(value, list): + result[attr] = list(map( + lambda x: x.to_dict() if hasattr(x, "to_dict") else x, + value + )) + elif hasattr(value, "to_dict"): + result[attr] = value.to_dict() + elif isinstance(value, dict): + result[attr] = dict(map( + lambda item: (item[0], item[1].to_dict()) + if hasattr(item[1], "to_dict") else item, + value.items() + )) + else: + result[attr] = value + + return result + +class Rule(object): + attribute_types = { + "action": "string", + "ports": "list[NetworkPolicyPort]", + "_from": "list[NetworkPolicyPeer]", + "to": "list[NetworkPolicyPeer]", + "to_services": "list[NamespacedName]", + "name": "string", + "enable_logging": "bool", + "applied_to": "list[NetworkPolicyPeer]" + } + + def __init__(self, action=None, ports=None, _from=None, to=None, to_services=None, name=None, enable_logging=None, applied_to=None): + self.action = action + self.ports = ports + self._from = _from + self.to = to + self.to_services = to_services + self.name = name + self.enable_logging = enable_logging + self.applied_to = applied_to + + def to_dict(self): + """Returns the model properties as a dict""" + result = {} + + for attr, _ in self.attribute_types.items(): + value = getattr(self, attr) + if isinstance(value, list): + result[attr] = list(map( + lambda x: x.to_dict() if hasattr(x, "to_dict") else x, + value + )) + elif hasattr(value, "to_dict"): + result[attr] = value.to_dict() + elif isinstance(value, dict): + result[attr] = dict(map( + lambda item: (item[0], item[1].to_dict()) + if hasattr(item[1], "to_dict") else item, + value.items() + )) + else: + result[attr] = value + + return result + +class NetworkPolicyPort(object): + attribute_types = { + "protocol": "string", + "port": "int or string", + "endport": "int", + } + + def __init__(self, protocol=None, port=None, endport=None): + self.protocol = protocol + self.port = port + self.endport = endport + + def to_dict(self): + """Returns the model properties as a dict""" + result = {} + + for attr, _ in self.attribute_types.items(): + value = getattr(self, attr) + if isinstance(value, list): + result[attr] = list(map( + lambda x: x.to_dict() if hasattr(x, "to_dict") else x, + value + )) + elif hasattr(value, "to_dict"): + result[attr] = value.to_dict() + elif isinstance(value, dict): + result[attr] = dict(map( + lambda item: (item[0], item[1].to_dict()) + if hasattr(item[1], "to_dict") else item, + value.items() + )) + else: + result[attr] = value + + return result + +class ClusterGroup(object): + attribute_types = { + "kind": "string", + "api_version": "string", + "metadata": "kubernetes.client.V1ObjectMeta", + "spec": "GroupSpec", + "status": "GroupStatus" + } + + def __init__(self, kind=None, api_version=None, metadata=None, spec=None, status=None): + self.kind = kind + self.api_version = api_version + self.metadata = metadata + self.spec = spec + self.status = status + + def to_dict(self): + """Returns the model properties as a dict""" + result = {} + + for attr, _ in self.attribute_types.items(): + value = getattr(self, attr) + if isinstance(value, list): + result[attr] = list(map( + lambda x: x.to_dict() if hasattr(x, "to_dict") else x, + value + )) + elif hasattr(value, "to_dict"): + result[attr] = value.to_dict() + elif isinstance(value, dict): + result[attr] = dict(map( + lambda item: (item[0], item[1].to_dict()) + if hasattr(item[1], "to_dict") else item, + value.items() + )) + else: + result[attr] = value + + return result + +class GroupSpec(object): + attribute_types = { + "pod_selector": "kubernetes.client.V1LabelSelector", + "namespace_selector": "kubernetes.client.V1LabelSelector", + "ip_blocks": "list[IPBlock]", + "service_reference": "ServiceReference", + "external_entity_selector": "kubernetes.client.V1LabelSelector", + "child_groups": "list[string]" + } + + def __init__(self, pod_selector=None, namespace_selector=None, ip_blocks=None, service_reference=None, external_entity_selector=None, child_groups=None): + self.pod_selector = pod_selector + self.namespace_selector = namespace_selector + self.ip_blocks = ip_blocks + self.service_reference = service_reference + self.external_entity_selector = external_entity_selector + self.child_groups = child_groups + + def to_dict(self): + """Returns the model properties as a dict""" + result = {} + + for attr, _ in self.attribute_types.items(): + value = getattr(self, attr) + if isinstance(value, list): + result[attr] = list(map( + lambda x: x.to_dict() if hasattr(x, "to_dict") else x, + value + )) + elif hasattr(value, "to_dict"): + result[attr] = value.to_dict() + elif isinstance(value, dict): + result[attr] = dict(map( + lambda item: (item[0], item[1].to_dict()) + if hasattr(item[1], "to_dict") else item, + value.items() + )) + else: + result[attr] = value + + return result + +class ServiceReference(object): + attribute_types = { + "name": "string", + "namespace": "string" + } + + def __init__(self, name=None, namespace=None): + self.name = name + self.namespace = namespace + + def to_dict(self): + """Returns the model properties as a dict""" + result = {} + + for attr, _ in self.attribute_types.items(): + value = getattr(self, attr) + if isinstance(value, list): + result[attr] = list(map( + lambda x: x.to_dict() if hasattr(x, "to_dict") else x, + value + )) + elif hasattr(value, "to_dict"): + result[attr] = value.to_dict() + elif isinstance(value, dict): + result[attr] = dict(map( + lambda item: (item[0], item[1].to_dict()) + if hasattr(item[1], "to_dict") else item, + value.items() + )) + else: + result[attr] = value + + return result + +class GroupStatus(object): + attribute_types = { + "conditions": "list[GroupCondition]" + } + + def __init__(self, conditions=None): + self.conditions = conditions + + def to_dict(self): + """Returns the model properties as a dict""" + result = {} + + for attr, _ in self.attribute_types.items(): + value = getattr(self, attr) + if isinstance(value, list): + result[attr] = list(map( + lambda x: x.to_dict() if hasattr(x, "to_dict") else x, + value + )) + elif hasattr(value, "to_dict"): + result[attr] = value.to_dict() + elif isinstance(value, dict): + result[attr] = dict(map( + lambda item: (item[0], item[1].to_dict()) + if hasattr(item[1], "to_dict") else item, + value.items() + )) + else: + result[attr] = value + + return result + +class GroupCondition(object): + attribute_types = { + "type": "string", + "status": "string", + "last_transition_time": "datetime", + } + + def __init__(self, type=None, status=None, last_transition_time=None): + self.type = type + self.status = status + self.last_transition_time = last_transition_time + + def to_dict(self): + """Returns the model properties as a dict""" + result = {} + + for attr, _ in self.attribute_types.items(): + value = getattr(self, attr) + if isinstance(value, list): + result[attr] = list(map( + lambda x: x.to_dict() if hasattr(x, "to_dict") else x, + value + )) + elif hasattr(value, "to_dict"): + result[attr] = value.to_dict() + elif isinstance(value, dict): + result[attr] = dict(map( + lambda item: (item[0], item[1].to_dict()) + if hasattr(item[1], "to_dict") else item, + value.items() + )) + else: + result[attr] = value + + return result + +class ClusterNetworkPolicy(object): + attribute_types = { + "kind": "string", + "api_version": "string", + "metadata": "kubernetes.client.V1ObjectMeta", + "spec": "ClusterNetworkPolicySpec", + "status": "NetworkPolicyStatus" + } + + def __init__(self, kind=None, api_version=None, metadata=None, spec=None, status=None): + self.kind = kind + self.api_version = api_version + self.metadata = metadata + self.spec = spec + self.status = status + + def to_dict(self): + """Returns the model properties as a dict""" + result = {} + + for attr, _ in self.attribute_types.items(): + value = getattr(self, attr) + if isinstance(value, list): + result[attr] = list(map( + lambda x: x.to_dict() if hasattr(x, "to_dict") else x, + value + )) + elif hasattr(value, "to_dict"): + result[attr] = value.to_dict() + elif isinstance(value, dict): + result[attr] = dict(map( + lambda item: (item[0], item[1].to_dict()) + if hasattr(item[1], "to_dict") else item, + value.items() + )) + else: + result[attr] = value + + return result + +class ClusterNetworkPolicySpec(object): + attribute_types = { + "tier": "string", + "priority": "float", + "applied_to": "list[NetworkPolicyPeer]", + "ingress": "list[Rule]", + "egress": "list[Rule]" + } + + def __init__(self, tier=None, priority=None, applied_to=None, ingress=None, egress=None): + self.tier = tier + self.priority = priority + self.applied_to = applied_to + self.ingress = ingress + self.egress = egress + + def to_dict(self): + """Returns the model properties as a dict""" + result = {} + + for attr, _ in self.attribute_types.items(): + value = getattr(self, attr) + if isinstance(value, list): + result[attr] = list(map( + lambda x: x.to_dict() if hasattr(x, "to_dict") else x, + value + )) + elif hasattr(value, "to_dict"): + result[attr] = value.to_dict() + elif isinstance(value, dict): + result[attr] = dict(map( + lambda item: (item[0], item[1].to_dict()) + if hasattr(item[1], "to_dict") else item, + value.items() + )) + else: + result[attr] = value + + return result + +class NetworkPolicyStatus(object): + attribute_types = { + "phase": "string", + "observed_generation": "int", + "current_nodes_realized": "int", + "desired_nodes_realized": "int" + } + + def __init__(self, phase=None, observed_generation=None, current_nodes_realized=None, desired_nodes_realized=None): + self.phase = phase + self.observed_generation = observed_generation + self.current_nodes_realized = current_nodes_realized + self.desired_nodes_realized = desired_nodes_realized + + def to_dict(self): + """Returns the model properties as a dict""" + result = {} + + for attr, _ in self.attribute_types.items(): + value = getattr(self, attr) + if isinstance(value, list): + result[attr] = list(map( + lambda x: x.to_dict() if hasattr(x, "to_dict") else x, + value + )) + elif hasattr(value, "to_dict"): + result[attr] = value.to_dict() + elif isinstance(value, dict): + result[attr] = dict(map( + lambda item: (item[0], item[1].to_dict()) + if hasattr(item[1], "to_dict") else item, + value.items() + )) + else: + result[attr] = value + + return result + +class NamespacedName(object): + attribute_types = { + "name": "string", + "namespace": "string" + } + + def __init__(self, name=None, namespace=None): + self.name = name + self.namespace = namespace + + def to_dict(self): + """Returns the model properties as a dict""" + result = {} + + for attr, _ in self.attribute_types.items(): + value = getattr(self, attr) + if isinstance(value, list): + result[attr] = list(map( + lambda x: x.to_dict() if hasattr(x, "to_dict") else x, + value + )) + elif hasattr(value, "to_dict"): + result[attr] = value.to_dict() + elif isinstance(value, dict): + result[attr] = dict(map( + lambda item: (item[0], item[1].to_dict()) + if hasattr(item[1], "to_dict") else item, + value.items() + )) + else: + result[attr] = value + + return result diff --git a/snowflake/udf/policy_recommendation/create_function.sql b/snowflake/udf/policy_recommendation/create_function.sql new file mode 100644 index 000000000..587e99174 --- /dev/null +++ b/snowflake/udf/policy_recommendation/create_function.sql @@ -0,0 +1,57 @@ +create or replace function preprocessing_%VERSION%( + jobType STRING(20), + isolationMethod NUMBER(1, 0), + nsAllowList STRING(10000), + labelIgnoreList STRING(10000), + sourcePodNamespace STRING(256), + sourcePodLabels STRING(10000), + destinationIP STRING(50), + destinationPodNamespace STRING(256), + destinationPodLabels STRING(10000), + destinationServicePortName STRING(256), + destinationTransportPort NUMBER(5, 0), + protocolIdentifier NUMBER(3, 0), + flowType NUMBER(3, 0) +) +returns table ( appliedTo STRING, + ingress STRING, + egress STRING ) +language python +runtime_version=3.8 +imports=('@UDFS/policy_recommendation_%VERSION%.zip') +handler='policy_recommendation/preprocessing_udf.PreProcessing'; + +create or replace function policy_recommendation_%VERSION%( + jobType STRING(20), + recommendationId STRING(40), + isolationMethod NUMBER(1, 0), + nsAllowList STRING(10000), + appliedTo STRING, + ingress STRING, + egress STRING +) +returns table ( jobType STRING(20), + recommendationId STRING(40), + timeCreated TIMESTAMP_NTZ, + yamls STRING ) +language python +runtime_version=3.8 +packages = ('six', 'python-dateutil', 'urllib3', 'requests', 'pyyaml') +imports=('@UDFS/policy_recommendation_%VERSION%.zip', '@UDFS/kubernetes.zip') +handler='policy_recommendation/policy_recommendation_udf.PolicyRecommendation'; + +create or replace function static_policy_recommendation_%VERSION%( + jobType STRING(20), + recommendationId STRING(40), + isolationMethod NUMBER(1, 0), + nsAllowList STRING(10000) +) +returns table ( jobType STRING(20), + recommendationId STRING(40), + timeCreated TIMESTAMP_NTZ, + yamls STRING ) +language python +runtime_version=3.8 +packages = ('six', 'python-dateutil', 'urllib3', 'requests', 'pyyaml') +imports=('@UDFS/policy_recommendation_%VERSION%.zip', '@UDFS/kubernetes.zip') +handler='policy_recommendation/static_policy_recommendation_udf.StaticPolicyRecommendation'; diff --git a/snowflake/udf/policy_recommendation/policy_recommendation_udf.py b/snowflake/udf/policy_recommendation/policy_recommendation_udf.py new file mode 100644 index 000000000..594391cbf --- /dev/null +++ b/snowflake/udf/policy_recommendation/policy_recommendation_udf.py @@ -0,0 +1,339 @@ +import datetime +import json +import random +import string +import uuid +import sys + +import kubernetes.client + +import policy_recommendation.antrea_crd as antrea_crd +from policy_recommendation.policy_recommendation_utils import * +from policy_recommendation.preprocessing_udf import ROW_DELIMITER + +DEFAULT_POLICY_PRIORITY = 5 + +def generate_policy_name(info): + return "-".join([info, "".join(random.sample(string.ascii_lowercase + string.digits, 5))]) + +def generate_k8s_egress_rule(egress): + if len(egress.split(ROW_DELIMITER)) == 4: + ns, labels, port, protocolIdentifier = egress.split(ROW_DELIMITER) + egress_peer = kubernetes.client.V1NetworkPolicyPeer( + namespace_selector = kubernetes.client.V1LabelSelector( + match_labels = { + "name":ns + } + ), + pod_selector = kubernetes.client.V1LabelSelector( + match_labels = json.loads(labels) + ), + ) + elif len(egress.split(ROW_DELIMITER)) == 3: + destinationIP, port, protocolIdentifier = egress.split(ROW_DELIMITER) + if get_IP_version(destinationIP) == "v4": + cidr = destinationIP + "/32" + else: + cidr = destinationIP + "/128" + egress_peer = kubernetes.client.V1NetworkPolicyPeer( + ip_block = kubernetes.client.V1IPBlock( + cidr = cidr, + ) + ) + else: + sys.exit(1) + ports = kubernetes.client.V1NetworkPolicyPort( + port = int(port), + protocol = protocolIdentifier + ) + egress_rule = kubernetes.client.V1NetworkPolicyEgressRule( + to = [egress_peer], + ports = [ports] + ) + return egress_rule + +def generate_k8s_ingress_rule(ingress): + if len(ingress.split(ROW_DELIMITER)) != 4: + sys.exit(1) + ns, labels, port, protocolIdentifier = ingress.split(ROW_DELIMITER) + ingress_peer = kubernetes.client.V1NetworkPolicyPeer( + namespace_selector = kubernetes.client.V1LabelSelector( + match_labels = { + "name":ns + } + ), + pod_selector = kubernetes.client.V1LabelSelector( + match_labels = json.loads(labels) + ), + ) + ports = kubernetes.client.V1NetworkPolicyPort( + port = int(port), + protocol = protocolIdentifier + ) + ingress_rule = kubernetes.client.V1NetworkPolicyIngressRule( + _from = [ingress_peer], + ports = [ports] + ) + return ingress_rule + +def generate_k8s_np(applied_to, ingresses, egresses, ns_allow_list): + ns, labels = applied_to.split(ROW_DELIMITER) + if ns in ns_allow_list: + return "" + ingress_list = sorted(list(ingresses)) + egress_list = sorted(list(egresses)) + egressRules = [] + for egress in egress_list: + if ROW_DELIMITER in egress: + egressRules.append(generate_k8s_egress_rule(egress)) + ingressRules = [] + for ingress in ingress_list: + if ROW_DELIMITER in ingress: + ingressRules.append(generate_k8s_ingress_rule(ingress)) + if egressRules or ingressRules: + policy_types = [] + if egressRules: + policy_types.append("Egress") + if ingressRules: + policy_types.append("Ingress") + np_name = generate_policy_name("recommend-k8s-np") + np = kubernetes.client.V1NetworkPolicy( + api_version = "networking.k8s.io/v1", + kind = "NetworkPolicy", + metadata = kubernetes.client.V1ObjectMeta( + name = np_name, + namespace = ns + ), + spec = kubernetes.client.V1NetworkPolicySpec( + egress = egressRules, + ingress = ingressRules, + pod_selector = kubernetes.client.V1LabelSelector( + match_labels = json.loads(labels) + ), + policy_types = policy_types + ) + ) + return dict_to_yaml(np.to_dict()) + else: + return "" + +def generate_anp_egress_rule(egress): + if len(egress.split(ROW_DELIMITER)) == 4: + # Pod-to-Pod flow + ns, labels, port, protocolIdentifier = egress.split(ROW_DELIMITER) + egress_peer = antrea_crd.NetworkPolicyPeer( + namespace_selector = kubernetes.client.V1LabelSelector( + match_labels = { + "kubernetes.io/metadata.name":ns + } + ), + pod_selector = kubernetes.client.V1LabelSelector( + match_labels = json.loads(labels) + ), + ) + ports = antrea_crd.NetworkPolicyPort( + protocol = protocolIdentifier, + port = int(port) + ) + egress_rule = antrea_crd.Rule( + action = "Allow", + to = [egress_peer], + ports = [ports] + ) + elif len(egress.split(ROW_DELIMITER)) == 3: + # Pod-to-External flow + destinationIP, port, protocolIdentifier = egress.split(ROW_DELIMITER) + if get_IP_version(destinationIP) == "v4": + cidr = destinationIP + "/32" + else: + cidr = destinationIP + "/128" + egress_peer = antrea_crd.NetworkPolicyPeer( + ip_block = antrea_crd.IPBlock( + CIDR = cidr, + ) + ) + ports = antrea_crd.NetworkPolicyPort( + protocol = protocolIdentifier, + port = int(port) + ) + egress_rule = antrea_crd.Rule( + action = "Allow", + to = [egress_peer], + ports = [ports] + ) + elif len(egress.split(ROW_DELIMITER)) == 2: + # Pod-to-Svc flow + svc_ns, svc_name = egress.split(ROW_DELIMITER) + egress_rule = antrea_crd.Rule( + action = "Allow", + to_services = [ + antrea_crd.NamespacedName( + namespace = svc_ns, + name = svc_name + ) + ] + ) + else: + sys.exit(1) + return egress_rule + +def generate_anp_ingress_rule(ingress): + if len(ingress.split(ROW_DELIMITER)) != 4: + sys.exit(1) + ns, labels, port, protocolIdentifier = ingress.split(ROW_DELIMITER) + ingress_peer = antrea_crd.NetworkPolicyPeer( + namespace_selector = kubernetes.client.V1LabelSelector( + match_labels = { + "kubernetes.io/metadata.name":ns + } + ), + pod_selector = kubernetes.client.V1LabelSelector( + match_labels = json.loads(labels) + ), + ) + ports = antrea_crd.NetworkPolicyPort( + protocol = protocolIdentifier, + port = int(port) + ) + ingress_rule = antrea_crd.Rule( + action = "Allow", + _from = [ingress_peer], + ports = [ports] + ) + return ingress_rule + +def generate_anp(applied_to, ingresses, egresses, ns_allow_list): + ns, labels = applied_to.split(ROW_DELIMITER) + if ns in ns_allow_list: + return "" + ingress_list = sorted(list(ingresses)) + egress_list = sorted(list(egresses)) + egressRules = [] + for egress in egress_list: + if ROW_DELIMITER in egress: + egress_rule = generate_anp_egress_rule(egress) + if egress_rule: + egressRules.append(egress_rule) + ingressRules = [] + for ingress in ingress_list: + if ROW_DELIMITER in ingress: + ingress_rule = generate_anp_ingress_rule(ingress) + if ingress_rule: + ingressRules.append(ingress_rule) + if egressRules or ingressRules: + np_name = generate_policy_name("recommend-allow-anp") + np = antrea_crd.NetworkPolicy( + api_version = "crd.antrea.io/v1alpha1", + kind = "NetworkPolicy", + metadata = kubernetes.client.V1ObjectMeta( + name = np_name, + namespace = ns, + ), + spec = antrea_crd.NetworkPolicySpec( + tier = "Application", + priority = DEFAULT_POLICY_PRIORITY, + applied_to = [antrea_crd.NetworkPolicyPeer( + pod_selector = kubernetes.client.V1LabelSelector( + match_labels = json.loads(labels) + ), + )], + egress = egressRules, + ingress = ingressRules, + ) + ) + return dict_to_yaml(np.to_dict()) + else: + return "" + +def generate_reject_acnp(applied_to, ns_allow_list): + ns, labels = applied_to.split(ROW_DELIMITER) + if ns in ns_allow_list: + return "" + np_name = generate_policy_name("recommend-reject-acnp") + applied_to = antrea_crd.NetworkPolicyPeer( + pod_selector = kubernetes.client.V1LabelSelector( + match_labels = json.loads(labels) + ), + namespace_selector = kubernetes.client.V1LabelSelector( + match_labels = { + "kubernetes.io/metadata.name":ns + } + ) + ) + np = antrea_crd.ClusterNetworkPolicy( + kind = "ClusterNetworkPolicy", + api_version = "crd.antrea.io/v1alpha1", + metadata = kubernetes.client.V1ObjectMeta( + name = np_name, + ), + spec = antrea_crd.NetworkPolicySpec( + tier = "Baseline", + priority = DEFAULT_POLICY_PRIORITY, + applied_to = [applied_to], + egress = [antrea_crd.Rule( + action = "Reject", + to = [antrea_crd.NetworkPolicyPeer( + pod_selector = kubernetes.client.V1LabelSelector())] + )], + ingress = [antrea_crd.Rule( + action = "Reject", + _from = [antrea_crd.NetworkPolicyPeer( + pod_selector = kubernetes.client.V1LabelSelector())] + )], + ) + ) + return dict_to_yaml(np.to_dict()) + +class Result: + def __init__(self, job_type, recommendation_id, policy): + self.job_type = job_type + if not recommendation_id: + self.recommendation_id = str(uuid.uuid4()) + else: + self.recommendation_id = recommendation_id + self.time_created = datetime.datetime.now() + self.yamls = policy + +class PolicyRecommendation: + def __init__(self): + self._ingresses = set() + self._egresses = set() + + def process(self, + jobType, + recommendationId, + isolationMethod, + nsAllowList, + appliedTo, + ingress, + egress): + assert(jobType == "initial") + # ideally this would be done in the constructor, but this is not + # supported in Snowflake (passing arguments once via the constructor) + # instead we will keep overriding self._jobType with the same value + self._jobType = jobType + self._recommendationId = recommendationId + self._isolationMethod = isolationMethod + self._nsAllowList = nsAllowList + self._applied_to = appliedTo + self._ingresses.add(ingress) + self._egresses.add(egress) + yield None + + def end_partition(self): + nsAllowList = self._nsAllowList.split(',') + if self._isolationMethod == 3: + allow_policy = generate_k8s_np(self._applied_to, self._ingresses, self._egresses, nsAllowList) + if allow_policy: + result = Result(self._jobType, self._recommendationId, allow_policy) + yield(result.job_type, result.recommendation_id, result.time_created, result.yamls) + else: + allow_policy = generate_anp(self._applied_to, self._ingresses, self._egresses, nsAllowList) + if allow_policy: + result = Result(self._jobType, self._recommendationId, allow_policy) + yield(result.job_type, result.recommendation_id, result.time_created, result.yamls) + if self._isolationMethod == 1: + reject_policy = generate_reject_acnp(self._applied_to, nsAllowList) + if reject_policy: + result = Result(self._jobType, self._recommendationId, reject_policy) + yield(result.job_type, result.recommendation_id, result.time_created, result.yamls) diff --git a/snowflake/udf/policy_recommendation/policy_recommendation_udf_test.py b/snowflake/udf/policy_recommendation/policy_recommendation_udf_test.py new file mode 100644 index 000000000..c19c4d1d5 --- /dev/null +++ b/snowflake/udf/policy_recommendation/policy_recommendation_udf_test.py @@ -0,0 +1,362 @@ +import unittest +import random + +from policy_recommendation_udf import * + +class TestPolicyRecommendation(unittest.TestCase): + flows_processed = [ + [ + [ + 'antrea-test#{"podname": "perftest-a"}', + '', + 'antrea-test#{"podname": "perftest-b"}#5201#TCP' + ], + [ + 'antrea-test#{"podname": "perftest-a"}', + '', + 'antrea-e2e#perftestsvc' + ], + [ + 'antrea-test#{"podname": "perftest-a"}', + '', + '192.168.0.1#80#TCP' + ], + ], + [ + [ + 'antrea-test#{"podname": "perftest-b"}', + 'antrea-test#{"podname": "perftest-a"}#5201#TCP', + '' + ] + ], + [ + [ + 'antrea-test#{"podname": "perftest-c"}', + 'antrea-test#{"podname": "perftest-a"}#5201#TCP', + '' + ] + ], + [ + [ + 'antrea-test#{"podname": "perftest-a"}', + '', + 'antrea-test#{"podname": "perftest-b"}#5201#TCP' + ], + [ + 'antrea-test#{"podname": "perftest-a"}', + '', + 'antrea-test#{"podname": "perftest-c"}#5201#TCP' + ], + [ + 'antrea-test#{"podname": "perftest-a"}', + '', + '192.168.0.1#80#TCP' + ], + ], + ] + + expected_k8s_policies = [ +"""apiVersion: networking.k8s.io/v1 +kind: NetworkPolicy +metadata: + name: recommend-k8s-np-y0cq6 + namespace: antrea-test +spec: + egress: + - ports: + - port: 80 + protocol: TCP + to: + - ipBlock: + cidr: 192.168.0.1/32 + - ports: + - port: 5201 + protocol: TCP + to: + - namespaceSelector: + matchLabels: + name: antrea-test + podSelector: + matchLabels: + podname: perftest-b + - ports: + - port: 5201 + protocol: TCP + to: + - namespaceSelector: + matchLabels: + name: antrea-test + podSelector: + matchLabels: + podname: perftest-c + ingress: [] + podSelector: + matchLabels: + podname: perftest-a + policyTypes: + - Egress +""", + +"""apiVersion: networking.k8s.io/v1 +kind: NetworkPolicy +metadata: + name: recommend-k8s-np-y0cq6 + namespace: antrea-test +spec: + egress: [] + ingress: + - from: + - namespaceSelector: + matchLabels: + name: antrea-test + podSelector: + matchLabels: + podname: perftest-a + ports: + - port: 5201 + protocol: TCP + podSelector: + matchLabels: + podname: perftest-b + policyTypes: + - Ingress +""", + +"""apiVersion: networking.k8s.io/v1 +kind: NetworkPolicy +metadata: + name: recommend-k8s-np-y0cq6 + namespace: antrea-test +spec: + egress: [] + ingress: + - from: + - namespaceSelector: + matchLabels: + name: antrea-test + podSelector: + matchLabels: + podname: perftest-a + ports: + - port: 5201 + protocol: TCP + podSelector: + matchLabels: + podname: perftest-c + policyTypes: + - Ingress +""", +] + + expected_allow_antrea_policies = [ +"""apiVersion: crd.antrea.io/v1alpha1 +kind: NetworkPolicy +metadata: + name: recommend-allow-anp-y0cq6 + namespace: antrea-test +spec: + appliedTo: + - podSelector: + matchLabels: + podname: perftest-a + egress: + - action: Allow + ports: + - port: 80 + protocol: TCP + to: + - ipBlock: + cidr: 192.168.0.1/32 + - action: Allow + toServices: + - name: perftestsvc + namespace: antrea-e2e + - action: Allow + ports: + - port: 5201 + protocol: TCP + to: + - namespaceSelector: + matchLabels: + kubernetes.io/metadata.name: antrea-test + podSelector: + matchLabels: + podname: perftest-b + ingress: [] + priority: 5 + tier: Application +""", + +"""apiVersion: crd.antrea.io/v1alpha1 +kind: NetworkPolicy +metadata: + name: recommend-allow-anp-y0cq6 + namespace: antrea-test +spec: + appliedTo: + - podSelector: + matchLabels: + podname: perftest-b + egress: [] + ingress: + - action: Allow + from: + - namespaceSelector: + matchLabels: + kubernetes.io/metadata.name: antrea-test + podSelector: + matchLabels: + podname: perftest-a + ports: + - port: 5201 + protocol: TCP + priority: 5 + tier: Application +""", + +"""apiVersion: crd.antrea.io/v1alpha1 +kind: NetworkPolicy +metadata: + name: recommend-allow-anp-y0cq6 + namespace: antrea-test +spec: + appliedTo: + - podSelector: + matchLabels: + podname: perftest-c + egress: [] + ingress: + - action: Allow + from: + - namespaceSelector: + matchLabels: + kubernetes.io/metadata.name: antrea-test + podSelector: + matchLabels: + podname: perftest-a + ports: + - port: 5201 + protocol: TCP + priority: 5 + tier: Application +""", +] + + + expected_reject_acnp = [ +"""apiVersion: crd.antrea.io/v1alpha1 +kind: ClusterNetworkPolicy +metadata: + name: recommend-reject-acnp-5zt4w +spec: + appliedTo: + - namespaceSelector: + matchLabels: + kubernetes.io/metadata.name: antrea-test + podSelector: + matchLabels: + podname: perftest-a + egress: + - action: Reject + to: + - podSelector: {} + ingress: + - action: Reject + from: + - podSelector: {} + priority: 5 + tier: Baseline +""", + +"""apiVersion: crd.antrea.io/v1alpha1 +kind: ClusterNetworkPolicy +metadata: + name: recommend-reject-acnp-5zt4w +spec: + appliedTo: + - namespaceSelector: + matchLabels: + kubernetes.io/metadata.name: antrea-test + podSelector: + matchLabels: + podname: perftest-b + egress: + - action: Reject + to: + - podSelector: {} + ingress: + - action: Reject + from: + - podSelector: {} + priority: 5 + tier: Baseline +""", + +"""apiVersion: crd.antrea.io/v1alpha1 +kind: ClusterNetworkPolicy +metadata: + name: recommend-reject-acnp-5zt4w +spec: + appliedTo: + - namespaceSelector: + matchLabels: + kubernetes.io/metadata.name: antrea-test + podSelector: + matchLabels: + podname: perftest-c + egress: + - action: Reject + to: + - podSelector: {} + ingress: + - action: Reject + from: + - podSelector: {} + priority: 5 + tier: Baseline +""", +] + + + def setup(self): + self.policy_recommendation = PolicyRecommendation() + + def process_flows(self, + flows, + jobType="initial", + isolationMethod=3, + nsAllowList="kube-system,flow-aggregator,flow-visibility" + ): + for flow in flows: + next(self.policy_recommendation.process( + jobType=jobType, + recommendationId="", + isolationMethod=isolationMethod, + nsAllowList=nsAllowList, + appliedTo=flow[0], + ingress=flow[1], + egress=flow[2] + )) + + def test_end_partition(self): + for isolationMethod, flows_processed, expected_policies in [ + (1, self.flows_processed[0], [self.expected_allow_antrea_policies[0]] + [self.expected_reject_acnp[0]]), + (1, self.flows_processed[1], [self.expected_allow_antrea_policies[1]] + [self.expected_reject_acnp[1]]), + (1, self.flows_processed[2], [self.expected_allow_antrea_policies[2]] + [self.expected_reject_acnp[2]]), + (2, self.flows_processed[0], [self.expected_allow_antrea_policies[0]]), + (2, self.flows_processed[1], [self.expected_allow_antrea_policies[1]]), + (2, self.flows_processed[2], [self.expected_allow_antrea_policies[2]]), + (3, self.flows_processed[3], [self.expected_k8s_policies[0]]), + (3, self.flows_processed[1], [self.expected_k8s_policies[1]]), + (3, self.flows_processed[2], [self.expected_k8s_policies[2]]), + ]: + self.setup() + self.process_flows(isolationMethod=isolationMethod, flows=flows_processed) + # Initialize the random number generator to get predictable generated policy names + random.seed(0) + for expected_policy, result in zip(expected_policies, self.policy_recommendation.end_partition()): + job_type, _, _, yamls = result + self.assertEqual(yamls, expected_policy) + +if __name__ == "__main__": + unittest.main() diff --git a/snowflake/udf/policy_recommendation/policy_recommendation_utils.py b/snowflake/udf/policy_recommendation/policy_recommendation_utils.py new file mode 100644 index 000000000..28eb5254e --- /dev/null +++ b/snowflake/udf/policy_recommendation/policy_recommendation_utils.py @@ -0,0 +1,48 @@ +# Copyright 2022 Antrea Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http:#www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ipaddress import ip_address, IPv4Address +import json +from re import sub +import yaml + +def is_intstring(s): + try: + int(s) + return True + except ValueError: + return False + +def get_IP_version(IP): + return "v4" if type(ip_address(IP)) is IPv4Address else "v6" + +def camel(s): + s = sub(r"(_|-)+", " ", s).title().replace(" ", "") + return s[0].lower() + s[1:] if s else "" + +def camel_dict(d): + result = {} + for key, value in d.items(): + if isinstance(value, list): + result[camel(key)] = list(map( + lambda x: camel_dict(x) if isinstance(x, dict) else x, value + )) + elif isinstance(value, dict) and key != "match_labels": + result[camel(key)] = camel_dict(value) + elif value is not None: + result[camel(key)] = value + return result + +def dict_to_yaml(d): + return yaml.dump(yaml.load(json.dumps(camel_dict(d)), Loader=yaml.FullLoader)) diff --git a/snowflake/udf/policy_recommendation/preprocessing_udf.py b/snowflake/udf/policy_recommendation/preprocessing_udf.py new file mode 100644 index 000000000..a4917831d --- /dev/null +++ b/snowflake/udf/policy_recommendation/preprocessing_udf.py @@ -0,0 +1,86 @@ +import json + +ROW_DELIMITER = "#" + +def parseLabels(labels, omitKeys = []): + if not labels: + return "{}" + # Just for PoC, generated records having labels in single-quote + labels = labels.replace("\'", "\"") + labels_dict = json.loads(labels) + labels_dict = { + key: value + for key, value in labels_dict.items() + if key not in omitKeys + } + return json.dumps(labels_dict, sort_keys=True) + +def get_flow_type(flowType, destinationServicePortName, destinationPodLabels): + if flowType == 3: + return "pod_to_external" + elif destinationServicePortName: + return "pod_to_svc" + elif destinationPodLabels: + return "pod_to_pod" + else: + return "pod_to_external" + +def get_protocol_string(protocolIdentifier): + if protocolIdentifier == 6: + return "TCP" + elif protocolIdentifier == 17: + return "UDP" + else: + return "UNKNOWN" + +class Result: + def __init__(self, applied_to, ingress, egress): + self.applied_to = applied_to + self.ingress = ingress + self.egress = egress + +class PreProcessing: + def __init__(self): + return + + def process(self, + jobType, + isolationMethod, + nsAllowList, + labelIgnoreList, + sourcePodNamespace, + sourcePodLabels, + destinationIP, + destinationPodNamespace, + destinationPodLabels, + destinationServicePortName, + destinationTransportPort, + protocolIdentifier, + flowType): + labelsToIgnore = [] + if labelIgnoreList: + labelsToIgnore = labelIgnoreList.split(',') + sourcePodLabels = parseLabels(sourcePodLabels, labelsToIgnore) + destinationPodLabels = parseLabels(destinationPodLabels, labelsToIgnore) + flowType = get_flow_type(flowType, destinationServicePortName, destinationPodLabels) + protocolIdentifier = get_protocol_string(protocolIdentifier) + + # Build row for source Pod as applied_to + applied_to = ROW_DELIMITER.join([sourcePodNamespace, sourcePodLabels]) + if flowType == "pod_to_external": + egress = ROW_DELIMITER.join([destinationIP, str(destinationTransportPort), protocolIdentifier]) + elif flowType == "pod_to_svc" and isolationMethod != 3: + # K8s policies don't support Pod to Service rules + svc_ns, svc_name = destinationServicePortName.partition(':')[0].split('/') + egress = ROW_DELIMITER.join([svc_ns, svc_name]) + else: + egress = ROW_DELIMITER.join([destinationPodNamespace, destinationPodLabels, str(destinationTransportPort), protocolIdentifier]) + row = Result(applied_to, "", egress) + yield(row.applied_to, row.ingress, row.egress) + + # Build row for destination Pod (if possible) as applied_to + if flowType != "pod_to_external": + applied_to = ROW_DELIMITER.join([destinationPodNamespace, destinationPodLabels]) + ingress = ROW_DELIMITER.join([sourcePodNamespace, sourcePodLabels, str(destinationTransportPort), protocolIdentifier]) + row = Result(applied_to, ingress, "") + yield(row.applied_to, row.ingress, row.egress) diff --git a/snowflake/udf/policy_recommendation/preprocessing_udf_test.py b/snowflake/udf/policy_recommendation/preprocessing_udf_test.py new file mode 100644 index 000000000..3beb4c887 --- /dev/null +++ b/snowflake/udf/policy_recommendation/preprocessing_udf_test.py @@ -0,0 +1,104 @@ +import unittest + +from preprocessing_udf import * + +class TestStaticPolicyRecommendation(unittest.TestCase): + flows_input = [ + ( + "antrea-test", + "{\"podname\":\"perftest-a\"}", + "10.10.0.5", + "antrea-test", + "{\"podname\":\"perftest-b\"}", + "", + 5201, + 6, + 1 + ), + ( + "antrea-test", + "{\"podname\":\"perftest-a\"}", + "10.10.0.6", + "antrea-test", + "{\"podname\":\"perftest-c\"}", + "antrea-e2e/perftestsvc:5201", + 5201, + 6, + 1 + ), + ( + "antrea-test", + "{\"podname\":\"perftest-a\"}", + "192.168.0.1", + "", + "", + "", + 80, + 6, + 3 + ) + ] + + flows_processed = [ + [ + [ + 'antrea-test#{"podname": "perftest-a"}', + '', + 'antrea-test#{"podname": "perftest-b"}#5201#TCP' + ], + [ + 'antrea-test#{"podname": "perftest-b"}', + 'antrea-test#{"podname": "perftest-a"}#5201#TCP', + '' + ] + ], + [ + [ + 'antrea-test#{"podname": "perftest-a"}', + '', + 'antrea-e2e#perftestsvc' + ], + [ + 'antrea-test#{"podname": "perftest-c"}', + 'antrea-test#{"podname": "perftest-a"}#5201#TCP', + '' + ] + ], + [ + [ + 'antrea-test#{"podname": "perftest-a"}', + '', + '192.168.0.1#80#TCP' + ], + ], + ] + + def setup(self): + self.preprocessing = PreProcessing() + + def test_process(self): + self.setup() + for flow_input, expected_flows_processed in zip(self.flows_input, self.flows_processed): + process_result = self.preprocessing.process( + jobType="initial", + isolationMethod=1, + nsAllowList="kube-system,flow-aggregator,flow-visibility", + labelIgnoreList="pod-template-hash,controller-revision-hash,pod-template-generation", + sourcePodNamespace=flow_input[0], + sourcePodLabels=flow_input[1], + destinationIP=flow_input[2], + destinationPodNamespace=flow_input[3], + destinationPodLabels=flow_input[4], + destinationServicePortName=flow_input[5], + destinationTransportPort=flow_input[6], + protocolIdentifier=flow_input[7], + flowType=flow_input[8] + ) + for flow_processed, expected_flow_processed in zip(process_result, expected_flows_processed): + applied_to, ingress, egress = flow_processed + self.assertEqual(applied_to, expected_flow_processed[0]) + self.assertEqual(ingress, expected_flow_processed[1]) + self.assertEqual(egress, expected_flow_processed[2]) + +if __name__ == "__main__": + unittest.main() diff --git a/snowflake/udf/policy_recommendation/static_policy_recommendation_udf.py b/snowflake/udf/policy_recommendation/static_policy_recommendation_udf.py new file mode 100644 index 000000000..ccbd1db27 --- /dev/null +++ b/snowflake/udf/policy_recommendation/static_policy_recommendation_udf.py @@ -0,0 +1,107 @@ +import datetime +import uuid + +import kubernetes.client + +import policy_recommendation.antrea_crd as antrea_crd +from policy_recommendation.policy_recommendation_utils import * +from policy_recommendation.policy_recommendation_udf import generate_policy_name, DEFAULT_POLICY_PRIORITY + +def recommend_policies_for_ns_allow_list(ns_allow_list): + policies = [] + for ns in ns_allow_list: + np_name = generate_policy_name("recommend-allow-acnp-{}".format(ns)) + acnp = antrea_crd.ClusterNetworkPolicy( + kind = "ClusterNetworkPolicy", + api_version = "crd.antrea.io/v1alpha1", + metadata = kubernetes.client.V1ObjectMeta( + name = np_name, + ), + spec = antrea_crd.NetworkPolicySpec( + tier = "Platform", + priority = DEFAULT_POLICY_PRIORITY, + applied_to = [antrea_crd.NetworkPolicyPeer( + namespace_selector = kubernetes.client.V1LabelSelector( + match_labels = { + "kubernetes.io/metadata.name":ns + } + ) + )], + egress = [antrea_crd.Rule( + action = "Allow", + to = [antrea_crd.NetworkPolicyPeer( + pod_selector = kubernetes.client.V1LabelSelector())] + )], + ingress = [antrea_crd.Rule( + action = "Allow", + _from = [antrea_crd.NetworkPolicyPeer( + pod_selector = kubernetes.client.V1LabelSelector())] + )], + ) + ) + policies.append(dict_to_yaml(acnp.to_dict())) + return policies + +def reject_all_acnp(): + np = antrea_crd.ClusterNetworkPolicy( + kind = "ClusterNetworkPolicy", + api_version = "crd.antrea.io/v1alpha1", + metadata = kubernetes.client.V1ObjectMeta( + name = "recommend-reject-all-acnp", + ), + spec = antrea_crd.NetworkPolicySpec( + tier = "Baseline", + priority = DEFAULT_POLICY_PRIORITY, + applied_to = [antrea_crd.NetworkPolicyPeer( + pod_selector = kubernetes.client.V1LabelSelector(), + namespace_selector = kubernetes.client.V1LabelSelector() + )], + egress = [antrea_crd.Rule( + action = "Reject", + to = [antrea_crd.NetworkPolicyPeer( + pod_selector = kubernetes.client.V1LabelSelector())] + )], + ingress = [antrea_crd.Rule( + action = "Reject", + _from = [antrea_crd.NetworkPolicyPeer( + pod_selector = kubernetes.client.V1LabelSelector())] + )], + ) + ) + return dict_to_yaml(np.to_dict()) + +class Result: + def __init__(self, job_type, recommendation_id, policy): + self.job_type = job_type + if not recommendation_id: + self.recommendation_id = str(uuid.uuid4()) + else: + self.recommendation_id = recommendation_id + self.time_created = datetime.datetime.now() + self.yamls = policy + +class StaticPolicyRecommendation: + def __init__(self): + return + + def process(self, + jobType, + recommendationId, + isolationMethod, + nsAllowList): + self._jobType = jobType + self._recommendationId = recommendationId + self._nsAllowList = nsAllowList + self._isolationMethod = isolationMethod + yield None + + def end_partition(self): + if self._nsAllowList: + ns_allow_policies = recommend_policies_for_ns_allow_list(self._nsAllowList.split(',')) + for policy in ns_allow_policies: + result = Result(self._jobType, self._recommendationId, policy) + yield(result.job_type, result.recommendation_id, result.time_created, result.yamls) + if self._isolationMethod == 2: + reject_all_policy = reject_all_acnp() + result = Result(self._jobType, self._recommendationId, reject_all_policy) + yield(result.job_type, result.recommendation_id, result.time_created, result.yamls) diff --git a/snowflake/udf/policy_recommendation/static_policy_recommendation_udf_test.py b/snowflake/udf/policy_recommendation/static_policy_recommendation_udf_test.py new file mode 100644 index 000000000..f84e7fc70 --- /dev/null +++ b/snowflake/udf/policy_recommendation/static_policy_recommendation_udf_test.py @@ -0,0 +1,123 @@ +import unittest +import random + +from static_policy_recommendation_udf import * + +class TestStaticPolicyRecommendation(unittest.TestCase): + expected_ns_allow_policies = [ +"""apiVersion: crd.antrea.io/v1alpha1 +kind: ClusterNetworkPolicy +metadata: + name: recommend-allow-acnp-kube-system-y0cq6 +spec: + appliedTo: + - namespaceSelector: + matchLabels: + kubernetes.io/metadata.name: kube-system + egress: + - action: Allow + to: + - podSelector: {} + ingress: + - action: Allow + from: + - podSelector: {} + priority: 5 + tier: Platform +""", + +"""apiVersion: crd.antrea.io/v1alpha1 +kind: ClusterNetworkPolicy +metadata: + name: recommend-allow-acnp-flow-aggregator-5zt4w +spec: + appliedTo: + - namespaceSelector: + matchLabels: + kubernetes.io/metadata.name: flow-aggregator + egress: + - action: Allow + to: + - podSelector: {} + ingress: + - action: Allow + from: + - podSelector: {} + priority: 5 + tier: Platform +""", + +"""apiVersion: crd.antrea.io/v1alpha1 +kind: ClusterNetworkPolicy +metadata: + name: recommend-allow-acnp-flow-visibility-n6isg +spec: + appliedTo: + - namespaceSelector: + matchLabels: + kubernetes.io/metadata.name: flow-visibility + egress: + - action: Allow + to: + - podSelector: {} + ingress: + - action: Allow + from: + - podSelector: {} + priority: 5 + tier: Platform +""", +] + + expected_reject_all_acnp = [ +"""apiVersion: crd.antrea.io/v1alpha1 +kind: ClusterNetworkPolicy +metadata: + name: recommend-reject-all-acnp +spec: + appliedTo: + - namespaceSelector: {} + podSelector: {} + egress: + - action: Reject + to: + - podSelector: {} + ingress: + - action: Reject + from: + - podSelector: {} + priority: 5 + tier: Baseline +""" +] + + def setup(self): + self.static_policy_recommendation = StaticPolicyRecommendation() + + def process(self, + jobType="initial", + recommendationId="", + isolationMethod=1, + nsAllowList=""): + next(self.static_policy_recommendation.process( + jobType=jobType, + recommendationId=recommendationId, + isolationMethod=isolationMethod, + nsAllowList=nsAllowList, + )) + + def test_end_partition(self): + for isolationMethod, nsAllowList, expected_policies in [ + (1, "kube-system,flow-aggregator,flow-visibility", self.expected_ns_allow_policies), + (2, "", self.expected_reject_all_acnp), + (3, "", []), + ]: + self.setup() + self.process(isolationMethod=isolationMethod, nsAllowList=nsAllowList) + random.seed(0) + for expected_policy, result in zip(expected_policies, self.static_policy_recommendation.end_partition()): + _, _, _, yamls = result + self.assertEqual(yamls, expected_policy) + +if __name__ == "__main__": + unittest.main() diff --git a/snowflake/udf/policy_recommendation/version.txt b/snowflake/udf/policy_recommendation/version.txt new file mode 100644 index 000000000..b82608c0b --- /dev/null +++ b/snowflake/udf/policy_recommendation/version.txt @@ -0,0 +1 @@ +v0.1.0 From ee680137279d813491f072fa29811dfbc50e3a8a Mon Sep 17 00:00:00 2001 From: Yongming Ding Date: Wed, 30 Nov 2022 11:59:31 -0800 Subject: [PATCH 2/5] Address comments Signed-off-by: Yongming Ding --- snowflake/README.md | 4 +- snowflake/cmd/policyRecommendation.go | 66 +++++++----- snowflake/pkg/infra/manager.go | 107 +------------------ snowflake/pkg/infra/udfs.go | 79 ++++++++++++++ snowflake/pkg/udfs/udfs.go | 26 ----- snowflake/pkg/utils/timestamps/timestamps.go | 14 +-- snowflake/pkg/utils/utils.go | 68 ++++++++++-- 7 files changed, 185 insertions(+), 179 deletions(-) create mode 100644 snowflake/pkg/infra/udfs.go delete mode 100644 snowflake/pkg/udfs/udfs.go diff --git a/snowflake/README.md b/snowflake/README.md index add8cdcb7..e1d63b6fe 100644 --- a/snowflake/README.md +++ b/snowflake/README.md @@ -163,7 +163,9 @@ take seconds to minutes depending on the number of flows. We recommend using a [Medium size warehouse](https://docs.snowflake.com/en/user-guide/warehouses-overview.html) if you are working on a big dataset. If no warehouse is provided by the `--warehouse-name` option, we will create a temporary X-Small size warehouse by -default. +default. Running NetworkPolicy Recommendation will consume Snowflake credits, +the amount of which will depend on the size of the warehouse and the contents +of the database. ## Network flow visibility with Grafana diff --git a/snowflake/cmd/policyRecommendation.go b/snowflake/cmd/policyRecommendation.go index 10ea8a80d..8136556f0 100644 --- a/snowflake/cmd/policyRecommendation.go +++ b/snowflake/cmd/policyRecommendation.go @@ -17,13 +17,13 @@ package cmd import ( "context" "fmt" + "strings" "time" "github.com/google/uuid" "github.com/spf13/cobra" "antrea.io/theia/snowflake/pkg/infra" - "antrea.io/theia/snowflake/pkg/udfs" "antrea.io/theia/snowflake/pkg/utils/timestamps" ) @@ -34,14 +34,28 @@ const ( defaultFunctionVersion = "v0.1.0" defaultWaitTimeout = "10m" // Limit the number of rows per partition to avoid hitting the 5 minutes end_partition() timeout. - partitionSizeLimit = 30000 + partitionSizeLimit = 50000 ) -func buildPolicyRecommendationUdfQuery(jobType string, limit uint, isolationMethod int, start string, end string, startTs string, endTs string, nsAllowList string, labelIgnoreList string, clusterUUID string, databaseName string, functionVersion string) (string, error) { +func buildPolicyRecommendationUdfQuery( + jobType string, + limit uint, + isolationMethod int, + start string, + end string, + startTs string, + endTs string, + nsAllowList string, + labelIgnoreList string, + clusterUUID string, + databaseName string, + functionVersion string, +) (string, error) { now := time.Now() recommendationID := uuid.New().String() - functionName := udfs.GetFunctionName(staticPolicyRecommendationFunctionName, functionVersion) - query := fmt.Sprintf(`SELECT r.jobType, r.recommendationId, r.timeCreated, r.yamls FROM + functionName := infra.GetFunctionName(staticPolicyRecommendationFunctionName, functionVersion) + var queryBuilder strings.Builder + fmt.Fprintf(&queryBuilder, `SELECT r.jobType, r.recommendationId, r.timeCreated, r.yamls FROM TABLE(%s( '%s', '%s', @@ -50,7 +64,7 @@ func buildPolicyRecommendationUdfQuery(jobType string, limit uint, isolationMeth ) over (partition by 1)) as r; `, functionName, jobType, recommendationID, isolationMethod, nsAllowList) - query += `WITH filtered_flows AS ( + queryBuilder.WriteString(`WITH filtered_flows AS ( SELECT sourcePodNamespace, sourcePodLabels, @@ -63,13 +77,11 @@ SELECT flowType FROM flows -` - - query += `WHERE +WHERE ingressNetworkPolicyName IS NULL AND egressNetworkPolicyName IS NULL -` +`) var startTime string if startTs != "" { @@ -82,7 +94,7 @@ AND } } if startTime != "" { - query += fmt.Sprintf(`AND + fmt.Fprintf(&queryBuilder, `AND flowStartSeconds >= '%s' `, startTime) } @@ -98,7 +110,7 @@ AND } } if endTime != "" { - query += fmt.Sprintf(`AND + fmt.Fprintf(&queryBuilder, `AND flowEndSeconds >= '%s' `, endTime) } @@ -108,14 +120,14 @@ AND if err != nil { return "", err } - query += fmt.Sprintf(`AND + fmt.Fprintf(&queryBuilder, `AND clusterUUID = '%s' `, clusterUUID) } else { logger.Info("No clusterUUID input, all flows will be considered during policy recommendation.") } - query += `GROUP BY + queryBuilder.WriteString(`GROUP BY sourcePodNamespace, sourcePodLabels, destinationIP, @@ -125,21 +137,21 @@ destinationServicePortName, destinationTransportPort, protocolIdentifier, flowType - ` +`) if limit > 0 { - query += fmt.Sprintf(` + fmt.Fprintf(&queryBuilder, ` LIMIT %d`, limit) } else { // limit the number unique flow records to 500k to avoid udf timeout - query += ` -LIMIT 500000` + queryBuilder.WriteString(` +LIMIT 500000`) } // Choose the destinationIP as the partition field for the preprocessing // UDTF because flow rows could be divided into the most subsets - functionName = udfs.GetFunctionName(preprocessingFunctionName, functionVersion) - query += fmt.Sprintf(`), processed_flows AS (SELECT r.appliedTo, r.ingress, r.egress FROM filtered_flows AS f, + functionName = infra.GetFunctionName(preprocessingFunctionName, functionVersion) + fmt.Fprintf(&queryBuilder, `), processed_flows AS (SELECT r.appliedTo, r.ingress, r.egress FROM filtered_flows AS f, TABLE(%s( '%s', %d, @@ -159,7 +171,7 @@ TABLE(%s( // Scan the row number for each appliedTo group and divide the partitions // larger than partitionSizeLimit. - query += fmt.Sprintf(`), pf_with_index AS ( + fmt.Fprintf(&queryBuilder, `), pf_with_index AS ( SELECT pf.appliedTo, pf.ingress, @@ -171,8 +183,8 @@ FROM processed_flows as pf // Choose the appliedTo as the partition field for the policyRecommendation // UDTF because each network policy is recommended based on all ingress and // egress traffic related to an appliedTo group. - functionName = udfs.GetFunctionName(policyRecommendationFunctionName, functionVersion) - query += fmt.Sprintf(`) SELECT r.jobType, r.recommendationId, r.timeCreated, r.yamls FROM pf_with_index, + functionName = infra.GetFunctionName(policyRecommendationFunctionName, functionVersion) + fmt.Fprintf(&queryBuilder, `) SELECT r.jobType, r.recommendationId, r.timeCreated, r.yamls FROM pf_with_index, TABLE(%s( '%s', '%s', @@ -184,7 +196,7 @@ TABLE(%s( ) over (partition by pf_with_index.appliedTo, pf_with_index.row_index)) as r `, functionName, jobType, recommendationID, isolationMethod, nsAllowList) - return query, nil + return queryBuilder.String(), nil } // policyRecommendationCmd represents the policy-recommendation command @@ -228,17 +240,13 @@ You can also bring your own by using the "--warehouse-name" parameter. if err != nil { return fmt.Errorf("invalid --wait-timeout argument, err when parsing it as a duration: %v", err) } - verbose := verbosity >= 2 query, err := buildPolicyRecommendationUdfQuery(jobType, limit, isolationMethod, start, end, startTs, endTs, nsAllowList, labelIgnoreList, clusterUUID, databaseName, functionVersion) if err != nil { return err } ctx, cancel := context.WithTimeout(context.Background(), waitDuration) defer cancel() - // stackName, stateBackendURL, secretsProviderURL, region, workdir are not provided here - // because we only uses snowflake client in this command. - mgr := infra.NewManager(logger, "", "", "", "", warehouseName, "", verbose) - rows, err := mgr.RunUdf(ctx, query, databaseName) + rows, err := infra.RunUdf(ctx, logger, query, databaseName, warehouseName) if err != nil { return fmt.Errorf("error when running policy recommendation UDF: %w", err) } diff --git a/snowflake/pkg/infra/manager.go b/snowflake/pkg/infra/manager.go index 4022f40e8..6a1dd5079 100644 --- a/snowflake/pkg/infra/manager.go +++ b/snowflake/pkg/infra/manager.go @@ -15,8 +15,6 @@ package infra import ( - "archive/tar" - "compress/gzip" "context" "database/sql" "embed" @@ -24,7 +22,6 @@ import ( "fmt" "io" "io/fs" - "net/http" "os" "path/filepath" "runtime" @@ -95,55 +92,6 @@ func writeMigrationsToDisk(fsys fs.FS, migrationsPath string, dest string) error return nil } -func downloadAndUntar(ctx context.Context, logger logr.Logger, url string, dir string) error { - req, err := http.NewRequestWithContext(ctx, "GET", url, nil) - if err != nil { - return err - } - client := http.DefaultClient - resp, err := client.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - gzr, err := gzip.NewReader(resp.Body) - if err != nil { - return err - } - defer gzr.Close() - tr := tar.NewReader(gzr) - for { - hdr, err := tr.Next() - if err == io.EOF { - break // End of archive - } - if err != nil { - return err - } - dest := filepath.Join(dir, hdr.Name) - logger.V(4).Info("Untarring", "path", hdr.Name) - if hdr.Typeflag != tar.TypeReg { - continue - } - if err := func() error { - f, err := os.OpenFile(dest, os.O_CREATE|os.O_RDWR, os.FileMode(hdr.Mode)) - if err != nil { - return err - } - defer f.Close() - - // copy over contents - if _, err := io.Copy(f, tr); err != nil { - return err - } - return nil - }(); err != nil { - return err - } - } - return nil -} - func installPulumiCLI(ctx context.Context, logger logr.Logger, dir string) error { logger.Info("Downloading and installing Pulumi", "version", pulumiVersion) cachedVersion, err := os.ReadFile(filepath.Join(dir, ".pulumi-version")) @@ -176,7 +124,7 @@ func installPulumiCLI(ctx context.Context, logger logr.Logger, dir string) error if err := os.MkdirAll(filepath.Join(dir, "pulumi"), 0755); err != nil { return err } - if err := downloadAndUntar(ctx, logger, url, dir); err != nil { + if err := utils.DownloadAndUntar(ctx, logger, url, dir, "", true); err != nil { return err } @@ -209,7 +157,7 @@ func installMigrateSnowflakeCLI(ctx context.Context, logger logr.Logger, dir str return fmt.Errorf("OS / arch combination is not supported: %s / %s", operatingSystem, arch) } url := fmt.Sprintf("https://github.com/antoninbas/migrate-snowflake/releases/download/%s/migrate-snowflake_%s_%s.tar.gz", migrateSnowflakeVersion, migrateSnowflakeVersion, target) - if err := downloadAndUntar(ctx, logger, url, dir); err != nil { + if err := utils.DownloadAndUntar(ctx, logger, url, dir, "", true); err != nil { return err } @@ -536,7 +484,7 @@ func createUdfs(ctx context.Context, logger logr.Logger, databaseName string, wa } // Download and stage Kubernetes python client for policy recommendation udf - err = utils.DownloadFile(k8sPythonClientUrl, k8sPythonClientFileName) + err = utils.DownloadAndUntar(ctx, logger, k8sPythonClientUrl, ".", k8sPythonClientFileName, false) if err != nil { return err } @@ -615,52 +563,3 @@ func createUdfs(ctx context.Context, logger logr.Logger, databaseName string, wa } return nil } - -func (m *Manager) RunUdf(ctx context.Context, query string, databaseName string) (*sql.Rows, error) { - logger := m.logger - logger.Info("Running UDF") - dsn, _, err := sf.GetDSN() - if err != nil { - return nil, fmt.Errorf("failed to create DSN: %w", err) - } - - db, err := sql.Open("snowflake", dsn) - if err != nil { - return nil, fmt.Errorf("failed to connect to Snowflake: %w", err) - } - defer db.Close() - - sfClient := sf.NewClient(db, logger) - - if err := sfClient.UseDatabase(ctx, databaseName); err != nil { - return nil, err - } - - if err := sfClient.UseSchema(ctx, schemaName); err != nil { - return nil, err - } - - warehouseName := m.warehouseName - if warehouseName == "" { - temporaryWarehouse := newTemporaryWarehouse(sfClient, logger) - warehouseName = temporaryWarehouse.Name() - if err := temporaryWarehouse.Create(ctx); err != nil { - return nil, err - } - defer func() { - if err := temporaryWarehouse.Delete(ctx); err != nil { - logger.Error(err, "Failed to delete temporary warehouse, please do it manually", "name", warehouseName) - } - }() - } - - if err := sfClient.UseWarehouse(ctx, warehouseName); err != nil { - return nil, err - } - - rows, err := sfClient.ExecMultiStatementQuery(ctx, query, true) - if err != nil { - return nil, fmt.Errorf("error when running UDF: %w", err) - } - return rows, nil -} diff --git a/snowflake/pkg/infra/udfs.go b/snowflake/pkg/infra/udfs.go new file mode 100644 index 000000000..cb2dd7702 --- /dev/null +++ b/snowflake/pkg/infra/udfs.go @@ -0,0 +1,79 @@ +// Copyright 2022 Antrea Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package infra + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/go-logr/logr" + + sf "antrea.io/theia/snowflake/pkg/snowflake" +) + +func GetFunctionName(baseName string, version string) string { + version = strings.ReplaceAll(version, ".", "_") + version = strings.ReplaceAll(version, "-", "_") + return fmt.Sprintf("%s_%s", baseName, version) +} + +func RunUdf(ctx context.Context, logger logr.Logger, query string, databaseName string, warehouseName string) (*sql.Rows, error) { + logger.Info("Running UDF") + dsn, _, err := sf.GetDSN() + if err != nil { + return nil, fmt.Errorf("failed to create DSN: %w", err) + } + + db, err := sql.Open("snowflake", dsn) + if err != nil { + return nil, fmt.Errorf("failed to connect to Snowflake: %w", err) + } + defer db.Close() + + sfClient := sf.NewClient(db, logger) + + if err := sfClient.UseDatabase(ctx, databaseName); err != nil { + return nil, err + } + + if err := sfClient.UseSchema(ctx, schemaName); err != nil { + return nil, err + } + + if warehouseName == "" { + temporaryWarehouse := newTemporaryWarehouse(sfClient, logger) + warehouseName = temporaryWarehouse.Name() + if err := temporaryWarehouse.Create(ctx); err != nil { + return nil, err + } + defer func() { + if err := temporaryWarehouse.Delete(ctx); err != nil { + logger.Error(err, "Failed to delete temporary warehouse, please do it manually", "name", warehouseName) + } + }() + } + + if err := sfClient.UseWarehouse(ctx, warehouseName); err != nil { + return nil, err + } + + rows, err := sfClient.ExecMultiStatementQuery(ctx, query, true) + if err != nil { + return nil, fmt.Errorf("error when running UDF: %w", err) + } + return rows, nil +} diff --git a/snowflake/pkg/udfs/udfs.go b/snowflake/pkg/udfs/udfs.go deleted file mode 100644 index b644b2663..000000000 --- a/snowflake/pkg/udfs/udfs.go +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright 2022 Antrea Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package udfs - -import ( - "fmt" - "strings" -) - -func GetFunctionName(baseName string, version string) string { - version = strings.ReplaceAll(version, ".", "_") - version = strings.ReplaceAll(version, "-", "_") - return fmt.Sprintf("%s_%s", baseName, version) -} diff --git a/snowflake/pkg/utils/timestamps/timestamps.go b/snowflake/pkg/utils/timestamps/timestamps.go index 1ebae5f1b..6179bdffc 100644 --- a/snowflake/pkg/utils/timestamps/timestamps.go +++ b/snowflake/pkg/utils/timestamps/timestamps.go @@ -20,18 +20,14 @@ import ( "time" ) -func ParseTimestamp(t string, now time.Time, defaultT ...time.Time) (string, error) { - defaultTimestamp := now - if len(defaultT) > 0 { - defaultTimestamp = defaultT[0] - } +func ParseTimestamp(t string, now time.Time) (string, error) { ts, err := func() (time.Time, error) { fields := strings.Split(t, "-") if len(fields) == 0 { - return defaultTimestamp, nil + return now, nil } if len(fields) > 1 && fields[0] != "now" { - return defaultTimestamp, fmt.Errorf("bad timestamp: %s", t) + return now, fmt.Errorf("bad timestamp: %s", t) } if len(fields) == 1 { return now, nil @@ -39,11 +35,11 @@ func ParseTimestamp(t string, now time.Time, defaultT ...time.Time) (string, err if len(fields) == 2 { d, err := time.ParseDuration(fields[1]) if err != nil { - return defaultTimestamp, fmt.Errorf("bad timestamp: %s", t) + return now, fmt.Errorf("bad timestamp: %s", t) } return now.Add(-d), nil } - return defaultTimestamp, fmt.Errorf("bad timestamp: %s", t) + return now, fmt.Errorf("bad timestamp: %s", t) }() if err != nil { return "", nil diff --git a/snowflake/pkg/utils/utils.go b/snowflake/pkg/utils/utils.go index 15a74ebda..f9c8869d7 100644 --- a/snowflake/pkg/utils/utils.go +++ b/snowflake/pkg/utils/utils.go @@ -15,26 +15,74 @@ package utils import ( + "archive/tar" + "compress/gzip" + "context" "io" "net/http" "os" + "path/filepath" + + "github.com/go-logr/logr" ) -// Download a file from the given url to the current directory -func DownloadFile(url string, filename string) error { - resp, err := http.Get(url) +func DownloadAndUntar(ctx context.Context, logger logr.Logger, url string, dir string, filename string, untar bool) error { + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return err + } + client := http.DefaultClient + resp, err := client.Do(req) if err != nil { return err } defer resp.Body.Close() - if resp.StatusCode != 200 { + if untar { + gzr, err := gzip.NewReader(resp.Body) + if err != nil { + return err + } + defer gzr.Close() + tr := tar.NewReader(gzr) + for { + hdr, err := tr.Next() + if err == io.EOF { + break // End of archive + } + if err != nil { + return err + } + dest := filepath.Join(dir, hdr.Name) + logger.V(4).Info("Untarring", "path", hdr.Name) + if hdr.Typeflag != tar.TypeReg { + continue + } + if err := func() error { + f, err := os.OpenFile(dest, os.O_CREATE|os.O_RDWR, os.FileMode(hdr.Mode)) + if err != nil { + return err + } + defer f.Close() + + // copy over contents + if _, err := io.Copy(f, tr); err != nil { + return err + } + return nil + }(); err != nil { + return err + } + } return nil - } - file, err := os.Create(filename) - if err != nil { + } else { + dest := filepath.Join(dir, filename) + logger.V(4).Info("Downloading", "path", dest) + file, err := os.Create(dest) + if err != nil { + return err + } + defer file.Close() + _, err = io.Copy(file, resp.Body) return err } - defer file.Close() - _, err = io.Copy(file, resp.Body) - return err } From 6d3889ad982a616da7c68fa21234ba808aa90d0b Mon Sep 17 00:00:00 2001 From: Yongming Ding Date: Thu, 1 Dec 2022 16:45:38 -0800 Subject: [PATCH 3/5] Address 2nd round comments Signed-off-by: Yongming Ding --- snowflake/Makefile | 2 +- snowflake/cmd/policyRecommendation.go | 10 +- snowflake/main.go | 9 -- snowflake/pkg/infra/constants.go | 3 +- snowflake/pkg/infra/manager.go | 78 ++++------- snowflake/pkg/infra/stack.go | 4 +- snowflake/pkg/infra/temporary_warehouse.go | 4 +- snowflake/pkg/{infra => udfs}/udfs.go | 7 +- snowflake/pkg/utils/utils.go | 128 ++++++++++++------ snowflake/udfs/udfs.go | 24 ++++ snowflake/{udf => udfs/udfs}/Makefile | 0 .../udfs}/policy_recommendation/__init__.py | 0 .../udfs}/policy_recommendation/antrea_crd.py | 0 .../policy_recommendation/create_function.sql | 0 .../policy_recommendation_udf.py | 0 .../policy_recommendation_udf_test.py | 0 .../policy_recommendation_utils.py | 0 .../preprocessing_udf.py | 0 .../preprocessing_udf_test.py | 0 .../static_policy_recommendation_udf.py | 0 .../static_policy_recommendation_udf_test.py | 0 .../udfs}/policy_recommendation/version.txt | 0 22 files changed, 153 insertions(+), 116 deletions(-) rename snowflake/pkg/{infra => udfs}/udfs.go (91%) create mode 100644 snowflake/udfs/udfs.go rename snowflake/{udf => udfs/udfs}/Makefile (100%) rename snowflake/{udf => udfs/udfs}/policy_recommendation/__init__.py (100%) rename snowflake/{udf => udfs/udfs}/policy_recommendation/antrea_crd.py (100%) rename snowflake/{udf => udfs/udfs}/policy_recommendation/create_function.sql (100%) rename snowflake/{udf => udfs/udfs}/policy_recommendation/policy_recommendation_udf.py (100%) rename snowflake/{udf => udfs/udfs}/policy_recommendation/policy_recommendation_udf_test.py (100%) rename snowflake/{udf => udfs/udfs}/policy_recommendation/policy_recommendation_utils.py (100%) rename snowflake/{udf => udfs/udfs}/policy_recommendation/preprocessing_udf.py (100%) rename snowflake/{udf => udfs/udfs}/policy_recommendation/preprocessing_udf_test.py (100%) rename snowflake/{udf => udfs/udfs}/policy_recommendation/static_policy_recommendation_udf.py (100%) rename snowflake/{udf => udfs/udfs}/policy_recommendation/static_policy_recommendation_udf_test.py (100%) rename snowflake/{udf => udfs/udfs}/policy_recommendation/version.txt (100%) diff --git a/snowflake/Makefile b/snowflake/Makefile index 4655e28a4..be027706b 100644 --- a/snowflake/Makefile +++ b/snowflake/Makefile @@ -5,7 +5,7 @@ all: bin .PHONY: bin bin: - make -C udf/ + make -C udfs/udfs/ $(GO) build -o $(BINDIR)/theia-sf antrea.io/theia/snowflake .PHONY: .coverage diff --git a/snowflake/cmd/policyRecommendation.go b/snowflake/cmd/policyRecommendation.go index 8136556f0..4d7722389 100644 --- a/snowflake/cmd/policyRecommendation.go +++ b/snowflake/cmd/policyRecommendation.go @@ -23,7 +23,7 @@ import ( "github.com/google/uuid" "github.com/spf13/cobra" - "antrea.io/theia/snowflake/pkg/infra" + "antrea.io/theia/snowflake/pkg/udfs" "antrea.io/theia/snowflake/pkg/utils/timestamps" ) @@ -53,7 +53,7 @@ func buildPolicyRecommendationUdfQuery( ) (string, error) { now := time.Now() recommendationID := uuid.New().String() - functionName := infra.GetFunctionName(staticPolicyRecommendationFunctionName, functionVersion) + functionName := udfs.GetFunctionName(staticPolicyRecommendationFunctionName, functionVersion) var queryBuilder strings.Builder fmt.Fprintf(&queryBuilder, `SELECT r.jobType, r.recommendationId, r.timeCreated, r.yamls FROM TABLE(%s( @@ -150,7 +150,7 @@ LIMIT 500000`) // Choose the destinationIP as the partition field for the preprocessing // UDTF because flow rows could be divided into the most subsets - functionName = infra.GetFunctionName(preprocessingFunctionName, functionVersion) + functionName = udfs.GetFunctionName(preprocessingFunctionName, functionVersion) fmt.Fprintf(&queryBuilder, `), processed_flows AS (SELECT r.appliedTo, r.ingress, r.egress FROM filtered_flows AS f, TABLE(%s( '%s', @@ -183,7 +183,7 @@ FROM processed_flows as pf // Choose the appliedTo as the partition field for the policyRecommendation // UDTF because each network policy is recommended based on all ingress and // egress traffic related to an appliedTo group. - functionName = infra.GetFunctionName(policyRecommendationFunctionName, functionVersion) + functionName = udfs.GetFunctionName(policyRecommendationFunctionName, functionVersion) fmt.Fprintf(&queryBuilder, `) SELECT r.jobType, r.recommendationId, r.timeCreated, r.yamls FROM pf_with_index, TABLE(%s( '%s', @@ -246,7 +246,7 @@ You can also bring your own by using the "--warehouse-name" parameter. } ctx, cancel := context.WithTimeout(context.Background(), waitDuration) defer cancel() - rows, err := infra.RunUdf(ctx, logger, query, databaseName, warehouseName) + rows, err := udfs.RunUdf(ctx, logger, query, databaseName, warehouseName) if err != nil { return fmt.Errorf("error when running policy recommendation UDF: %w", err) } diff --git a/snowflake/main.go b/snowflake/main.go index fc53a50ac..c85af0237 100644 --- a/snowflake/main.go +++ b/snowflake/main.go @@ -15,18 +15,9 @@ package main import ( - "embed" - "antrea.io/theia/snowflake/cmd" - "antrea.io/theia/snowflake/pkg/infra" ) -// Embed the udfs directory here because go:embed doesn't support embeding in subpackages - -//go:embed udf/* -var udfFs embed.FS - func main() { - infra.UdfFs = udfFs cmd.Execute() } diff --git a/snowflake/pkg/infra/constants.go b/snowflake/pkg/infra/constants.go index d21c09f4f..2c1d9c115 100644 --- a/snowflake/pkg/infra/constants.go +++ b/snowflake/pkg/infra/constants.go @@ -43,7 +43,7 @@ const ( databaseNamePrefix = "ANTREA_" - schemaName = "THEIA" + SchemaName = "THEIA" flowRetentionDays = 30 flowDeletionTaskName = "DELETE_STALE_FLOWS" udfStageName = "UDFS" @@ -54,6 +54,7 @@ const ( flowsTableName = "FLOWS" migrationsDir = "migrations" + udfsDir = "udfs" udfVersionPlaceholder = "%VERSION%" udfCreateFunctionSQLFilename = "create_function.sql" diff --git a/snowflake/pkg/infra/manager.go b/snowflake/pkg/infra/manager.go index 6a1dd5079..c08834f35 100644 --- a/snowflake/pkg/infra/manager.go +++ b/snowflake/pkg/infra/manager.go @@ -17,7 +17,6 @@ package infra import ( "context" "database/sql" - "embed" "errors" "fmt" "io" @@ -38,10 +37,9 @@ import ( "antrea.io/theia/snowflake/database" sf "antrea.io/theia/snowflake/pkg/snowflake" utils "antrea.io/theia/snowflake/pkg/utils" + "antrea.io/theia/snowflake/udfs" ) -var UdfFs embed.FS - type pulumiPlugin struct { name string version string @@ -55,43 +53,6 @@ func deleteTemporaryWorkdir(d string) { os.RemoveAll(d) } -func writeMigrationsToDisk(fsys fs.FS, migrationsPath string, dest string) error { - if err := os.MkdirAll(dest, 0755); err != nil { - return err - } - entries, err := fs.ReadDir(fsys, migrationsPath) - if err != nil { - return err - } - for _, e := range entries { - if e.IsDir() { - continue - } - if err := func() error { - in, err := fsys.Open(filepath.Join(migrationsPath, e.Name())) - if err != nil { - return err - } - defer in.Close() - - out, err := os.Create(filepath.Join(dest, e.Name())) - if err != nil { - return err - } - defer out.Close() - - _, err = io.Copy(out, in) - if err != nil { - return err - } - return out.Close() - }(); err != nil { - return err - } - } - return nil -} - func installPulumiCLI(ctx context.Context, logger logr.Logger, dir string) error { logger.Info("Downloading and installing Pulumi", "version", pulumiVersion) cachedVersion, err := os.ReadFile(filepath.Join(dir, ".pulumi-version")) @@ -124,7 +85,7 @@ func installPulumiCLI(ctx context.Context, logger logr.Logger, dir string) error if err := os.MkdirAll(filepath.Join(dir, "pulumi"), 0755); err != nil { return err } - if err := utils.DownloadAndUntar(ctx, logger, url, dir, "", true); err != nil { + if err := utils.DownloadAndUntar(ctx, logger, url, dir); err != nil { return err } @@ -157,7 +118,7 @@ func installMigrateSnowflakeCLI(ctx context.Context, logger logr.Logger, dir str return fmt.Errorf("OS / arch combination is not supported: %s / %s", operatingSystem, arch) } url := fmt.Sprintf("https://github.com/antoninbas/migrate-snowflake/releases/download/%s/migrate-snowflake_%s_%s.tar.gz", migrateSnowflakeVersion, migrateSnowflakeVersion, target) - if err := utils.DownloadAndUntar(ctx, logger, url, dir, "", true); err != nil { + if err := utils.DownloadAndUntar(ctx, logger, url, dir); err != nil { return err } @@ -321,7 +282,7 @@ func (m *Manager) run(ctx context.Context, destroy bool) (*Result, error) { warehouseName := m.warehouseName if !destroy { logger.Info("Copying database migrations to disk") - if err := writeMigrationsToDisk(database.Migrations, database.MigrationsPath, filepath.Join(workdir, migrationsDir)); err != nil { + if err := utils.WriteEmbedDirToDisk(ctx, logger, database.Migrations, database.MigrationsPath, filepath.Join(workdir, migrationsDir)); err != nil { return nil, err } logger.Info("Copied database migrations to disk") @@ -337,7 +298,7 @@ func (m *Manager) run(ctx context.Context, destroy bool) (*Result, error) { return nil, fmt.Errorf("failed to connect to Snowflake: %w", err) } defer db.Close() - temporaryWarehouse := newTemporaryWarehouse(sf.NewClient(db, logger), logger) + temporaryWarehouse := NewTemporaryWarehouse(sf.NewClient(db, logger), logger) warehouseName = temporaryWarehouse.Name() if err := temporaryWarehouse.Create(ctx); err != nil { return nil, err @@ -430,7 +391,7 @@ func (m *Manager) run(ctx context.Context, destroy bool) (*Result, error) { return nil, err } - err = createUdfs(ctx, logger, outs["databaseName"], warehouseName) + err = createUdfs(ctx, logger, outs["databaseName"], warehouseName, workdir) if err != nil { return nil, err } @@ -440,7 +401,7 @@ func (m *Manager) run(ctx context.Context, destroy bool) (*Result, error) { BucketName: outs["bucketID"], BucketFlowsFolder: s3BucketFlowsFolder, DatabaseName: outs["databaseName"], - SchemaName: schemaName, + SchemaName: SchemaName, FlowsTableName: flowsTableName, SNSTopicARN: outs["snsTopicARN"], SQSQueueARN: outs["sqsQueueARN"], @@ -456,7 +417,7 @@ func (m *Manager) Offboard(ctx context.Context) error { return err } -func createUdfs(ctx context.Context, logger logr.Logger, databaseName string, warehouseName string) error { +func createUdfs(ctx context.Context, logger logr.Logger, databaseName string, warehouseName string, workdir string) error { logger.Info("creating UDFs") dsn, _, err := sf.GetDSN() if err != nil { @@ -475,7 +436,7 @@ func createUdfs(ctx context.Context, logger logr.Logger, databaseName string, wa return err } - if err := sfClient.UseSchema(ctx, schemaName); err != nil { + if err := sfClient.UseSchema(ctx, SchemaName); err != nil { return err } @@ -484,11 +445,11 @@ func createUdfs(ctx context.Context, logger logr.Logger, databaseName string, wa } // Download and stage Kubernetes python client for policy recommendation udf - err = utils.DownloadAndUntar(ctx, logger, k8sPythonClientUrl, ".", k8sPythonClientFileName, false) + k8sPythonClientFilePath, err := utils.Download(ctx, logger, k8sPythonClientUrl, workdir, k8sPythonClientFileName) if err != nil { return err } - k8sPythonClientFilePath, _ := filepath.Abs(k8sPythonClientFileName) + k8sPythonClientFilePath, _ = filepath.Abs(k8sPythonClientFilePath) err = sfClient.StageFile(ctx, k8sPythonClientFilePath, udfStageName) if err != nil { return err @@ -500,7 +461,14 @@ func createUdfs(ctx context.Context, logger logr.Logger, databaseName string, wa } }() - if err := fs.WalkDir(UdfFs, ".", func(path string, d fs.DirEntry, err error) error { + logger.Info("Copying UDFs to disk") + udfsDirPath := filepath.Join(workdir, udfsDir) + if err := utils.WriteEmbedDirToDisk(ctx, logger, udfs.UdfsFs, udfs.UdfsPath, udfsDirPath); err != nil { + return err + } + logger.Info("Copied UDFs to disk") + + if err := filepath.WalkDir(udfsDirPath, func(path string, d fs.DirEntry, err error) error { if err != nil { return err } @@ -512,7 +480,7 @@ func createUdfs(ctx context.Context, logger logr.Logger, databaseName string, wa functionVersionPath := filepath.Join(directoryPath, "version.txt") var version string if _, err := os.Stat(functionVersionPath); errors.Is(err, os.ErrNotExist) { - logger.Info("did not find version.txt file for function") + logger.Info("did not find version.txt file for function", "functionVersionPath", functionVersionPath) version = "" } else { version, err = readVersionFromFile(functionVersionPath) @@ -539,12 +507,12 @@ func createUdfs(ctx context.Context, logger logr.Logger, databaseName string, wa return err } createFunctionSQLPath := filepath.Join(directoryPath, udfCreateFunctionSQLFilename) - if _, err := fs.Stat(UdfFs, createFunctionSQLPath); errors.Is(err, os.ErrNotExist) { - logger.Info("did not find SQL file to create function, skipping") + if _, err := os.Stat(createFunctionSQLPath); errors.Is(err, os.ErrNotExist) { + logger.Info("did not find SQL file to create function, skipping", "createFunctionSQLPath", createFunctionSQLPath) return nil } logger.Info("creating UDF", "from", createFunctionSQLPath, "version", version) - b, err := fs.ReadFile(UdfFs, createFunctionSQLPath) + b, err := os.ReadFile(createFunctionSQLPath) if err != nil { return err } diff --git a/snowflake/pkg/infra/stack.go b/snowflake/pkg/infra/stack.go index 490eff0c7..cdb75fb19 100644 --- a/snowflake/pkg/infra/stack.go +++ b/snowflake/pkg/infra/stack.go @@ -238,7 +238,7 @@ func declareSnowflakeDatabase( schema, err := snowflake.NewSchema(ctx, "antrea-sf-schema", &snowflake.SchemaArgs{ Database: db.ID(), - Name: pulumi.String(schemaName), + Name: pulumi.String(SchemaName), }, pulumi.Parent(db), pulumi.DeleteBeforeReplace(true)) if err != nil { return nil, err @@ -293,7 +293,7 @@ func declareSnowflakeDatabase( ErrorIntegration: notificationIntegration.ID(), // FQN required for table and stage, see https://github.com/pulumi/pulumi-snowflake/issues/129 // 0x27 is the hex representation of single quote. We use it to enclose Pod labels string. - CopyStatement: pulumi.Sprintf("COPY INTO %s.%s.%s FROM @%s.%s.%s FILE_FORMAT = (TYPE = CSV FIELD_OPTIONALLY_ENCLOSED_BY='0x27')", databaseName, schemaName, flowsTableName, databaseName, schemaName, ingestionStageName), + CopyStatement: pulumi.Sprintf("COPY INTO %s.%s.%s FROM @%s.%s.%s FILE_FORMAT = (TYPE = CSV FIELD_OPTIONALLY_ENCLOSED_BY='0x27')", databaseName, SchemaName, flowsTableName, databaseName, SchemaName, ingestionStageName), }, pulumi.Parent(schema), pulumi.DependsOn([]pulumi.Resource{ingestionStage, dbMigrations}), pulumi.DeleteBeforeReplace(true)) if err != nil { return nil, err diff --git a/snowflake/pkg/infra/temporary_warehouse.go b/snowflake/pkg/infra/temporary_warehouse.go index 9b5978a79..3116d26f8 100644 --- a/snowflake/pkg/infra/temporary_warehouse.go +++ b/snowflake/pkg/infra/temporary_warehouse.go @@ -19,7 +19,7 @@ import ( "fmt" "strings" - "github.com/dustinkirkland/golang-petname" + petname "github.com/dustinkirkland/golang-petname" "github.com/go-logr/logr" sf "antrea.io/theia/snowflake/pkg/snowflake" @@ -31,7 +31,7 @@ type temporaryWarehouse struct { warehouseName string } -func newTemporaryWarehouse(sfClient sf.Client, logger logr.Logger) *temporaryWarehouse { +func NewTemporaryWarehouse(sfClient sf.Client, logger logr.Logger) *temporaryWarehouse { return &temporaryWarehouse{ sfClient: sfClient, logger: logger, diff --git a/snowflake/pkg/infra/udfs.go b/snowflake/pkg/udfs/udfs.go similarity index 91% rename from snowflake/pkg/infra/udfs.go rename to snowflake/pkg/udfs/udfs.go index cb2dd7702..f1cf1506b 100644 --- a/snowflake/pkg/infra/udfs.go +++ b/snowflake/pkg/udfs/udfs.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package infra +package udfs import ( "context" @@ -22,6 +22,7 @@ import ( "github.com/go-logr/logr" + "antrea.io/theia/snowflake/pkg/infra" sf "antrea.io/theia/snowflake/pkg/snowflake" ) @@ -50,12 +51,12 @@ func RunUdf(ctx context.Context, logger logr.Logger, query string, databaseName return nil, err } - if err := sfClient.UseSchema(ctx, schemaName); err != nil { + if err := sfClient.UseSchema(ctx, infra.SchemaName); err != nil { return nil, err } if warehouseName == "" { - temporaryWarehouse := newTemporaryWarehouse(sfClient, logger) + temporaryWarehouse := infra.NewTemporaryWarehouse(sfClient, logger) warehouseName = temporaryWarehouse.Name() if err := temporaryWarehouse.Create(ctx); err != nil { return nil, err diff --git a/snowflake/pkg/utils/utils.go b/snowflake/pkg/utils/utils.go index f9c8869d7..8e406790c 100644 --- a/snowflake/pkg/utils/utils.go +++ b/snowflake/pkg/utils/utils.go @@ -19,70 +19,122 @@ import ( "compress/gzip" "context" "io" + "io/fs" "net/http" "os" + "path" "path/filepath" + "strings" "github.com/go-logr/logr" ) -func DownloadAndUntar(ctx context.Context, logger logr.Logger, url string, dir string, filename string, untar bool) error { +func Download(ctx context.Context, logger logr.Logger, url string, dir string, filename string) (string, error) { req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { - return err + return "", err } client := http.DefaultClient resp, err := client.Do(req) if err != nil { - return err + return "", err } defer resp.Body.Close() - if untar { - gzr, err := gzip.NewReader(resp.Body) + if filename == "" { + filename = path.Base(req.URL.Path) + } + dest := filepath.Join(dir, filename) + logger.V(4).Info("Downloading", "path", dest) + file, err := os.Create(dest) + if err != nil { + return "", err + } + defer file.Close() + _, err = io.Copy(file, resp.Body) + return dest, err +} + +func DownloadAndUntar(ctx context.Context, logger logr.Logger, url string, dir string) error { + tarFilepath, err := Download(ctx, logger, url, dir, "") + if err != nil { + return err + } + f, err := os.Open(tarFilepath) + if err != nil { + return err + } + defer f.Close() + gzr, err := gzip.NewReader(f) + if err != nil { + return err + } + defer gzr.Close() + tr := tar.NewReader(gzr) + for { + hdr, err := tr.Next() + if err == io.EOF { + break // End of archive + } if err != nil { return err } - defer gzr.Close() - tr := tar.NewReader(gzr) - for { - hdr, err := tr.Next() - if err == io.EOF { - break // End of archive - } + dest := filepath.Join(dir, hdr.Name) + logger.V(4).Info("Untarring", "path", hdr.Name) + if hdr.Typeflag != tar.TypeReg { + continue + } + if err := func() error { + f, err := os.OpenFile(dest, os.O_CREATE|os.O_RDWR, os.FileMode(hdr.Mode)) if err != nil { return err } - dest := filepath.Join(dir, hdr.Name) - logger.V(4).Info("Untarring", "path", hdr.Name) - if hdr.Typeflag != tar.TypeReg { - continue - } - if err := func() error { - f, err := os.OpenFile(dest, os.O_CREATE|os.O_RDWR, os.FileMode(hdr.Mode)) - if err != nil { - return err - } - defer f.Close() + defer f.Close() - // copy over contents - if _, err := io.Copy(f, tr); err != nil { - return err - } - return nil - }(); err != nil { + // copy over contents + if _, err := io.Copy(f, tr); err != nil { return err } - } - return nil - } else { - dest := filepath.Join(dir, filename) - logger.V(4).Info("Downloading", "path", dest) - file, err := os.Create(dest) - if err != nil { + return nil + }(); err != nil { return err } - defer file.Close() - _, err = io.Copy(file, resp.Body) + } + return nil +} + +func WriteEmbedDirToDisk(ctx context.Context, logger logr.Logger, fsys fs.FS, embedPath string, dest string) error { + if err := os.MkdirAll(dest, 0755); err != nil { return err } + + return fs.WalkDir(fsys, embedPath, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + + outpath := filepath.Join(dest, strings.TrimPrefix(path, embedPath)) + + if d.IsDir() { + os.MkdirAll(outpath, 0755) + return nil + } + + in, err := fsys.Open(path) + if err != nil { + return err + } + defer in.Close() + + out, err := os.Create(outpath) + if err != nil { + return err + } + defer out.Close() + + _, err = io.Copy(out, in) + if err != nil { + return err + } + return nil + }) } diff --git a/snowflake/udfs/udfs.go b/snowflake/udfs/udfs.go new file mode 100644 index 000000000..9eee19ddc --- /dev/null +++ b/snowflake/udfs/udfs.go @@ -0,0 +1,24 @@ +// Copyright 2022 Antrea Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package udfs + +import ( + "embed" +) + +//go:embed udfs +var UdfsFs embed.FS + +const UdfsPath = "udfs" diff --git a/snowflake/udf/Makefile b/snowflake/udfs/udfs/Makefile similarity index 100% rename from snowflake/udf/Makefile rename to snowflake/udfs/udfs/Makefile diff --git a/snowflake/udf/policy_recommendation/__init__.py b/snowflake/udfs/udfs/policy_recommendation/__init__.py similarity index 100% rename from snowflake/udf/policy_recommendation/__init__.py rename to snowflake/udfs/udfs/policy_recommendation/__init__.py diff --git a/snowflake/udf/policy_recommendation/antrea_crd.py b/snowflake/udfs/udfs/policy_recommendation/antrea_crd.py similarity index 100% rename from snowflake/udf/policy_recommendation/antrea_crd.py rename to snowflake/udfs/udfs/policy_recommendation/antrea_crd.py diff --git a/snowflake/udf/policy_recommendation/create_function.sql b/snowflake/udfs/udfs/policy_recommendation/create_function.sql similarity index 100% rename from snowflake/udf/policy_recommendation/create_function.sql rename to snowflake/udfs/udfs/policy_recommendation/create_function.sql diff --git a/snowflake/udf/policy_recommendation/policy_recommendation_udf.py b/snowflake/udfs/udfs/policy_recommendation/policy_recommendation_udf.py similarity index 100% rename from snowflake/udf/policy_recommendation/policy_recommendation_udf.py rename to snowflake/udfs/udfs/policy_recommendation/policy_recommendation_udf.py diff --git a/snowflake/udf/policy_recommendation/policy_recommendation_udf_test.py b/snowflake/udfs/udfs/policy_recommendation/policy_recommendation_udf_test.py similarity index 100% rename from snowflake/udf/policy_recommendation/policy_recommendation_udf_test.py rename to snowflake/udfs/udfs/policy_recommendation/policy_recommendation_udf_test.py diff --git a/snowflake/udf/policy_recommendation/policy_recommendation_utils.py b/snowflake/udfs/udfs/policy_recommendation/policy_recommendation_utils.py similarity index 100% rename from snowflake/udf/policy_recommendation/policy_recommendation_utils.py rename to snowflake/udfs/udfs/policy_recommendation/policy_recommendation_utils.py diff --git a/snowflake/udf/policy_recommendation/preprocessing_udf.py b/snowflake/udfs/udfs/policy_recommendation/preprocessing_udf.py similarity index 100% rename from snowflake/udf/policy_recommendation/preprocessing_udf.py rename to snowflake/udfs/udfs/policy_recommendation/preprocessing_udf.py diff --git a/snowflake/udf/policy_recommendation/preprocessing_udf_test.py b/snowflake/udfs/udfs/policy_recommendation/preprocessing_udf_test.py similarity index 100% rename from snowflake/udf/policy_recommendation/preprocessing_udf_test.py rename to snowflake/udfs/udfs/policy_recommendation/preprocessing_udf_test.py diff --git a/snowflake/udf/policy_recommendation/static_policy_recommendation_udf.py b/snowflake/udfs/udfs/policy_recommendation/static_policy_recommendation_udf.py similarity index 100% rename from snowflake/udf/policy_recommendation/static_policy_recommendation_udf.py rename to snowflake/udfs/udfs/policy_recommendation/static_policy_recommendation_udf.py diff --git a/snowflake/udf/policy_recommendation/static_policy_recommendation_udf_test.py b/snowflake/udfs/udfs/policy_recommendation/static_policy_recommendation_udf_test.py similarity index 100% rename from snowflake/udf/policy_recommendation/static_policy_recommendation_udf_test.py rename to snowflake/udfs/udfs/policy_recommendation/static_policy_recommendation_udf_test.py diff --git a/snowflake/udf/policy_recommendation/version.txt b/snowflake/udfs/udfs/policy_recommendation/version.txt similarity index 100% rename from snowflake/udf/policy_recommendation/version.txt rename to snowflake/udfs/udfs/policy_recommendation/version.txt From 07568098d92afd52314c4deff1a2d0f6d8c3d4b7 Mon Sep 17 00:00:00 2001 From: Yongming Ding Date: Fri, 2 Dec 2022 13:22:46 -0800 Subject: [PATCH 4/5] Address 3rd round comments & add UTs Signed-off-by: Yongming Ding --- snowflake/Makefile | 7 +- snowflake/pkg/infra/manager.go | 14 ++-- snowflake/pkg/snowflake/snowflake.go | 24 +++---- snowflake/pkg/udfs/udfs.go | 2 +- .../pkg/utils/{utils.go => file/file.go} | 8 +-- .../file/file_test.go} | 14 ++-- snowflake/pkg/utils/timestamps/timestamps.go | 2 +- .../pkg/utils/timestamps/timestamps_test.go | 65 +++++++++++++++++++ 8 files changed, 105 insertions(+), 31 deletions(-) rename snowflake/pkg/utils/{utils.go => file/file.go} (90%) rename snowflake/pkg/{infra/manager_test.go => utils/file/file_test.go} (76%) create mode 100644 snowflake/pkg/utils/timestamps/timestamps_test.go diff --git a/snowflake/Makefile b/snowflake/Makefile index be027706b..2c8fea0fa 100644 --- a/snowflake/Makefile +++ b/snowflake/Makefile @@ -1,11 +1,14 @@ GO ?= go BINDIR := $(CURDIR)/bin -all: bin +all: udfs bin + +.PHONY: udfs +udfs: + make -C udfs/udfs/ .PHONY: bin bin: - make -C udfs/udfs/ $(GO) build -o $(BINDIR)/theia-sf antrea.io/theia/snowflake .PHONY: .coverage diff --git a/snowflake/pkg/infra/manager.go b/snowflake/pkg/infra/manager.go index c08834f35..022056d9f 100644 --- a/snowflake/pkg/infra/manager.go +++ b/snowflake/pkg/infra/manager.go @@ -36,7 +36,7 @@ import ( "antrea.io/theia/snowflake/database" sf "antrea.io/theia/snowflake/pkg/snowflake" - utils "antrea.io/theia/snowflake/pkg/utils" + fileutils "antrea.io/theia/snowflake/pkg/utils/file" "antrea.io/theia/snowflake/udfs" ) @@ -85,7 +85,7 @@ func installPulumiCLI(ctx context.Context, logger logr.Logger, dir string) error if err := os.MkdirAll(filepath.Join(dir, "pulumi"), 0755); err != nil { return err } - if err := utils.DownloadAndUntar(ctx, logger, url, dir); err != nil { + if err := fileutils.DownloadAndUntar(ctx, logger, url, dir); err != nil { return err } @@ -118,7 +118,7 @@ func installMigrateSnowflakeCLI(ctx context.Context, logger logr.Logger, dir str return fmt.Errorf("OS / arch combination is not supported: %s / %s", operatingSystem, arch) } url := fmt.Sprintf("https://github.com/antoninbas/migrate-snowflake/releases/download/%s/migrate-snowflake_%s_%s.tar.gz", migrateSnowflakeVersion, migrateSnowflakeVersion, target) - if err := utils.DownloadAndUntar(ctx, logger, url, dir); err != nil { + if err := fileutils.DownloadAndUntar(ctx, logger, url, dir); err != nil { return err } @@ -282,7 +282,7 @@ func (m *Manager) run(ctx context.Context, destroy bool) (*Result, error) { warehouseName := m.warehouseName if !destroy { logger.Info("Copying database migrations to disk") - if err := utils.WriteEmbedDirToDisk(ctx, logger, database.Migrations, database.MigrationsPath, filepath.Join(workdir, migrationsDir)); err != nil { + if err := fileutils.WriteFSDirToDisk(ctx, logger, database.Migrations, database.MigrationsPath, filepath.Join(workdir, migrationsDir)); err != nil { return nil, err } logger.Info("Copied database migrations to disk") @@ -445,7 +445,7 @@ func createUdfs(ctx context.Context, logger logr.Logger, databaseName string, wa } // Download and stage Kubernetes python client for policy recommendation udf - k8sPythonClientFilePath, err := utils.Download(ctx, logger, k8sPythonClientUrl, workdir, k8sPythonClientFileName) + k8sPythonClientFilePath, err := fileutils.Download(ctx, logger, k8sPythonClientUrl, workdir, k8sPythonClientFileName) if err != nil { return err } @@ -463,7 +463,7 @@ func createUdfs(ctx context.Context, logger logr.Logger, databaseName string, wa logger.Info("Copying UDFs to disk") udfsDirPath := filepath.Join(workdir, udfsDir) - if err := utils.WriteEmbedDirToDisk(ctx, logger, udfs.UdfsFs, udfs.UdfsPath, udfsDirPath); err != nil { + if err := fileutils.WriteFSDirToDisk(ctx, logger, udfs.UdfsFs, udfs.UdfsPath, udfsDirPath); err != nil { return err } logger.Info("Copied UDFs to disk") @@ -521,7 +521,7 @@ func createUdfs(ctx context.Context, logger logr.Logger, databaseName string, wa return fmt.Errorf("version placeholder '%s' not found in SQL file", udfVersionPlaceholder) } query = strings.ReplaceAll(query, udfVersionPlaceholder, version) - _, err = sfClient.ExecMultiStatementQuery(ctx, query, false) + err = sfClient.ExecMultiStatement(ctx, query) if err != nil { return fmt.Errorf("error when creating UDF: %w", err) } diff --git a/snowflake/pkg/snowflake/snowflake.go b/snowflake/pkg/snowflake/snowflake.go index bbc31c214..3af25b176 100644 --- a/snowflake/pkg/snowflake/snowflake.go +++ b/snowflake/pkg/snowflake/snowflake.go @@ -115,26 +115,28 @@ func (c *client) UseDatabase(ctx context.Context, name string) error { func (c *client) UseSchema(ctx context.Context, name string) error { query := fmt.Sprintf("USE SCHEMA %s", name) - c.logger.Info("Snowflake query", "query", query) + c.logger.V(2).Info("Snowflake query", "query", query) _, err := c.db.ExecContext(ctx, query) return err } func (c *client) StageFile(ctx context.Context, path string, stage string) error { query := fmt.Sprintf("PUT file://%s @%s AUTO_COMPRESS = FALSE OVERWRITE = TRUE", path, stage) - c.logger.Info("Snowflake query", "query", query) + c.logger.V(2).Info("Snowflake query", "query", query) _, err := c.db.ExecContext(ctx, query) return err } -func (c *client) ExecMultiStatementQuery(ctx context.Context, query string, result bool) (*sql.Rows, error) { +func (c *client) ExecMultiStatement(ctx context.Context, query string) error { multi_statement_context, _ := gosnowflake.WithMultiStatement(ctx, 0) - c.logger.Info("Snowflake query", "query", query) - if !result { - _, err := c.db.ExecContext(multi_statement_context, query) - return nil, err - } else { - rows, err := c.db.QueryContext(multi_statement_context, query) - return rows, err - } + c.logger.V(2).Info("Snowflake query", "query", query) + _, err := c.db.ExecContext(multi_statement_context, query) + return err +} + +func (c *client) QueryMultiStatement(ctx context.Context, query string) (*sql.Rows, error) { + multi_statement_context, _ := gosnowflake.WithMultiStatement(ctx, 0) + c.logger.V(2).Info("Snowflake query", "query", query) + rows, err := c.db.QueryContext(multi_statement_context, query) + return rows, err } diff --git a/snowflake/pkg/udfs/udfs.go b/snowflake/pkg/udfs/udfs.go index f1cf1506b..a51bc237a 100644 --- a/snowflake/pkg/udfs/udfs.go +++ b/snowflake/pkg/udfs/udfs.go @@ -72,7 +72,7 @@ func RunUdf(ctx context.Context, logger logr.Logger, query string, databaseName return nil, err } - rows, err := sfClient.ExecMultiStatementQuery(ctx, query, true) + rows, err := sfClient.QueryMultiStatement(ctx, query) if err != nil { return nil, fmt.Errorf("error when running UDF: %w", err) } diff --git a/snowflake/pkg/utils/utils.go b/snowflake/pkg/utils/file/file.go similarity index 90% rename from snowflake/pkg/utils/utils.go rename to snowflake/pkg/utils/file/file.go index 8e406790c..e12261d47 100644 --- a/snowflake/pkg/utils/utils.go +++ b/snowflake/pkg/utils/file/file.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package utils +package file import ( "archive/tar" @@ -102,17 +102,17 @@ func DownloadAndUntar(ctx context.Context, logger logr.Logger, url string, dir s return nil } -func WriteEmbedDirToDisk(ctx context.Context, logger logr.Logger, fsys fs.FS, embedPath string, dest string) error { +func WriteFSDirToDisk(ctx context.Context, logger logr.Logger, fsys fs.FS, fsysPath string, dest string) error { if err := os.MkdirAll(dest, 0755); err != nil { return err } - return fs.WalkDir(fsys, embedPath, func(path string, d fs.DirEntry, err error) error { + return fs.WalkDir(fsys, fsysPath, func(path string, d fs.DirEntry, err error) error { if err != nil { return err } - outpath := filepath.Join(dest, strings.TrimPrefix(path, embedPath)) + outpath := filepath.Join(dest, strings.TrimPrefix(path, fsysPath)) if d.IsDir() { os.MkdirAll(outpath, 0755) diff --git a/snowflake/pkg/infra/manager_test.go b/snowflake/pkg/utils/file/file_test.go similarity index 76% rename from snowflake/pkg/infra/manager_test.go rename to snowflake/pkg/utils/file/file_test.go index 7c9c05662..b6ea170ea 100644 --- a/snowflake/pkg/infra/manager_test.go +++ b/snowflake/pkg/utils/file/file_test.go @@ -12,28 +12,32 @@ // See the License for the specific language governing permissions and // limitations under the License. -package infra +package file import ( + "context" "os" "path/filepath" "testing" - "antrea.io/theia/snowflake/database" + "github.com/go-logr/logr" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "antrea.io/theia/snowflake/database" ) -func TestWriteMigrationsToDisk(t *testing.T) { +func TestWriteFSDirToDisk(t *testing.T) { + var logger logr.Logger tempDir, err := os.MkdirTemp("", "antrea-pulumi-test") require.NoError(t, err) defer os.RemoveAll(tempDir) - err = writeMigrationsToDisk(database.Migrations, database.MigrationsPath, filepath.Join(tempDir, migrationsDir)) + err = WriteFSDirToDisk(context.TODO(), logger, database.Migrations, database.MigrationsPath, filepath.Join(tempDir, database.MigrationsPath)) require.NoError(t, err) entries, err := database.Migrations.ReadDir(database.MigrationsPath) require.NoError(t, err) for _, entry := range entries { - _, err := os.Stat(filepath.Join(tempDir, migrationsDir, entry.Name())) + _, err := os.Stat(filepath.Join(tempDir, database.MigrationsPath, entry.Name())) assert.NoErrorf(t, err, "Migration file %s not exist", entry.Name()) } } diff --git a/snowflake/pkg/utils/timestamps/timestamps.go b/snowflake/pkg/utils/timestamps/timestamps.go index 6179bdffc..104cda207 100644 --- a/snowflake/pkg/utils/timestamps/timestamps.go +++ b/snowflake/pkg/utils/timestamps/timestamps.go @@ -42,7 +42,7 @@ func ParseTimestamp(t string, now time.Time) (string, error) { return now, fmt.Errorf("bad timestamp: %s", t) }() if err != nil { - return "", nil + return "", err } return ts.UTC().Format(time.RFC3339), nil } diff --git a/snowflake/pkg/utils/timestamps/timestamps_test.go b/snowflake/pkg/utils/timestamps/timestamps_test.go new file mode 100644 index 000000000..c7b0c7835 --- /dev/null +++ b/snowflake/pkg/utils/timestamps/timestamps_test.go @@ -0,0 +1,65 @@ +// Copyright 2022 Antrea Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package timestamps + +import ( + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestParseTimestamp(t *testing.T) { + now := time.Now() + nowTimestamp := now.UTC().Format(time.RFC3339) + for _, tc := range []struct { + name string + inputTimestamp string + expectedTimestamp string + expectedError error + }{ + { + name: "Successful case 1", + inputTimestamp: "now-1h", + expectedTimestamp: now.Add(-time.Hour).UTC().Format(time.RFC3339), + expectedError: nil, + }, + { + name: "Successful case 2", + inputTimestamp: "now", + expectedTimestamp: nowTimestamp, + expectedError: nil, + }, + { + name: "Successful case 3", + inputTimestamp: "", + expectedTimestamp: nowTimestamp, + expectedError: nil, + }, + { + name: "Failed case", + inputTimestamp: "now-1c", + expectedTimestamp: "", + expectedError: fmt.Errorf("bad timestamp: now-1c"), + }, + } { + t.Run(tc.name, func(t *testing.T) { + timestamp, err := ParseTimestamp(tc.inputTimestamp, now) + assert.Equal(t, tc.expectedTimestamp, timestamp) + assert.Equal(t, tc.expectedError, err) + }) + } +} From fe1ac86d82ce1b3a40328bdef18729f6ced1f0a7 Mon Sep 17 00:00:00 2001 From: Yongming Ding Date: Tue, 6 Dec 2022 13:07:32 -0800 Subject: [PATCH 5/5] Address comments from yanjun Signed-off-by: Yongming Ding --- snowflake/cmd/policyRecommendation.go | 26 ++++++++++++++++++-------- snowflake/pkg/infra/manager.go | 10 +++++----- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/snowflake/cmd/policyRecommendation.go b/snowflake/cmd/policyRecommendation.go index 4d7722389..b66612626 100644 --- a/snowflake/cmd/policyRecommendation.go +++ b/snowflake/cmd/policyRecommendation.go @@ -111,7 +111,7 @@ AND } if endTime != "" { fmt.Fprintf(&queryBuilder, `AND - flowEndSeconds >= '%s' + flowEndSeconds < '%s' `, endTime) } @@ -221,10 +221,20 @@ You can also bring your own by using the "--warehouse-name" parameter. return fmt.Errorf("invalid --type argument") } limit, _ := cmd.Flags().GetUint("limit") - isolationMethod, _ := cmd.Flags().GetInt("isolationMethod") - if isolationMethod < 1 && isolationMethod > 3 { - return fmt.Errorf("invalid -isolationMethod argument") + + policyType, _ := cmd.Flags().GetString("policy-type") + var isolationMethod int + if policyType == "anp-deny-applied" { + isolationMethod = 1 + } else if policyType == "anp-deny-all" { + isolationMethod = 2 + } else if policyType == "k8s-np" { + isolationMethod = 3 + } else { + return fmt.Errorf(`type of generated NetworkPolicy should be +anp-deny-applied or anp-deny-all or k8s-np`) } + start, _ := cmd.Flags().GetString("start") end, _ := cmd.Flags().GetString("end") startTs, _ := cmd.Flags().GetString("start-ts") @@ -272,10 +282,10 @@ func init() { policyRecommendationCmd.Flags().String("type", "initial", "Type of recommendation job (initial|subsequent), we only support initial jobType for now") policyRecommendationCmd.Flags().Uint("limit", 0, "Limit on the number of flows to read, default it 0 (no limit)") - policyRecommendationCmd.Flags().Int("isolationMethod", 1, `Network isolation preference. Currently we have 3 options: -1: Recommending allow ANP/ACNP policies, with default deny rules only on Pods which have an allow rule applied -2: Recommending allow ANP/ACNP policies, with default deny rules for whole cluster -3: Recommending allow K8s NetworkPolicies only`) + policyRecommendationCmd.Flags().String("policy-type", "anp-deny-applied", `Types of recommended NetworkPolicy. Currently we have 3 options: +anp-deny-applied: Recommending allow ANP/ACNP policies, with default deny rules only on Pods which have an allow rule applied +anp-deny-all: Recommending allow ANP/ACNP policies, with default deny rules for whole cluster +k8s-np: Recommending allow K8s NetworkPolicies only`) policyRecommendationCmd.Flags().String("start", "", "Start time for flows, with reference to the current time (e.g., now-1h)") policyRecommendationCmd.Flags().String("end", "", "End time for flows, with reference to the current timr (e.g., now)") policyRecommendationCmd.Flags().String("start-ts", "", "Start time for flows, as a RFC3339 UTC timestamp (e.g., 2022-07-01T19:35:31Z)") diff --git a/snowflake/pkg/infra/manager.go b/snowflake/pkg/infra/manager.go index 022056d9f..e02a15967 100644 --- a/snowflake/pkg/infra/manager.go +++ b/snowflake/pkg/infra/manager.go @@ -418,7 +418,7 @@ func (m *Manager) Offboard(ctx context.Context) error { } func createUdfs(ctx context.Context, logger logr.Logger, databaseName string, warehouseName string, workdir string) error { - logger.Info("creating UDFs") + logger.Info("Creating UDFs") dsn, _, err := sf.GetDSN() if err != nil { return fmt.Errorf("failed to create DSN: %w", err) @@ -475,12 +475,12 @@ func createUdfs(ctx context.Context, logger logr.Logger, databaseName string, wa if filepath.Ext(path) != ".zip" { return nil } - logger.Info("staging", "path", path) + logger.Info("Staging", "path", path) directoryPath := path[:len(path)-4] functionVersionPath := filepath.Join(directoryPath, "version.txt") var version string if _, err := os.Stat(functionVersionPath); errors.Is(err, os.ErrNotExist) { - logger.Info("did not find version.txt file for function", "functionVersionPath", functionVersionPath) + logger.Info("Did not find version.txt file for function", "functionVersionPath", functionVersionPath) version = "" } else { version, err = readVersionFromFile(functionVersionPath) @@ -508,10 +508,10 @@ func createUdfs(ctx context.Context, logger logr.Logger, databaseName string, wa } createFunctionSQLPath := filepath.Join(directoryPath, udfCreateFunctionSQLFilename) if _, err := os.Stat(createFunctionSQLPath); errors.Is(err, os.ErrNotExist) { - logger.Info("did not find SQL file to create function, skipping", "createFunctionSQLPath", createFunctionSQLPath) + logger.Info("Did not find SQL file to create function, skipping", "createFunctionSQLPath", createFunctionSQLPath) return nil } - logger.Info("creating UDF", "from", createFunctionSQLPath, "version", version) + logger.Info("Creating UDF", "from", createFunctionSQLPath, "version", version) b, err := os.ReadFile(createFunctionSQLPath) if err != nil { return err