Skip to content

Commit

Permalink
support ordering magic methods for #[pyclass]
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Jun 3, 2023
1 parent fa949ff commit a0c47aa
Show file tree
Hide file tree
Showing 15 changed files with 610 additions and 62 deletions.
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ harness = false
name = "bench_call"
harness = false

[[bench]]
name = "bench_comparisons"
harness = false

[[bench]]
name = "bench_err"
harness = false
Expand Down
70 changes: 70 additions & 0 deletions benches/bench_comparisons.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
use criterion::{criterion_group, criterion_main, Bencher, Criterion};

use pyo3::{prelude::*, pyclass::CompareOp, Python};

#[pyclass]
struct OrderedDunderMethods(i64);

#[pymethods]
impl OrderedDunderMethods {
fn __lt__(&self, other: &Self) -> bool {
self.0 < other.0
}

fn __le__(&self, other: &Self) -> bool {
self.0 <= other.0
}

fn __eq__(&self, other: &Self) -> bool {
self.0 == other.0
}

fn __ne__(&self, other: &Self) -> bool {
self.0 != other.0
}

fn __gt__(&self, other: &Self) -> bool {
self.0 > other.0
}

fn __ge__(&self, other: &Self) -> bool {
self.0 >= other.0
}
}

#[pyclass]
#[derive(PartialEq, Eq, PartialOrd, Ord)]
struct OrderedRichcmp(i64);

#[pymethods]
impl OrderedRichcmp {
fn __richcmp__(&self, other: &Self, op: CompareOp) -> bool {
op.matches(self.cmp(other))
}
}

fn bench_ordered_dunder_methods(b: &mut Bencher<'_>) {
Python::with_gil(|py| {
let obj1 = Py::new(py, OrderedDunderMethods(0)).unwrap().into_ref(py);
let obj2 = Py::new(py, OrderedDunderMethods(1)).unwrap().into_ref(py);

b.iter(|| obj2.gt(obj1).unwrap());
});
}

fn bench_ordered_richcmp(b: &mut Bencher<'_>) {
Python::with_gil(|py| {
let obj1 = Py::new(py, OrderedRichcmp(0)).unwrap().into_ref(py);
let obj2 = Py::new(py, OrderedRichcmp(1)).unwrap().into_ref(py);

b.iter(|| obj2.gt(obj1).unwrap());
});
}

fn criterion_benchmark(c: &mut Criterion) {
c.bench_function("ordered_dunder_methods", bench_ordered_dunder_methods);
c.bench_function("ordered_richcmp", bench_ordered_richcmp);
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
19 changes: 18 additions & 1 deletion guide/src/class/protocols.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,28 @@ given signatures should be interpreted as follows:
```
</details>

- `__lt__(<self>, object) -> object`
- `__le__(<self>, object) -> object`
- `__eq__(<self>, object) -> object`
- `__ne__(<self>, object) -> object`
- `__gt__(<self>, object) -> object`
- `__ge__(<self>, object) -> object`

The implementations of Python's rich-comparison operators `<`, `<=`, `==`, `!=`, `>` and `>=` respectively.

_Note that implementing any of these methods will cause Python not to generate a default `__hash__` implementation, so consider also implementing `__hash__`._
<details>
<summary>Return type</summary>
The return type will normally be `bool` or `PyResult<bool>`, however any Python object can be returned.
</details>

- `__richcmp__(<self>, object, pyo3::basic::CompareOp) -> object`

Overloads Python comparison operations (`==`, `!=`, `<`, `<=`, `>`, and `>=`).
Implements Python comparison operations (`==`, `!=`, `<`, `<=`, `>`, and `>=`) in a single method.
The `CompareOp` argument indicates the comparison operation being performed.

_This method cannot be implemented in combination with any of `__lt__`, `__le__`, `__eq__`, `__ne__`, `__gt__`, or `__ge__`._

_Note that implementing `__richcmp__` will cause Python not to generate a default `__hash__` implementation, so consider implementing `__hash__` when implementing `__richcmp__`._
<details>
<summary>Return type</summary>
Expand Down
1 change: 1 addition & 0 deletions newsfragments/3203.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Support `__lt__`, `__le__`, `__eq__`, `__ne__`, `__gt__` and `__ge__` in `#[pymethods]`
57 changes: 33 additions & 24 deletions pyo3-macros-backend/src/pyimpl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,41 +235,50 @@ fn add_shared_proto_slots(
mut implemented_proto_fragments: HashSet<String>,
) {
macro_rules! try_add_shared_slot {
($first:literal, $second:literal, $slot:ident) => {{
let first_implemented = implemented_proto_fragments.remove($first);
let second_implemented = implemented_proto_fragments.remove($second);
if first_implemented || second_implemented {
($slot:ident, $($fragments:literal),*) => {{
let mut implemented = false;
$(implemented |= implemented_proto_fragments.remove($fragments));*;
if implemented {
proto_impls.push(quote! { _pyo3::impl_::pyclass::$slot!(#ty) })
}
}};
}

try_add_shared_slot!(
generate_pyclass_getattro_slot,
"__getattribute__",
"__getattr__",
generate_pyclass_getattro_slot
"__getattr__"
);
try_add_shared_slot!("__setattr__", "__delattr__", generate_pyclass_setattr_slot);
try_add_shared_slot!("__set__", "__delete__", generate_pyclass_setdescr_slot);
try_add_shared_slot!("__setitem__", "__delitem__", generate_pyclass_setitem_slot);
try_add_shared_slot!("__add__", "__radd__", generate_pyclass_add_slot);
try_add_shared_slot!("__sub__", "__rsub__", generate_pyclass_sub_slot);
try_add_shared_slot!("__mul__", "__rmul__", generate_pyclass_mul_slot);
try_add_shared_slot!("__mod__", "__rmod__", generate_pyclass_mod_slot);
try_add_shared_slot!("__divmod__", "__rdivmod__", generate_pyclass_divmod_slot);
try_add_shared_slot!("__lshift__", "__rlshift__", generate_pyclass_lshift_slot);
try_add_shared_slot!("__rshift__", "__rrshift__", generate_pyclass_rshift_slot);
try_add_shared_slot!("__and__", "__rand__", generate_pyclass_and_slot);
try_add_shared_slot!("__or__", "__ror__", generate_pyclass_or_slot);
try_add_shared_slot!("__xor__", "__rxor__", generate_pyclass_xor_slot);
try_add_shared_slot!("__matmul__", "__rmatmul__", generate_pyclass_matmul_slot);
try_add_shared_slot!("__truediv__", "__rtruediv__", generate_pyclass_truediv_slot);
try_add_shared_slot!(generate_pyclass_setattr_slot, "__setattr__", "__delattr__");
try_add_shared_slot!(generate_pyclass_setdescr_slot, "__set__", "__delete__");
try_add_shared_slot!(generate_pyclass_setitem_slot, "__setitem__", "__delitem__");
try_add_shared_slot!(generate_pyclass_add_slot, "__add__", "__radd__");
try_add_shared_slot!(generate_pyclass_sub_slot, "__sub__", "__rsub__");
try_add_shared_slot!(generate_pyclass_mul_slot, "__mul__", "__rmul__");
try_add_shared_slot!(generate_pyclass_mod_slot, "__mod__", "__rmod__");
try_add_shared_slot!(generate_pyclass_divmod_slot, "__divmod__", "__rdivmod__");
try_add_shared_slot!(generate_pyclass_lshift_slot, "__lshift__", "__rlshift__");
try_add_shared_slot!(generate_pyclass_rshift_slot, "__rshift__", "__rrshift__");
try_add_shared_slot!(generate_pyclass_and_slot, "__and__", "__rand__");
try_add_shared_slot!(generate_pyclass_or_slot, "__or__", "__ror__");
try_add_shared_slot!(generate_pyclass_xor_slot, "__xor__", "__rxor__");
try_add_shared_slot!(generate_pyclass_matmul_slot, "__matmul__", "__rmatmul__");
try_add_shared_slot!(generate_pyclass_truediv_slot, "__truediv__", "__rtruediv__");
try_add_shared_slot!(
generate_pyclass_floordiv_slot,
"__floordiv__",
"__rfloordiv__",
generate_pyclass_floordiv_slot
"__rfloordiv__"
);
try_add_shared_slot!(generate_pyclass_pow_slot, "__pow__", "__rpow__");
try_add_shared_slot!(
generate_pyclass_richcompare_slot,
"__lt__",
"__le__",
"__eq__",
"__ne__",
"__gt__",
"__ge__"
);
try_add_shared_slot!("__pow__", "__rpow__", generate_pyclass_pow_slot);

// if this assertion trips, a slot fragment has been implemented which has not been added in the
// list above
Expand Down
25 changes: 25 additions & 0 deletions pyo3-macros-backend/src/pymethod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@ impl PyMethodKind {
"__ror__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__ROR__)),
"__pow__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__POW__)),
"__rpow__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RPOW__)),
"__lt__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__LT__)),
"__le__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__LE__)),
"__eq__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__EQ__)),
"__ne__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__NE__)),
"__gt__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__GT__)),
"__ge__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__GE__)),
// Some tricky protocols which don't fit the pattern of the rest
"__call__" => PyMethodKind::Proto(PyMethodProtoKind::Call),
"__traverse__" => PyMethodKind::Proto(PyMethodProtoKind::Traverse),
Expand Down Expand Up @@ -1300,6 +1306,25 @@ const __RPOW__: SlotFragmentDef = SlotFragmentDef::new("__rpow__", &[Ty::Object,
.extract_error_mode(ExtractErrorMode::NotImplemented)
.ret_ty(Ty::Object);

const __LT__: SlotFragmentDef = SlotFragmentDef::new("__lt__", &[Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.ret_ty(Ty::Object);
const __LE__: SlotFragmentDef = SlotFragmentDef::new("__le__", &[Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.ret_ty(Ty::Object);
const __EQ__: SlotFragmentDef = SlotFragmentDef::new("__eq__", &[Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.ret_ty(Ty::Object);
const __NE__: SlotFragmentDef = SlotFragmentDef::new("__ne__", &[Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.ret_ty(Ty::Object);
const __GT__: SlotFragmentDef = SlotFragmentDef::new("__gt__", &[Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.ret_ty(Ty::Object);
const __GE__: SlotFragmentDef = SlotFragmentDef::new("__ge__", &[Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.ret_ty(Ty::Object);

fn extract_proto_arguments(
py: &syn::Ident,
spec: &FnSpec<'_>,
Expand Down
1 change: 1 addition & 0 deletions pytests/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ hypothesis>=3.55
pytest>=6.0
pytest-benchmark>=3.4
psutil>=5.6
typing_extensions>=4.0.0
111 changes: 111 additions & 0 deletions pytests/src/comparisons.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
use pyo3::prelude::*;
use pyo3::{types::PyModule, Python};

#[pyclass]
struct Eq(i64);

#[pymethods]
impl Eq {
#[new]
fn new(value: i64) -> Self {
Self(value)
}

fn __eq__(&self, other: &Self) -> bool {
self.0 == other.0
}

fn __ne__(&self, other: &Self) -> bool {
self.0 != other.0
}
}

#[pyclass]
struct EqDefaultNe(i64);

#[pymethods]
impl EqDefaultNe {
#[new]
fn new(value: i64) -> Self {
Self(value)
}

fn __eq__(&self, other: &Self) -> bool {
self.0 == other.0
}
}

#[pyclass]
struct Ordered(i64);

#[pymethods]
impl Ordered {
#[new]
fn new(value: i64) -> Self {
Self(value)
}

fn __lt__(&self, other: &Self) -> bool {
self.0 < other.0
}

fn __le__(&self, other: &Self) -> bool {
self.0 <= other.0
}

fn __eq__(&self, other: &Self) -> bool {
self.0 == other.0
}

fn __ne__(&self, other: &Self) -> bool {
self.0 != other.0
}

fn __gt__(&self, other: &Self) -> bool {
self.0 > other.0
}

fn __ge__(&self, other: &Self) -> bool {
self.0 >= other.0
}
}

#[pyclass]
struct OrderedDefaultNe(i64);

#[pymethods]
impl OrderedDefaultNe {
#[new]
fn new(value: i64) -> Self {
Self(value)
}

fn __lt__(&self, other: &Self) -> bool {
self.0 < other.0
}

fn __le__(&self, other: &Self) -> bool {
self.0 <= other.0
}

fn __eq__(&self, other: &Self) -> bool {
self.0 == other.0
}

fn __gt__(&self, other: &Self) -> bool {
self.0 > other.0
}

fn __ge__(&self, other: &Self) -> bool {
self.0 >= other.0
}
}

#[pymodule]
pub fn comparisons(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<Eq>()?;
m.add_class::<EqDefaultNe>()?;
m.add_class::<Ordered>()?;
m.add_class::<OrderedDefaultNe>()?;
Ok(())
}
3 changes: 3 additions & 0 deletions pytests/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use pyo3::types::PyDict;
use pyo3::wrap_pymodule;

pub mod buf_and_str;
pub mod comparisons;
pub mod datetime;
pub mod deprecated_pyfunctions;
pub mod dict_iter;
Expand All @@ -19,6 +20,7 @@ pub mod subclassing;
fn pyo3_pytests(py: Python<'_>, m: &PyModule) -> PyResult<()> {
#[cfg(not(Py_LIMITED_API))]
m.add_wrapped(wrap_pymodule!(buf_and_str::buf_and_str))?;
m.add_wrapped(wrap_pymodule!(comparisons::comparisons))?;
#[cfg(not(Py_LIMITED_API))]
m.add_wrapped(wrap_pymodule!(datetime::datetime))?;
m.add_wrapped(wrap_pymodule!(
Expand All @@ -40,6 +42,7 @@ fn pyo3_pytests(py: Python<'_>, m: &PyModule) -> PyResult<()> {
let sys = PyModule::import(py, "sys")?;
let sys_modules: &PyDict = sys.getattr("modules")?.downcast()?;
sys_modules.set_item("pyo3_pytests.buf_and_str", m.getattr("buf_and_str")?)?;
sys_modules.set_item("pyo3_pytests.comparisons", m.getattr("comparisons")?)?;
sys_modules.set_item("pyo3_pytests.datetime", m.getattr("datetime")?)?;
sys_modules.set_item(
"pyo3_pytests.deprecated_pyfunctions",
Expand Down
Loading

0 comments on commit a0c47aa

Please sign in to comment.