From 8f0b7935eaefbcfc2e38fff95a10c10e79a833b0 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Mon, 14 Oct 2024 16:25:54 -0400 Subject: [PATCH 01/14] feat(go/adbc/driver/snowflake): improve GetObjects performance and semantics --- .../driver/internal/driverbase/connection.go | 42 ++- go/adbc/driver/snowflake/connection.go | 306 ++++++++++++------ .../queries/get_objects_catalogs.sql | 25 -- .../queries/get_objects_dbschemas.sql | 32 +- .../snowflake/queries/get_objects_tables.sql | 94 +++--- 5 files changed, 305 insertions(+), 194 deletions(-) delete mode 100644 go/adbc/driver/snowflake/queries/get_objects_catalogs.sql diff --git a/go/adbc/driver/internal/driverbase/connection.go b/go/adbc/driver/internal/driverbase/connection.go index 6e78816351..37433e0ced 100644 --- a/go/adbc/driver/internal/driverbase/connection.go +++ b/go/adbc/driver/internal/driverbase/connection.go @@ -349,14 +349,17 @@ func (cnxn *connection) GetObjects(ctx context.Context, depth adbc.ObjectDepth, bufferSize := len(catalogs) addCatalogCh := make(chan GetObjectsInfo, bufferSize) - for _, cat := range catalogs { - addCatalogCh <- GetObjectsInfo{CatalogName: Nullable(cat)} - } - - close(addCatalogCh) + errCh := make(chan error, 1) + go func() { + defer close(addCatalogCh) + for _, cat := range catalogs { + addCatalogCh <- GetObjectsInfo{CatalogName: Nullable(cat)} + } + }() if depth == adbc.ObjectDepthCatalogs { - return BuildGetObjectsRecordReader(cnxn.Base().Alloc, addCatalogCh) + close(errCh) + return BuildGetObjectsRecordReader(cnxn.Base().Alloc, addCatalogCh, errCh) } g, ctxG := errgroup.WithContext(ctx) @@ -386,7 +389,7 @@ func (cnxn *connection) GetObjects(ctx context.Context, depth adbc.ObjectDepth, g.Go(func() error { defer close(addDbSchemasCh); return gSchemas.Wait() }) if depth == adbc.ObjectDepthDBSchemas { - rdr, err := BuildGetObjectsRecordReader(cnxn.Base().Alloc, addDbSchemasCh) + rdr, err := BuildGetObjectsRecordReader(cnxn.Base().Alloc, addDbSchemasCh, errCh) return rdr, errors.Join(err, g.Wait()) } @@ -432,7 +435,7 @@ func (cnxn *connection) GetObjects(ctx context.Context, depth adbc.ObjectDepth, g.Go(func() error { defer close(addTablesCh); return gTables.Wait() }) - rdr, err := BuildGetObjectsRecordReader(cnxn.Base().Alloc, addTablesCh) + rdr, err := BuildGetObjectsRecordReader(cnxn.Base().Alloc, addTablesCh, errCh) return rdr, errors.Join(err, g.Wait()) } @@ -659,17 +662,26 @@ func (g *GetObjectsInfo) Scan(src any) error { // BuildGetObjectsRecordReader constructs a RecordReader for the GetObjects ADBC method. // It accepts a channel of GetObjectsInfo to allow concurrent retrieval of metadata and // serialization to Arrow record. -func BuildGetObjectsRecordReader(mem memory.Allocator, in chan GetObjectsInfo) (array.RecordReader, error) { +func BuildGetObjectsRecordReader(mem memory.Allocator, in <-chan GetObjectsInfo, errCh <-chan error) (array.RecordReader, error) { bldr := array.NewRecordBuilder(mem, adbc.GetObjectsSchema) defer bldr.Release() - for catalog := range in { - b, err := json.Marshal(catalog) - if err != nil { - return nil, err - } +CATALOGLOOP: + for { + select { + case catalog, ok := <-in: + if !ok { + break CATALOGLOOP + } + b, err := json.Marshal(catalog) + if err != nil { + return nil, err + } - if err := json.Unmarshal(b, bldr); err != nil { + if err := json.Unmarshal(b, bldr); err != nil { + return nil, err + } + case err := <-errCh: return nil, err } } diff --git a/go/adbc/driver/snowflake/connection.go b/go/adbc/driver/snowflake/connection.go index a8361a3653..77d3c2cd5f 100644 --- a/go/adbc/driver/snowflake/connection.go +++ b/go/adbc/driver/snowflake/connection.go @@ -24,6 +24,7 @@ import ( "embed" "fmt" "io" + "io/fs" "path" "strconv" "strings" @@ -42,7 +43,6 @@ const ( defaultPrefetchConcurrency = 10 queryTemplateGetObjectsAll = "get_objects_all.sql" - queryTemplateGetObjectsCatalogs = "get_objects_catalogs.sql" queryTemplateGetObjectsDbSchemas = "get_objects_dbschemas.sql" queryTemplateGetObjectsTables = "get_objects_tables.sql" queryTemplateGetObjectsTerseCatalogs = "get_objects_terse_catalogs.sql" @@ -73,9 +73,51 @@ type connectionImpl struct { useHighPrecision bool } +func escapeSingleQuoteForLike(arg string) string { + if len(arg) == 0 { + return arg + } + + idx := strings.IndexByte(arg, '\'') + if idx == -1 { + return arg + } + + var b strings.Builder + b.Grow(len(arg)) + + for { + before, after, found := strings.Cut(arg, `'`) + b.WriteString(before) + if !found { + return b.String() + } + + if before[len(before)-1] != '\\' { + b.WriteByte('\\') + } + b.WriteByte('\'') + arg = after + } +} + +func getQueryID(ctx context.Context, query string, driverConn any) (string, error) { + rows, err := driverConn.(driver.QueryerContext).QueryContext(ctx, query, nil) + if err != nil { + return "", err + } + + return rows.(gosnowflake.SnowflakeRows).GetQueryID(), rows.Close() +} + +func isWildcardStr(ident string) bool { + return strings.ContainsAny(ident, "_%") +} + func (c *connectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string) (array.RecordReader, error) { var ( pkQueryID, fkQueryID, uniqueQueryID, terseDbQueryID string + showSchemaQueryID, tableQueryID string ) conn, err := c.sqldb.Conn(ctx) @@ -85,82 +127,165 @@ func (c *connectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth, defer conn.Close() gQueryIDs, gQueryIDsCtx := errgroup.WithContext(ctx) + queryFile := queryTemplateGetObjectsAll switch depth { case adbc.ObjectDepthCatalogs: - if catalog == nil { - queryFile = queryTemplateGetObjectsTerseCatalogs - // if the catalog is null, show the terse databases - // which doesn't require a database context - gQueryIDs.Go(func() error { - return conn.Raw(func(driverConn any) error { - rows, err := driverConn.(driver.QueryerContext).QueryContext(gQueryIDsCtx, "SHOW TERSE DATABASES", nil) - if err != nil { - return err - } + queryFile = queryTemplateGetObjectsTerseCatalogs + gQueryIDs.Go(func() error { + return conn.Raw(func(driverConn any) (err error) { + query := "SHOW TERSE /* ADBC:getObjectsCatalogs */ DATABASES" + if catalog != nil && len(*catalog) > 0 && *catalog != "%" && *catalog != ".*" { + query += " LIKE '" + escapeSingleQuoteForLike(*catalog) + "'" + } + query += " IN ACCOUNT" - terseDbQueryID = rows.(gosnowflake.SnowflakeRows).GetQueryID() - return rows.Close() - }) + terseDbQueryID, err = getQueryID(gQueryIDsCtx, query, driverConn) + return }) - } else { - queryFile = queryTemplateGetObjectsCatalogs - } + }) case adbc.ObjectDepthDBSchemas: queryFile = queryTemplateGetObjectsDbSchemas + gQueryIDs.Go(func() error { + return conn.Raw(func(driverConn any) (err error) { + query := "SHOW TERSE /* ADBC:getObjectsDBSchemas */ SCHEMAS" + if dbSchema != nil && len(*dbSchema) > 0 && *dbSchema != "%" && *dbSchema != ".*" { + query += " LIKE '" + escapeSingleQuoteForLike(*dbSchema) + "'" + } + if catalog == nil || isWildcardStr(*catalog) { + query += " IN ACCOUNT" + } else { + query += " IN DATABASE \"" + quoteTblName(*catalog) + "\"" + } + + showSchemaQueryID, err = getQueryID(gQueryIDsCtx, query, driverConn) + return + }) + }) + + gQueryIDs.Go(func() error { + return conn.Raw(func(driverConn any) (err error) { + query := "SHOW TERSE /* ADBC:getObjectsDBSchemas */ DATABASES" + if catalog != nil && len(*catalog) > 0 && *catalog != "%" && *catalog != ".*" { + query += " LIKE '" + escapeSingleQuoteForLike(*catalog) + "'" + } + query += " IN ACCOUNT" + + terseDbQueryID, err = getQueryID(gQueryIDsCtx, query, driverConn) + return + }) + }) case adbc.ObjectDepthTables: queryFile = queryTemplateGetObjectsTables fallthrough default: + var suffix string + if catalog == nil { + suffix = " IN ACCOUNT" + } else { + escapedCatalog := quoteTblName(*catalog) + if dbSchema == nil || isWildcardStr(*dbSchema) { + suffix = " IN DATABASE \"" + escapedCatalog + "\"" + } else { + escapedSchema := quoteTblName(*dbSchema) + if tableName == nil || isWildcardStr(*tableName) { + suffix = " IN SCHEMA \"" + escapedCatalog + "\".\"" + escapedSchema + "\"" + } else { + escapedTable := quoteTblName(*tableName) + suffix = " IN TABLE \"" + escapedCatalog + "\".\"" + escapedSchema + "\".\"" + escapedTable + "\"" + } + } + } + // Detailed constraint info not available in information_schema // Need to dispatch SHOW queries and use conn.Raw to extract the queryID for reuse in GetObjects query gQueryIDs.Go(func() error { - return conn.Raw(func(driverConn any) error { - rows, err := driverConn.(driver.QueryerContext).QueryContext(gQueryIDsCtx, "SHOW PRIMARY KEYS", nil) - if err != nil { - return err + return conn.Raw(func(driverConn any) (err error) { + pkQueryID, err = getQueryID(gQueryIDsCtx, "SHOW PRIMARY KEYS /* ADBC:getObjectsTables */"+suffix, driverConn) + return err + }) + }) + + gQueryIDs.Go(func() error { + return conn.Raw(func(driverConn any) (err error) { + fkQueryID, err = getQueryID(gQueryIDsCtx, "SHOW IMPORTED KEYS /* ADBC:getObjectsTables */"+suffix, driverConn) + return err + }) + }) + + gQueryIDs.Go(func() error { + return conn.Raw(func(driverConn any) (err error) { + uniqueQueryID, err = getQueryID(gQueryIDsCtx, "SHOW UNIQUE KEYS /* ADBC:getObjectsTables */"+suffix, driverConn) + return err + }) + }) + + gQueryIDs.Go(func() error { + return conn.Raw(func(driverConn any) (err error) { + query := "SHOW TERSE /* ADBC:getObjectsDBSchemas */ SCHEMAS" + if dbSchema != nil && len(*dbSchema) > 0 && *dbSchema != "%" && *dbSchema != ".*" { + query += " LIKE '" + escapeSingleQuoteForLike(*dbSchema) + "'" + } + if catalog == nil || isWildcardStr(*catalog) { + query += " IN ACCOUNT" + } else { + query += " IN DATABASE \"" + quoteTblName(*catalog) + "\"" } - pkQueryID = rows.(gosnowflake.SnowflakeRows).GetQueryID() - return rows.Close() + showSchemaQueryID, err = getQueryID(gQueryIDsCtx, query, driverConn) + return }) }) gQueryIDs.Go(func() error { - return conn.Raw(func(driverConn any) error { - rows, err := driverConn.(driver.QueryerContext).QueryContext(gQueryIDsCtx, "SHOW IMPORTED KEYS", nil) - if err != nil { - return err + return conn.Raw(func(driverConn any) (err error) { + query := "SHOW TERSE /* ADBC:getObjectsDBSchemas */ DATABASES" + if catalog != nil && len(*catalog) > 0 && *catalog != "%" && *catalog != ".*" { + query += " LIKE '" + escapeSingleQuoteForLike(*catalog) + "'" } + query += " IN ACCOUNT" - fkQueryID = rows.(gosnowflake.SnowflakeRows).GetQueryID() - return rows.Close() + terseDbQueryID, err = getQueryID(gQueryIDsCtx, query, driverConn) + return }) }) gQueryIDs.Go(func() error { - return conn.Raw(func(driverConn any) error { - rows, err := driverConn.(driver.QueryerContext).QueryContext(gQueryIDsCtx, "SHOW UNIQUE KEYS", nil) - if err != nil { - return err + return conn.Raw(func(driverConn any) (err error) { + objType := "objects" + if len(tableType) == 1 { + if strings.EqualFold("VIEW", tableType[0]) { + objType = "views" + } else if strings.EqualFold("TABLE", tableType[0]) { + objType = "tables" + } } - uniqueQueryID = rows.(gosnowflake.SnowflakeRows).GetQueryID() - return rows.Close() + query := "SHOW TERSE /* ADBC:getObjectsTables */ " + objType + if tableName != nil && len(*tableName) > 0 && *tableName != "%" && *tableName != ".*" { + query += " LIKE '" + escapeSingleQuoteForLike(*tableName) + "'" + } + if catalog == nil || isWildcardStr(*catalog) { + query += " IN ACCOUNT" + } else { + escapedCatalog := quoteTblName(*catalog) + if dbSchema == nil || isWildcardStr(*dbSchema) { + query += " IN DATABASE \"" + escapedCatalog + "\"" + } else { + query += " IN SCHEMA \"" + escapedCatalog + "\".\"" + quoteTblName(*dbSchema) + "\"" + } + } + + tableQueryID, err = getQueryID(gQueryIDsCtx, query, driverConn) + return }) }) } - f, err := queryTemplates.Open(path.Join("queries", queryFile)) + queryBytes, err := fs.ReadFile(queryTemplates, path.Join("queries", queryFile)) if err != nil { return nil, err } - defer f.Close() - - var bldr strings.Builder - if _, err := io.Copy(&bldr, f); err != nil { - return nil, err - } // Need constraint subqueries to complete before we can query GetObjects if err := gQueryIDs.Wait(); err != nil { @@ -180,76 +305,71 @@ func (c *connectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth, sql.Named("FK_QUERY_ID", fkQueryID), sql.Named("UNIQUE_QUERY_ID", uniqueQueryID), sql.Named("SHOW_DB_QUERY_ID", terseDbQueryID), - } - - // the connection that is used is not the same connection context where the database may have been set - // if the caller called SetCurrentCatalog() so need to ensure the database context is appropriate - if !isNilOrEmpty(catalog) { - _, e := conn.ExecContext(context.Background(), fmt.Sprintf("USE DATABASE %s;", quoteTblName(*catalog)), nil) - if e != nil { - return nil, errToAdbcErr(adbc.StatusIO, e) + sql.Named("SHOW_SCHEMA_QUERY_ID", showSchemaQueryID), + sql.Named("SHOW_TABLE_QUERY_ID", tableQueryID), + } + + // currently only the Columns / all case still requires a current database/schema + // to be propagated. The rest of the cases all solely use SHOW queries for the metadata + // just as done by the snowflake JDBC driver. In those cases we don't need to propagate + // the current session database/schema. + if depth == adbc.ObjectDepthColumns || depth == adbc.ObjectDepthAll { + // the connection that is used is not the same connection context where the database may have been set + // if the caller called SetCurrentCatalog() so need to ensure the database context is appropriate + if !isNilOrEmpty(catalog) { + _, e := conn.ExecContext(context.Background(), fmt.Sprintf("USE DATABASE %s;", quoteTblName(*catalog)), nil) + if e != nil { + return nil, errToAdbcErr(adbc.StatusIO, e) + } } - } - // the connection that is used is not the same connection context where the schema may have been set - // if the caller called SetCurrentDbSchema() so need to ensure the schema context is appropriate - if !isNilOrEmpty(dbSchema) { - _, e2 := conn.ExecContext(context.Background(), fmt.Sprintf("USE SCHEMA %s;", quoteTblName(*dbSchema)), nil) - if e2 != nil { - return nil, errToAdbcErr(adbc.StatusIO, e2) + // the connection that is used is not the same connection context where the schema may have been set + // if the caller called SetCurrentDbSchema() so need to ensure the schema context is appropriate + if !isNilOrEmpty(dbSchema) { + _, e2 := conn.ExecContext(context.Background(), fmt.Sprintf("USE SCHEMA %s;", quoteTblName(*dbSchema)), nil) + if e2 != nil { + return nil, errToAdbcErr(adbc.StatusIO, e2) + } } } - query := bldr.String() + query := string(queryBytes) rows, err := conn.QueryContext(ctx, query, args...) if err != nil { return nil, errToAdbcErr(adbc.StatusIO, err) } defer rows.Close() - catalogCh := make(chan driverbase.GetObjectsInfo, 1) - readerCh := make(chan array.RecordReader) + catalogCh := make(chan driverbase.GetObjectsInfo, 5) errCh := make(chan error) go func() { - rdr, err := driverbase.BuildGetObjectsRecordReader(c.Alloc, catalogCh) - if err != nil { - errCh <- err - } - - readerCh <- rdr - close(readerCh) - }() - - for rows.Next() { - var getObjectsCatalog driverbase.GetObjectsInfo - if err := rows.Scan(&getObjectsCatalog); err != nil { - return nil, errToAdbcErr(adbc.StatusInvalidData, err) - } + defer close(catalogCh) + for rows.Next() { + var getObjectsCatalog driverbase.GetObjectsInfo + if err := rows.Scan(&getObjectsCatalog); err != nil { + errCh <- errToAdbcErr(adbc.StatusInvalidData, err) + return + } - // A few columns need additional processing outside of Snowflake - for i, sch := range getObjectsCatalog.CatalogDbSchemas { - for j, tab := range sch.DbSchemaTables { - for k, col := range tab.TableColumns { - field := c.toArrowField(col) - xdbcDataType := driverbase.ToXdbcDataType(field.Type) + // A few columns need additional processing outside of Snowflake + for i, sch := range getObjectsCatalog.CatalogDbSchemas { + for j, tab := range sch.DbSchemaTables { + for k, col := range tab.TableColumns { + field := c.toArrowField(col) + xdbcDataType := driverbase.ToXdbcDataType(field.Type) - getObjectsCatalog.CatalogDbSchemas[i].DbSchemaTables[j].TableColumns[k].XdbcDataType = driverbase.Nullable(int16(field.Type.ID())) - getObjectsCatalog.CatalogDbSchemas[i].DbSchemaTables[j].TableColumns[k].XdbcSqlDataType = driverbase.Nullable(int16(xdbcDataType)) + getObjectsCatalog.CatalogDbSchemas[i].DbSchemaTables[j].TableColumns[k].XdbcDataType = driverbase.Nullable(int16(field.Type.ID())) + getObjectsCatalog.CatalogDbSchemas[i].DbSchemaTables[j].TableColumns[k].XdbcSqlDataType = driverbase.Nullable(int16(xdbcDataType)) + } } } - } - catalogCh <- getObjectsCatalog - } - close(catalogCh) + catalogCh <- getObjectsCatalog + } + }() - select { - case rdr := <-readerCh: - return rdr, nil - case err := <-errCh: - return nil, err - } + return driverbase.BuildGetObjectsRecordReader(c.Alloc, catalogCh, errCh) } func isNilOrEmpty(str *string) bool { @@ -266,7 +386,7 @@ func (c *connectionImpl) PrepareDriverInfo(ctx context.Context, infoCodes []adbc // ListTableTypes implements driverbase.TableTypeLister. func (*connectionImpl) ListTableTypes(ctx context.Context) ([]string, error) { - return []string{"BASE TABLE", "TEMPORARY TABLE", "VIEW"}, nil + return []string{"TABLE", "VIEW"}, nil } // GetCurrentCatalog implements driverbase.CurrentNamespacer. diff --git a/go/adbc/driver/snowflake/queries/get_objects_catalogs.sql b/go/adbc/driver/snowflake/queries/get_objects_catalogs.sql deleted file mode 100644 index ec2cef5157..0000000000 --- a/go/adbc/driver/snowflake/queries/get_objects_catalogs.sql +++ /dev/null @@ -1,25 +0,0 @@ --- Licensed to the Apache Software Foundation (ASF) under one --- or more contributor license agreements. See the NOTICE file --- distributed with this work for additional information --- regarding copyright ownership. The ASF licenses this file --- to you under the Apache License, Version 2.0 (the --- "License"); you may not use this file except in compliance --- with the License. You may obtain a copy of the License at --- --- http://www.apache.org/licenses/LICENSE-2.0 --- --- Unless required by applicable law or agreed to in writing, --- software distributed under the License is distributed on an --- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY --- KIND, either express or implied. See the License for the --- specific language governing permissions and limitations --- under the License. - -SELECT - { - 'catalog_name': database_name, - 'catalog_db_schemas': null - } get_objects -FROM - information_schema.databases -WHERE database_name ILIKE :CATALOG; diff --git a/go/adbc/driver/snowflake/queries/get_objects_dbschemas.sql b/go/adbc/driver/snowflake/queries/get_objects_dbschemas.sql index 360a6d0837..627d11c321 100644 --- a/go/adbc/driver/snowflake/queries/get_objects_dbschemas.sql +++ b/go/adbc/driver/snowflake/queries/get_objects_dbschemas.sql @@ -16,23 +16,27 @@ -- under the License. WITH db_schemas AS ( - SELECT - catalog_name, - schema_name, - FROM information_schema.schemata - WHERE catalog_name ILIKE :CATALOG AND schema_name ILIKE :DB_SCHEMA + SELECT + "database_name" as "catalog_name", + "name" as "schema_name" + FROM table(RESULT_SCAN(:SHOW_SCHEMA_QUERY_ID)) + WHERE "database_name" ILIKE :CATALOG +), db_info AS ( + SELECT "name" AS "database_name" + FROM table(RESULT_SCAN(:SHOW_DB_QUERY_ID)) + WHERE "name" ILIKE :CATALOG ) -SELECT +SELECT { - 'catalog_name': database_name, + 'catalog_name': "database_name", 'catalog_db_schemas': ARRAY_AGG({ - 'db_schema_name': schema_name, + 'db_schema_name': "schema_name", 'db_schema_tables': null }) } get_objects -FROM - information_schema.databases -LEFT JOIN db_schemas -ON database_name = catalog_name -WHERE database_name ILIKE :CATALOG -GROUP BY database_name; +FROM + db_info +LEFT JOIN db_schemas +ON "database_name" = "catalog_name" +WHERE "database_name" ILIKE :CATALOG +GROUP BY "database_name"; diff --git a/go/adbc/driver/snowflake/queries/get_objects_tables.sql b/go/adbc/driver/snowflake/queries/get_objects_tables.sql index b3b16ff515..ec284947c4 100644 --- a/go/adbc/driver/snowflake/queries/get_objects_tables.sql +++ b/go/adbc/driver/snowflake/queries/get_objects_tables.sql @@ -17,37 +17,37 @@ WITH pk_constraints AS ( SELECT - "database_name" table_catalog, - "schema_name" table_schema, - "table_name" table_name, - "constraint_name" constraint_name, - 'PRIMARY KEY' constraint_type, + "database_name" "table_catalog", + "schema_name" "table_schema", + "table_name" "table_name", + "constraint_name" "constraint_name", + 'PRIMARY KEY' "constraint_type", ARRAY_AGG("column_name") WITHIN GROUP (ORDER BY "key_sequence") constraint_column_names, [] constraint_column_usage, - FROM TABLE(RESULT_SCAN(:PK_QUERY_ID)) - WHERE table_catalog ILIKE :CATALOG AND table_schema ILIKE :DB_SCHEMA AND table_name ILIKE :TABLE - GROUP BY table_catalog, table_schema, table_name, "constraint_name" + FROM TABLE(RESULT_SCAN(LAST_QUERY_ID(-4))) + WHERE "table_catalog" ILIKE :CATALOG AND "table_schema" ILIKE :DB_SCHEMA AND "table_name" ILIKE :TABLE + GROUP BY "table_catalog", "table_schema", "table_name", "constraint_name" ), unique_constraints AS ( SELECT - "database_name" table_catalog, - "schema_name" table_schema, - "table_name" table_name, - "constraint_name" constraint_name, - 'UNIQUE' constraint_type, + "database_name" "table_catalog", + "schema_name" "table_schema", + "table_name" "table_name", + "constraint_name" "constraint_name", + 'UNIQUE' "constraint_type", ARRAY_AGG("column_name") WITHIN GROUP (ORDER BY "key_sequence") constraint_column_names, [] constraint_column_usage, FROM TABLE(RESULT_SCAN(:UNIQUE_QUERY_ID)) - WHERE table_catalog ILIKE :CATALOG AND table_schema ILIKE :DB_SCHEMA AND table_name ILIKE :TABLE - GROUP BY table_catalog, table_schema, table_name, "constraint_name" + WHERE "table_catalog" ILIKE :CATALOG AND "table_schema" ILIKE :DB_SCHEMA AND "table_name" ILIKE :TABLE + GROUP BY "table_catalog", "table_schema", "table_name", "constraint_name" ), fk_constraints AS ( SELECT - "fk_database_name" table_catalog, - "fk_schema_name" table_schema, - "fk_table_name" table_name, - "fk_name" constraint_name, - 'FOREIGN KEY' constraint_type, + "fk_database_name" "table_catalog", + "fk_schema_name" "table_schema", + "fk_table_name" "table_name", + "fk_name" "constraint_name", + 'FOREIGN KEY' "constraint_type", ARRAY_AGG("fk_column_name") WITHIN GROUP (ORDER BY "key_sequence") constraint_column_names, ARRAY_AGG({ 'fk_catalog': "pk_database_name", @@ -56,17 +56,17 @@ fk_constraints AS ( 'fk_column_name': "pk_column_name" }) WITHIN GROUP (ORDER BY "key_sequence") constraint_column_usage, FROM TABLE(RESULT_SCAN(:FK_QUERY_ID)) - WHERE table_catalog ILIKE :CATALOG AND table_schema ILIKE :DB_SCHEMA AND table_name ILIKE :TABLE - GROUP BY table_catalog, table_schema, table_name, constraint_name + WHERE "table_catalog" ILIKE :CATALOG AND "table_schema" ILIKE :DB_SCHEMA AND "table_name" ILIKE :TABLE + GROUP BY "table_catalog", "table_schema", "table_name", "constraint_name" ), constraints AS ( SELECT - table_catalog, - table_schema, - table_name, + "table_catalog", + "table_schema", + "table_name", ARRAY_AGG({ - 'constraint_name': constraint_name, - 'constraint_type': constraint_type, + 'constraint_name': "constraint_name", + 'constraint_type': "constraint_type", 'constraint_column_names': constraint_column_names, 'constraint_column_usage': constraint_column_usage }) table_constraints, @@ -77,45 +77,45 @@ constraints AS ( UNION ALL SELECT * FROM fk_constraints ) - GROUP BY table_catalog, table_schema, table_name + GROUP BY "table_catalog", "table_schema", "table_name" ), tables AS ( SELECT - table_catalog catalog_name, - table_schema schema_name, + "database_name" "catalog_name", + "schema_name" "schema_name", ARRAY_AGG({ - 'table_name': table_name, - 'table_type': table_type, + 'table_name': "name", + 'table_type': "kind", 'table_constraints': table_constraints, 'table_columns': null }) db_schema_tables -FROM information_schema.tables +FROM TABLE(RESULT_SCAN(:SHOW_TABLE_QUERY_ID)) LEFT JOIN constraints -USING (table_catalog, table_schema, table_name) -WHERE table_catalog ILIKE :CATALOG AND table_schema ILIKE :DB_SCHEMA AND table_name ILIKE :TABLE -GROUP BY table_catalog, table_schema +ON "database_name" = "table_catalog" AND "schema_name" = "table_schema" AND "name" = "table_name" +WHERE "database_name" ILIKE :CATALOG AND "schema_name" ILIKE :DB_SCHEMA AND "name" ILIKE :TABLE +GROUP BY "database_name", "schema_name" ), db_schemas AS ( SELECT - catalog_name, - schema_name, + "database_name" "catalog_name", + "name" "schema_name", db_schema_tables, - FROM information_schema.schemata + FROM TABLE(RESULT_SCAN(:SHOW_SCHEMA_QUERY_ID)) LEFT JOIN tables - USING (catalog_name, schema_name) - WHERE catalog_name ILIKE :CATALOG AND schema_name ILIKE :DB_SCHEMA + ON "database_name" = "catalog_name" AND "name" = tables."schema_name" + WHERE "database_name" ILIKE :CATALOG AND "name" ILIKE :DB_SCHEMA ) SELECT { - 'catalog_name': database_name, + 'catalog_name': "name", 'catalog_db_schemas': ARRAY_AGG({ - 'db_schema_name': schema_name, + 'db_schema_name': db_schemas."schema_name", 'db_schema_tables': db_schema_tables }) } get_objects FROM - information_schema.databases + TABLE(RESULT_SCAN(:SHOW_DB_QUERY_ID)) LEFT JOIN db_schemas -ON database_name = catalog_name -WHERE database_name ILIKE :CATALOG -GROUP BY database_name; +ON "name" = "catalog_name" +WHERE "name" ILIKE :CATALOG +GROUP BY "name"; From 39e1ed273e102aad8fc67afb76ec9a87c8968610 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Mon, 14 Oct 2024 16:35:19 -0400 Subject: [PATCH 02/14] trim whitespace --- .../snowflake/queries/get_objects_dbschemas.sql | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/go/adbc/driver/snowflake/queries/get_objects_dbschemas.sql b/go/adbc/driver/snowflake/queries/get_objects_dbschemas.sql index 627d11c321..bc454866af 100644 --- a/go/adbc/driver/snowflake/queries/get_objects_dbschemas.sql +++ b/go/adbc/driver/snowflake/queries/get_objects_dbschemas.sql @@ -16,17 +16,17 @@ -- under the License. WITH db_schemas AS ( - SELECT + SELECT "database_name" as "catalog_name", "name" as "schema_name" FROM table(RESULT_SCAN(:SHOW_SCHEMA_QUERY_ID)) WHERE "database_name" ILIKE :CATALOG ), db_info AS ( - SELECT "name" AS "database_name" + SELECT "name" AS "database_name" FROM table(RESULT_SCAN(:SHOW_DB_QUERY_ID)) WHERE "name" ILIKE :CATALOG ) -SELECT +SELECT { 'catalog_name': "database_name", 'catalog_db_schemas': ARRAY_AGG({ @@ -34,9 +34,9 @@ SELECT 'db_schema_tables': null }) } get_objects -FROM - db_info -LEFT JOIN db_schemas +FROM + db_info +LEFT JOIN db_schemas ON "database_name" = "catalog_name" WHERE "database_name" ILIKE :CATALOG GROUP BY "database_name"; From 9acf8a8ab373cc9e249337d3677bfd3d7973b3c9 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Wed, 16 Oct 2024 12:04:05 -0400 Subject: [PATCH 03/14] updates to make tests work --- c/driver/snowflake/snowflake_test.cc | 5 +- c/validation/adbc_validation_connection.cc | 9 +- c/validation/adbc_validation_statement.cc | 4 +- .../driver/internal/driverbase/connection.go | 9 +- go/adbc/driver/snowflake/connection.go | 160 ++++++++++-------- .../snowflake/queries/get_objects_all.sql | 18 +- .../queries/get_objects_dbschemas.sql | 4 +- .../snowflake/queries/get_objects_tables.sql | 76 +-------- 8 files changed, 123 insertions(+), 162 deletions(-) diff --git a/c/driver/snowflake/snowflake_test.cc b/c/driver/snowflake/snowflake_test.cc index 60003353da..90735b3ffe 100644 --- a/c/driver/snowflake/snowflake_test.cc +++ b/c/driver/snowflake/snowflake_test.cc @@ -131,7 +131,9 @@ class SnowflakeQuirks : public adbc_validation::DriverQuirks { return NANOARROW_TYPE_DOUBLE; case NANOARROW_TYPE_STRING: case NANOARROW_TYPE_LARGE_STRING: - return NANOARROW_TYPE_STRING; + case NANOARROW_TYPE_LIST: + case NANOARROW_TYPE_LARGE_LIST: + return NANOARROW_TYPE_STRING; default: return ingest_type; } @@ -150,6 +152,7 @@ class SnowflakeQuirks : public adbc_validation::DriverQuirks { bool supports_error_on_incompatible_schema() const override { return false; } bool ddl_implicit_commit_txn() const override { return true; } std::string db_schema() const override { return schema_; } + std::string catalog() const override { return "ADBC_TESTING"; } const char* uri_; bool skip_{false}; diff --git a/c/validation/adbc_validation_connection.cc b/c/validation/adbc_validation_connection.cc index a885fa2c86..6ef4302137 100644 --- a/c/validation/adbc_validation_connection.cc +++ b/c/validation/adbc_validation_connection.cc @@ -701,9 +701,8 @@ void ConnectionTest::TestMetadataGetObjectsTablesTypes() { db_schemas_index < ArrowArrayViewListChildOffset(catalog_db_schemas_list, row + 1); db_schemas_index++) { - ASSERT_FALSE(ArrowArrayViewIsNull(db_schema_tables_list, db_schemas_index)) - << "Row " << row << " should have non-null db_schema_tables"; - + + // db_schema_tables should either be null or an empty list for (int64_t tables_index = ArrowArrayViewListChildOffset(db_schema_tables_list, db_schemas_index); tables_index < @@ -752,6 +751,8 @@ void ConnectionTest::TestMetadataGetObjectsColumns() { test_cases.push_back({std::nullopt, {"int64s", "strings"}, {1, 2}}); test_cases.push_back({"in%", {"int64s"}, {1}}); + const std::string catalog = quirks()->catalog(); + for (const auto& test_case : test_cases) { std::string scope = "Filter: "; scope += test_case.filter ? *test_case.filter : "(no filter)"; @@ -763,7 +764,7 @@ void ConnectionTest::TestMetadataGetObjectsColumns() { ASSERT_THAT( AdbcConnectionGetObjects( - &connection, ADBC_OBJECT_DEPTH_COLUMNS, nullptr, nullptr, nullptr, nullptr, + &connection, ADBC_OBJECT_DEPTH_COLUMNS, catalog.c_str(), nullptr, nullptr, nullptr, test_case.filter.has_value() ? test_case.filter->c_str() : nullptr, &reader.stream.value, &error), IsOkStatus(&error)); diff --git a/c/validation/adbc_validation_statement.cc b/c/validation/adbc_validation_statement.cc index 07ab0b22af..19166d8524 100644 --- a/c/validation/adbc_validation_statement.cc +++ b/c/validation/adbc_validation_statement.cc @@ -2218,7 +2218,7 @@ void StatementTest::TestSqlBind() { ASSERT_THAT( AdbcStatementSetSqlQuery( - &statement, "SELECT * FROM bindtest ORDER BY \"col1\" ASC NULLS FIRST", &error), + &statement, "SELECT * FROM bindtest ORDER BY col1 ASC NULLS FIRST", &error), IsOkStatus(&error)); { StreamReader reader; @@ -2226,7 +2226,7 @@ void StatementTest::TestSqlBind() { &reader.rows_affected, &error), IsOkStatus(&error)); ASSERT_THAT(reader.rows_affected, - ::testing::AnyOf(::testing::Eq(0), ::testing::Eq(-1))); + ::testing::AnyOf(::testing::Eq(3), ::testing::Eq(-1))); ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); ASSERT_NO_FATAL_FAILURE(reader.Next()); diff --git a/go/adbc/driver/internal/driverbase/connection.go b/go/adbc/driver/internal/driverbase/connection.go index 37433e0ced..b09f74e301 100644 --- a/go/adbc/driver/internal/driverbase/connection.go +++ b/go/adbc/driver/internal/driverbase/connection.go @@ -624,20 +624,20 @@ type ColumnInfo struct { type TableInfo struct { TableName string `json:"table_name"` TableType string `json:"table_type"` - TableColumns []ColumnInfo `json:"table_columns,omitempty"` - TableConstraints []ConstraintInfo `json:"table_constraints,omitempty"` + TableColumns []ColumnInfo `json:"table_columns"` + TableConstraints []ConstraintInfo `json:"table_constraints"` } // DBSchemaInfo is a structured representation of adbc.DBSchemaSchema type DBSchemaInfo struct { DbSchemaName *string `json:"db_schema_name,omitempty"` - DbSchemaTables []TableInfo `json:"db_schema_tables,omitempty"` + DbSchemaTables []TableInfo `json:"db_schema_tables"` } // GetObjectsInfo is a structured representation of adbc.GetObjectsSchema type GetObjectsInfo struct { CatalogName *string `json:"catalog_name,omitempty"` - CatalogDbSchemas []DBSchemaInfo `json:"catalog_db_schemas,omitempty"` + CatalogDbSchemas []DBSchemaInfo `json:"catalog_db_schemas"` } // Scan implements sql.Scanner. @@ -688,6 +688,7 @@ CATALOGLOOP: rec := bldr.NewRecord() defer rec.Release() + return array.NewRecordReader(adbc.GetObjectsSchema, []arrow.Record{rec}) } diff --git a/go/adbc/driver/snowflake/connection.go b/go/adbc/driver/snowflake/connection.go index 77d3c2cd5f..90ae67f4a8 100644 --- a/go/adbc/driver/snowflake/connection.go +++ b/go/adbc/driver/snowflake/connection.go @@ -26,6 +26,7 @@ import ( "io" "io/fs" "path" + "runtime" "strconv" "strings" "time" @@ -110,6 +111,41 @@ func getQueryID(ctx context.Context, query string, driverConn any) (string, erro return rows.(gosnowflake.SnowflakeRows).GetQueryID(), rows.Close() } +const ( + objSchemas = "SCHEMAS" + objDatabases = "DATABASES" +) + +func goGetQueryID(ctx context.Context, conn *sql.Conn, grp *errgroup.Group, objType string, catalog, dbSchema *string, outQueryID *string) { + grp.Go(func() error { + return conn.Raw(func(driverConn any) (err error) { + query := "SHOW TERSE /* ADBC:getObjects */ " + objType + switch objType { + case objDatabases: + if catalog != nil && len(*catalog) > 0 && *catalog != "%" && *catalog != ".*" { + query += " LIKE '" + escapeSingleQuoteForLike(*catalog) + "'" + } + query += " IN ACCOUNT" + case objSchemas: + if dbSchema != nil && len(*dbSchema) > 0 && *dbSchema != "%" && *dbSchema != ".*" { + query += " LIKE '" + escapeSingleQuoteForLike(*dbSchema) + "'" + } + + if catalog == nil || isWildcardStr(*catalog) { + query += " IN ACCOUNT" + } else { + query += " IN DATABASE " + quoteTblName(*catalog) + } + default: + return fmt.Errorf("unimplemented object type") + } + + *outQueryID, err = getQueryID(ctx, query, driverConn) + return + }) + }) +} + func isWildcardStr(ident string) bool { return strings.ContainsAny(ident, "_%") } @@ -126,73 +162,84 @@ func (c *connectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth, } defer conn.Close() + var hasViews, hasTables bool + for _, t := range tableType { + if strings.EqualFold("VIEW", t) { + hasViews = true + } else if strings.EqualFold("TABLE", t) { + hasTables = true + } + } + + if len(tableType) > 0 && depth >= adbc.ObjectDepthTables && !hasViews && !hasTables { + depth = adbc.ObjectDepthDBSchemas + } gQueryIDs, gQueryIDsCtx := errgroup.WithContext(ctx) queryFile := queryTemplateGetObjectsAll switch depth { case adbc.ObjectDepthCatalogs: queryFile = queryTemplateGetObjectsTerseCatalogs - gQueryIDs.Go(func() error { - return conn.Raw(func(driverConn any) (err error) { - query := "SHOW TERSE /* ADBC:getObjectsCatalogs */ DATABASES" - if catalog != nil && len(*catalog) > 0 && *catalog != "%" && *catalog != ".*" { - query += " LIKE '" + escapeSingleQuoteForLike(*catalog) + "'" - } - query += " IN ACCOUNT" - - terseDbQueryID, err = getQueryID(gQueryIDsCtx, query, driverConn) - return - }) - }) + goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objDatabases, + catalog, dbSchema, &terseDbQueryID) case adbc.ObjectDepthDBSchemas: queryFile = queryTemplateGetObjectsDbSchemas + goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objSchemas, + catalog, dbSchema, &showSchemaQueryID) + goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objDatabases, + catalog, dbSchema, &terseDbQueryID) + case adbc.ObjectDepthTables: + queryFile = queryTemplateGetObjectsTables + goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objSchemas, + catalog, dbSchema, &showSchemaQueryID) + goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objDatabases, + catalog, dbSchema, &terseDbQueryID) gQueryIDs.Go(func() error { return conn.Raw(func(driverConn any) (err error) { - query := "SHOW TERSE /* ADBC:getObjectsDBSchemas */ SCHEMAS" - if dbSchema != nil && len(*dbSchema) > 0 && *dbSchema != "%" && *dbSchema != ".*" { - query += " LIKE '" + escapeSingleQuoteForLike(*dbSchema) + "'" + objType := "objects" + if len(tableType) == 1 { + if strings.EqualFold("VIEW", tableType[0]) { + objType = "views" + } else if strings.EqualFold("TABLE", tableType[0]) { + objType = "tables" + } + } + + query := "SHOW TERSE /* ADBC:getObjectsTables */ " + objType + if tableName != nil && len(*tableName) > 0 && *tableName != "%" && *tableName != ".*" { + query += " LIKE '" + escapeSingleQuoteForLike(*tableName) + "'" } if catalog == nil || isWildcardStr(*catalog) { query += " IN ACCOUNT" } else { - query += " IN DATABASE \"" + quoteTblName(*catalog) + "\"" - } - - showSchemaQueryID, err = getQueryID(gQueryIDsCtx, query, driverConn) - return - }) - }) - - gQueryIDs.Go(func() error { - return conn.Raw(func(driverConn any) (err error) { - query := "SHOW TERSE /* ADBC:getObjectsDBSchemas */ DATABASES" - if catalog != nil && len(*catalog) > 0 && *catalog != "%" && *catalog != ".*" { - query += " LIKE '" + escapeSingleQuoteForLike(*catalog) + "'" + escapedCatalog := quoteTblName(*catalog) + if dbSchema == nil || isWildcardStr(*dbSchema) { + query += " IN DATABASE " + escapedCatalog + } else { + query += " IN SCHEMA " + escapedCatalog + "." + quoteTblName(*dbSchema) + } } - query += " IN ACCOUNT" - terseDbQueryID, err = getQueryID(gQueryIDsCtx, query, driverConn) + tableQueryID, err = getQueryID(gQueryIDsCtx, query, driverConn) return }) }) - case adbc.ObjectDepthTables: - queryFile = queryTemplateGetObjectsTables - fallthrough + // fallthrough default: var suffix string - if catalog == nil { + if catalog == nil || isWildcardStr(*catalog) { suffix = " IN ACCOUNT" } else { escapedCatalog := quoteTblName(*catalog) if dbSchema == nil || isWildcardStr(*dbSchema) { - suffix = " IN DATABASE \"" + escapedCatalog + "\"" + suffix = " IN DATABASE " + escapedCatalog } else { escapedSchema := quoteTblName(*dbSchema) if tableName == nil || isWildcardStr(*tableName) { - suffix = " IN SCHEMA \"" + escapedCatalog + "\".\"" + escapedSchema + "\"" + suffix = " IN SCHEMA " + escapedCatalog + "." + escapedSchema } else { escapedTable := quoteTblName(*tableName) - suffix = " IN TABLE \"" + escapedCatalog + "\".\"" + escapedSchema + "\".\"" + escapedTable + "\"" + suffix = " IN TABLE " + escapedCatalog + "." + escapedSchema + "." + escapedTable } } } @@ -220,35 +267,10 @@ func (c *connectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth, }) }) - gQueryIDs.Go(func() error { - return conn.Raw(func(driverConn any) (err error) { - query := "SHOW TERSE /* ADBC:getObjectsDBSchemas */ SCHEMAS" - if dbSchema != nil && len(*dbSchema) > 0 && *dbSchema != "%" && *dbSchema != ".*" { - query += " LIKE '" + escapeSingleQuoteForLike(*dbSchema) + "'" - } - if catalog == nil || isWildcardStr(*catalog) { - query += " IN ACCOUNT" - } else { - query += " IN DATABASE \"" + quoteTblName(*catalog) + "\"" - } - - showSchemaQueryID, err = getQueryID(gQueryIDsCtx, query, driverConn) - return - }) - }) - - gQueryIDs.Go(func() error { - return conn.Raw(func(driverConn any) (err error) { - query := "SHOW TERSE /* ADBC:getObjectsDBSchemas */ DATABASES" - if catalog != nil && len(*catalog) > 0 && *catalog != "%" && *catalog != ".*" { - query += " LIKE '" + escapeSingleQuoteForLike(*catalog) + "'" - } - query += " IN ACCOUNT" - - terseDbQueryID, err = getQueryID(gQueryIDsCtx, query, driverConn) - return - }) - }) + goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objDatabases, + catalog, dbSchema, &terseDbQueryID) + goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objSchemas, + catalog, dbSchema, &showSchemaQueryID) gQueryIDs.Go(func() error { return conn.Raw(func(driverConn any) (err error) { @@ -270,9 +292,9 @@ func (c *connectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth, } else { escapedCatalog := quoteTblName(*catalog) if dbSchema == nil || isWildcardStr(*dbSchema) { - query += " IN DATABASE \"" + escapedCatalog + "\"" + query += " IN DATABASE " + escapedCatalog } else { - query += " IN SCHEMA \"" + escapedCatalog + "\".\"" + quoteTblName(*dbSchema) + "\"" + query += " IN SCHEMA " + escapedCatalog + "." + quoteTblName(*dbSchema) } } @@ -340,7 +362,7 @@ func (c *connectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth, } defer rows.Close() - catalogCh := make(chan driverbase.GetObjectsInfo, 5) + catalogCh := make(chan driverbase.GetObjectsInfo, runtime.NumCPU()) errCh := make(chan error) go func() { diff --git a/go/adbc/driver/snowflake/queries/get_objects_all.sql b/go/adbc/driver/snowflake/queries/get_objects_all.sql index 45b807f15e..7fc10f2e24 100644 --- a/go/adbc/driver/snowflake/queries/get_objects_all.sql +++ b/go/adbc/driver/snowflake/queries/get_objects_all.sql @@ -86,12 +86,12 @@ constraints AS ( table_catalog, table_schema, table_name, - ARRAY_AGG({ + ARRAY_AGG(NULLIF({ 'constraint_name': constraint_name, 'constraint_type': constraint_type, 'constraint_column_names': constraint_column_names, 'constraint_column_usage': constraint_column_usage - }) table_constraints, + }, {})) table_constraints, FROM ( SELECT * FROM pk_constraints UNION ALL @@ -105,12 +105,12 @@ tables AS ( SELECT table_catalog catalog_name, table_schema schema_name, - ARRAY_AGG({ + ARRAY_AGG(NULLIF({ 'table_name': table_name, 'table_type': table_type, - 'table_columns': table_columns, - 'table_constraints': table_constraints - }) db_schema_tables + 'table_columns': COALESCE(table_columns, []), + 'table_constraints': COALESCE(table_constraints, []) + }, {})) db_schema_tables FROM information_schema.tables LEFT JOIN columns USING (table_catalog, table_schema, table_name) @@ -123,7 +123,7 @@ db_schemas AS ( SELECT catalog_name, schema_name, - db_schema_tables, + COALESCE(db_schema_tables, []) db_schema_tables, FROM information_schema.schemata LEFT JOIN tables USING (catalog_name, schema_name) @@ -132,10 +132,10 @@ db_schemas AS ( SELECT { 'catalog_name': database_name, - 'catalog_db_schemas': ARRAY_AGG({ + 'catalog_db_schemas': ARRAY_AGG(NULLIF({ 'db_schema_name': schema_name, 'db_schema_tables': db_schema_tables - }) + }, {})) } get_objects FROM information_schema.databases diff --git a/go/adbc/driver/snowflake/queries/get_objects_dbschemas.sql b/go/adbc/driver/snowflake/queries/get_objects_dbschemas.sql index bc454866af..872118f7c7 100644 --- a/go/adbc/driver/snowflake/queries/get_objects_dbschemas.sql +++ b/go/adbc/driver/snowflake/queries/get_objects_dbschemas.sql @@ -29,10 +29,10 @@ WITH db_schemas AS ( SELECT { 'catalog_name': "database_name", - 'catalog_db_schemas': ARRAY_AGG({ + 'catalog_db_schemas': ARRAY_AGG(NULLIF({ 'db_schema_name': "schema_name", 'db_schema_tables': null - }) + }, {})) } get_objects FROM db_info diff --git a/go/adbc/driver/snowflake/queries/get_objects_tables.sql b/go/adbc/driver/snowflake/queries/get_objects_tables.sql index ec284947c4..9d6ce36ed8 100644 --- a/go/adbc/driver/snowflake/queries/get_objects_tables.sql +++ b/go/adbc/driver/snowflake/queries/get_objects_tables.sql @@ -15,83 +15,17 @@ -- specific language governing permissions and limitations -- under the License. -WITH pk_constraints AS ( - SELECT - "database_name" "table_catalog", - "schema_name" "table_schema", - "table_name" "table_name", - "constraint_name" "constraint_name", - 'PRIMARY KEY' "constraint_type", - ARRAY_AGG("column_name") WITHIN GROUP (ORDER BY "key_sequence") constraint_column_names, - [] constraint_column_usage, - FROM TABLE(RESULT_SCAN(LAST_QUERY_ID(-4))) - WHERE "table_catalog" ILIKE :CATALOG AND "table_schema" ILIKE :DB_SCHEMA AND "table_name" ILIKE :TABLE - GROUP BY "table_catalog", "table_schema", "table_name", "constraint_name" -), -unique_constraints AS ( - SELECT - "database_name" "table_catalog", - "schema_name" "table_schema", - "table_name" "table_name", - "constraint_name" "constraint_name", - 'UNIQUE' "constraint_type", - ARRAY_AGG("column_name") WITHIN GROUP (ORDER BY "key_sequence") constraint_column_names, - [] constraint_column_usage, - FROM TABLE(RESULT_SCAN(:UNIQUE_QUERY_ID)) - WHERE "table_catalog" ILIKE :CATALOG AND "table_schema" ILIKE :DB_SCHEMA AND "table_name" ILIKE :TABLE - GROUP BY "table_catalog", "table_schema", "table_name", "constraint_name" -), -fk_constraints AS ( - SELECT - "fk_database_name" "table_catalog", - "fk_schema_name" "table_schema", - "fk_table_name" "table_name", - "fk_name" "constraint_name", - 'FOREIGN KEY' "constraint_type", - ARRAY_AGG("fk_column_name") WITHIN GROUP (ORDER BY "key_sequence") constraint_column_names, - ARRAY_AGG({ - 'fk_catalog': "pk_database_name", - 'fk_db_schema': "pk_schema_name", - 'fk_table': "pk_table_name", - 'fk_column_name': "pk_column_name" - }) WITHIN GROUP (ORDER BY "key_sequence") constraint_column_usage, - FROM TABLE(RESULT_SCAN(:FK_QUERY_ID)) - WHERE "table_catalog" ILIKE :CATALOG AND "table_schema" ILIKE :DB_SCHEMA AND "table_name" ILIKE :TABLE - GROUP BY "table_catalog", "table_schema", "table_name", "constraint_name" -), -constraints AS ( - SELECT - "table_catalog", - "table_schema", - "table_name", - ARRAY_AGG({ - 'constraint_name': "constraint_name", - 'constraint_type': "constraint_type", - 'constraint_column_names': constraint_column_names, - 'constraint_column_usage': constraint_column_usage - }) table_constraints, - FROM ( - SELECT * FROM pk_constraints - UNION ALL - SELECT * FROM unique_constraints - UNION ALL - SELECT * FROM fk_constraints - ) - GROUP BY "table_catalog", "table_schema", "table_name" -), -tables AS ( +WITH tables AS ( SELECT "database_name" "catalog_name", "schema_name" "schema_name", ARRAY_AGG({ 'table_name': "name", 'table_type': "kind", - 'table_constraints': table_constraints, + 'table_constraints': null, 'table_columns': null }) db_schema_tables FROM TABLE(RESULT_SCAN(:SHOW_TABLE_QUERY_ID)) -LEFT JOIN constraints -ON "database_name" = "table_catalog" AND "schema_name" = "table_schema" AND "name" = "table_name" WHERE "database_name" ILIKE :CATALOG AND "schema_name" ILIKE :DB_SCHEMA AND "name" ILIKE :TABLE GROUP BY "database_name", "schema_name" ), @@ -99,7 +33,7 @@ db_schemas AS ( SELECT "database_name" "catalog_name", "name" "schema_name", - db_schema_tables, + COALESCE(db_schema_tables, []) db_schema_tables, FROM TABLE(RESULT_SCAN(:SHOW_SCHEMA_QUERY_ID)) LEFT JOIN tables ON "database_name" = "catalog_name" AND "name" = tables."schema_name" @@ -108,10 +42,10 @@ db_schemas AS ( SELECT { 'catalog_name': "name", - 'catalog_db_schemas': ARRAY_AGG({ + 'catalog_db_schemas': ARRAY_AGG(NULLIF({ 'db_schema_name': db_schemas."schema_name", 'db_schema_tables': db_schema_tables - }) + }, {})) } get_objects FROM TABLE(RESULT_SCAN(:SHOW_DB_QUERY_ID)) From 71dd6bc151cb5ff659f041262d707513bff8230b Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Wed, 16 Oct 2024 12:25:10 -0400 Subject: [PATCH 04/14] fix lint and flakey test --- c/driver/snowflake/snowflake_test.cc | 2 +- c/validation/adbc_validation_connection.cc | 27 +++++++++++----------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/c/driver/snowflake/snowflake_test.cc b/c/driver/snowflake/snowflake_test.cc index 90735b3ffe..67d3cbb3fd 100644 --- a/c/driver/snowflake/snowflake_test.cc +++ b/c/driver/snowflake/snowflake_test.cc @@ -133,7 +133,7 @@ class SnowflakeQuirks : public adbc_validation::DriverQuirks { case NANOARROW_TYPE_LARGE_STRING: case NANOARROW_TYPE_LIST: case NANOARROW_TYPE_LARGE_LIST: - return NANOARROW_TYPE_STRING; + return NANOARROW_TYPE_STRING; default: return ingest_type; } diff --git a/c/validation/adbc_validation_connection.cc b/c/validation/adbc_validation_connection.cc index 6ef4302137..9cef88c6d5 100644 --- a/c/validation/adbc_validation_connection.cc +++ b/c/validation/adbc_validation_connection.cc @@ -701,7 +701,6 @@ void ConnectionTest::TestMetadataGetObjectsTablesTypes() { db_schemas_index < ArrowArrayViewListChildOffset(catalog_db_schemas_list, row + 1); db_schemas_index++) { - // db_schema_tables should either be null or an empty list for (int64_t tables_index = ArrowArrayViewListChildOffset(db_schema_tables_list, db_schemas_index); @@ -743,13 +742,12 @@ void ConnectionTest::TestMetadataGetObjectsColumns() { struct TestCase { std::optional filter; - std::vector column_names; - std::vector ordinal_positions; + std::vector> columns; }; std::vector test_cases; - test_cases.push_back({std::nullopt, {"int64s", "strings"}, {1, 2}}); - test_cases.push_back({"in%", {"int64s"}, {1}}); + test_cases.push_back({std::nullopt, {{"int64s", 1}, {"strings", 2}}}); + test_cases.push_back({"in%", {{"int64s", 1}}}); const std::string catalog = quirks()->catalog(); @@ -759,13 +757,14 @@ void ConnectionTest::TestMetadataGetObjectsColumns() { SCOPED_TRACE(scope); StreamReader reader; + std::vector> columns; std::vector column_names; std::vector ordinal_positions; ASSERT_THAT( AdbcConnectionGetObjects( - &connection, ADBC_OBJECT_DEPTH_COLUMNS, catalog.c_str(), nullptr, nullptr, nullptr, - test_case.filter.has_value() ? test_case.filter->c_str() : nullptr, + &connection, ADBC_OBJECT_DEPTH_COLUMNS, catalog.c_str(), nullptr, nullptr, + nullptr, test_case.filter.has_value() ? test_case.filter->c_str() : nullptr, &reader.stream.value, &error), IsOkStatus(&error)); ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); @@ -835,10 +834,9 @@ void ConnectionTest::TestMetadataGetObjectsColumns() { std::string temp(name.data, name.size_bytes); std::transform(temp.begin(), temp.end(), temp.begin(), [](unsigned char c) { return std::tolower(c); }); - column_names.push_back(std::move(temp)); - ordinal_positions.push_back( - static_cast(ArrowArrayViewGetIntUnsafe( - table_columns->children[1], columns_index))); + columns.emplace_back(std::move(temp), + static_cast(ArrowArrayViewGetIntUnsafe( + table_columns->children[1], columns_index))); } } } @@ -848,8 +846,11 @@ void ConnectionTest::TestMetadataGetObjectsColumns() { } while (reader.array->release); ASSERT_TRUE(found_expected_table) << "Did (not) find table in metadata"; - ASSERT_EQ(test_case.column_names, column_names); - ASSERT_EQ(test_case.ordinal_positions, ordinal_positions); + // metadata columns do not guarantee the order they are returned in, we can + // avoid the test being flakey by sorting the column names we found + std::sort(columns.begin(), columns.end(), + [](const auto& a, const auto& b) -> bool { return a.first < b.first; }); + ASSERT_EQ(test_case.columns, columns); } } From 8d581d542221fc6f5078d1829a1a8613fe5d7dc8 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Wed, 16 Oct 2024 12:53:47 -0400 Subject: [PATCH 05/14] add catalog for test --- c/driver/flightsql/sqlite_flightsql_test.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/c/driver/flightsql/sqlite_flightsql_test.cc b/c/driver/flightsql/sqlite_flightsql_test.cc index 454ea02977..40601e2803 100644 --- a/c/driver/flightsql/sqlite_flightsql_test.cc +++ b/c/driver/flightsql/sqlite_flightsql_test.cc @@ -121,6 +121,7 @@ class SqliteFlightSqlQuirks : public adbc_validation::DriverQuirks { bool supports_get_objects() const override { return true; } bool supports_partitioned_data() const override { return true; } bool supports_dynamic_parameter_binding() const override { return true; } + std::string catalog() const { return "main"; } }; class SqliteFlightSqlTest : public ::testing::Test, public adbc_validation::DatabaseTest { From 1ad0c098bfa8017e35a4dc62b1e9ef941857f109 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Wed, 16 Oct 2024 13:04:20 -0400 Subject: [PATCH 06/14] reduce duplication --- go/adbc/driver/snowflake/connection.go | 132 +++++++++++-------------- 1 file changed, 55 insertions(+), 77 deletions(-) diff --git a/go/adbc/driver/snowflake/connection.go b/go/adbc/driver/snowflake/connection.go index 90ae67f4a8..c77b454367 100644 --- a/go/adbc/driver/snowflake/connection.go +++ b/go/adbc/driver/snowflake/connection.go @@ -114,28 +114,47 @@ func getQueryID(ctx context.Context, query string, driverConn any) (string, erro const ( objSchemas = "SCHEMAS" objDatabases = "DATABASES" + objViews = "VIEWS" + objTables = "TABLES" + objObjects = "OBJECTS" ) -func goGetQueryID(ctx context.Context, conn *sql.Conn, grp *errgroup.Group, objType string, catalog, dbSchema *string, outQueryID *string) { +func addLike(query string, pattern *string) string { + if pattern != nil && len(*pattern) > 0 && *pattern != "%" && *pattern != ".*" { + query += " LIKE '" + escapeSingleQuoteForLike(*pattern) + "'" + } + return query +} + +func goGetQueryID(ctx context.Context, conn *sql.Conn, grp *errgroup.Group, objType string, catalog, dbSchema, tableName *string, outQueryID *string) { grp.Go(func() error { return conn.Raw(func(driverConn any) (err error) { query := "SHOW TERSE /* ADBC:getObjects */ " + objType switch objType { case objDatabases: - if catalog != nil && len(*catalog) > 0 && *catalog != "%" && *catalog != ".*" { - query += " LIKE '" + escapeSingleQuoteForLike(*catalog) + "'" - } + query = addLike(query, catalog) query += " IN ACCOUNT" case objSchemas: - if dbSchema != nil && len(*dbSchema) > 0 && *dbSchema != "%" && *dbSchema != ".*" { - query += " LIKE '" + escapeSingleQuoteForLike(*dbSchema) + "'" - } + query = addLike(query, dbSchema) if catalog == nil || isWildcardStr(*catalog) { query += " IN ACCOUNT" } else { query += " IN DATABASE " + quoteTblName(*catalog) } + case objViews, objTables, objObjects: + query = addLike(query, tableName) + + if catalog == nil || isWildcardStr(*catalog) { + query += " IN ACCOUNT" + } else { + escapedCatalog := quoteTblName(*catalog) + if dbSchema == nil || isWildcardStr(*dbSchema) { + query += " IN DATABASE " + escapedCatalog + } else { + query += " IN SCHEMA " + escapedCatalog + "." + quoteTblName(*dbSchema) + } + } default: return fmt.Errorf("unimplemented object type") } @@ -150,7 +169,7 @@ func isWildcardStr(ident string) bool { return strings.ContainsAny(ident, "_%") } -func (c *connectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string) (array.RecordReader, error) { +func (c *connectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog, dbSchema, tableName, columnName *string, tableType []string) (array.RecordReader, error) { var ( pkQueryID, fkQueryID, uniqueQueryID, terseDbQueryID string showSchemaQueryID, tableQueryID string @@ -181,50 +200,29 @@ func (c *connectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth, case adbc.ObjectDepthCatalogs: queryFile = queryTemplateGetObjectsTerseCatalogs goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objDatabases, - catalog, dbSchema, &terseDbQueryID) + catalog, dbSchema, tableName, &terseDbQueryID) case adbc.ObjectDepthDBSchemas: queryFile = queryTemplateGetObjectsDbSchemas goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objSchemas, - catalog, dbSchema, &showSchemaQueryID) + catalog, dbSchema, tableName, &showSchemaQueryID) goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objDatabases, - catalog, dbSchema, &terseDbQueryID) + catalog, dbSchema, tableName, &terseDbQueryID) case adbc.ObjectDepthTables: queryFile = queryTemplateGetObjectsTables goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objSchemas, - catalog, dbSchema, &showSchemaQueryID) + catalog, dbSchema, tableName, &showSchemaQueryID) goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objDatabases, - catalog, dbSchema, &terseDbQueryID) - gQueryIDs.Go(func() error { - return conn.Raw(func(driverConn any) (err error) { - objType := "objects" - if len(tableType) == 1 { - if strings.EqualFold("VIEW", tableType[0]) { - objType = "views" - } else if strings.EqualFold("TABLE", tableType[0]) { - objType = "tables" - } - } - - query := "SHOW TERSE /* ADBC:getObjectsTables */ " + objType - if tableName != nil && len(*tableName) > 0 && *tableName != "%" && *tableName != ".*" { - query += " LIKE '" + escapeSingleQuoteForLike(*tableName) + "'" - } - if catalog == nil || isWildcardStr(*catalog) { - query += " IN ACCOUNT" - } else { - escapedCatalog := quoteTblName(*catalog) - if dbSchema == nil || isWildcardStr(*dbSchema) { - query += " IN DATABASE " + escapedCatalog - } else { - query += " IN SCHEMA " + escapedCatalog + "." + quoteTblName(*dbSchema) - } - } - - tableQueryID, err = getQueryID(gQueryIDsCtx, query, driverConn) - return - }) - }) - // fallthrough + catalog, dbSchema, tableName, &terseDbQueryID) + objType := "objects" + if len(tableType) == 1 { + if strings.EqualFold("VIEW", tableType[0]) { + objType = "views" + } else if strings.EqualFold("TABLE", tableType[0]) { + objType = "tables" + } + } + goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objType, + catalog, dbSchema, tableName, &tableQueryID) default: var suffix string if catalog == nil || isWildcardStr(*catalog) { @@ -268,40 +266,20 @@ func (c *connectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth, }) goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objDatabases, - catalog, dbSchema, &terseDbQueryID) + catalog, dbSchema, tableName, &terseDbQueryID) goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objSchemas, - catalog, dbSchema, &showSchemaQueryID) - - gQueryIDs.Go(func() error { - return conn.Raw(func(driverConn any) (err error) { - objType := "objects" - if len(tableType) == 1 { - if strings.EqualFold("VIEW", tableType[0]) { - objType = "views" - } else if strings.EqualFold("TABLE", tableType[0]) { - objType = "tables" - } - } - - query := "SHOW TERSE /* ADBC:getObjectsTables */ " + objType - if tableName != nil && len(*tableName) > 0 && *tableName != "%" && *tableName != ".*" { - query += " LIKE '" + escapeSingleQuoteForLike(*tableName) + "'" - } - if catalog == nil || isWildcardStr(*catalog) { - query += " IN ACCOUNT" - } else { - escapedCatalog := quoteTblName(*catalog) - if dbSchema == nil || isWildcardStr(*dbSchema) { - query += " IN DATABASE " + escapedCatalog - } else { - query += " IN SCHEMA " + escapedCatalog + "." + quoteTblName(*dbSchema) - } - } - - tableQueryID, err = getQueryID(gQueryIDsCtx, query, driverConn) - return - }) - }) + catalog, dbSchema, tableName, &showSchemaQueryID) + + objType := "objects" + if len(tableType) == 1 { + if strings.EqualFold("VIEW", tableType[0]) { + objType = "views" + } else if strings.EqualFold("TABLE", tableType[0]) { + objType = "tables" + } + } + goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objType, + catalog, dbSchema, tableName, &tableQueryID) } queryBytes, err := fs.ReadFile(queryTemplates, path.Join("queries", queryFile)) From 78829a88dd8b55e4d3b1cb0e58d11e5b1c7a76d9 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Wed, 16 Oct 2024 13:35:48 -0400 Subject: [PATCH 07/14] fix get object tables --- go/adbc/driver/snowflake/connection.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/go/adbc/driver/snowflake/connection.go b/go/adbc/driver/snowflake/connection.go index c77b454367..5abf90476c 100644 --- a/go/adbc/driver/snowflake/connection.go +++ b/go/adbc/driver/snowflake/connection.go @@ -213,12 +213,12 @@ func (c *connectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog, dbSchema, tableName, &showSchemaQueryID) goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objDatabases, catalog, dbSchema, tableName, &terseDbQueryID) - objType := "objects" + objType := objObjects if len(tableType) == 1 { if strings.EqualFold("VIEW", tableType[0]) { - objType = "views" + objType = objViews } else if strings.EqualFold("TABLE", tableType[0]) { - objType = "tables" + objType = objTables } } goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objType, @@ -270,12 +270,12 @@ func (c *connectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth, goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objSchemas, catalog, dbSchema, tableName, &showSchemaQueryID) - objType := "objects" + objType := objObjects if len(tableType) == 1 { if strings.EqualFold("VIEW", tableType[0]) { - objType = "views" + objType = objViews } else if strings.EqualFold("TABLE", tableType[0]) { - objType = "tables" + objType = objTables } } goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objType, From d276045a99aa8d2f0df5e5837eb8419f8e8395f1 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Wed, 16 Oct 2024 13:49:19 -0400 Subject: [PATCH 08/14] use create or replace --- c/driver/snowflake/snowflake_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/c/driver/snowflake/snowflake_test.cc b/c/driver/snowflake/snowflake_test.cc index 67d3cbb3fd..21322ad7ea 100644 --- a/c/driver/snowflake/snowflake_test.cc +++ b/c/driver/snowflake/snowflake_test.cc @@ -99,7 +99,7 @@ class SnowflakeQuirks : public adbc_validation::DriverQuirks { adbc_validation::Handle statement; CHECK_OK(AdbcStatementNew(connection, &statement.value, error)); - std::string create = "CREATE TABLE \""; + std::string create = "CREATE OR REPLACE TABLE \""; create += name; create += "\" (int64s INT, strings TEXT)"; CHECK_OK(AdbcStatementSetSqlQuery(&statement.value, create.c_str(), error)); From a17cd90663df06e368749e9ff173cef0e4961374 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Wed, 16 Oct 2024 14:11:43 -0400 Subject: [PATCH 09/14] fix go unit tests --- go/adbc/driver/snowflake/connection.go | 28 ++++++++++++------------- go/adbc/driver/snowflake/driver_test.go | 11 ++++++---- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/go/adbc/driver/snowflake/connection.go b/go/adbc/driver/snowflake/connection.go index 5abf90476c..973fdaa291 100644 --- a/go/adbc/driver/snowflake/connection.go +++ b/go/adbc/driver/snowflake/connection.go @@ -314,22 +314,22 @@ func (c *connectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth, // just as done by the snowflake JDBC driver. In those cases we don't need to propagate // the current session database/schema. if depth == adbc.ObjectDepthColumns || depth == adbc.ObjectDepthAll { - // the connection that is used is not the same connection context where the database may have been set - // if the caller called SetCurrentCatalog() so need to ensure the database context is appropriate - if !isNilOrEmpty(catalog) { - _, e := conn.ExecContext(context.Background(), fmt.Sprintf("USE DATABASE %s;", quoteTblName(*catalog)), nil) - if e != nil { - return nil, errToAdbcErr(adbc.StatusIO, e) - } + dbname, err := c.GetCurrentCatalog() + if err != nil { + return nil, errToAdbcErr(adbc.StatusIO, err) } - // the connection that is used is not the same connection context where the schema may have been set - // if the caller called SetCurrentDbSchema() so need to ensure the schema context is appropriate - if !isNilOrEmpty(dbSchema) { - _, e2 := conn.ExecContext(context.Background(), fmt.Sprintf("USE SCHEMA %s;", quoteTblName(*dbSchema)), nil) - if e2 != nil { - return nil, errToAdbcErr(adbc.StatusIO, e2) - } + schemaname, err := c.GetCurrentDbSchema() + if err != nil { + return nil, errToAdbcErr(adbc.StatusIO, err) + } + + // the connection that is used is not the same connection context where the database may have been set + // if the caller called SetCurrentCatalog() so need to ensure the database context is appropriate + multiCtx, _ := gosnowflake.WithMultiStatement(ctx, 2) + _, err = conn.ExecContext(multiCtx, fmt.Sprintf("USE DATABASE %s; USE SCHEMA %s;", quoteTblName(dbname), quoteTblName(schemaname))) + if err != nil { + return nil, errToAdbcErr(adbc.StatusIO, err) } } diff --git a/go/adbc/driver/snowflake/driver_test.go b/go/adbc/driver/snowflake/driver_test.go index 895015ffd7..c67389ca14 100644 --- a/go/adbc/driver/snowflake/driver_test.go +++ b/go/adbc/driver/snowflake/driver_test.go @@ -1215,15 +1215,15 @@ func (suite *SnowflakeTests) TestSqlIngestMapType() { [ { "col_int64": 1, - "col_map": "{\n \"key_value\": [\n {\n \"key\": \"key1\",\n \"value\": 1\n }\n ]\n}" + "col_map": "{\n \"key1\": 1\n}" }, { "col_int64": 2, - "col_map": "{\n \"key_value\": [\n {\n \"key\": \"key2\",\n \"value\": 2\n }\n ]\n}" + "col_map": "{\n \"key2\": 2\n}" }, { "col_int64": 3, - "col_map": "{\n \"key_value\": [\n {\n \"key\": \"key3\",\n \"value\": 3\n }\n ]\n}" + "col_map": "{\n \"key3\": 3\n}" } ] `))) @@ -2161,6 +2161,9 @@ func (suite *SnowflakeTests) TestGetSetClientConfigFile() { func (suite *SnowflakeTests) TestGetObjectsWithNilCatalog() { // this test demonstrates calling GetObjects with the catalog depth and a nil catalog - _, err := suite.cnxn.GetObjects(suite.ctx, adbc.ObjectDepthCatalogs, nil, nil, nil, nil, nil) + rdr, err := suite.cnxn.GetObjects(suite.ctx, adbc.ObjectDepthCatalogs, nil, nil, nil, nil, nil) suite.NoError(err) + // test suite validates memory allocator so we need to make sure we call + // release on the result reader + rdr.Release() } From a5a18a4dc50163dc6696a7ea648c3d5fd8ffad48 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Wed, 16 Oct 2024 14:16:04 -0400 Subject: [PATCH 10/14] remove unused func --- go/adbc/driver/snowflake/connection.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/go/adbc/driver/snowflake/connection.go b/go/adbc/driver/snowflake/connection.go index 973fdaa291..04ed56348a 100644 --- a/go/adbc/driver/snowflake/connection.go +++ b/go/adbc/driver/snowflake/connection.go @@ -372,10 +372,6 @@ func (c *connectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth, return driverbase.BuildGetObjectsRecordReader(c.Alloc, catalogCh, errCh) } -func isNilOrEmpty(str *string) bool { - return str == nil || *str == "" -} - // PrepareDriverInfo implements driverbase.DriverInfoPreparer. func (c *connectionImpl) PrepareDriverInfo(ctx context.Context, infoCodes []adbc.InfoCode) error { if err := c.ConnectionImplBase.DriverInfo.RegisterInfoCode(adbc.InfoVendorSql, true); err != nil { From 43ceb381ef5fc5e25c78f6f06b9a350da02c168f Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Thu, 17 Oct 2024 13:02:48 -0400 Subject: [PATCH 11/14] updates from feedback --- c/validation/adbc_validation_connection.cc | 13 +++++++------ go/adbc/driver/snowflake/connection.go | 12 ++++++++++-- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/c/validation/adbc_validation_connection.cc b/c/validation/adbc_validation_connection.cc index 9cef88c6d5..032f1d328f 100644 --- a/c/validation/adbc_validation_connection.cc +++ b/c/validation/adbc_validation_connection.cc @@ -701,7 +701,9 @@ void ConnectionTest::TestMetadataGetObjectsTablesTypes() { db_schemas_index < ArrowArrayViewListChildOffset(catalog_db_schemas_list, row + 1); db_schemas_index++) { - // db_schema_tables should either be null or an empty list + ASSERT_FALSE(ArrowArrayViewIsNull(db_schema_tables_list, db_schemas_index)) + << "Row " << row << " should have non-null db_schema_tables"; + for (int64_t tables_index = ArrowArrayViewListChildOffset(db_schema_tables_list, db_schemas_index); tables_index < @@ -742,6 +744,7 @@ void ConnectionTest::TestMetadataGetObjectsColumns() { struct TestCase { std::optional filter; + // the pair is column name & ordinal position of the column std::vector> columns; }; @@ -846,11 +849,9 @@ void ConnectionTest::TestMetadataGetObjectsColumns() { } while (reader.array->release); ASSERT_TRUE(found_expected_table) << "Did (not) find table in metadata"; - // metadata columns do not guarantee the order they are returned in, we can - // avoid the test being flakey by sorting the column names we found - std::sort(columns.begin(), columns.end(), - [](const auto& a, const auto& b) -> bool { return a.first < b.first; }); - ASSERT_EQ(test_case.columns, columns); + // metadata columns do not guarantee the order they are returned in, just + // validate all the elements are there. + ASSERT_THAT(columns, testing::UnorderedElementsAreArray(test_case.columns)); } } diff --git a/go/adbc/driver/snowflake/connection.go b/go/adbc/driver/snowflake/connection.go index 04ed56348a..190426c7f9 100644 --- a/go/adbc/driver/snowflake/connection.go +++ b/go/adbc/driver/snowflake/connection.go @@ -190,11 +190,17 @@ func (c *connectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth, } } + // force empty result from SHOW TABLES if tableType list is not empty + // and does not contain TABLE or VIEW in the list. + // we need this because we should have non-null db_schema_tables when + // depth is Tables, Columns or All. + var badTableType = "tabletypedoesnotexist" if len(tableType) > 0 && depth >= adbc.ObjectDepthTables && !hasViews && !hasTables { - depth = adbc.ObjectDepthDBSchemas + tableName = &badTableType + tableType = []string{"TABLE"} } - gQueryIDs, gQueryIDsCtx := errgroup.WithContext(ctx) + gQueryIDs, gQueryIDsCtx := errgroup.WithContext(ctx) queryFile := queryTemplateGetObjectsAll switch depth { case adbc.ObjectDepthCatalogs: @@ -213,6 +219,7 @@ func (c *connectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog, dbSchema, tableName, &showSchemaQueryID) goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objDatabases, catalog, dbSchema, tableName, &terseDbQueryID) + objType := objObjects if len(tableType) == 1 { if strings.EqualFold("VIEW", tableType[0]) { @@ -221,6 +228,7 @@ func (c *connectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth, objType = objTables } } + goGetQueryID(gQueryIDsCtx, conn, gQueryIDs, objType, catalog, dbSchema, tableName, &tableQueryID) default: From 73a7fbf12b8c2464cf17dec683ba47ade33e72ed Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Thu, 17 Oct 2024 13:32:04 -0400 Subject: [PATCH 12/14] fix handling for binary --- c/driver/snowflake/snowflake_test.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/c/driver/snowflake/snowflake_test.cc b/c/driver/snowflake/snowflake_test.cc index 21322ad7ea..63767e1426 100644 --- a/c/driver/snowflake/snowflake_test.cc +++ b/c/driver/snowflake/snowflake_test.cc @@ -134,6 +134,10 @@ class SnowflakeQuirks : public adbc_validation::DriverQuirks { case NANOARROW_TYPE_LIST: case NANOARROW_TYPE_LARGE_LIST: return NANOARROW_TYPE_STRING; + case NANOARROW_TYPE_BINARY: + case NANOARROW_TYPE_LARGE_BINARY: + case NANOARROW_TYPE_FIXED_SIZE_BINARY: + return NANOARROW_TYPE_BINARY; default: return ingest_type; } From e26883f974ed84965037baf3b5d3addde053a317 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Thu, 17 Oct 2024 14:00:13 -0400 Subject: [PATCH 13/14] handle updated tests --- c/driver/snowflake/snowflake_test.cc | 3 +++ c/validation/adbc_validation.h | 6 ++++++ c/validation/adbc_validation_statement.cc | 12 ++++++++++++ go/adbc/driver/snowflake/statement.go | 4 ++-- 4 files changed, 23 insertions(+), 2 deletions(-) diff --git a/c/driver/snowflake/snowflake_test.cc b/c/driver/snowflake/snowflake_test.cc index 63767e1426..262286192a 100644 --- a/c/driver/snowflake/snowflake_test.cc +++ b/c/driver/snowflake/snowflake_test.cc @@ -155,6 +155,9 @@ class SnowflakeQuirks : public adbc_validation::DriverQuirks { bool supports_dynamic_parameter_binding() const override { return true; } bool supports_error_on_incompatible_schema() const override { return false; } bool ddl_implicit_commit_txn() const override { return true; } + bool supports_ingest_view_types() const override { return false; } + bool supports_ingest_float16() const override { return false; } + std::string db_schema() const override { return schema_; } std::string catalog() const override { return "ADBC_TESTING"; } diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h index fa3c1cdccb..f8ef350cc2 100644 --- a/c/validation/adbc_validation.h +++ b/c/validation/adbc_validation.h @@ -238,6 +238,12 @@ class DriverQuirks { /// column matching. virtual bool supports_error_on_incompatible_schema() const { return true; } + /// \brief Whether ingestion supports StringView/BinaryView types + virtual bool supports_ingest_view_types() const { return true; } + + /// \brief Whether ingestion supports Float16 + virtual bool supports_ingest_float16() const { return true; } + /// \brief Default catalog to use for tests virtual std::string catalog() const { return ""; } diff --git a/c/validation/adbc_validation_statement.cc b/c/validation/adbc_validation_statement.cc index 19166d8524..94cee1fba3 100644 --- a/c/validation/adbc_validation_statement.cc +++ b/c/validation/adbc_validation_statement.cc @@ -246,6 +246,10 @@ void StatementTest::TestSqlIngestInt64() { } void StatementTest::TestSqlIngestFloat16() { + if (!quirks()->supports_ingest_float16()) { + GTEST_SKIP(); + } + ASSERT_NO_FATAL_FAILURE(TestSqlIngestNumericType(NANOARROW_TYPE_HALF_FLOAT)); } @@ -268,6 +272,10 @@ void StatementTest::TestSqlIngestLargeString() { } void StatementTest::TestSqlIngestStringView() { + if (!quirks()->supports_ingest_view_types()) { + GTEST_SKIP(); + } + ASSERT_NO_FATAL_FAILURE(TestSqlIngestType( NANOARROW_TYPE_STRING_VIEW, {std::nullopt, "", "", "longer than 12 bytes", "δΎ‹"}, false)); @@ -302,6 +310,10 @@ void StatementTest::TestSqlIngestFixedSizeBinary() { } void StatementTest::TestSqlIngestBinaryView() { + if (!quirks()->supports_ingest_view_types()) { + GTEST_SKIP(); + } + ASSERT_NO_FATAL_FAILURE(TestSqlIngestType>( NANOARROW_TYPE_LARGE_BINARY, {std::nullopt, std::vector{}, diff --git a/go/adbc/driver/snowflake/statement.go b/go/adbc/driver/snowflake/statement.go index 1fd1f658fe..574e390453 100644 --- a/go/adbc/driver/snowflake/statement.go +++ b/go/adbc/driver/snowflake/statement.go @@ -321,9 +321,9 @@ func toSnowflakeType(dt arrow.DataType) string { case arrow.DECIMAL, arrow.DECIMAL256: dec := dt.(arrow.DecimalType) return fmt.Sprintf("NUMERIC(%d,%d)", dec.GetPrecision(), dec.GetScale()) - case arrow.STRING, arrow.LARGE_STRING: + case arrow.STRING, arrow.LARGE_STRING, arrow.STRING_VIEW: return "text" - case arrow.BINARY, arrow.LARGE_BINARY: + case arrow.BINARY, arrow.LARGE_BINARY, arrow.BINARY_VIEW: return "binary" case arrow.FIXED_SIZE_BINARY: fsb := dt.(*arrow.FixedSizeBinaryType) From 4e9f6812a909685613921a926ce54492e3206dfa Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Thu, 17 Oct 2024 14:10:27 -0400 Subject: [PATCH 14/14] fix linting --- c/validation/adbc_validation_statement.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/c/validation/adbc_validation_statement.cc b/c/validation/adbc_validation_statement.cc index 94cee1fba3..150aabf327 100644 --- a/c/validation/adbc_validation_statement.cc +++ b/c/validation/adbc_validation_statement.cc @@ -249,7 +249,7 @@ void StatementTest::TestSqlIngestFloat16() { if (!quirks()->supports_ingest_float16()) { GTEST_SKIP(); } - + ASSERT_NO_FATAL_FAILURE(TestSqlIngestNumericType(NANOARROW_TYPE_HALF_FLOAT)); }