Skip to content

Commit

Permalink
Only reuse the pointer object when it matches the _type_ of the con…
Browse files Browse the repository at this point in the history
…tainer

Closes python#107940. Also, solves a related yet undiscovered issue where an array of
pointers reuses the array's memory for the pointer objects.
  • Loading branch information
ambv committed Aug 24, 2023
1 parent a071ecb commit 16d39d2
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 31 deletions.
71 changes: 68 additions & 3 deletions Lib/test/test_ctypes/test_cast.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import sys
import unittest
from ctypes import (Structure, Union, POINTER, cast, sizeof, addressof,
c_void_p, c_char_p, c_wchar_p,
c_byte, c_short, c_int)
from ctypes import (Structure, Union, pointer, POINTER, sizeof, addressof,
c_void_p, c_char_p, c_wchar_p, cast,
c_byte, c_short, c_int, c_int16)


class Test(unittest.TestCase):
Expand Down Expand Up @@ -95,6 +95,71 @@ class MyUnion(Union):
_fields_ = [("a", c_int)]
self.assertRaises(TypeError, cast, array, MyUnion)

def test_pointer_identity(self):
class Struct(Structure):
_fields_ = [('a', c_int16)]
Struct3 = 3 * Struct
c_array = (2 * Struct3)(
Struct3(Struct(a=1), Struct(a=2), Struct(a=3)),
Struct3(Struct(a=4), Struct(a=5), Struct(a=6))
)
self.assertEqual(c_array[0][0].a, 1)
self.assertEqual(c_array[0][1].a, 2)
self.assertEqual(c_array[0][2].a, 3)
self.assertEqual(c_array[1][0].a, 4)
self.assertEqual(c_array[1][1].a, 5)
self.assertEqual(c_array[1][2].a, 6)
p_obj = cast(pointer(c_array), POINTER(pointer(c_array)._type_))
obj = p_obj.contents
self.assertEqual(obj[0][0].a, 1)
self.assertEqual(obj[0][1].a, 2)
self.assertEqual(obj[0][2].a, 3)
self.assertEqual(obj[1][0].a, 4)
self.assertEqual(obj[1][1].a, 5)
self.assertEqual(obj[1][2].a, 6)
p_obj = cast(pointer(c_array[0]), POINTER(pointer(c_array)._type_))
obj = p_obj.contents
self.assertEqual(obj[0][0].a, 1)
self.assertEqual(obj[0][1].a, 2)
self.assertEqual(obj[0][2].a, 3)
self.assertEqual(obj[1][0].a, 4)
self.assertEqual(obj[1][1].a, 5)
self.assertEqual(obj[1][2].a, 6)
StructPointer = POINTER(Struct)
s1 = Struct(a=10)
s2 = Struct(a=20)
s3 = Struct(a=30)
pointer_array = (3 * StructPointer)(pointer(s1), pointer(s2), pointer(s3))
self.assertEqual(pointer_array[0][0].a, 10)
self.assertEqual(pointer_array[1][0].a, 20)
self.assertEqual(pointer_array[2][0].a, 30)
self.assertEqual(pointer_array[0].contents.a, 10)
self.assertEqual(pointer_array[1].contents.a, 20)
self.assertEqual(pointer_array[2].contents.a, 30)
p_obj = cast(pointer(pointer_array[0]), POINTER(pointer(pointer_array)._type_))
obj = p_obj.contents
self.assertEqual(obj[0][0].a, 10)
self.assertEqual(obj[1][0].a, 20)
self.assertEqual(obj[2][0].a, 30)
self.assertEqual(obj[0].contents.a, 10)
self.assertEqual(obj[1].contents.a, 20)
self.assertEqual(obj[2].contents.a, 30)
class StructWithPointers(Structure):
_fields_ = [("s1", POINTER(Struct)), ("s2", POINTER(Struct))]
struct = StructWithPointers(s1=pointer(s1), s2=pointer(s2))
p_obj = pointer(struct)
obj = p_obj.contents
self.assertEqual(obj.s1[0].a, 10)
self.assertEqual(obj.s2[0].a, 20)
self.assertEqual(obj.s1.contents.a, 10)
self.assertEqual(obj.s2.contents.a, 20)
p_obj = cast(pointer(struct), POINTER(pointer(pointer_array)._type_))
obj = p_obj.contents
self.assertEqual(obj[0][0].a, 10)
self.assertEqual(obj[1][0].a, 20)
self.assertEqual(obj[0].contents.a, 10)
self.assertEqual(obj[1].contents.a, 20)


if __name__ == "__main__":
unittest.main()
56 changes: 28 additions & 28 deletions Modules/_ctypes/_ctypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -5139,6 +5139,8 @@ static PyObject *
Pointer_get_contents(CDataObject *self, void *closure)
{
StgDictObject *stgdict;
PyObject *ptr2ptr;
CDataObject *p2p;

if (*(void **)self->b_ptr == NULL) {
PyErr_SetString(PyExc_ValueError,
Expand All @@ -5148,38 +5150,36 @@ Pointer_get_contents(CDataObject *self, void *closure)

stgdict = PyObject_stgdict((PyObject *)self);
assert(stgdict); /* Cannot be NULL for pointer instances */
assert(stgdict->proto);

PyObject *keep = GetKeepedObjects(self);
if (keep != NULL) {
// check if it's a pointer to a pointer:
// pointers will have '0' key in the _objects
int ptr_probe = PyDict_ContainsString(keep, "0");
if (ptr_probe < 0) {
if (self->b_objects != NULL && PyDict_CheckExact(self->b_objects)) {
// Pointer_set_contents uses KeepRef(self, 1, value); we retrieve that
ptr2ptr = PyDict_GetItemString(self->b_objects, "1");
if (ptr2ptr == NULL) {
PyErr_SetString(PyExc_ValueError,
"Unexpected NULL pointer in _objects");
return NULL;
}
if (ptr_probe) {
PyObject *item;
if (PyDict_GetItemStringRef(keep, "1", &item) < 0) {
return NULL;
}
if (item == NULL) {
PyErr_SetString(PyExc_ValueError,
"Unexpected NULL pointer in _objects");
return NULL;
}
#ifndef NDEBUG
CDataObject *ptr2ptr = (CDataObject *)item;
// Don't construct a new object,
// return existing one instead to preserve refcount.
// Double-check that we are returning the same thing.
// if our base pointer is cast from another type,
// its `_type_` proto will be incompatible with the
// type of the object stored in `b_objects["1"]` because
// `_objects` is shared between casts and the original.
int res = PyObject_IsInstance(ptr2ptr, stgdict->proto);
if (res == -1) {
return NULL;
}
if (res) {
// It's not a cast: don't construct a new object,
// return existing one instead to preserve refcount
p2p = (CDataObject*) ptr2ptr;
assert(
*(void**) self->b_ptr == ptr2ptr->b_ptr ||
*(void**) self->b_value.c == ptr2ptr->b_ptr ||
*(void**) self->b_ptr == ptr2ptr->b_value.c ||
*(void**) self->b_value.c == ptr2ptr->b_value.c
);
#endif
return item;
*(void**) self->b_ptr == p2p->b_ptr ||
*(void**) self->b_value.c == p2p->b_ptr ||
*(void**) self->b_ptr == p2p->b_value.c ||
*(void**) self->b_value.c == p2p->b_value.c
); // double-check that we are returning the same thing
Py_INCREF(ptr2ptr);
return ptr2ptr;
}
}

Expand Down

0 comments on commit 16d39d2

Please sign in to comment.