Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(go/adbc/driver/snowflake): improve GetObjects performance and semantics #2254

Merged
merged 14 commits into from
Oct 17, 2024
1 change: 1 addition & 0 deletions c/driver/flightsql/sqlite_flightsql_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class SqliteFlightSqlQuirks : public adbc_validation::DriverQuirks {
bool supports_get_objects() const override { return true; }
bool supports_partitioned_data() const override { return true; }
bool supports_dynamic_parameter_binding() const override { return true; }
std::string catalog() const { return "main"; }
};

class SqliteFlightSqlTest : public ::testing::Test, public adbc_validation::DatabaseTest {
Expand Down
12 changes: 11 additions & 1 deletion c/driver/snowflake/snowflake_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class SnowflakeQuirks : public adbc_validation::DriverQuirks {
adbc_validation::Handle<struct AdbcStatement> statement;
CHECK_OK(AdbcStatementNew(connection, &statement.value, error));

std::string create = "CREATE TABLE \"";
std::string create = "CREATE OR REPLACE TABLE \"";
create += name;
create += "\" (int64s INT, strings TEXT)";
CHECK_OK(AdbcStatementSetSqlQuery(&statement.value, create.c_str(), error));
Expand Down Expand Up @@ -131,7 +131,13 @@ class SnowflakeQuirks : public adbc_validation::DriverQuirks {
return NANOARROW_TYPE_DOUBLE;
case NANOARROW_TYPE_STRING:
case NANOARROW_TYPE_LARGE_STRING:
case NANOARROW_TYPE_LIST:
case NANOARROW_TYPE_LARGE_LIST:
return NANOARROW_TYPE_STRING;
case NANOARROW_TYPE_BINARY:
case NANOARROW_TYPE_LARGE_BINARY:
case NANOARROW_TYPE_FIXED_SIZE_BINARY:
return NANOARROW_TYPE_BINARY;
default:
return ingest_type;
}
Expand All @@ -149,7 +155,11 @@ class SnowflakeQuirks : public adbc_validation::DriverQuirks {
bool supports_dynamic_parameter_binding() const override { return true; }
bool supports_error_on_incompatible_schema() const override { return false; }
bool ddl_implicit_commit_txn() const override { return true; }
bool supports_ingest_view_types() const override { return false; }
bool supports_ingest_float16() const override { return false; }

std::string db_schema() const override { return schema_; }
std::string catalog() const override { return "ADBC_TESTING"; }

const char* uri_;
bool skip_{false};
Expand Down
6 changes: 6 additions & 0 deletions c/validation/adbc_validation.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,12 @@ class DriverQuirks {
/// column matching.
virtual bool supports_error_on_incompatible_schema() const { return true; }

/// \brief Whether ingestion supports StringView/BinaryView types
virtual bool supports_ingest_view_types() const { return true; }

/// \brief Whether ingestion supports Float16
virtual bool supports_ingest_float16() const { return true; }

/// \brief Default catalog to use for tests
virtual std::string catalog() const { return ""; }

Expand Down
27 changes: 15 additions & 12 deletions c/validation/adbc_validation_connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -744,27 +744,30 @@ void ConnectionTest::TestMetadataGetObjectsColumns() {

struct TestCase {
std::optional<std::string> filter;
std::vector<std::string> column_names;
std::vector<int32_t> ordinal_positions;
// the pair is column name & ordinal position of the column
std::vector<std::pair<std::string, int32_t>> columns;
zeroshade marked this conversation as resolved.
Show resolved Hide resolved
};

std::vector<TestCase> test_cases;
test_cases.push_back({std::nullopt, {"int64s", "strings"}, {1, 2}});
test_cases.push_back({"in%", {"int64s"}, {1}});
test_cases.push_back({std::nullopt, {{"int64s", 1}, {"strings", 2}}});
test_cases.push_back({"in%", {{"int64s", 1}}});

const std::string catalog = quirks()->catalog();

for (const auto& test_case : test_cases) {
std::string scope = "Filter: ";
scope += test_case.filter ? *test_case.filter : "(no filter)";
SCOPED_TRACE(scope);

StreamReader reader;
std::vector<std::pair<std::string, int32_t>> columns;
std::vector<std::string> column_names;
std::vector<int32_t> ordinal_positions;

ASSERT_THAT(
AdbcConnectionGetObjects(
&connection, ADBC_OBJECT_DEPTH_COLUMNS, nullptr, nullptr, nullptr, nullptr,
test_case.filter.has_value() ? test_case.filter->c_str() : nullptr,
&connection, ADBC_OBJECT_DEPTH_COLUMNS, catalog.c_str(), nullptr, nullptr,
nullptr, test_case.filter.has_value() ? test_case.filter->c_str() : nullptr,
&reader.stream.value, &error),
IsOkStatus(&error));
ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
Expand Down Expand Up @@ -834,10 +837,9 @@ void ConnectionTest::TestMetadataGetObjectsColumns() {
std::string temp(name.data, name.size_bytes);
std::transform(temp.begin(), temp.end(), temp.begin(),
[](unsigned char c) { return std::tolower(c); });
column_names.push_back(std::move(temp));
ordinal_positions.push_back(
static_cast<int32_t>(ArrowArrayViewGetIntUnsafe(
table_columns->children[1], columns_index)));
columns.emplace_back(std::move(temp),
static_cast<int32_t>(ArrowArrayViewGetIntUnsafe(
table_columns->children[1], columns_index)));
}
}
}
Expand All @@ -847,8 +849,9 @@ void ConnectionTest::TestMetadataGetObjectsColumns() {
} while (reader.array->release);

ASSERT_TRUE(found_expected_table) << "Did (not) find table in metadata";
ASSERT_EQ(test_case.column_names, column_names);
ASSERT_EQ(test_case.ordinal_positions, ordinal_positions);
// metadata columns do not guarantee the order they are returned in, just
// validate all the elements are there.
ASSERT_THAT(columns, testing::UnorderedElementsAreArray(test_case.columns));
}
}

Expand Down
16 changes: 14 additions & 2 deletions c/validation/adbc_validation_statement.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,10 @@ void StatementTest::TestSqlIngestInt64() {
}

void StatementTest::TestSqlIngestFloat16() {
if (!quirks()->supports_ingest_float16()) {
GTEST_SKIP();
}

ASSERT_NO_FATAL_FAILURE(TestSqlIngestNumericType<float>(NANOARROW_TYPE_HALF_FLOAT));
}

Expand All @@ -268,6 +272,10 @@ void StatementTest::TestSqlIngestLargeString() {
}

void StatementTest::TestSqlIngestStringView() {
if (!quirks()->supports_ingest_view_types()) {
GTEST_SKIP();
}

ASSERT_NO_FATAL_FAILURE(TestSqlIngestType<std::string>(
NANOARROW_TYPE_STRING_VIEW, {std::nullopt, "", "", "longer than 12 bytes", ""},
false));
Expand Down Expand Up @@ -302,6 +310,10 @@ void StatementTest::TestSqlIngestFixedSizeBinary() {
}

void StatementTest::TestSqlIngestBinaryView() {
if (!quirks()->supports_ingest_view_types()) {
GTEST_SKIP();
}

ASSERT_NO_FATAL_FAILURE(TestSqlIngestType<std::vector<std::byte>>(
NANOARROW_TYPE_LARGE_BINARY,
{std::nullopt, std::vector<std::byte>{},
Expand Down Expand Up @@ -2218,15 +2230,15 @@ void StatementTest::TestSqlBind() {

ASSERT_THAT(
AdbcStatementSetSqlQuery(
&statement, "SELECT * FROM bindtest ORDER BY \"col1\" ASC NULLS FIRST", &error),
&statement, "SELECT * FROM bindtest ORDER BY col1 ASC NULLS FIRST", &error),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we perhaps need a quirk for escaping column names?

(I also wouldn't be opposed to trying to make these tests more data-driven...I should go find time to sketch it out)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's more about consistency. Our CREATE TABLE query earlier in this function doesn't quote the column names, so our select statement needs to also not quote the names. Almost everywhere else we quote the columns. We just need to be consistent.

That said, I agree with it would be awesome for these tests to be more data-driven.

IsOkStatus(&error));
{
StreamReader reader;
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
&reader.rows_affected, &error),
IsOkStatus(&error));
ASSERT_THAT(reader.rows_affected,
::testing::AnyOf(::testing::Eq(0), ::testing::Eq(-1)));
::testing::AnyOf(::testing::Eq(3), ::testing::Eq(-1)));

ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
ASSERT_NO_FATAL_FAILURE(reader.Next());
Expand Down
51 changes: 32 additions & 19 deletions go/adbc/driver/internal/driverbase/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,14 +349,17 @@ func (cnxn *connection) GetObjects(ctx context.Context, depth adbc.ObjectDepth,

bufferSize := len(catalogs)
addCatalogCh := make(chan GetObjectsInfo, bufferSize)
for _, cat := range catalogs {
addCatalogCh <- GetObjectsInfo{CatalogName: Nullable(cat)}
}

close(addCatalogCh)
errCh := make(chan error, 1)
go func() {
defer close(addCatalogCh)
for _, cat := range catalogs {
addCatalogCh <- GetObjectsInfo{CatalogName: Nullable(cat)}
}
}()

if depth == adbc.ObjectDepthCatalogs {
return BuildGetObjectsRecordReader(cnxn.Base().Alloc, addCatalogCh)
close(errCh)
return BuildGetObjectsRecordReader(cnxn.Base().Alloc, addCatalogCh, errCh)
}

g, ctxG := errgroup.WithContext(ctx)
Expand Down Expand Up @@ -386,7 +389,7 @@ func (cnxn *connection) GetObjects(ctx context.Context, depth adbc.ObjectDepth,
g.Go(func() error { defer close(addDbSchemasCh); return gSchemas.Wait() })

if depth == adbc.ObjectDepthDBSchemas {
rdr, err := BuildGetObjectsRecordReader(cnxn.Base().Alloc, addDbSchemasCh)
rdr, err := BuildGetObjectsRecordReader(cnxn.Base().Alloc, addDbSchemasCh, errCh)
return rdr, errors.Join(err, g.Wait())
}

Expand Down Expand Up @@ -432,7 +435,7 @@ func (cnxn *connection) GetObjects(ctx context.Context, depth adbc.ObjectDepth,

g.Go(func() error { defer close(addTablesCh); return gTables.Wait() })

rdr, err := BuildGetObjectsRecordReader(cnxn.Base().Alloc, addTablesCh)
rdr, err := BuildGetObjectsRecordReader(cnxn.Base().Alloc, addTablesCh, errCh)
return rdr, errors.Join(err, g.Wait())
}

Expand Down Expand Up @@ -621,20 +624,20 @@ type ColumnInfo struct {
type TableInfo struct {
TableName string `json:"table_name"`
TableType string `json:"table_type"`
TableColumns []ColumnInfo `json:"table_columns,omitempty"`
TableConstraints []ConstraintInfo `json:"table_constraints,omitempty"`
TableColumns []ColumnInfo `json:"table_columns"`
TableConstraints []ConstraintInfo `json:"table_constraints"`
}

// DBSchemaInfo is a structured representation of adbc.DBSchemaSchema
type DBSchemaInfo struct {
DbSchemaName *string `json:"db_schema_name,omitempty"`
DbSchemaTables []TableInfo `json:"db_schema_tables,omitempty"`
DbSchemaTables []TableInfo `json:"db_schema_tables"`
}

// GetObjectsInfo is a structured representation of adbc.GetObjectsSchema
type GetObjectsInfo struct {
CatalogName *string `json:"catalog_name,omitempty"`
CatalogDbSchemas []DBSchemaInfo `json:"catalog_db_schemas,omitempty"`
CatalogDbSchemas []DBSchemaInfo `json:"catalog_db_schemas"`
}

// Scan implements sql.Scanner.
Expand All @@ -659,23 +662,33 @@ func (g *GetObjectsInfo) Scan(src any) error {
// BuildGetObjectsRecordReader constructs a RecordReader for the GetObjects ADBC method.
// It accepts a channel of GetObjectsInfo to allow concurrent retrieval of metadata and
// serialization to Arrow record.
func BuildGetObjectsRecordReader(mem memory.Allocator, in chan GetObjectsInfo) (array.RecordReader, error) {
func BuildGetObjectsRecordReader(mem memory.Allocator, in <-chan GetObjectsInfo, errCh <-chan error) (array.RecordReader, error) {
bldr := array.NewRecordBuilder(mem, adbc.GetObjectsSchema)
defer bldr.Release()

for catalog := range in {
b, err := json.Marshal(catalog)
if err != nil {
return nil, err
}
CATALOGLOOP:
for {
select {
case catalog, ok := <-in:
if !ok {
break CATALOGLOOP
}
b, err := json.Marshal(catalog)
if err != nil {
return nil, err
}

if err := json.Unmarshal(b, bldr); err != nil {
if err := json.Unmarshal(b, bldr); err != nil {
return nil, err
}
case err := <-errCh:
return nil, err
}
}

rec := bldr.NewRecord()
defer rec.Release()

return array.NewRecordReader(adbc.GetObjectsSchema, []arrow.Record{rec})
}

Expand Down
Loading
Loading