Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support ordering magic methods for #[pyclass] #3203

Merged
merged 1 commit into from
Jun 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ harness = false
name = "bench_call"
harness = false

[[bench]]
name = "bench_comparisons"
harness = false

[[bench]]
name = "bench_err"
harness = false
Expand Down
70 changes: 70 additions & 0 deletions benches/bench_comparisons.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
use criterion::{criterion_group, criterion_main, Bencher, Criterion};

use pyo3::{prelude::*, pyclass::CompareOp, Python};

#[pyclass]
struct OrderedDunderMethods(i64);

#[pymethods]
impl OrderedDunderMethods {
fn __lt__(&self, other: &Self) -> bool {
self.0 < other.0
}

fn __le__(&self, other: &Self) -> bool {
self.0 <= other.0
}

fn __eq__(&self, other: &Self) -> bool {
self.0 == other.0
}

fn __ne__(&self, other: &Self) -> bool {
self.0 != other.0
}

fn __gt__(&self, other: &Self) -> bool {
self.0 > other.0
}

fn __ge__(&self, other: &Self) -> bool {
self.0 >= other.0
}
}

#[pyclass]
#[derive(PartialEq, Eq, PartialOrd, Ord)]
struct OrderedRichcmp(i64);

#[pymethods]
impl OrderedRichcmp {
fn __richcmp__(&self, other: &Self, op: CompareOp) -> bool {
op.matches(self.cmp(other))
}
}

fn bench_ordered_dunder_methods(b: &mut Bencher<'_>) {
Python::with_gil(|py| {
let obj1 = Py::new(py, OrderedDunderMethods(0)).unwrap().into_ref(py);
let obj2 = Py::new(py, OrderedDunderMethods(1)).unwrap().into_ref(py);

b.iter(|| obj2.gt(obj1).unwrap());
adamreichold marked this conversation as resolved.
Show resolved Hide resolved
});
}

fn bench_ordered_richcmp(b: &mut Bencher<'_>) {
Python::with_gil(|py| {
let obj1 = Py::new(py, OrderedRichcmp(0)).unwrap().into_ref(py);
let obj2 = Py::new(py, OrderedRichcmp(1)).unwrap().into_ref(py);

b.iter(|| obj2.gt(obj1).unwrap());
});
}

fn criterion_benchmark(c: &mut Criterion) {
c.bench_function("ordered_dunder_methods", bench_ordered_dunder_methods);
c.bench_function("ordered_richcmp", bench_ordered_richcmp);
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
19 changes: 18 additions & 1 deletion guide/src/class/protocols.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,28 @@ given signatures should be interpreted as follows:
```
</details>

- `__lt__(<self>, object) -> object`
- `__le__(<self>, object) -> object`
- `__eq__(<self>, object) -> object`
- `__ne__(<self>, object) -> object`
- `__gt__(<self>, object) -> object`
- `__ge__(<self>, object) -> object`

The implementations of Python's "rich comparison" operators `<`, `<=`, `==`, `!=`, `>` and `>=` respectively.

_Note that implementing any of these methods will cause Python not to generate a default `__hash__` implementation, so consider also implementing `__hash__`._
<details>
<summary>Return type</summary>
The return type will normally be `bool` or `PyResult<bool>`, however any Python object can be returned.
</details>

- `__richcmp__(<self>, object, pyo3::basic::CompareOp) -> object`

Overloads Python comparison operations (`==`, `!=`, `<`, `<=`, `>`, and `>=`).
Implements Python comparison operations (`==`, `!=`, `<`, `<=`, `>`, and `>=`) in a single method.
The `CompareOp` argument indicates the comparison operation being performed.

_This method cannot be implemented in combination with any of `__lt__`, `__le__`, `__eq__`, `__ne__`, `__gt__`, or `__ge__`._

_Note that implementing `__richcmp__` will cause Python not to generate a default `__hash__` implementation, so consider implementing `__hash__` when implementing `__richcmp__`._
<details>
<summary>Return type</summary>
Expand Down
1 change: 1 addition & 0 deletions newsfragments/3203.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Support `__lt__`, `__le__`, `__eq__`, `__ne__`, `__gt__` and `__ge__` in `#[pymethods]`
57 changes: 33 additions & 24 deletions pyo3-macros-backend/src/pyimpl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,41 +235,50 @@ fn add_shared_proto_slots(
mut implemented_proto_fragments: HashSet<String>,
) {
macro_rules! try_add_shared_slot {
($first:literal, $second:literal, $slot:ident) => {{
let first_implemented = implemented_proto_fragments.remove($first);
let second_implemented = implemented_proto_fragments.remove($second);
if first_implemented || second_implemented {
($slot:ident, $($fragments:literal),*) => {{
let mut implemented = false;
$(implemented |= implemented_proto_fragments.remove($fragments));*;
if implemented {
proto_impls.push(quote! { _pyo3::impl_::pyclass::$slot!(#ty) })
}
}};
}

try_add_shared_slot!(
generate_pyclass_getattro_slot,
"__getattribute__",
"__getattr__",
generate_pyclass_getattro_slot
"__getattr__"
);
try_add_shared_slot!("__setattr__", "__delattr__", generate_pyclass_setattr_slot);
try_add_shared_slot!("__set__", "__delete__", generate_pyclass_setdescr_slot);
try_add_shared_slot!("__setitem__", "__delitem__", generate_pyclass_setitem_slot);
try_add_shared_slot!("__add__", "__radd__", generate_pyclass_add_slot);
try_add_shared_slot!("__sub__", "__rsub__", generate_pyclass_sub_slot);
try_add_shared_slot!("__mul__", "__rmul__", generate_pyclass_mul_slot);
try_add_shared_slot!("__mod__", "__rmod__", generate_pyclass_mod_slot);
try_add_shared_slot!("__divmod__", "__rdivmod__", generate_pyclass_divmod_slot);
try_add_shared_slot!("__lshift__", "__rlshift__", generate_pyclass_lshift_slot);
try_add_shared_slot!("__rshift__", "__rrshift__", generate_pyclass_rshift_slot);
try_add_shared_slot!("__and__", "__rand__", generate_pyclass_and_slot);
try_add_shared_slot!("__or__", "__ror__", generate_pyclass_or_slot);
try_add_shared_slot!("__xor__", "__rxor__", generate_pyclass_xor_slot);
try_add_shared_slot!("__matmul__", "__rmatmul__", generate_pyclass_matmul_slot);
try_add_shared_slot!("__truediv__", "__rtruediv__", generate_pyclass_truediv_slot);
try_add_shared_slot!(generate_pyclass_setattr_slot, "__setattr__", "__delattr__");
try_add_shared_slot!(generate_pyclass_setdescr_slot, "__set__", "__delete__");
try_add_shared_slot!(generate_pyclass_setitem_slot, "__setitem__", "__delitem__");
try_add_shared_slot!(generate_pyclass_add_slot, "__add__", "__radd__");
try_add_shared_slot!(generate_pyclass_sub_slot, "__sub__", "__rsub__");
try_add_shared_slot!(generate_pyclass_mul_slot, "__mul__", "__rmul__");
try_add_shared_slot!(generate_pyclass_mod_slot, "__mod__", "__rmod__");
try_add_shared_slot!(generate_pyclass_divmod_slot, "__divmod__", "__rdivmod__");
try_add_shared_slot!(generate_pyclass_lshift_slot, "__lshift__", "__rlshift__");
try_add_shared_slot!(generate_pyclass_rshift_slot, "__rshift__", "__rrshift__");
try_add_shared_slot!(generate_pyclass_and_slot, "__and__", "__rand__");
try_add_shared_slot!(generate_pyclass_or_slot, "__or__", "__ror__");
try_add_shared_slot!(generate_pyclass_xor_slot, "__xor__", "__rxor__");
try_add_shared_slot!(generate_pyclass_matmul_slot, "__matmul__", "__rmatmul__");
try_add_shared_slot!(generate_pyclass_truediv_slot, "__truediv__", "__rtruediv__");
try_add_shared_slot!(
generate_pyclass_floordiv_slot,
"__floordiv__",
"__rfloordiv__",
generate_pyclass_floordiv_slot
"__rfloordiv__"
);
try_add_shared_slot!(generate_pyclass_pow_slot, "__pow__", "__rpow__");
try_add_shared_slot!(
generate_pyclass_richcompare_slot,
"__lt__",
"__le__",
"__eq__",
"__ne__",
"__gt__",
"__ge__"
);
try_add_shared_slot!("__pow__", "__rpow__", generate_pyclass_pow_slot);

// if this assertion trips, a slot fragment has been implemented which has not been added in the
// list above
Expand Down
25 changes: 25 additions & 0 deletions pyo3-macros-backend/src/pymethod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@ impl PyMethodKind {
"__ror__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__ROR__)),
"__pow__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__POW__)),
"__rpow__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RPOW__)),
"__lt__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__LT__)),
"__le__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__LE__)),
"__eq__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__EQ__)),
"__ne__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__NE__)),
"__gt__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__GT__)),
"__ge__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__GE__)),
// Some tricky protocols which don't fit the pattern of the rest
"__call__" => PyMethodKind::Proto(PyMethodProtoKind::Call),
"__traverse__" => PyMethodKind::Proto(PyMethodProtoKind::Traverse),
Expand Down Expand Up @@ -1300,6 +1306,25 @@ const __RPOW__: SlotFragmentDef = SlotFragmentDef::new("__rpow__", &[Ty::Object,
.extract_error_mode(ExtractErrorMode::NotImplemented)
.ret_ty(Ty::Object);

const __LT__: SlotFragmentDef = SlotFragmentDef::new("__lt__", &[Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.ret_ty(Ty::Object);
const __LE__: SlotFragmentDef = SlotFragmentDef::new("__le__", &[Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.ret_ty(Ty::Object);
const __EQ__: SlotFragmentDef = SlotFragmentDef::new("__eq__", &[Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.ret_ty(Ty::Object);
const __NE__: SlotFragmentDef = SlotFragmentDef::new("__ne__", &[Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.ret_ty(Ty::Object);
const __GT__: SlotFragmentDef = SlotFragmentDef::new("__gt__", &[Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.ret_ty(Ty::Object);
const __GE__: SlotFragmentDef = SlotFragmentDef::new("__ge__", &[Ty::Object])
.extract_error_mode(ExtractErrorMode::NotImplemented)
.ret_ty(Ty::Object);

fn extract_proto_arguments(
py: &syn::Ident,
spec: &FnSpec<'_>,
Expand Down
1 change: 1 addition & 0 deletions pytests/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ hypothesis>=3.55
pytest>=6.0
pytest-benchmark>=3.4
psutil>=5.6
typing_extensions>=4.0.0
111 changes: 111 additions & 0 deletions pytests/src/comparisons.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
use pyo3::prelude::*;
use pyo3::{types::PyModule, Python};

#[pyclass]
struct Eq(i64);

#[pymethods]
impl Eq {
#[new]
fn new(value: i64) -> Self {
Self(value)
}

fn __eq__(&self, other: &Self) -> bool {
self.0 == other.0
}

fn __ne__(&self, other: &Self) -> bool {
self.0 != other.0
}
}

#[pyclass]
struct EqDefaultNe(i64);

#[pymethods]
impl EqDefaultNe {
#[new]
fn new(value: i64) -> Self {
Self(value)
}

fn __eq__(&self, other: &Self) -> bool {
self.0 == other.0
}
}

#[pyclass]
struct Ordered(i64);

#[pymethods]
impl Ordered {
#[new]
fn new(value: i64) -> Self {
Self(value)
}

fn __lt__(&self, other: &Self) -> bool {
self.0 < other.0
}

fn __le__(&self, other: &Self) -> bool {
self.0 <= other.0
}

fn __eq__(&self, other: &Self) -> bool {
self.0 == other.0
}

fn __ne__(&self, other: &Self) -> bool {
self.0 != other.0
}

fn __gt__(&self, other: &Self) -> bool {
self.0 > other.0
}

fn __ge__(&self, other: &Self) -> bool {
self.0 >= other.0
}
}

#[pyclass]
struct OrderedDefaultNe(i64);

#[pymethods]
impl OrderedDefaultNe {
#[new]
fn new(value: i64) -> Self {
Self(value)
}

fn __lt__(&self, other: &Self) -> bool {
self.0 < other.0
}

fn __le__(&self, other: &Self) -> bool {
self.0 <= other.0
}

fn __eq__(&self, other: &Self) -> bool {
self.0 == other.0
}

fn __gt__(&self, other: &Self) -> bool {
self.0 > other.0
}

fn __ge__(&self, other: &Self) -> bool {
self.0 >= other.0
}
}

#[pymodule]
pub fn comparisons(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<Eq>()?;
m.add_class::<EqDefaultNe>()?;
m.add_class::<Ordered>()?;
m.add_class::<OrderedDefaultNe>()?;
Ok(())
}
3 changes: 3 additions & 0 deletions pytests/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use pyo3::types::PyDict;
use pyo3::wrap_pymodule;

pub mod buf_and_str;
pub mod comparisons;
pub mod datetime;
pub mod deprecated_pyfunctions;
pub mod dict_iter;
Expand All @@ -19,6 +20,7 @@ pub mod subclassing;
fn pyo3_pytests(py: Python<'_>, m: &PyModule) -> PyResult<()> {
#[cfg(not(Py_LIMITED_API))]
m.add_wrapped(wrap_pymodule!(buf_and_str::buf_and_str))?;
m.add_wrapped(wrap_pymodule!(comparisons::comparisons))?;
#[cfg(not(Py_LIMITED_API))]
m.add_wrapped(wrap_pymodule!(datetime::datetime))?;
m.add_wrapped(wrap_pymodule!(
Expand All @@ -40,6 +42,7 @@ fn pyo3_pytests(py: Python<'_>, m: &PyModule) -> PyResult<()> {
let sys = PyModule::import(py, "sys")?;
let sys_modules: &PyDict = sys.getattr("modules")?.downcast()?;
sys_modules.set_item("pyo3_pytests.buf_and_str", m.getattr("buf_and_str")?)?;
sys_modules.set_item("pyo3_pytests.comparisons", m.getattr("comparisons")?)?;
sys_modules.set_item("pyo3_pytests.datetime", m.getattr("datetime")?)?;
sys_modules.set_item(
"pyo3_pytests.deprecated_pyfunctions",
Expand Down
Loading