From adc238cf9ced111d045ef1634ab8ebf50fb6dc7f Mon Sep 17 00:00:00 2001 From: Yongming Ding Date: Fri, 2 Dec 2022 13:22:46 -0800 Subject: [PATCH] Address 3rd round comments Signed-off-by: Yongming Ding --- snowflake/Makefile | 7 +++++-- snowflake/pkg/infra/manager.go | 14 ++++++------- snowflake/pkg/snowflake/snowflake.go | 20 ++++++++++--------- snowflake/pkg/udfs/udfs.go | 2 +- .../pkg/utils/{utils.go => file/file.go} | 8 ++++---- 5 files changed, 28 insertions(+), 23 deletions(-) rename snowflake/pkg/utils/{utils.go => file/file.go} (90%) diff --git a/snowflake/Makefile b/snowflake/Makefile index 3bba35f87..899725e68 100644 --- a/snowflake/Makefile +++ b/snowflake/Makefile @@ -1,11 +1,14 @@ GO ?= go BINDIR := $(CURDIR)/bin -all: bin +all: udfs bin + +.PHONY: udfs +udfs: + make -C udfs/udfs/ .PHONY: bin bin: - make -C udfs/udfs/ $(GO) build -o $(BINDIR)/theia-sf antrea.io/theia/snowflake .PHONY: test diff --git a/snowflake/pkg/infra/manager.go b/snowflake/pkg/infra/manager.go index c08834f35..022056d9f 100644 --- a/snowflake/pkg/infra/manager.go +++ b/snowflake/pkg/infra/manager.go @@ -36,7 +36,7 @@ import ( "antrea.io/theia/snowflake/database" sf "antrea.io/theia/snowflake/pkg/snowflake" - utils "antrea.io/theia/snowflake/pkg/utils" + fileutils "antrea.io/theia/snowflake/pkg/utils/file" "antrea.io/theia/snowflake/udfs" ) @@ -85,7 +85,7 @@ func installPulumiCLI(ctx context.Context, logger logr.Logger, dir string) error if err := os.MkdirAll(filepath.Join(dir, "pulumi"), 0755); err != nil { return err } - if err := utils.DownloadAndUntar(ctx, logger, url, dir); err != nil { + if err := fileutils.DownloadAndUntar(ctx, logger, url, dir); err != nil { return err } @@ -118,7 +118,7 @@ func installMigrateSnowflakeCLI(ctx context.Context, logger logr.Logger, dir str return fmt.Errorf("OS / arch combination is not supported: %s / %s", operatingSystem, arch) } url := fmt.Sprintf("https://github.com/antoninbas/migrate-snowflake/releases/download/%s/migrate-snowflake_%s_%s.tar.gz", migrateSnowflakeVersion, migrateSnowflakeVersion, target) - if err := utils.DownloadAndUntar(ctx, logger, url, dir); err != nil { + if err := fileutils.DownloadAndUntar(ctx, logger, url, dir); err != nil { return err } @@ -282,7 +282,7 @@ func (m *Manager) run(ctx context.Context, destroy bool) (*Result, error) { warehouseName := m.warehouseName if !destroy { logger.Info("Copying database migrations to disk") - if err := utils.WriteEmbedDirToDisk(ctx, logger, database.Migrations, database.MigrationsPath, filepath.Join(workdir, migrationsDir)); err != nil { + if err := fileutils.WriteFSDirToDisk(ctx, logger, database.Migrations, database.MigrationsPath, filepath.Join(workdir, migrationsDir)); err != nil { return nil, err } logger.Info("Copied database migrations to disk") @@ -445,7 +445,7 @@ func createUdfs(ctx context.Context, logger logr.Logger, databaseName string, wa } // Download and stage Kubernetes python client for policy recommendation udf - k8sPythonClientFilePath, err := utils.Download(ctx, logger, k8sPythonClientUrl, workdir, k8sPythonClientFileName) + k8sPythonClientFilePath, err := fileutils.Download(ctx, logger, k8sPythonClientUrl, workdir, k8sPythonClientFileName) if err != nil { return err } @@ -463,7 +463,7 @@ func createUdfs(ctx context.Context, logger logr.Logger, databaseName string, wa logger.Info("Copying UDFs to disk") udfsDirPath := filepath.Join(workdir, udfsDir) - if err := utils.WriteEmbedDirToDisk(ctx, logger, udfs.UdfsFs, udfs.UdfsPath, udfsDirPath); err != nil { + if err := fileutils.WriteFSDirToDisk(ctx, logger, udfs.UdfsFs, udfs.UdfsPath, udfsDirPath); err != nil { return err } logger.Info("Copied UDFs to disk") @@ -521,7 +521,7 @@ func createUdfs(ctx context.Context, logger logr.Logger, databaseName string, wa return fmt.Errorf("version placeholder '%s' not found in SQL file", udfVersionPlaceholder) } query = strings.ReplaceAll(query, udfVersionPlaceholder, version) - _, err = sfClient.ExecMultiStatementQuery(ctx, query, false) + err = sfClient.ExecMultiStatement(ctx, query) if err != nil { return fmt.Errorf("error when creating UDF: %w", err) } diff --git a/snowflake/pkg/snowflake/snowflake.go b/snowflake/pkg/snowflake/snowflake.go index bbc31c214..283ef34f4 100644 --- a/snowflake/pkg/snowflake/snowflake.go +++ b/snowflake/pkg/snowflake/snowflake.go @@ -115,7 +115,7 @@ func (c *client) UseDatabase(ctx context.Context, name string) error { func (c *client) UseSchema(ctx context.Context, name string) error { query := fmt.Sprintf("USE SCHEMA %s", name) - c.logger.Info("Snowflake query", "query", query) + c.logger.V(2).Info("Snowflake query", "query", query) _, err := c.db.ExecContext(ctx, query) return err } @@ -127,14 +127,16 @@ func (c *client) StageFile(ctx context.Context, path string, stage string) error return err } -func (c *client) ExecMultiStatementQuery(ctx context.Context, query string, result bool) (*sql.Rows, error) { +func (c *client) ExecMultiStatement(ctx context.Context, query string) error { multi_statement_context, _ := gosnowflake.WithMultiStatement(ctx, 0) c.logger.Info("Snowflake query", "query", query) - if !result { - _, err := c.db.ExecContext(multi_statement_context, query) - return nil, err - } else { - rows, err := c.db.QueryContext(multi_statement_context, query) - return rows, err - } + _, err := c.db.ExecContext(multi_statement_context, query) + return err +} + +func (c *client) QueryMultiStatement(ctx context.Context, query string) (*sql.Rows, error) { + multi_statement_context, _ := gosnowflake.WithMultiStatement(ctx, 0) + c.logger.Info("Snowflake query", "query", query) + rows, err := c.db.QueryContext(multi_statement_context, query) + return rows, err } diff --git a/snowflake/pkg/udfs/udfs.go b/snowflake/pkg/udfs/udfs.go index f1cf1506b..a51bc237a 100644 --- a/snowflake/pkg/udfs/udfs.go +++ b/snowflake/pkg/udfs/udfs.go @@ -72,7 +72,7 @@ func RunUdf(ctx context.Context, logger logr.Logger, query string, databaseName return nil, err } - rows, err := sfClient.ExecMultiStatementQuery(ctx, query, true) + rows, err := sfClient.QueryMultiStatement(ctx, query) if err != nil { return nil, fmt.Errorf("error when running UDF: %w", err) } diff --git a/snowflake/pkg/utils/utils.go b/snowflake/pkg/utils/file/file.go similarity index 90% rename from snowflake/pkg/utils/utils.go rename to snowflake/pkg/utils/file/file.go index 8e406790c..e12261d47 100644 --- a/snowflake/pkg/utils/utils.go +++ b/snowflake/pkg/utils/file/file.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package utils +package file import ( "archive/tar" @@ -102,17 +102,17 @@ func DownloadAndUntar(ctx context.Context, logger logr.Logger, url string, dir s return nil } -func WriteEmbedDirToDisk(ctx context.Context, logger logr.Logger, fsys fs.FS, embedPath string, dest string) error { +func WriteFSDirToDisk(ctx context.Context, logger logr.Logger, fsys fs.FS, fsysPath string, dest string) error { if err := os.MkdirAll(dest, 0755); err != nil { return err } - return fs.WalkDir(fsys, embedPath, func(path string, d fs.DirEntry, err error) error { + return fs.WalkDir(fsys, fsysPath, func(path string, d fs.DirEntry, err error) error { if err != nil { return err } - outpath := filepath.Join(dest, strings.TrimPrefix(path, embedPath)) + outpath := filepath.Join(dest, strings.TrimPrefix(path, fsysPath)) if d.IsDir() { os.MkdirAll(outpath, 0755)