Skip to content

Commit

Permalink
Try #3185:
Browse files Browse the repository at this point in the history
  • Loading branch information
bors[bot] committed Jun 1, 2023
2 parents 451729a + 8571c81 commit a8d18a7
Show file tree
Hide file tree
Showing 3 changed files with 264 additions and 1 deletion.
1 change: 1 addition & 0 deletions newsfragments/3185.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix conversion of classes implementing `__complex__` to `Complex` when using `abi3` or PyPy.
140 changes: 140 additions & 0 deletions src/conversions/num_complex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,18 @@ macro_rules! complex_conversion {

#[cfg(any(Py_LIMITED_API, PyPy))]
unsafe {
let obj = if obj.is_instance_of::<PyComplex>() {
obj
} else if let Some(method) =
obj.lookup_special(crate::intern!(obj.py(), "__complex__"))?
{
method.call0()?
} else {
// `obj` might still implement `__float__` or `__index__`, which will be
// handled by `PyComplex_{Real,Imag}AsDouble`, including propagating any
// errors if those methods don't exist / raise exceptions.
obj
};
let ptr = obj.as_ptr();
let real = ffi::PyComplex_RealAsDouble(ptr);
if real == -1.0 {
Expand All @@ -172,6 +184,7 @@ complex_conversion!(f64);
#[cfg(test)]
mod tests {
use super::*;
use crate::types::PyModule;

#[test]
fn from_complex() {
Expand All @@ -197,4 +210,131 @@ mod tests {
assert!(obj.extract::<Complex<f64>>(py).is_err());
});
}
#[test]
fn from_python_magic() {
Python::with_gil(|py| {
let module = PyModule::from_code(
py,
r#"
class A:
def __complex__(self): return 3.0+1.2j
class B:
def __float__(self): return 3.0
class C:
def __index__(self): return 3
"#,
"test.py",
"test",
)
.unwrap();
let from_complex = module.getattr("A").unwrap().call0().unwrap();
assert_eq!(
from_complex.extract::<Complex<f64>>().unwrap(),
Complex::new(3.0, 1.2)
);
let from_float = module.getattr("B").unwrap().call0().unwrap();
assert_eq!(
from_float.extract::<Complex<f64>>().unwrap(),
Complex::new(3.0, 0.0)
);
// Before Python 3.8, `__index__` wasn't tried by `float`/`complex`.
#[cfg(Py_3_8)]
{
let from_index = module.getattr("C").unwrap().call0().unwrap();
assert_eq!(
from_index.extract::<Complex<f64>>().unwrap(),
Complex::new(3.0, 0.0)
);
}
})
}
#[test]
fn from_python_inherited_magic() {
Python::with_gil(|py| {
let module = PyModule::from_code(
py,
r#"
class First: pass
class ComplexMixin:
def __complex__(self): return 3.0+1.2j
class FloatMixin:
def __float__(self): return 3.0
class IndexMixin:
def __index__(self): return 3
class A(First, ComplexMixin): pass
class B(First, FloatMixin): pass
class C(First, IndexMixin): pass
"#,
"test.py",
"test",
)
.unwrap();
let from_complex = module.getattr("A").unwrap().call0().unwrap();
assert_eq!(
from_complex.extract::<Complex<f64>>().unwrap(),
Complex::new(3.0, 1.2)
);
let from_float = module.getattr("B").unwrap().call0().unwrap();
assert_eq!(
from_float.extract::<Complex<f64>>().unwrap(),
Complex::new(3.0, 0.0)
);
#[cfg(Py_3_8)]
{
let from_index = module.getattr("C").unwrap().call0().unwrap();
assert_eq!(
from_index.extract::<Complex<f64>>().unwrap(),
Complex::new(3.0, 0.0)
);
}
})
}
#[test]
fn from_python_noncallable_descriptor_magic() {
// Functions and lambdas implement the descriptor protocol in a way that makes
// `type(inst).attr(inst)` equivalent to `inst.attr()` for methods, but this isn't the only
// way the descriptor protocol might be implemented.
Python::with_gil(|py| {
let module = PyModule::from_code(
py,
r#"
class A:
@property
def __complex__(self):
return lambda: 3.0+1.2j
"#,
"test.py",
"test",
)
.unwrap();
let obj = module.getattr("A").unwrap().call0().unwrap();
assert_eq!(
obj.extract::<Complex<f64>>().unwrap(),
Complex::new(3.0, 1.2)
);
})
}
#[test]
fn from_python_nondescriptor_magic() {
// Magic methods don't need to implement the descriptor protocol, if they're callable.
Python::with_gil(|py| {
let module = PyModule::from_code(
py,
r#"
class MyComplex:
def __call__(self): return 3.0+1.2j
class A:
__complex__ = MyComplex()
"#,
"test.py",
"test",
)
.unwrap();
let obj = module.getattr("A").unwrap().call0().unwrap();
assert_eq!(
obj.extract::<Complex<f64>>().unwrap(),
Complex::new(3.0, 1.2)
);
})
}
}
124 changes: 123 additions & 1 deletion src/types/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,55 @@ impl PyAny {
}
}

/// Retrieve an attribute value, skipping the instance dictionary during the lookup but still
/// binding the object to the instance.
///
/// This is useful when trying to resolve Python's "magic" methods like `__getitem__`, which
/// are looked up starting from the type object. This returns an `Option` as it is not
/// typically a direct error for the special lookup to fail, as magic methods are optional in
/// many situations in which they might be called.
///
/// To avoid repeated temporary allocations of Python strings, the [`intern!`] macro can be used
/// to intern `attr_name`.
#[allow(dead_code)] // Currently only used with num-complex+abi3, so dead without that.
pub(crate) fn lookup_special<N>(&self, attr_name: N) -> PyResult<Option<&PyAny>>
where
N: IntoPy<Py<PyString>>,
{
let py = self.py();
let self_type = self.get_type();
let attr = if let Ok(attr) = self_type.getattr(attr_name) {
attr
} else {
return Ok(None);
};

// Manually resolve descriptor protocol.
unsafe {
if cfg!(Py_3_10)
|| ffi::PyType_HasFeature(attr.get_type_ptr(), ffi::Py_TPFLAGS_HEAPTYPE) != 0
{
// This is the preferred faster path, but does not work on static types (generally,
// types defined in extension modules) before Python 3.10.
let descr_get_ptr = ffi::PyType_GetSlot(attr.get_type_ptr(), ffi::Py_tp_descr_get);
if descr_get_ptr.is_null() {
return Ok(Some(attr));
}
let descr_get: ffi::descrgetfunc = std::mem::transmute(descr_get_ptr);
let ret = descr_get(attr.as_ptr(), self.as_ptr(), self_type.as_ptr());
if ret.is_null() {
Err(PyErr::fetch(py))
} else {
Ok(Some(py.from_owned_ptr(ret)))
}
} else if let Ok(descr_get) = attr.get_type().getattr(crate::intern!(py, "__get__")) {
descr_get.call1((attr, self, self_type)).map(Some)
} else {
Ok(Some(attr))
}
}
}

/// Sets an attribute value.
///
/// This is equivalent to the Python expression `self.attr_name = value`.
Expand Down Expand Up @@ -974,9 +1023,82 @@ impl PyAny {
#[cfg(test)]
mod tests {
use crate::{
types::{IntoPyDict, PyBool, PyList, PyLong, PyModule},
types::{IntoPyDict, PyAny, PyBool, PyList, PyLong, PyModule},
Python, ToPyObject,
};

#[test]
fn test_lookup_special() {
Python::with_gil(|py| {
let module = PyModule::from_code(
py,
r#"
class CustomCallable:
def __call__(self):
return 1
class SimpleInt:
def __int__(self):
return 1
class InheritedInt(SimpleInt): pass
class NoInt: pass
class NoDescriptorInt:
__int__ = CustomCallable()
class InstanceOverrideInt:
def __int__(self):
return 1
instance_override = InstanceOverrideInt()
instance_override.__int__ = lambda self: 2
class ErrorInDescriptorInt:
@property
def __int__(self):
raise ValueError("uh-oh!")
class NonHeapNonDescriptorInt:
# A static-typed callable that doesn't implement `__get__`. These are pretty hard to come by.
__int__ = int
"#,
"test.py",
"test",
)
.unwrap();

let int = crate::intern!(py, "__int__");
let eval_int =
|obj: &PyAny| obj.lookup_special(int)?.unwrap().call0()?.extract::<u32>();

let simple = module.getattr("SimpleInt").unwrap().call0().unwrap();
assert_eq!(eval_int(simple).unwrap(), 1);
let inherited = module.getattr("InheritedInt").unwrap().call0().unwrap();
assert_eq!(eval_int(inherited).unwrap(), 1);
let no_descriptor = module.getattr("NoDescriptorInt").unwrap().call0().unwrap();
assert_eq!(eval_int(no_descriptor).unwrap(), 1);
let missing = module.getattr("NoInt").unwrap().call0().unwrap();
assert!(missing.lookup_special(int).unwrap().is_none());
// Note the instance override should _not_ call the instance method that returns 2,
// because that's not how special lookups are meant to work.
let instance_override = module.getattr("instance_override").unwrap();
assert_eq!(eval_int(instance_override).unwrap(), 1);
let descriptor_error = module
.getattr("ErrorInDescriptorInt")
.unwrap()
.call0()
.unwrap();
assert!(descriptor_error.lookup_special(int).is_err());
let nonheap_nondescriptor = module
.getattr("NonHeapNonDescriptorInt")
.unwrap()
.call0()
.unwrap();
assert_eq!(eval_int(nonheap_nondescriptor).unwrap(), 0);
})
}

#[test]
fn test_call_for_non_existing_method() {
Python::with_gil(|py| {
Expand Down

0 comments on commit a8d18a7

Please sign in to comment.