Skip to content

Commit

Permalink
Address 2nd round comments
Browse files Browse the repository at this point in the history
Signed-off-by: Yongming Ding <dyongming@vmware.com>
  • Loading branch information
Yongming Ding committed Dec 2, 2022
1 parent 2d138c4 commit 10e73e5
Show file tree
Hide file tree
Showing 22 changed files with 153 additions and 116 deletions.
2 changes: 1 addition & 1 deletion snowflake/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ all: bin

.PHONY: bin
bin:
make -C udf/
make -C udfs/udfs/
$(GO) build -o $(BINDIR)/theia-sf antrea.io/theia/snowflake

.PHONY: test
Expand Down
10 changes: 5 additions & 5 deletions snowflake/cmd/policyRecommendation.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (
"github.com/google/uuid"
"github.com/spf13/cobra"

"antrea.io/theia/snowflake/pkg/infra"
"antrea.io/theia/snowflake/pkg/udfs"
"antrea.io/theia/snowflake/pkg/utils/timestamps"
)

Expand Down Expand Up @@ -53,7 +53,7 @@ func buildPolicyRecommendationUdfQuery(
) (string, error) {
now := time.Now()
recommendationID := uuid.New().String()
functionName := infra.GetFunctionName(staticPolicyRecommendationFunctionName, functionVersion)
functionName := udfs.GetFunctionName(staticPolicyRecommendationFunctionName, functionVersion)
var queryBuilder strings.Builder
fmt.Fprintf(&queryBuilder, `SELECT r.jobType, r.recommendationId, r.timeCreated, r.yamls FROM
TABLE(%s(
Expand Down Expand Up @@ -150,7 +150,7 @@ LIMIT 500000`)

// Choose the destinationIP as the partition field for the preprocessing
// UDTF because flow rows could be divided into the most subsets
functionName = infra.GetFunctionName(preprocessingFunctionName, functionVersion)
functionName = udfs.GetFunctionName(preprocessingFunctionName, functionVersion)
fmt.Fprintf(&queryBuilder, `), processed_flows AS (SELECT r.appliedTo, r.ingress, r.egress FROM filtered_flows AS f,
TABLE(%s(
'%s',
Expand Down Expand Up @@ -183,7 +183,7 @@ FROM processed_flows as pf
// Choose the appliedTo as the partition field for the policyRecommendation
// UDTF because each network policy is recommended based on all ingress and
// egress traffic related to an appliedTo group.
functionName = infra.GetFunctionName(policyRecommendationFunctionName, functionVersion)
functionName = udfs.GetFunctionName(policyRecommendationFunctionName, functionVersion)
fmt.Fprintf(&queryBuilder, `) SELECT r.jobType, r.recommendationId, r.timeCreated, r.yamls FROM pf_with_index,
TABLE(%s(
'%s',
Expand Down Expand Up @@ -246,7 +246,7 @@ You can also bring your own by using the "--warehouse-name" parameter.
}
ctx, cancel := context.WithTimeout(context.Background(), waitDuration)
defer cancel()
rows, err := infra.RunUdf(ctx, logger, query, databaseName, warehouseName)
rows, err := udfs.RunUdf(ctx, logger, query, databaseName, warehouseName)
if err != nil {
return fmt.Errorf("error when running policy recommendation UDF: %w", err)
}
Expand Down
9 changes: 0 additions & 9 deletions snowflake/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,9 @@
package main

import (
"embed"

"antrea.io/theia/snowflake/cmd"
"antrea.io/theia/snowflake/pkg/infra"
)

// Embed the udfs directory here because go:embed doesn't support embeding in subpackages

//go:embed udf/*
var udfFs embed.FS

func main() {
infra.UdfFs = udfFs
cmd.Execute()
}
3 changes: 2 additions & 1 deletion snowflake/pkg/infra/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ const (

databaseNamePrefix = "ANTREA_"

schemaName = "THEIA"
SchemaName = "THEIA"
flowRetentionDays = 30
flowDeletionTaskName = "DELETE_STALE_FLOWS"
udfStageName = "UDFS"
Expand All @@ -54,6 +54,7 @@ const (
flowsTableName = "FLOWS"

migrationsDir = "migrations"
udfsDir = "udfs"

udfVersionPlaceholder = "%VERSION%"
udfCreateFunctionSQLFilename = "create_function.sql"
Expand Down
78 changes: 23 additions & 55 deletions snowflake/pkg/infra/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ package infra
import (
"context"
"database/sql"
"embed"
"errors"
"fmt"
"io"
Expand All @@ -38,10 +37,9 @@ import (
"antrea.io/theia/snowflake/database"
sf "antrea.io/theia/snowflake/pkg/snowflake"
utils "antrea.io/theia/snowflake/pkg/utils"
"antrea.io/theia/snowflake/udfs"
)

var UdfFs embed.FS

type pulumiPlugin struct {
name string
version string
Expand All @@ -55,43 +53,6 @@ func deleteTemporaryWorkdir(d string) {
os.RemoveAll(d)
}

func writeMigrationsToDisk(fsys fs.FS, migrationsPath string, dest string) error {
if err := os.MkdirAll(dest, 0755); err != nil {
return err
}
entries, err := fs.ReadDir(fsys, migrationsPath)
if err != nil {
return err
}
for _, e := range entries {
if e.IsDir() {
continue
}
if err := func() error {
in, err := fsys.Open(filepath.Join(migrationsPath, e.Name()))
if err != nil {
return err
}
defer in.Close()

out, err := os.Create(filepath.Join(dest, e.Name()))
if err != nil {
return err
}
defer out.Close()

_, err = io.Copy(out, in)
if err != nil {
return err
}
return out.Close()
}(); err != nil {
return err
}
}
return nil
}

func installPulumiCLI(ctx context.Context, logger logr.Logger, dir string) error {
logger.Info("Downloading and installing Pulumi", "version", pulumiVersion)
cachedVersion, err := os.ReadFile(filepath.Join(dir, ".pulumi-version"))
Expand Down Expand Up @@ -124,7 +85,7 @@ func installPulumiCLI(ctx context.Context, logger logr.Logger, dir string) error
if err := os.MkdirAll(filepath.Join(dir, "pulumi"), 0755); err != nil {
return err
}
if err := utils.DownloadAndUntar(ctx, logger, url, dir, "", true); err != nil {
if err := utils.DownloadAndUntar(ctx, logger, url, dir); err != nil {
return err
}

Expand Down Expand Up @@ -157,7 +118,7 @@ func installMigrateSnowflakeCLI(ctx context.Context, logger logr.Logger, dir str
return fmt.Errorf("OS / arch combination is not supported: %s / %s", operatingSystem, arch)
}
url := fmt.Sprintf("https://github.com/antoninbas/migrate-snowflake/releases/download/%s/migrate-snowflake_%s_%s.tar.gz", migrateSnowflakeVersion, migrateSnowflakeVersion, target)
if err := utils.DownloadAndUntar(ctx, logger, url, dir, "", true); err != nil {
if err := utils.DownloadAndUntar(ctx, logger, url, dir); err != nil {
return err
}

Expand Down Expand Up @@ -321,7 +282,7 @@ func (m *Manager) run(ctx context.Context, destroy bool) (*Result, error) {
warehouseName := m.warehouseName
if !destroy {
logger.Info("Copying database migrations to disk")
if err := writeMigrationsToDisk(database.Migrations, database.MigrationsPath, filepath.Join(workdir, migrationsDir)); err != nil {
if err := utils.WriteEmbedDirToDisk(ctx, logger, database.Migrations, database.MigrationsPath, filepath.Join(workdir, migrationsDir)); err != nil {
return nil, err
}
logger.Info("Copied database migrations to disk")
Expand All @@ -337,7 +298,7 @@ func (m *Manager) run(ctx context.Context, destroy bool) (*Result, error) {
return nil, fmt.Errorf("failed to connect to Snowflake: %w", err)
}
defer db.Close()
temporaryWarehouse := newTemporaryWarehouse(sf.NewClient(db, logger), logger)
temporaryWarehouse := NewTemporaryWarehouse(sf.NewClient(db, logger), logger)
warehouseName = temporaryWarehouse.Name()
if err := temporaryWarehouse.Create(ctx); err != nil {
return nil, err
Expand Down Expand Up @@ -430,7 +391,7 @@ func (m *Manager) run(ctx context.Context, destroy bool) (*Result, error) {
return nil, err
}

err = createUdfs(ctx, logger, outs["databaseName"], warehouseName)
err = createUdfs(ctx, logger, outs["databaseName"], warehouseName, workdir)
if err != nil {
return nil, err
}
Expand All @@ -440,7 +401,7 @@ func (m *Manager) run(ctx context.Context, destroy bool) (*Result, error) {
BucketName: outs["bucketID"],
BucketFlowsFolder: s3BucketFlowsFolder,
DatabaseName: outs["databaseName"],
SchemaName: schemaName,
SchemaName: SchemaName,
FlowsTableName: flowsTableName,
SNSTopicARN: outs["snsTopicARN"],
SQSQueueARN: outs["sqsQueueARN"],
Expand All @@ -456,7 +417,7 @@ func (m *Manager) Offboard(ctx context.Context) error {
return err
}

func createUdfs(ctx context.Context, logger logr.Logger, databaseName string, warehouseName string) error {
func createUdfs(ctx context.Context, logger logr.Logger, databaseName string, warehouseName string, workdir string) error {
logger.Info("creating UDFs")
dsn, _, err := sf.GetDSN()
if err != nil {
Expand All @@ -475,7 +436,7 @@ func createUdfs(ctx context.Context, logger logr.Logger, databaseName string, wa
return err
}

if err := sfClient.UseSchema(ctx, schemaName); err != nil {
if err := sfClient.UseSchema(ctx, SchemaName); err != nil {
return err
}

Expand All @@ -484,11 +445,11 @@ func createUdfs(ctx context.Context, logger logr.Logger, databaseName string, wa
}

// Download and stage Kubernetes python client for policy recommendation udf
err = utils.DownloadAndUntar(ctx, logger, k8sPythonClientUrl, ".", k8sPythonClientFileName, false)
k8sPythonClientFilePath, err := utils.Download(ctx, logger, k8sPythonClientUrl, workdir, k8sPythonClientFileName)
if err != nil {
return err
}
k8sPythonClientFilePath, _ := filepath.Abs(k8sPythonClientFileName)
k8sPythonClientFilePath, _ = filepath.Abs(k8sPythonClientFilePath)
err = sfClient.StageFile(ctx, k8sPythonClientFilePath, udfStageName)
if err != nil {
return err
Expand All @@ -500,7 +461,14 @@ func createUdfs(ctx context.Context, logger logr.Logger, databaseName string, wa
}
}()

if err := fs.WalkDir(UdfFs, ".", func(path string, d fs.DirEntry, err error) error {
logger.Info("Copying UDFs to disk")
udfsDirPath := filepath.Join(workdir, udfsDir)
if err := utils.WriteEmbedDirToDisk(ctx, logger, udfs.UdfsFs, udfs.UdfsPath, udfsDirPath); err != nil {
return err
}
logger.Info("Copied UDFs to disk")

if err := filepath.WalkDir(udfsDirPath, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
Expand All @@ -512,7 +480,7 @@ func createUdfs(ctx context.Context, logger logr.Logger, databaseName string, wa
functionVersionPath := filepath.Join(directoryPath, "version.txt")
var version string
if _, err := os.Stat(functionVersionPath); errors.Is(err, os.ErrNotExist) {
logger.Info("did not find version.txt file for function")
logger.Info("did not find version.txt file for function", "functionVersionPath", functionVersionPath)
version = ""
} else {
version, err = readVersionFromFile(functionVersionPath)
Expand All @@ -539,12 +507,12 @@ func createUdfs(ctx context.Context, logger logr.Logger, databaseName string, wa
return err
}
createFunctionSQLPath := filepath.Join(directoryPath, udfCreateFunctionSQLFilename)
if _, err := fs.Stat(UdfFs, createFunctionSQLPath); errors.Is(err, os.ErrNotExist) {
logger.Info("did not find SQL file to create function, skipping")
if _, err := os.Stat(createFunctionSQLPath); errors.Is(err, os.ErrNotExist) {
logger.Info("did not find SQL file to create function, skipping", "createFunctionSQLPath", createFunctionSQLPath)
return nil
}
logger.Info("creating UDF", "from", createFunctionSQLPath, "version", version)
b, err := fs.ReadFile(UdfFs, createFunctionSQLPath)
b, err := os.ReadFile(createFunctionSQLPath)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions snowflake/pkg/infra/stack.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ func declareSnowflakeDatabase(

schema, err := snowflake.NewSchema(ctx, "antrea-sf-schema", &snowflake.SchemaArgs{
Database: db.ID(),
Name: pulumi.String(schemaName),
Name: pulumi.String(SchemaName),
}, pulumi.Parent(db), pulumi.DeleteBeforeReplace(true))
if err != nil {
return nil, err
Expand Down Expand Up @@ -293,7 +293,7 @@ func declareSnowflakeDatabase(
ErrorIntegration: notificationIntegration.ID(),
// FQN required for table and stage, see https://github.com/pulumi/pulumi-snowflake/issues/129
// 0x27 is the hex representation of single quote. We use it to enclose Pod labels string.
CopyStatement: pulumi.Sprintf("COPY INTO %s.%s.%s FROM @%s.%s.%s FILE_FORMAT = (TYPE = CSV FIELD_OPTIONALLY_ENCLOSED_BY='0x27')", databaseName, schemaName, flowsTableName, databaseName, schemaName, ingestionStageName),
CopyStatement: pulumi.Sprintf("COPY INTO %s.%s.%s FROM @%s.%s.%s FILE_FORMAT = (TYPE = CSV FIELD_OPTIONALLY_ENCLOSED_BY='0x27')", databaseName, SchemaName, flowsTableName, databaseName, SchemaName, ingestionStageName),
}, pulumi.Parent(schema), pulumi.DependsOn([]pulumi.Resource{ingestionStage, dbMigrations}), pulumi.DeleteBeforeReplace(true))
if err != nil {
return nil, err
Expand Down
4 changes: 2 additions & 2 deletions snowflake/pkg/infra/temporary_warehouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
"fmt"
"strings"

"github.com/dustinkirkland/golang-petname"
petname "github.com/dustinkirkland/golang-petname"
"github.com/go-logr/logr"

sf "antrea.io/theia/snowflake/pkg/snowflake"
Expand All @@ -31,7 +31,7 @@ type temporaryWarehouse struct {
warehouseName string
}

func newTemporaryWarehouse(sfClient sf.Client, logger logr.Logger) *temporaryWarehouse {
func NewTemporaryWarehouse(sfClient sf.Client, logger logr.Logger) *temporaryWarehouse {
return &temporaryWarehouse{
sfClient: sfClient,
logger: logger,
Expand Down
7 changes: 4 additions & 3 deletions snowflake/pkg/infra/udfs.go → snowflake/pkg/udfs/udfs.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package infra
package udfs

import (
"context"
Expand All @@ -22,6 +22,7 @@ import (

"github.com/go-logr/logr"

"antrea.io/theia/snowflake/pkg/infra"
sf "antrea.io/theia/snowflake/pkg/snowflake"
)

Expand Down Expand Up @@ -50,12 +51,12 @@ func RunUdf(ctx context.Context, logger logr.Logger, query string, databaseName
return nil, err
}

if err := sfClient.UseSchema(ctx, schemaName); err != nil {
if err := sfClient.UseSchema(ctx, infra.SchemaName); err != nil {
return nil, err
}

if warehouseName == "" {
temporaryWarehouse := newTemporaryWarehouse(sfClient, logger)
temporaryWarehouse := infra.NewTemporaryWarehouse(sfClient, logger)
warehouseName = temporaryWarehouse.Name()
if err := temporaryWarehouse.Create(ctx); err != nil {
return nil, err
Expand Down
Loading

0 comments on commit 10e73e5

Please sign in to comment.