Skip to content

Commit

Permalink
fix(db_api): use sqlparse to split DDL statements (#372)
Browse files Browse the repository at this point in the history
Instead of simple `str.split(";")` method use more smart `sqlparse` package to split DDL statements executed in a form:
```python
cursor.execute("""
    ddl_statement1;
    ddl_statement2;
    ddl_statement3;
""")
```
  • Loading branch information
Ilya Gurov authored Jun 22, 2021
1 parent b7b3c38 commit ed9e124
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 4 deletions.
7 changes: 5 additions & 2 deletions google/cloud/spanner_dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

"""Database cursor for Google Cloud Spanner DB-API."""

import sqlparse

from google.api_core.exceptions import Aborted
from google.api_core.exceptions import AlreadyExists
from google.api_core.exceptions import FailedPrecondition
Expand Down Expand Up @@ -174,9 +176,10 @@ def execute(self, sql, args=None):
try:
classification = parse_utils.classify_stmt(sql)
if classification == parse_utils.STMT_DDL:
for ddl in sql.split(";"):
ddl = ddl.strip()
for ddl in sqlparse.split(sql):
if ddl:
if ddl[-1] == ";":
ddl = ddl[:-1]
self.connection._ddl_statements.append(ddl)
if self.connection.autocommit:
self.connection.run_prior_DDL_statements()
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/spanner_dbapi/parse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def classify_stmt(query):

def parse_insert(insert_sql, params):
"""
Parse an INSERT statement an generate a list of tuples of the form:
Parse an INSERT statement and generate a list of tuples of the form:
[
(SQL, params_per_row1),
(SQL, params_per_row2),
Expand Down
14 changes: 13 additions & 1 deletion tests/unit/spanner_dbapi/test_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,6 +941,13 @@ def test_ddls_with_semicolon(self):
EXP_DDLS = [
"CREATE TABLE table_name (row_id INT64) PRIMARY KEY ()",
"DROP INDEX index_name",
(
"CREATE TABLE papers ("
"\n id INT64,"
"\n authors ARRAY<STRING(100)>,"
'\n author_list STRING(MAX) AS (ARRAY_TO_STRING(authors, ";")) stored'
") PRIMARY KEY (id)"
),
"DROP TABLE table_name",
]

Expand All @@ -956,7 +963,12 @@ def test_ddls_with_semicolon(self):
cursor.execute(
"CREATE TABLE table_name (row_id INT64) PRIMARY KEY ();"
"DROP INDEX index_name;\n"
"DROP TABLE table_name;"
"CREATE TABLE papers ("
"\n id INT64,"
"\n authors ARRAY<STRING(100)>,"
'\n author_list STRING(MAX) AS (ARRAY_TO_STRING(authors, ";")) stored'
") PRIMARY KEY (id);"
"DROP TABLE table_name;",
)

self.assertEqual(connection._ddl_statements, EXP_DDLS)

0 comments on commit ed9e124

Please sign in to comment.