Skip to content

Commit

Permalink
Finalize column rebinding
Browse files Browse the repository at this point in the history
- Track encoding ctypes and trigger rebinding on changes
- Prevent losing rows when switching from fetchmany to fetchone and then triggering a rebind
- Add tests for the above
- Avoid unnecessary dict copy when checking if rebind is necessary
- Minor improvements to diagnostic variables and error handling
  • Loading branch information
ffelixg committed May 18, 2024
1 parent ff0423f commit b2be6a9
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 70 deletions.
99 changes: 54 additions & 45 deletions src/cursor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,15 +371,6 @@ static bool BindCols(Cursor* cur, int cCols, int fetch_rows_cap)
}
}

if (cur->bind_byte_cap < 0) {
PyErr_SetString(ProgrammingError, "Cursor attribute bind_byte_cap must be non negative.");
return 0;
}
if (cur->bind_cell_cap < 0) {
PyErr_SetString(ProgrammingError, "Cursor attribute bind_cell_cap must be non negative.");
return 0;
}

// number of rows to be fetched at a time
cur->fetch_buffer_width = total_buf_size;
if (bind_all) {
Expand All @@ -396,14 +387,15 @@ static bool BindCols(Cursor* cur, int cCols, int fetch_rows_cap)
} else {
cur->fetch_buffer_length = 1;
}
cur->fetch_buffer_length_used = 0;

// single large buffer using row-wise layout with row status array at the end
void* buf = PyMem_Malloc((cur->fetch_buffer_width + sizeof(SQLUSMALLINT)) * cur->fetch_buffer_length);
if (!buf) {
cur->fetch_buffer = PyMem_Malloc((cur->fetch_buffer_width + sizeof(SQLUSMALLINT)) * cur->fetch_buffer_length);
if (!cur->fetch_buffer) {
PyErr_NoMemory();
return false;
}
cur->row_status_array = (SQLUSMALLINT*)((uintptr_t)buf + cur->fetch_buffer_width * cur->fetch_buffer_length);
cur->row_status_array = (SQLUSMALLINT*)((uintptr_t)cur->fetch_buffer + cur->fetch_buffer_width * cur->fetch_buffer_length);
cur->current_row = 0;

SQLRETURN ret_bind = SQL_SUCCESS, ret_attr;
Expand All @@ -417,10 +409,6 @@ static bool BindCols(Cursor* cur, int cCols, int fetch_rows_cap)
if (!SQL_SUCCEEDED(ret_attr)) {
goto skip;
}
ret_attr = SQLSetStmtAttr(cur->hstmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)cur->fetch_buffer_length, 0);
if (!SQL_SUCCEEDED(ret_attr)) {
goto skip;
}
// TODO use for error checking
ret_attr = SQLSetStmtAttr(cur->hstmt, SQL_ATTR_ROW_STATUS_PTR, cur->row_status_array, 0);
if (!SQL_SUCCEEDED(ret_attr)) {
Expand All @@ -440,9 +428,9 @@ static bool BindCols(Cursor* cur, int cCols, int fetch_rows_cap)
cur->hstmt,
(SQLUSMALLINT)(iCol+1),
cInfo->c_type,
(void*)((uintptr_t)buf + cInfo->buf_offset),
(void*)((uintptr_t)cur->fetch_buffer + cInfo->buf_offset),
cInfo->buf_size,
(SQLLEN*)((uintptr_t)buf + cInfo->buf_offset - sizeof(SQLLEN))
(SQLLEN*)((uintptr_t)cur->fetch_buffer + cInfo->buf_offset - sizeof(SQLLEN))
);
if (!SQL_SUCCEEDED(ret_bind)) {
break;
Expand All @@ -461,13 +449,9 @@ static bool BindCols(Cursor* cur, int cCols, int fetch_rows_cap)
}

if (!SQL_SUCCEEDED(ret_bind)) {
PyMem_Free(buf);
return RaiseErrorFromHandle(cur->cnxn, "SQLBindCol", cur->cnxn->hdbc, cur->hstmt);
}

assert(cur->fetch_buffer == 0);
cur->fetch_buffer = buf;

return true;
}

Expand All @@ -477,29 +461,21 @@ static bool PrepareFetch(Cursor* cur, int n_rows)
// Returns false on exception, true otherwise.
// Need to do this because the API allows changing this after executing a statement.

PyObject* converted_types = 0;
bool native_uuid = UseNativeUUID();
bool converted_types_changed = false;

if (cur->cnxn->map_sqltype_to_converter)
{
converted_types = PyDict_Copy(cur->cnxn->map_sqltype_to_converter);
if (!converted_types)
{
return false;
}
if (cur->bound_converted_types)
{
switch (PyObject_RichCompareBool(cur->bound_converted_types, converted_types, Py_EQ))
switch (PyObject_RichCompareBool(cur->bound_converted_types, cur->cnxn->map_sqltype_to_converter, Py_EQ))
{
case -1: // error
Py_DECREF(converted_types);
return false;
case 0: // not equal
converted_types_changed = true;
break;
case 1: // equal
Py_DECREF(converted_types);
break;
}
}
Expand All @@ -513,18 +489,30 @@ static bool PrepareFetch(Cursor* cur, int n_rows)
converted_types_changed = true;
}

if (cur->bound_native_uuid != native_uuid || converted_types_changed || !cur->fetch_buffer)
if (
!cur->fetch_buffer || cur->bound_native_uuid != native_uuid || converted_types_changed ||
cur->ctype_of_char_enc != cur->cnxn->sqlchar_enc.ctype || cur->ctype_of_wchar_enc != cur->cnxn->sqlwchar_enc.ctype
)
{
Py_XDECREF(cur->bound_converted_types);
Py_CLEAR(cur->bound_converted_types);
PyObject* converted_types = 0;
if (cur->cnxn->map_sqltype_to_converter) {
converted_types = PyDict_Copy(cur->cnxn->map_sqltype_to_converter);
if (!converted_types)
return false;
}
cur->bound_converted_types = converted_types;
cur->bound_native_uuid = native_uuid;
cur->ctype_of_char_enc = cur->cnxn->sqlchar_enc.ctype;
cur->ctype_of_wchar_enc = cur->cnxn->sqlwchar_enc.ctype;

if (cur->description != Py_None)
{
int cCols = PyTuple_GET_SIZE(cur->description);
BindColsFree(cur);
if (!BindCols(cur, cCols, n_rows))
{
BindColsFree(cur);
return false;
}
}
Expand Down Expand Up @@ -552,7 +540,7 @@ static bool free_results(Cursor* self, int flags)
}

BindColsFree(self);
Py_XDECREF(self->bound_converted_types);
Py_CLEAR(self->bound_converted_types);

if (self->colinfos)
{
Expand Down Expand Up @@ -1378,7 +1366,7 @@ static PyObject* Cursor_setinputsizes(PyObject* self, PyObject* sizes)
Py_RETURN_NONE;
}

static PyObject* Cursor_fetch(Cursor* cur)
static PyObject* Cursor_fetch(Cursor* cur, Py_ssize_t max)
{
// Internal function to fetch a single row and construct a Row object from it. Used by all of the fetching
// functions.
Expand All @@ -1387,15 +1375,33 @@ static PyObject* Cursor_fetch(Cursor* cur)
// exception is set and zero is returned. (To differentiate between the last two, use PyErr_Occurred.)

SQLRETURN ret = 0;
SQLRETURN ret_attr = 0;
Py_ssize_t field_count, i;
PyObject** apValues;

// One fetch per cycle.
if (cur->current_row == 0) {
Py_BEGIN_ALLOW_THREADS
ret = SQLFetch(cur->hstmt);
// Make sure no more rows are fetched than are requested by fetchone/fetchmany.
// Otherwise rows might get lost if buffers need to be rebound between fetches.
long fetch_buffer_length_used;
if (max >= 0 && (long)max < cur->fetch_buffer_length) {
fetch_buffer_length_used = (long)max;
} else {
fetch_buffer_length_used = cur->fetch_buffer_length;
}
if (cur->fetch_buffer_length_used != fetch_buffer_length_used) {
cur->fetch_buffer_length_used = fetch_buffer_length_used;
ret_attr = SQLSetStmtAttr(cur->hstmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)cur->fetch_buffer_length_used, 0);
}
if (SQL_SUCCEEDED(ret_attr))
ret = SQLFetch(cur->hstmt);
Py_END_ALLOW_THREADS

if (!SQL_SUCCEEDED(ret_attr)) {
return RaiseErrorFromHandle(cur->cnxn, "SQLSetStmtAttr", cur->cnxn->hdbc, cur->hstmt);
}

if (cur->cnxn->hdbc == SQL_NULL_HANDLE)
{
// The connection was closed by another thread in the ALLOW_THREADS block above.
Expand All @@ -1408,7 +1414,7 @@ static PyObject* Cursor_fetch(Cursor* cur)
if (!SQL_SUCCEEDED(ret))
return RaiseErrorFromHandle(cur->cnxn, "SQLFetch", cur->cnxn->hdbc, cur->hstmt);
} else {
if (cur->current_row >= cur->rows_fetched) {
if (cur->current_row >= (long)cur->rows_fetched) {
return 0;
}
}
Expand All @@ -1432,7 +1438,7 @@ static PyObject* Cursor_fetch(Cursor* cur)

apValues[i] = value;
}
cur->current_row = (cur->current_row + 1) % cur->fetch_buffer_length;
cur->current_row = (cur->current_row + 1) % cur->fetch_buffer_length_used;

return (PyObject*)Row_InternalNew(cur->description, cur->map_name_to_index, field_count, apValues);
}
Expand All @@ -1454,7 +1460,7 @@ static PyObject* Cursor_fetchlist(Cursor* cur, Py_ssize_t max)

while (max == -1 || max > 0)
{
row = Cursor_fetch(cur);
row = Cursor_fetch(cur, max);

if (!row)
{
Expand Down Expand Up @@ -1496,7 +1502,7 @@ static PyObject* Cursor_iternext(PyObject* self)
if (!cursor || !PrepareFetch(cursor, 1))
return 0;

result = Cursor_fetch(cursor);
result = Cursor_fetch(cursor, 1);

return result;
}
Expand All @@ -1509,7 +1515,7 @@ static PyObject* Cursor_fetchval(PyObject* self, PyObject* args)
if (!cursor || !PrepareFetch(cursor, 1))
return 0;

Object row(Cursor_fetch(cursor));
Object row(Cursor_fetch(cursor, 1));

if (!row)
{
Expand All @@ -1530,7 +1536,7 @@ static PyObject* Cursor_fetchone(PyObject* self, PyObject* args)
if (!cursor || !PrepareFetch(cursor, 1))
return 0;

row = Cursor_fetch(cursor);
row = Cursor_fetch(cursor, 1);

if (!row)
{
Expand Down Expand Up @@ -2472,8 +2478,8 @@ static PyMemberDef Cursor_members[] =
{"connection", T_OBJECT_EX, offsetof(Cursor, cnxn), READONLY, connection_doc },
{"fast_executemany",T_BOOL, offsetof(Cursor, fastexecmany), 0, fastexecmany_doc },
{"messages", T_OBJECT_EX, offsetof(Cursor, messages), READONLY, messages_doc },
{"bound_columns_count", T_UINT, offsetof(Cursor, bound_columns_count), READONLY, bound_columns_count_doc },
{"bound_buffer_rows", T_ULONG, offsetof(Cursor, fetch_buffer_length), READONLY, bound_buffer_rows_doc },
{"bound_columns_count", T_INT, offsetof(Cursor, bound_columns_count), READONLY, bound_columns_count_doc },
{"bound_buffer_rows", T_LONG, offsetof(Cursor, fetch_buffer_length), READONLY, bound_buffer_rows_doc },
{"bind_cell_cap", T_LONG, offsetof(Cursor, bind_cell_cap), 0, bind_cell_cap_doc },
{"bind_byte_cap", T_LONG, offsetof(Cursor, bind_byte_cap), 0, bind_byte_cap_doc },
{ 0 }
Expand Down Expand Up @@ -2764,9 +2770,12 @@ Cursor_New(Connection* cnxn)
cur->fetch_buffer = 0;
cur->bound_converted_types = 0;
cur->bound_native_uuid = 0;
cur->ctype_of_char_enc = SQL_C_CHAR;
cur->ctype_of_wchar_enc = SQL_C_WCHAR;
cur->bind_cell_cap = 10;
cur->bind_byte_cap = 20 * 1024 * 1024;
cur->bound_columns_count = 0;
cur->bound_columns_count = -1; // should indicate that it's not initialized yet
cur->fetch_buffer_length = -1;

Py_INCREF(cnxn);
Py_INCREF(cur->description);
Expand Down
10 changes: 8 additions & 2 deletions src/cursor.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ struct ColumnInfo
bool is_bound;
bool can_bind;
bool always_alloc;
// No need to do refcounting no the converter, since at least cur->bound_converted_types will have one.
PyObject* converter;
TextEnc* enc;
};
Expand Down Expand Up @@ -171,15 +172,20 @@ struct Cursor
void* fetch_buffer;
long fetch_buffer_width;
long fetch_buffer_length;
long rows_fetched;
long fetch_buffer_length_used;
SQLULEN rows_fetched;
SQLUSMALLINT* row_status_array;
long current_row;

// Track the configuration at the time of using SQLBindCol.
bool bound_native_uuid;
PyObject* bound_converted_types;
// Only track the ctype of cur->cnxn->sql(w)char_enc. Changing any other attribute of the encoding
// would not change the binding process.
SQLSMALLINT ctype_of_wchar_enc;
SQLSMALLINT ctype_of_char_enc;

unsigned int bound_columns_count;
int bound_columns_count;
long bind_cell_cap;
long bind_byte_cap;
};
Expand Down
66 changes: 43 additions & 23 deletions tests/sqlserver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1258,36 +1258,56 @@ def convert(value):
return value

cnxn = connect()
cursor = cnxn.cursor()

uidstr = 'CB4BB7F2-3AD9-4ED7-ABB8-7C704D75335C'
uid = uuid.UUID(uidstr)
uidbytes = b'\xf2\xb7K\xcb\xd9:\xd7N\xab\xb8|pMu3\\'

cursor.execute("drop table if exists t1")
cursor.execute("create table t1(g uniqueidentifier)")
for i in range(4):
cursor.execute(f"insert into t1 values (?)", (uid,))

cursor.execute("select g from t1")

pyodbc.native_uuid = False
v, = cursor.fetchone()
assert v == uidstr

cnxn.add_output_converter(pyodbc.SQL_GUID, convert)
v, = cursor.fetchone()
assert v == uidbytes
cnxn.remove_output_converter(pyodbc.SQL_GUID)
with cnxn:
cursor = cnxn.cursor()
cursor.execute("drop table if exists t1")
cursor.execute("create table t1(g uniqueidentifier)")
for i in range(6):
cursor.execute("insert into t1 values (?)", (uid,))

cursor.execute("select g from t1")

pyodbc.native_uuid = False
v, = cursor.fetchone()
assert v == uidstr

cnxn.add_output_converter(pyodbc.SQL_GUID, convert)
v, = cursor.fetchone()
assert v == uidbytes
cnxn.remove_output_converter(pyodbc.SQL_GUID)

pyodbc.native_uuid = True
v, = cursor.fetchone()
assert v == uid

pyodbc.native_uuid = False
v, = cursor.fetchone()
assert v == uidstr

cnxn.setdecoding(pyodbc.SQL_CHAR, encoding='utf-16-le')
v, = cursor.fetchone() # fetches into SQL_C_WCHAR buffer
assert v == uidstr
cnxn.setdecoding(pyodbc.SQL_CHAR, encoding='cp1252')
v, = cursor.fetchone() # fetches into SQL_C_CHAR buffer
assert v == uidstr

cursor.close()
cursor = cnxn.cursor()
cursor.execute("select 1 union select 2 union select 3 union select 4")

pyodbc.native_uuid = True
v, = cursor.fetchone()
assert v == uid
cursor.fetchmany(2) # should fetch 2 rows per SQLFetch call with rows left over
v, = cursor.fetchone() # even though 2 rows are allocated, only one should be used so we can rebind witout discarding data
assert v == 3
pyodbc.native_uuid = not pyodbc.native_uuid # force rebind
v, = cursor.fetchone()
assert v == 4

pyodbc.native_uuid = False
v, = cursor.fetchone()
assert v == uidstr
pyodbc.native_uuid = True
pyodbc.native_uuid = True


def test_too_large(cursor: pyodbc.Cursor):
Expand Down

0 comments on commit b2be6a9

Please sign in to comment.