Skip to content

Commit

Permalink
Also use OnceCell for LazyStaticType implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Jun 15, 2020
1 parent 0b8b19c commit a29702e
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 55 deletions.
9 changes: 7 additions & 2 deletions guide/src/class.md
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,11 @@ impl MyClass {
}
```

Note that defining a class attribute of the same type as the class will make the class unusable
attempting to use the class will cause a panic reading `Recursive evaluation of type_object`.
If having the attribute on instances is acceptable, create a `#[getter]`, which can also use a
`OnceCell` for caching.

## Callable objects

To specify a custom `__call__` method for a custom class, the method needs to be annotated with
Expand Down Expand Up @@ -921,10 +926,10 @@ unsafe impl pyo3::PyTypeInfo for MyClass {
const FLAGS: usize = 0;

#[inline]
fn type_object() -> &'static pyo3::ffi::PyTypeObject {
fn type_object_raw(py: pyo3::Python) -> &'static pyo3::ffi::PyTypeObject {
use pyo3::type_object::LazyStaticType;
static TYPE_OBJECT: LazyStaticType = LazyStaticType::new();
TYPE_OBJECT.get_or_init::<Self>()
TYPE_OBJECT.get_or_init::<Self>(py)
}
}

Expand Down
4 changes: 2 additions & 2 deletions pyo3-derive-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,10 +387,10 @@ fn impl_class(
const FLAGS: usize = #(#flags)|* | #extended;

#[inline]
fn type_object() -> &'static pyo3::ffi::PyTypeObject {
fn type_object_raw(py: pyo3::Python) -> &'static pyo3::ffi::PyTypeObject {
use pyo3::type_object::LazyStaticType;
static TYPE_OBJECT: LazyStaticType = LazyStaticType::new();
TYPE_OBJECT.get_or_init::<Self>()
TYPE_OBJECT.get_or_init::<Self>(py)
}
}

Expand Down
8 changes: 4 additions & 4 deletions src/freelist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,12 @@ impl<T> PyClassAlloc for T
where
T: PyTypeInfo + PyClassWithFreeList,
{
unsafe fn alloc(_py: Python) -> *mut Self::Layout {
unsafe fn alloc(py: Python) -> *mut Self::Layout {
if let Some(obj) = <Self as PyClassWithFreeList>::get_free_list().pop() {
ffi::PyObject_Init(obj, <Self as PyTypeInfo>::type_object() as *const _ as _);
ffi::PyObject_Init(obj, Self::type_object_raw(py) as *const _ as _);
obj as _
} else {
crate::pyclass::default_alloc::<Self>() as _
crate::pyclass::default_alloc::<Self>(py) as _
}
}

Expand All @@ -90,7 +90,7 @@ where
}

if let Some(obj) = <Self as PyClassWithFreeList>::get_free_list().insert(obj) {
match Self::type_object().tp_free {
match Self::type_object_raw(py).tp_free {
Some(free) => free(obj as *mut c_void),
None => tp_free_fallback(obj),
}
Expand Down
14 changes: 7 additions & 7 deletions src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ use std::os::raw::c_void;
use std::ptr;

#[inline]
pub(crate) unsafe fn default_alloc<T: PyTypeInfo>() -> *mut ffi::PyObject {
let type_obj = T::type_object();
pub(crate) unsafe fn default_alloc<T: PyTypeInfo>(py: Python) -> *mut ffi::PyObject {
let type_obj = T::type_object_raw(py);
// if the class derives native types(e.g., PyDict), call special new
if T::FLAGS & type_flags::EXTENDED != 0 && T::BaseLayout::IS_NATIVE_TYPE {
let base_tp = <T::BaseType as PyTypeInfo>::type_object();
let base_tp = T::BaseType::type_object_raw(py);
if let Some(base_new) = base_tp.tp_new {
return base_new(type_obj as *const _ as _, ptr::null_mut(), ptr::null_mut());
}
Expand All @@ -29,8 +29,8 @@ pub trait PyClassAlloc: PyTypeInfo + Sized {
///
/// # Safety
/// This function must return a valid pointer to the Python heap.
unsafe fn alloc(_py: Python) -> *mut Self::Layout {
default_alloc::<Self>() as _
unsafe fn alloc(py: Python) -> *mut Self::Layout {
default_alloc::<Self>(py) as _
}

/// Deallocate `#[pyclass]` on the Python heap.
Expand All @@ -44,7 +44,7 @@ pub trait PyClassAlloc: PyTypeInfo + Sized {
return;
}

match Self::type_object().tp_free {
match Self::type_object_raw(py).tp_free {
Some(free) => free(obj as *mut c_void),
None => tp_free_fallback(obj),
}
Expand Down Expand Up @@ -101,7 +101,7 @@ where
s => CString::new(s)?.into_raw(),
};

type_object.tp_base = <T::BaseType as PyTypeInfo>::type_object() as *const _ as _;
type_object.tp_base = <T::BaseType as PyTypeInfo>::type_object_raw(py) as *const _ as _;

type_object.tp_name = match module_name {
Some(module_name) => CString::new(format!("{}.{}", module_name, T::NAME))?.into_raw(),
Expand Down
65 changes: 44 additions & 21 deletions src/type_object.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
// Copyright (c) 2017-present PyO3 Project and Contributors
//! Python type object information
use crate::once_cell::OnceCell;
use crate::pyclass::{initialize_type_object, PyClass};
use crate::pyclass_init::PyObjectInit;
use crate::types::{PyAny, PyType};
use crate::{ffi, AsPyPointer, Python};
use std::cell::UnsafeCell;
use std::sync::atomic::{AtomicBool, Ordering};
use crate::{ffi, AsPyPointer, Python, PyNativeType};
use parking_lot::{Mutex, const_mutex};
use std::collections::HashSet;
use std::thread::{self, ThreadId};

/// `T: PyLayout<U>` represents that `T` is a concrete representaion of `U` in Python heap.
/// E.g., `PyCell` is a concrete representaion of all `pyclass`es, and `ffi::PyObject`
Expand Down Expand Up @@ -100,18 +102,18 @@ pub unsafe trait PyTypeInfo: Sized {
type AsRefTarget: crate::PyNativeType;

/// PyTypeObject instance for this type.
fn type_object() -> &'static ffi::PyTypeObject;
fn type_object_raw(py: Python) -> &'static ffi::PyTypeObject;

/// Check if `*mut ffi::PyObject` is instance of this type
fn is_instance(object: &PyAny) -> bool {
unsafe {
ffi::PyObject_TypeCheck(object.as_ptr(), Self::type_object() as *const _ as _) != 0
ffi::PyObject_TypeCheck(object.as_ptr(), Self::type_object(object.py()) as *const _ as _) != 0
}
}

/// Check if `*mut ffi::PyObject` is exact instance of this type
fn is_exact_instance(object: &PyAny) -> bool {
unsafe { (*object.as_ptr()).ob_type == Self::type_object() as *const _ as _ }
unsafe { (*object.as_ptr()).ob_type == Self::type_object(object.py()) as *const _ as _ }
}
}

Expand All @@ -131,39 +133,60 @@ where
T: PyTypeInfo,
{
fn type_object(py: Python) -> &PyType {
unsafe { py.from_borrowed_ptr(<Self as PyTypeInfo>::type_object() as *const _ as _) }
unsafe { py.from_borrowed_ptr(<Self as PyTypeInfo>::type_object_raw(py) as *const _ as _) }
}
}

/// Lazy type object for PyClass
#[doc(hidden)]
pub struct LazyStaticType {
value: UnsafeCell<ffi::PyTypeObject>,
initialized: AtomicBool,
// Boxed because Python expects the type object to have a stable address.
value: OnceCell<Box<ffi::PyTypeObject>>,
// Threads which have begun initialization. Used for reentrant initialization detection.
initializing_threads: Mutex<Option<HashSet<ThreadId>>>
}

impl LazyStaticType {
pub const fn new() -> Self {
LazyStaticType {
value: UnsafeCell::new(ffi::PyTypeObject_INIT),
initialized: AtomicBool::new(false),
value: OnceCell::new(),
initializing_threads: const_mutex(None)
}
}

pub fn get_or_init<T: PyClass>(&self) -> &ffi::PyTypeObject {
if !self
.initialized
.compare_and_swap(false, true, Ordering::Acquire)
{
let gil = Python::acquire_gil();
let py = gil.python();
initialize_type_object::<T>(py, T::MODULE, unsafe { &mut *self.value.get() })
pub fn get_or_init<T: PyClass>(&self, py: Python) -> &ffi::PyTypeObject {
self.value.get_or_init(py, || {
{
// Code evaluated at class init time, such as class attributes, might lead to
// recursive initalization of the type object if the class attribute is the same
// type as the class.
//
// That could lead to all sorts of unsafety such as using incomplete type objects
// to initialize class instances, so recursive initialization is prevented.
let thread_not_already_initializing = self.initializing_threads.lock()
.get_or_insert_with(HashSet::new)
.insert(thread::current().id());

if !thread_not_already_initializing {
panic!("Recursive initialization of type_object for {}", T::NAME);
}
}

// Okay, not recursive initialization - can proceed safely.
let mut type_object = Box::new(ffi::PyTypeObject_INIT);

initialize_type_object::<T>(py, T::MODULE, type_object.as_mut())
.unwrap_or_else(|e| {
e.print(py);
panic!("An error occurred while initializing class {}", T::NAME)
});
}
unsafe { &*self.value.get() }

// Initialization successfully complete, can clear the thread list.
// (No futher calls to get_or_init() will try to init, on any thread.)
*self.initializing_threads.lock() = None;

type_object
}).as_ref()
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ macro_rules! pyobject_native_type_convert(
const MODULE: Option<&'static str> = $module;

#[inline]
fn type_object() -> &'static $crate::ffi::PyTypeObject {
fn type_object_raw(_py: Python) -> &'static $crate::ffi::PyTypeObject {
unsafe{ &$typeobject }
}

Expand Down
33 changes: 15 additions & 18 deletions tests/test_class_attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,12 @@ impl Foo {
"bar".to_string()
}

#[classattr]
fn foo() -> Foo {
Foo { x: 1 }
}

#[classattr]
fn bar() -> Bar {
Bar { x: 2 }
}
}

#[pymethods]
impl Bar {
#[classattr]
fn foo() -> Foo {
Foo { x: 3 }
}
}

#[test]
fn class_attributes() {
let gil = Python::acquire_gil();
Expand All @@ -67,13 +54,23 @@ fn class_attributes_are_immutable() {
py_expect_exception!(py, foo_obj, "foo_obj.a = 6", TypeError);
}

#[pyclass]
struct SelfClassAttribute {
#[pyo3(get)]
x: i32
}

#[pymethods]
impl SelfClassAttribute {
#[classattr]
const SELF: SelfClassAttribute = SelfClassAttribute { x: 1 };
}

#[test]
#[should_panic(expected = "Recursive initialization of type_object for SelfClassAttribute")]
fn recursive_class_attributes() {
let gil = Python::acquire_gil();
let py = gil.python();
let foo_obj = py.get_type::<Foo>();
let bar_obj = py.get_type::<Bar>();
py_assert!(py, foo_obj, "foo_obj.foo.x == 1");
py_assert!(py, foo_obj, "foo_obj.bar.x == 2");
py_assert!(py, bar_obj, "bar_obj.foo.x == 3");

py.get_type::<SelfClassAttribute>();
}

0 comments on commit a29702e

Please sign in to comment.