Skip to content

Commit

Permalink
Address 3rd round comments
Browse files Browse the repository at this point in the history
Signed-off-by: Yongming Ding <dyongming@vmware.com>
  • Loading branch information
Yongming Ding committed Dec 2, 2022
1 parent 10e73e5 commit adc238c
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 23 deletions.
7 changes: 5 additions & 2 deletions snowflake/Makefile
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
GO ?= go
BINDIR := $(CURDIR)/bin

all: bin
all: udfs bin

.PHONY: udfs
udfs:
make -C udfs/udfs/

.PHONY: bin
bin:
make -C udfs/udfs/
$(GO) build -o $(BINDIR)/theia-sf antrea.io/theia/snowflake

.PHONY: test
Expand Down
14 changes: 7 additions & 7 deletions snowflake/pkg/infra/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import (

"antrea.io/theia/snowflake/database"
sf "antrea.io/theia/snowflake/pkg/snowflake"
utils "antrea.io/theia/snowflake/pkg/utils"
fileutils "antrea.io/theia/snowflake/pkg/utils/file"
"antrea.io/theia/snowflake/udfs"
)

Expand Down Expand Up @@ -85,7 +85,7 @@ func installPulumiCLI(ctx context.Context, logger logr.Logger, dir string) error
if err := os.MkdirAll(filepath.Join(dir, "pulumi"), 0755); err != nil {
return err
}
if err := utils.DownloadAndUntar(ctx, logger, url, dir); err != nil {
if err := fileutils.DownloadAndUntar(ctx, logger, url, dir); err != nil {
return err
}

Expand Down Expand Up @@ -118,7 +118,7 @@ func installMigrateSnowflakeCLI(ctx context.Context, logger logr.Logger, dir str
return fmt.Errorf("OS / arch combination is not supported: %s / %s", operatingSystem, arch)
}
url := fmt.Sprintf("https://github.com/antoninbas/migrate-snowflake/releases/download/%s/migrate-snowflake_%s_%s.tar.gz", migrateSnowflakeVersion, migrateSnowflakeVersion, target)
if err := utils.DownloadAndUntar(ctx, logger, url, dir); err != nil {
if err := fileutils.DownloadAndUntar(ctx, logger, url, dir); err != nil {
return err
}

Expand Down Expand Up @@ -282,7 +282,7 @@ func (m *Manager) run(ctx context.Context, destroy bool) (*Result, error) {
warehouseName := m.warehouseName
if !destroy {
logger.Info("Copying database migrations to disk")
if err := utils.WriteEmbedDirToDisk(ctx, logger, database.Migrations, database.MigrationsPath, filepath.Join(workdir, migrationsDir)); err != nil {
if err := fileutils.WriteFSDirToDisk(ctx, logger, database.Migrations, database.MigrationsPath, filepath.Join(workdir, migrationsDir)); err != nil {
return nil, err
}
logger.Info("Copied database migrations to disk")
Expand Down Expand Up @@ -445,7 +445,7 @@ func createUdfs(ctx context.Context, logger logr.Logger, databaseName string, wa
}

// Download and stage Kubernetes python client for policy recommendation udf
k8sPythonClientFilePath, err := utils.Download(ctx, logger, k8sPythonClientUrl, workdir, k8sPythonClientFileName)
k8sPythonClientFilePath, err := fileutils.Download(ctx, logger, k8sPythonClientUrl, workdir, k8sPythonClientFileName)
if err != nil {
return err
}
Expand All @@ -463,7 +463,7 @@ func createUdfs(ctx context.Context, logger logr.Logger, databaseName string, wa

logger.Info("Copying UDFs to disk")
udfsDirPath := filepath.Join(workdir, udfsDir)
if err := utils.WriteEmbedDirToDisk(ctx, logger, udfs.UdfsFs, udfs.UdfsPath, udfsDirPath); err != nil {
if err := fileutils.WriteFSDirToDisk(ctx, logger, udfs.UdfsFs, udfs.UdfsPath, udfsDirPath); err != nil {
return err
}
logger.Info("Copied UDFs to disk")
Expand Down Expand Up @@ -521,7 +521,7 @@ func createUdfs(ctx context.Context, logger logr.Logger, databaseName string, wa
return fmt.Errorf("version placeholder '%s' not found in SQL file", udfVersionPlaceholder)
}
query = strings.ReplaceAll(query, udfVersionPlaceholder, version)
_, err = sfClient.ExecMultiStatementQuery(ctx, query, false)
err = sfClient.ExecMultiStatement(ctx, query)
if err != nil {
return fmt.Errorf("error when creating UDF: %w", err)
}
Expand Down
20 changes: 11 additions & 9 deletions snowflake/pkg/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func (c *client) UseDatabase(ctx context.Context, name string) error {

func (c *client) UseSchema(ctx context.Context, name string) error {
query := fmt.Sprintf("USE SCHEMA %s", name)
c.logger.Info("Snowflake query", "query", query)
c.logger.V(2).Info("Snowflake query", "query", query)
_, err := c.db.ExecContext(ctx, query)
return err
}
Expand All @@ -127,14 +127,16 @@ func (c *client) StageFile(ctx context.Context, path string, stage string) error
return err
}

func (c *client) ExecMultiStatementQuery(ctx context.Context, query string, result bool) (*sql.Rows, error) {
func (c *client) ExecMultiStatement(ctx context.Context, query string) error {
multi_statement_context, _ := gosnowflake.WithMultiStatement(ctx, 0)
c.logger.Info("Snowflake query", "query", query)
if !result {
_, err := c.db.ExecContext(multi_statement_context, query)
return nil, err
} else {
rows, err := c.db.QueryContext(multi_statement_context, query)
return rows, err
}
_, err := c.db.ExecContext(multi_statement_context, query)
return err
}

func (c *client) QueryMultiStatement(ctx context.Context, query string) (*sql.Rows, error) {
multi_statement_context, _ := gosnowflake.WithMultiStatement(ctx, 0)
c.logger.Info("Snowflake query", "query", query)
rows, err := c.db.QueryContext(multi_statement_context, query)
return rows, err
}
2 changes: 1 addition & 1 deletion snowflake/pkg/udfs/udfs.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func RunUdf(ctx context.Context, logger logr.Logger, query string, databaseName
return nil, err
}

rows, err := sfClient.ExecMultiStatementQuery(ctx, query, true)
rows, err := sfClient.QueryMultiStatement(ctx, query)
if err != nil {
return nil, fmt.Errorf("error when running UDF: %w", err)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package utils
package file

import (
"archive/tar"
Expand Down Expand Up @@ -102,17 +102,17 @@ func DownloadAndUntar(ctx context.Context, logger logr.Logger, url string, dir s
return nil
}

func WriteEmbedDirToDisk(ctx context.Context, logger logr.Logger, fsys fs.FS, embedPath string, dest string) error {
func WriteFSDirToDisk(ctx context.Context, logger logr.Logger, fsys fs.FS, fsysPath string, dest string) error {
if err := os.MkdirAll(dest, 0755); err != nil {
return err
}

return fs.WalkDir(fsys, embedPath, func(path string, d fs.DirEntry, err error) error {
return fs.WalkDir(fsys, fsysPath, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}

outpath := filepath.Join(dest, strings.TrimPrefix(path, embedPath))
outpath := filepath.Join(dest, strings.TrimPrefix(path, fsysPath))

if d.IsDir() {
os.MkdirAll(outpath, 0755)
Expand Down

0 comments on commit adc238c

Please sign in to comment.