Skip to content

Commit

Permalink
fix(c/driver): be explicit about columns in ingestion (apache#1238)
Browse files Browse the repository at this point in the history
  • Loading branch information
lidavidm authored Nov 1, 2023
1 parent e6b6e83 commit f0ae519
Show file tree
Hide file tree
Showing 7 changed files with 194 additions and 15 deletions.
8 changes: 8 additions & 0 deletions c/driver/postgresql/postgresql_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@ class PostgresQuirks : public adbc_validation::DriverQuirks {
return ddl;
}

std::optional<std::string> PrimaryKeyIngestTableDdl(
std::string_view name) const override {
std::string ddl = "CREATE TABLE ";
ddl += name;
ddl += " (id BIGSERIAL PRIMARY KEY, value BIGINT)";
return ddl;
}

std::optional<std::string> CompositePrimaryKeyTableDdl(
std::string_view name) const override {
std::string ddl = "CREATE TABLE ";
Expand Down
26 changes: 20 additions & 6 deletions c/driver/postgresql/statement.cc
Original file line number Diff line number Diff line change
Expand Up @@ -887,7 +887,8 @@ AdbcStatusCode PostgresStatement::Cancel(struct AdbcError* error) {
AdbcStatusCode PostgresStatement::CreateBulkTable(
const std::string& current_schema, const struct ArrowSchema& source_schema,
const std::vector<struct ArrowSchemaView>& source_schema_fields,
std::string* escaped_table, struct AdbcError* error) {
std::string* escaped_table, std::string* escaped_field_list,
struct AdbcError* error) {
PGconn* conn = connection_->conn();

if (!ingest_.db_schema.empty() && ingest_.temporary) {
Expand Down Expand Up @@ -944,10 +945,9 @@ AdbcStatusCode PostgresStatement::CreateBulkTable(

switch (ingest_.mode) {
case IngestMode::kCreate:
case IngestMode::kAppend:
// Nothing to do
break;
case IngestMode::kAppend:
return ADBC_STATUS_OK;
case IngestMode::kReplace: {
std::string drop = "DROP TABLE IF EXISTS " + *escaped_table;
PGresult* result = PQexecParams(conn, drop.c_str(), /*nParams=*/0,
Expand All @@ -972,7 +972,10 @@ AdbcStatusCode PostgresStatement::CreateBulkTable(
create += " (";

for (size_t i = 0; i < source_schema_fields.size(); i++) {
if (i > 0) create += ", ";
if (i > 0) {
create += ", ";
*escaped_field_list += ", ";
}

const char* unescaped = source_schema.children[i]->name;
char* escaped = PQescapeIdentifier(conn, unescaped, std::strlen(unescaped));
Expand All @@ -982,6 +985,7 @@ AdbcStatusCode PostgresStatement::CreateBulkTable(
return ADBC_STATUS_INTERNAL;
}
create += escaped;
*escaped_field_list += escaped;
PQfreemem(escaped);

switch (source_schema_fields[i].type) {
Expand Down Expand Up @@ -1034,6 +1038,10 @@ AdbcStatusCode PostgresStatement::CreateBulkTable(
}
}

if (ingest_.mode == IngestMode::kAppend) {
return ADBC_STATUS_OK;
}

create += ")";
SetError(error, "%s%s", "[libpq] ", create.c_str());
PGresult* result = PQexecParams(conn, create.c_str(), /*nParams=*/0,
Expand Down Expand Up @@ -1203,15 +1211,21 @@ AdbcStatusCode PostgresStatement::ExecuteUpdateBulk(int64_t* rows_affected,
BindStream bind_stream(std::move(bind_));
std::memset(&bind_, 0, sizeof(bind_));
std::string escaped_table;
std::string escaped_field_list;
RAISE_ADBC(bind_stream.Begin(
[&]() -> AdbcStatusCode {
return CreateBulkTable(current_schema, bind_stream.bind_schema.value,
bind_stream.bind_schema_fields, &escaped_table, error);
bind_stream.bind_schema_fields, &escaped_table,
&escaped_field_list, error);
},
error));
RAISE_ADBC(bind_stream.SetParamTypes(*type_resolver_, error));

std::string query = "COPY " + escaped_table + " FROM STDIN WITH (FORMAT binary)";
std::string query = "COPY ";
query += escaped_table;
query += " (";
query += escaped_field_list;
query += ") FROM STDIN WITH (FORMAT binary)";
PGresult* result = PQexec(connection_->conn(), query.c_str());
if (PQresultStatus(result) != PGRES_COPY_IN) {
AdbcStatusCode code =
Expand Down
3 changes: 2 additions & 1 deletion c/driver/postgresql/statement.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ class PostgresStatement {
AdbcStatusCode CreateBulkTable(
const std::string& current_schema, const struct ArrowSchema& source_schema,
const std::vector<struct ArrowSchemaView>& source_schema_fields,
std::string* escaped_table, struct AdbcError* error);
std::string* escaped_table, std::string* escaped_field_list,
struct AdbcError* error);
AdbcStatusCode ExecuteUpdateBulk(int64_t* rows_affected, struct AdbcError* error);
AdbcStatusCode ExecuteUpdateQuery(int64_t* rows_affected, struct AdbcError* error);
AdbcStatusCode ExecutePreparedStatement(struct ArrowArrayStream* stream,
Expand Down
40 changes: 32 additions & 8 deletions c/driver/sqlite/sqlite.c
Original file line number Diff line number Diff line change
Expand Up @@ -1136,7 +1136,7 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt,
goto cleanup;
}

sqlite3_str_appendf(insert_query, "INSERT INTO %s VALUES (", table);
sqlite3_str_appendf(insert_query, "INSERT INTO %s (", table);
if (sqlite3_str_errcode(insert_query)) {
SetError(error, "[SQLite] Failed to build INSERT: %s", sqlite3_errmsg(stmt->conn));
code = ADBC_STATUS_INTERNAL;
Expand All @@ -1154,6 +1154,14 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt,
code = ADBC_STATUS_INTERNAL;
goto cleanup;
}

sqlite3_str_appendf(insert_query, "%s", ", ");
if (sqlite3_str_errcode(insert_query)) {
SetError(error, "[SQLite] Failed to build INSERT: %s",
sqlite3_errmsg(stmt->conn));
code = ADBC_STATUS_INTERNAL;
goto cleanup;
}
}

sqlite3_str_appendf(create_query, "\"%w\"", stmt->binder.schema.children[i]->name);
Expand All @@ -1163,6 +1171,13 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt,
goto cleanup;
}

sqlite3_str_appendf(insert_query, "\"%w\"", stmt->binder.schema.children[i]->name);
if (sqlite3_str_errcode(insert_query)) {
SetError(error, "[SQLite] Failed to build INSERT: %s", sqlite3_errmsg(stmt->conn));
code = ADBC_STATUS_INTERNAL;
goto cleanup;
}

int status =
ArrowSchemaViewInit(&view, stmt->binder.schema.children[i], &arrow_error);
if (status != 0) {
Expand Down Expand Up @@ -1199,13 +1214,6 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt,
default:
break;
}

sqlite3_str_appendf(insert_query, "%s?", (i > 0 ? ", " : ""));
if (sqlite3_str_errcode(insert_query)) {
SetError(error, "[SQLite] Failed to build INSERT: %s", sqlite3_errmsg(stmt->conn));
code = ADBC_STATUS_INTERNAL;
goto cleanup;
}
}

sqlite3_str_appendchar(create_query, 1, ')');
Expand All @@ -1215,6 +1223,22 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt,
goto cleanup;
}

sqlite3_str_appendall(insert_query, ") VALUES (");
if (sqlite3_str_errcode(insert_query)) {
SetError(error, "[SQLite] Failed to build INSERT: %s", sqlite3_errmsg(stmt->conn));
code = ADBC_STATUS_INTERNAL;
goto cleanup;
}

for (int i = 0; i < stmt->binder.schema.n_children; i++) {
sqlite3_str_appendf(insert_query, "%s?", (i > 0 ? ", " : ""));
if (sqlite3_str_errcode(insert_query)) {
SetError(error, "[SQLite] Failed to build INSERT: %s", sqlite3_errmsg(stmt->conn));
code = ADBC_STATUS_INTERNAL;
goto cleanup;
}
}

sqlite3_str_appendchar(insert_query, 1, ')');
if (sqlite3_str_errcode(insert_query)) {
SetError(error, "[SQLite] Failed to build INSERT: %s", sqlite3_errmsg(stmt->conn));
Expand Down
8 changes: 8 additions & 0 deletions c/driver/sqlite/sqlite_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ class SqliteQuirks : public adbc_validation::DriverQuirks {
return ddl;
}

std::optional<std::string> PrimaryKeyIngestTableDdl(
std::string_view name) const override {
std::string ddl = "CREATE TABLE ";
ddl += name;
ddl += " (id INTEGER PRIMARY KEY, value BIGINT)";
return ddl;
}

std::optional<std::string> CompositePrimaryKeyTableDdl(
std::string_view name) const override {
std::string ddl = "CREATE TABLE ";
Expand Down
109 changes: 109 additions & 0 deletions c/validation/adbc_validation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2803,6 +2803,115 @@ void StatementTest::TestSqlIngestTemporaryExclusive() {
}
}

void StatementTest::TestSqlIngestPrimaryKey() {
std::string name = "pkeytest";
auto ddl = quirks()->PrimaryKeyIngestTableDdl(name);
if (!ddl) {
GTEST_SKIP();
}
ASSERT_THAT(quirks()->DropTable(&connection, name, &error), IsOkStatus(&error));

// Create table
{
Handle<struct AdbcStatement> statement;
StreamReader reader;
ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementSetSqlQuery(&statement.value, ddl->c_str(), &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementRelease(&statement.value, &error), IsOkStatus(&error));
}

// Ingest without the primary key
{
Handle<struct ArrowSchema> schema;
Handle<struct ArrowArray> array;
struct ArrowError na_error;
ASSERT_THAT(MakeSchema(&schema.value, {{"value", NANOARROW_TYPE_INT64}}),
IsOkErrno());
ASSERT_THAT((MakeBatch<int64_t>(&schema.value, &array.value, &na_error,
{42, -42, std::nullopt})),
IsOkErrno());

Handle<struct AdbcStatement> statement;
ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TARGET_TABLE,
name.c_str(), &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_MODE,
ADBC_INGEST_OPTION_MODE_APPEND, &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementBind(&statement.value, &array.value, &schema.value, &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementRelease(&statement.value, &error), IsOkStatus(&error));
}

// Ingest with the primary key
{
Handle<struct ArrowSchema> schema;
Handle<struct ArrowArray> array;
struct ArrowError na_error;
ASSERT_THAT(MakeSchema(&schema.value,
{
{"id", NANOARROW_TYPE_INT64},
{"value", NANOARROW_TYPE_INT64},
}),
IsOkErrno());
ASSERT_THAT((MakeBatch<int64_t, int64_t>(&schema.value, &array.value, &na_error,
{4, 5, 6}, {1, 0, -1})),
IsOkErrno());

Handle<struct AdbcStatement> statement;
ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TARGET_TABLE,
name.c_str(), &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_MODE,
ADBC_INGEST_OPTION_MODE_APPEND, &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementBind(&statement.value, &array.value, &schema.value, &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementRelease(&statement.value, &error), IsOkStatus(&error));
}

// Get the data
{
Handle<struct AdbcStatement> statement;
StreamReader reader;
ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementSetSqlQuery(
&statement.value, "SELECT * FROM pkeytest ORDER BY id ASC", &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, &reader.stream.value, nullptr,
&error),
IsOkStatus(&error));

ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
ASSERT_EQ(2, reader.schema->n_children);
ASSERT_NO_FATAL_FAILURE(reader.Next());
ASSERT_NE(nullptr, reader.array->release);
ASSERT_EQ(6, reader.array->length);
ASSERT_EQ(2, reader.array->n_children);

// Different databases start numbering at 0 or 1 for the primary key
// column, so can't compare it
// TODO(https://github.com/apache/arrow-adbc/issues/938): if the test
// helpers converted data to plain C++ values we could do a more
// sophisticated assertion
ASSERT_NO_FATAL_FAILURE(CompareArray<int64_t>(reader.array_view->children[1],
{42, -42, std::nullopt, 1, 0, -1}));
}
}

void StatementTest::TestSqlPartitionedInts() {
ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 42", &error),
Expand Down
15 changes: 15 additions & 0 deletions c/validation/adbc_validation.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,19 @@ class DriverQuirks {
return std::nullopt;
}

/// \brief Get the statement to create a table with a primary key, or
/// nullopt if not supported. This is used to test ingestion into a table
/// with an auto-incrementing primary key (which should not require the
/// data to contain the primary key).
///
/// The table should have two columns:
/// - "id" which should be an auto-incrementing primary key compatible with int64
/// - "value" with Arrow type int64
virtual std::optional<std::string> PrimaryKeyIngestTableDdl(
std::string_view name) const {
return std::nullopt;
}

/// \brief Get the statement to create a table with a composite primary key,
/// or nullopt if not supported.
///
Expand Down Expand Up @@ -347,6 +360,7 @@ class StatementTest {
void TestSqlIngestTemporaryAppend();
void TestSqlIngestTemporaryReplace();
void TestSqlIngestTemporaryExclusive();
void TestSqlIngestPrimaryKey();

void TestSqlPartitionedInts();

Expand Down Expand Up @@ -444,6 +458,7 @@ class StatementTest {
TEST_F(FIXTURE, SqlIngestTemporaryAppend) { TestSqlIngestTemporaryAppend(); } \
TEST_F(FIXTURE, SqlIngestTemporaryReplace) { TestSqlIngestTemporaryReplace(); } \
TEST_F(FIXTURE, SqlIngestTemporaryExclusive) { TestSqlIngestTemporaryExclusive(); } \
TEST_F(FIXTURE, SqlIngestPrimaryKey) { TestSqlIngestPrimaryKey(); } \
TEST_F(FIXTURE, SqlPartitionedInts) { TestSqlPartitionedInts(); } \
TEST_F(FIXTURE, SqlPrepareGetParameterSchema) { TestSqlPrepareGetParameterSchema(); } \
TEST_F(FIXTURE, SqlPrepareSelectNoParams) { TestSqlPrepareSelectNoParams(); } \
Expand Down

0 comments on commit f0ae519

Please sign in to comment.