Skip to content

Commit

Permalink
Add option to disable memoryview cast
Browse files Browse the repository at this point in the history
  • Loading branch information
calebj committed Jan 8, 2025
1 parent 0a0450f commit ec72f5f
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 11 deletions.
15 changes: 8 additions & 7 deletions geoalchemy2/admin/dialects/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,15 @@ def reflect_geometry_column(inspector, table, column_info):
)


def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany, **kwargs):
"""Event handler to cast the parameters properly."""
if isinstance(parameters, (tuple, list)):
parameters = tuple(x.tobytes() if isinstance(x, memoryview) else x for x in parameters)
elif isinstance(parameters, dict):
for k in parameters:
if isinstance(parameters[k], memoryview):
parameters[k] = parameters[k].tobytes()
if kwargs.get("convert_memoryview", True):
if isinstance(parameters, (tuple, list)):
parameters = tuple(x.tobytes() if isinstance(x, memoryview) else x for x in parameters)
elif isinstance(parameters, dict):
for k in parameters:
if isinstance(parameters[k], memoryview):
parameters[k] = parameters[k].tobytes()

return statement, parameters

Expand Down
16 changes: 13 additions & 3 deletions geoalchemy2/admin/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,19 @@ def __init__(self, url, kwargs):
if journal_mode is not None:
self.params["connect"]["sqlite"]["journal_mode"] = journal_mode

convert_memoryview = url.query.get("geoalchemy2_execute_mysql_convert_memoryview", None)
if convert_memoryview is not None:
self.params["before_cursor_execute"]["mysql"]["convert_memoryview"] = self.str_to_bool(convert_memoryview)

@staticmethod
def str_to_bool(string):
"""Cast string to bool."""
return True if str(string).lower() in ["true", "1", "yes"] else False
def str_to_bool(argument):
"""Cast argument to bool."""
lowered = argument.lower()
if lowered in ('yes', 'y', 'true', 't', '1', 'enable', 'on'):
return True
elif lowered in ('no', 'n', 'false', 'f', '0', 'disable', 'off'):
return False
raise ValueError(argument)

def update_url(self, url):
"""Update the URL to one that no longer includes specific parameters."""
Expand All @@ -57,6 +66,7 @@ def update_url(self, url):
"geoalchemy2_connect_sqlite_transaction",
"geoalchemy2_connect_sqlite_init_mode",
"geoalchemy2_connect_sqlite_journal_mode",
"geoalchemy2_execute_mysql_convert_memoryview",
],
)

Expand Down
14 changes: 13 additions & 1 deletion tests/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def test_geo_engine_init():
"geoalchemy2_connect_sqlite_transaction": "true",
"geoalchemy2_connect_sqlite_init_mode": "WGS84",
"geoalchemy2_connect_sqlite_journal_mode": "OFF",
"geoalchemy2_execute_mysql_convert_memoryview": "off",
},
)
plugin = GeoEngine(url, {})
Expand All @@ -28,6 +29,10 @@ def test_geo_engine_init():
"journal_mode": "OFF",
}

assert plugin.params["before_cursor_execute"]["mysql"] == {
"convert_memoryview": False,
}


@pytest.mark.parametrize(
"value,expected",
Expand All @@ -40,14 +45,19 @@ def test_geo_engine_init():
("False", False),
("0", False),
("no", False),
("anything_else", False),
],
)
def test_str_to_bool(value, expected):
"""Test string to boolean conversion."""
assert GeoEngine.str_to_bool(value) == expected


def test_invalid_str_to_bool():
"""Test unknown parameter in boolean conversion."""
with pytest.raises(ValueError):
GeoEngine.str_to_bool("anything_else")


def test_update_url():
"""Test URL parameter cleanup."""
url = URL.create(
Expand All @@ -56,6 +66,7 @@ def test_update_url():
"geoalchemy2_connect_sqlite_transaction": "true",
"geoalchemy2_connect_sqlite_init_mode": "WGS84",
"geoalchemy2_connect_sqlite_journal_mode": "OFF",
"geoalchemy2_execute_mysql_convert_memoryview": "yes",
"other_param": "value",
},
)
Expand All @@ -66,6 +77,7 @@ def test_update_url():
assert "geoalchemy2_connect_sqlite_transaction" not in updated_url.query
assert "geoalchemy2_connect_sqlite_init_mode" not in updated_url.query
assert "geoalchemy2_connect_sqlite_journal_mode" not in updated_url.query
assert "geoalchemy2_execute_mysql_convert_memoryview" not in updated_url.query

# Check that other parameters are preserved
assert updated_url.query["other_param"] == "value"

0 comments on commit ec72f5f

Please sign in to comment.