diff --git a/src/cursor.cpp b/src/cursor.cpp index 9a8f9916..e123f289 100644 --- a/src/cursor.cpp +++ b/src/cursor.cpp @@ -371,15 +371,6 @@ static bool BindCols(Cursor* cur, int cCols, int fetch_rows_cap) } } - if (cur->bind_byte_cap < 0) { - PyErr_SetString(ProgrammingError, "Cursor attribute bind_byte_cap must be non negative."); - return 0; - } - if (cur->bind_cell_cap < 0) { - PyErr_SetString(ProgrammingError, "Cursor attribute bind_cell_cap must be non negative."); - return 0; - } - // number of rows to be fetched at a time cur->fetch_buffer_width = total_buf_size; if (bind_all) { @@ -396,14 +387,15 @@ static bool BindCols(Cursor* cur, int cCols, int fetch_rows_cap) } else { cur->fetch_buffer_length = 1; } + cur->fetch_buffer_length_used = 0; // single large buffer using row-wise layout with row status array at the end - void* buf = PyMem_Malloc((cur->fetch_buffer_width + sizeof(SQLUSMALLINT)) * cur->fetch_buffer_length); - if (!buf) { + cur->fetch_buffer = PyMem_Malloc((cur->fetch_buffer_width + sizeof(SQLUSMALLINT)) * cur->fetch_buffer_length); + if (!cur->fetch_buffer) { PyErr_NoMemory(); return false; } - cur->row_status_array = (SQLUSMALLINT*)((uintptr_t)buf + cur->fetch_buffer_width * cur->fetch_buffer_length); + cur->row_status_array = (SQLUSMALLINT*)((uintptr_t)cur->fetch_buffer + cur->fetch_buffer_width * cur->fetch_buffer_length); cur->current_row = 0; SQLRETURN ret_bind = SQL_SUCCESS, ret_attr; @@ -417,10 +409,6 @@ static bool BindCols(Cursor* cur, int cCols, int fetch_rows_cap) if (!SQL_SUCCEEDED(ret_attr)) { goto skip; } - ret_attr = SQLSetStmtAttr(cur->hstmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)cur->fetch_buffer_length, 0); - if (!SQL_SUCCEEDED(ret_attr)) { - goto skip; - } // TODO use for error checking ret_attr = SQLSetStmtAttr(cur->hstmt, SQL_ATTR_ROW_STATUS_PTR, cur->row_status_array, 0); if (!SQL_SUCCEEDED(ret_attr)) { @@ -440,9 +428,9 @@ static bool BindCols(Cursor* cur, int cCols, int fetch_rows_cap) cur->hstmt, (SQLUSMALLINT)(iCol+1), cInfo->c_type, - (void*)((uintptr_t)buf + cInfo->buf_offset), + (void*)((uintptr_t)cur->fetch_buffer + cInfo->buf_offset), cInfo->buf_size, - (SQLLEN*)((uintptr_t)buf + cInfo->buf_offset - sizeof(SQLLEN)) + (SQLLEN*)((uintptr_t)cur->fetch_buffer + cInfo->buf_offset - sizeof(SQLLEN)) ); if (!SQL_SUCCEEDED(ret_bind)) { break; @@ -461,13 +449,9 @@ static bool BindCols(Cursor* cur, int cCols, int fetch_rows_cap) } if (!SQL_SUCCEEDED(ret_bind)) { - PyMem_Free(buf); return RaiseErrorFromHandle(cur->cnxn, "SQLBindCol", cur->cnxn->hdbc, cur->hstmt); } - assert(cur->fetch_buffer == 0); - cur->fetch_buffer = buf; - return true; } @@ -477,29 +461,21 @@ static bool PrepareFetch(Cursor* cur, int n_rows) // Returns false on exception, true otherwise. // Need to do this because the API allows changing this after executing a statement. - PyObject* converted_types = 0; bool native_uuid = UseNativeUUID(); bool converted_types_changed = false; if (cur->cnxn->map_sqltype_to_converter) { - converted_types = PyDict_Copy(cur->cnxn->map_sqltype_to_converter); - if (!converted_types) - { - return false; - } if (cur->bound_converted_types) { - switch (PyObject_RichCompareBool(cur->bound_converted_types, converted_types, Py_EQ)) + switch (PyObject_RichCompareBool(cur->bound_converted_types, cur->cnxn->map_sqltype_to_converter, Py_EQ)) { case -1: // error - Py_DECREF(converted_types); return false; case 0: // not equal converted_types_changed = true; break; case 1: // equal - Py_DECREF(converted_types); break; } } @@ -513,11 +489,22 @@ static bool PrepareFetch(Cursor* cur, int n_rows) converted_types_changed = true; } - if (cur->bound_native_uuid != native_uuid || converted_types_changed || !cur->fetch_buffer) + if ( + !cur->fetch_buffer || cur->bound_native_uuid != native_uuid || converted_types_changed || + cur->ctype_of_char_enc != cur->cnxn->sqlchar_enc.ctype || cur->ctype_of_wchar_enc != cur->cnxn->sqlwchar_enc.ctype + ) { - Py_XDECREF(cur->bound_converted_types); + Py_CLEAR(cur->bound_converted_types); + PyObject* converted_types = 0; + if (cur->cnxn->map_sqltype_to_converter) { + converted_types = PyDict_Copy(cur->cnxn->map_sqltype_to_converter); + if (!converted_types) + return false; + } cur->bound_converted_types = converted_types; cur->bound_native_uuid = native_uuid; + cur->ctype_of_char_enc = cur->cnxn->sqlchar_enc.ctype; + cur->ctype_of_wchar_enc = cur->cnxn->sqlwchar_enc.ctype; if (cur->description != Py_None) { @@ -525,6 +512,7 @@ static bool PrepareFetch(Cursor* cur, int n_rows) BindColsFree(cur); if (!BindCols(cur, cCols, n_rows)) { + BindColsFree(cur); return false; } } @@ -552,7 +540,7 @@ static bool free_results(Cursor* self, int flags) } BindColsFree(self); - Py_XDECREF(self->bound_converted_types); + Py_CLEAR(self->bound_converted_types); if (self->colinfos) { @@ -1378,7 +1366,7 @@ static PyObject* Cursor_setinputsizes(PyObject* self, PyObject* sizes) Py_RETURN_NONE; } -static PyObject* Cursor_fetch(Cursor* cur) +static PyObject* Cursor_fetch(Cursor* cur, Py_ssize_t max) { // Internal function to fetch a single row and construct a Row object from it. Used by all of the fetching // functions. @@ -1387,15 +1375,33 @@ static PyObject* Cursor_fetch(Cursor* cur) // exception is set and zero is returned. (To differentiate between the last two, use PyErr_Occurred.) SQLRETURN ret = 0; + SQLRETURN ret_attr = 0; Py_ssize_t field_count, i; PyObject** apValues; // One fetch per cycle. if (cur->current_row == 0) { Py_BEGIN_ALLOW_THREADS - ret = SQLFetch(cur->hstmt); + // Make sure no more rows are fetched than are requested by fetchone/fetchmany. + // Otherwise rows might get lost if buffers need to be rebound between fetches. + long fetch_buffer_length_used; + if (max >= 0 && (long)max < cur->fetch_buffer_length) { + fetch_buffer_length_used = (long)max; + } else { + fetch_buffer_length_used = cur->fetch_buffer_length; + } + if (cur->fetch_buffer_length_used != fetch_buffer_length_used) { + cur->fetch_buffer_length_used = fetch_buffer_length_used; + ret_attr = SQLSetStmtAttr(cur->hstmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)cur->fetch_buffer_length_used, 0); + } + if (SQL_SUCCEEDED(ret_attr)) + ret = SQLFetch(cur->hstmt); Py_END_ALLOW_THREADS + if (!SQL_SUCCEEDED(ret_attr)) { + return RaiseErrorFromHandle(cur->cnxn, "SQLSetStmtAttr", cur->cnxn->hdbc, cur->hstmt); + } + if (cur->cnxn->hdbc == SQL_NULL_HANDLE) { // The connection was closed by another thread in the ALLOW_THREADS block above. @@ -1408,7 +1414,7 @@ static PyObject* Cursor_fetch(Cursor* cur) if (!SQL_SUCCEEDED(ret)) return RaiseErrorFromHandle(cur->cnxn, "SQLFetch", cur->cnxn->hdbc, cur->hstmt); } else { - if (cur->current_row >= cur->rows_fetched) { + if (cur->current_row >= (long)cur->rows_fetched) { return 0; } } @@ -1432,7 +1438,7 @@ static PyObject* Cursor_fetch(Cursor* cur) apValues[i] = value; } - cur->current_row = (cur->current_row + 1) % cur->fetch_buffer_length; + cur->current_row = (cur->current_row + 1) % cur->fetch_buffer_length_used; return (PyObject*)Row_InternalNew(cur->description, cur->map_name_to_index, field_count, apValues); } @@ -1454,7 +1460,7 @@ static PyObject* Cursor_fetchlist(Cursor* cur, Py_ssize_t max) while (max == -1 || max > 0) { - row = Cursor_fetch(cur); + row = Cursor_fetch(cur, max); if (!row) { @@ -1496,7 +1502,7 @@ static PyObject* Cursor_iternext(PyObject* self) if (!cursor || !PrepareFetch(cursor, 1)) return 0; - result = Cursor_fetch(cursor); + result = Cursor_fetch(cursor, 1); return result; } @@ -1509,7 +1515,7 @@ static PyObject* Cursor_fetchval(PyObject* self, PyObject* args) if (!cursor || !PrepareFetch(cursor, 1)) return 0; - Object row(Cursor_fetch(cursor)); + Object row(Cursor_fetch(cursor, 1)); if (!row) { @@ -1530,7 +1536,7 @@ static PyObject* Cursor_fetchone(PyObject* self, PyObject* args) if (!cursor || !PrepareFetch(cursor, 1)) return 0; - row = Cursor_fetch(cursor); + row = Cursor_fetch(cursor, 1); if (!row) { @@ -2472,8 +2478,8 @@ static PyMemberDef Cursor_members[] = {"connection", T_OBJECT_EX, offsetof(Cursor, cnxn), READONLY, connection_doc }, {"fast_executemany",T_BOOL, offsetof(Cursor, fastexecmany), 0, fastexecmany_doc }, {"messages", T_OBJECT_EX, offsetof(Cursor, messages), READONLY, messages_doc }, - {"bound_columns_count", T_UINT, offsetof(Cursor, bound_columns_count), READONLY, bound_columns_count_doc }, - {"bound_buffer_rows", T_ULONG, offsetof(Cursor, fetch_buffer_length), READONLY, bound_buffer_rows_doc }, + {"bound_columns_count", T_INT, offsetof(Cursor, bound_columns_count), READONLY, bound_columns_count_doc }, + {"bound_buffer_rows", T_LONG, offsetof(Cursor, fetch_buffer_length), READONLY, bound_buffer_rows_doc }, {"bind_cell_cap", T_LONG, offsetof(Cursor, bind_cell_cap), 0, bind_cell_cap_doc }, {"bind_byte_cap", T_LONG, offsetof(Cursor, bind_byte_cap), 0, bind_byte_cap_doc }, { 0 } @@ -2764,9 +2770,12 @@ Cursor_New(Connection* cnxn) cur->fetch_buffer = 0; cur->bound_converted_types = 0; cur->bound_native_uuid = 0; + cur->ctype_of_char_enc = SQL_C_CHAR; + cur->ctype_of_wchar_enc = SQL_C_WCHAR; cur->bind_cell_cap = 10; cur->bind_byte_cap = 20 * 1024 * 1024; - cur->bound_columns_count = 0; + cur->bound_columns_count = -1; // should indicate that it's not initialized yet + cur->fetch_buffer_length = -1; Py_INCREF(cnxn); Py_INCREF(cur->description); diff --git a/src/cursor.h b/src/cursor.h index 3762cb61..9050225c 100644 --- a/src/cursor.h +++ b/src/cursor.h @@ -40,6 +40,7 @@ struct ColumnInfo bool is_bound; bool can_bind; bool always_alloc; + // No need to do refcounting no the converter, since at least cur->bound_converted_types will have one. PyObject* converter; TextEnc* enc; }; @@ -171,15 +172,20 @@ struct Cursor void* fetch_buffer; long fetch_buffer_width; long fetch_buffer_length; - long rows_fetched; + long fetch_buffer_length_used; + SQLULEN rows_fetched; SQLUSMALLINT* row_status_array; long current_row; // Track the configuration at the time of using SQLBindCol. bool bound_native_uuid; PyObject* bound_converted_types; + // Only track the ctype of cur->cnxn->sql(w)char_enc. Changing any other attribute of the encoding + // would not change the binding process. + SQLSMALLINT ctype_of_wchar_enc; + SQLSMALLINT ctype_of_char_enc; - unsigned int bound_columns_count; + int bound_columns_count; long bind_cell_cap; long bind_byte_cap; }; diff --git a/tests/sqlserver_test.py b/tests/sqlserver_test.py index 991f4fcb..e75cf683 100755 --- a/tests/sqlserver_test.py +++ b/tests/sqlserver_test.py @@ -1258,36 +1258,56 @@ def convert(value): return value cnxn = connect() - cursor = cnxn.cursor() uidstr = 'CB4BB7F2-3AD9-4ED7-ABB8-7C704D75335C' uid = uuid.UUID(uidstr) uidbytes = b'\xf2\xb7K\xcb\xd9:\xd7N\xab\xb8|pMu3\\' - cursor.execute("drop table if exists t1") - cursor.execute("create table t1(g uniqueidentifier)") - for i in range(4): - cursor.execute(f"insert into t1 values (?)", (uid,)) - - cursor.execute("select g from t1") - - pyodbc.native_uuid = False - v, = cursor.fetchone() - assert v == uidstr - - cnxn.add_output_converter(pyodbc.SQL_GUID, convert) - v, = cursor.fetchone() - assert v == uidbytes - cnxn.remove_output_converter(pyodbc.SQL_GUID) + with cnxn: + cursor = cnxn.cursor() + cursor.execute("drop table if exists t1") + cursor.execute("create table t1(g uniqueidentifier)") + for i in range(6): + cursor.execute("insert into t1 values (?)", (uid,)) + + cursor.execute("select g from t1") + + pyodbc.native_uuid = False + v, = cursor.fetchone() + assert v == uidstr + + cnxn.add_output_converter(pyodbc.SQL_GUID, convert) + v, = cursor.fetchone() + assert v == uidbytes + cnxn.remove_output_converter(pyodbc.SQL_GUID) + + pyodbc.native_uuid = True + v, = cursor.fetchone() + assert v == uid + + pyodbc.native_uuid = False + v, = cursor.fetchone() + assert v == uidstr + + cnxn.setdecoding(pyodbc.SQL_CHAR, encoding='utf-16-le') + v, = cursor.fetchone() # fetches into SQL_C_WCHAR buffer + assert v == uidstr + cnxn.setdecoding(pyodbc.SQL_CHAR, encoding='cp1252') + v, = cursor.fetchone() # fetches into SQL_C_CHAR buffer + assert v == uidstr + + cursor.close() + cursor = cnxn.cursor() + cursor.execute("select 1 union select 2 union select 3 union select 4") - pyodbc.native_uuid = True - v, = cursor.fetchone() - assert v == uid + cursor.fetchmany(2) # should fetch 2 rows per SQLFetch call with rows left over + v, = cursor.fetchone() # even though 2 rows are allocated, only one should be used so we can rebind witout discarding data + assert v == 3 + pyodbc.native_uuid = not pyodbc.native_uuid # force rebind + v, = cursor.fetchone() + assert v == 4 - pyodbc.native_uuid = False - v, = cursor.fetchone() - assert v == uidstr - pyodbc.native_uuid = True + pyodbc.native_uuid = True def test_too_large(cursor: pyodbc.Cursor):