Skip to content

Commit

Permalink
Merge pull request #3632 from messense/extract-frozen-set
Browse files Browse the repository at this point in the history
Add support for extracting Rust set types from `frozenset`
  • Loading branch information
davidhewitt committed Dec 7, 2023
2 parents 601d957 + c21a84d commit 24d9113
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 10 deletions.
1 change: 1 addition & 0 deletions newsfragments/3632.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for extracting Rust set types from `frozenset`.
18 changes: 15 additions & 3 deletions src/conversions/hashbrown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
//! The required hashbrown version may vary based on the version of PyO3.
use crate::{
types::set::new_from_iter,
types::{IntoPyDict, PyDict, PySet},
types::{IntoPyDict, PyDict, PyFrozenSet, PySet},
FromPyObject, IntoPy, PyAny, PyErr, PyObject, PyResult, Python, ToPyObject,
};
use std::{cmp, hash};
Expand Down Expand Up @@ -93,8 +93,16 @@ where
S: hash::BuildHasher + Default,
{
fn extract(ob: &'source PyAny) -> PyResult<Self> {
let set: &PySet = ob.downcast()?;
set.iter().map(K::extract).collect()
match ob.downcast::<PySet>() {
Ok(set) => set.iter().map(K::extract).collect(),
Err(err) => {
if let Ok(frozen_set) = ob.downcast::<PyFrozenSet>() {
frozen_set.iter().map(K::extract).collect()
} else {
Err(PyErr::from(err))
}
}
}
}
}

Expand Down Expand Up @@ -173,6 +181,10 @@ mod tests {
let set = PySet::new(py, &[1, 2, 3, 4, 5]).unwrap();
let hash_set: hashbrown::HashSet<usize> = set.extract().unwrap();
assert_eq!(hash_set, [1, 2, 3, 4, 5].iter().copied().collect());

let set = PyFrozenSet::new(py, &[1, 2, 3, 4, 5]).unwrap();
let hash_set: hashbrown::HashSet<usize> = set.extract().unwrap();
assert_eq!(hash_set, [1, 2, 3, 4, 5].iter().copied().collect());
});
}

Expand Down
39 changes: 32 additions & 7 deletions src/conversions/std/set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ use std::{cmp, collections, hash};
#[cfg(feature = "experimental-inspect")]
use crate::inspect::types::TypeInfo;
use crate::{
types::set::new_from_iter, types::PySet, FromPyObject, IntoPy, PyAny, PyObject, PyResult,
Python, ToPyObject,
types::set::new_from_iter,
types::{PyFrozenSet, PySet},
FromPyObject, IntoPy, PyAny, PyErr, PyObject, PyResult, Python, ToPyObject,
};

impl<T, S> ToPyObject for collections::HashSet<T, S>
Expand Down Expand Up @@ -53,8 +54,16 @@ where
S: hash::BuildHasher + Default,
{
fn extract(ob: &'source PyAny) -> PyResult<Self> {
let set: &PySet = ob.downcast()?;
set.iter().map(K::extract).collect()
match ob.downcast::<PySet>() {
Ok(set) => set.iter().map(K::extract).collect(),
Err(err) => {
if let Ok(frozen_set) = ob.downcast::<PyFrozenSet>() {
frozen_set.iter().map(K::extract).collect()
} else {
Err(PyErr::from(err))
}
}
}
}

#[cfg(feature = "experimental-inspect")]
Expand Down Expand Up @@ -84,8 +93,16 @@ where
K: FromPyObject<'source> + cmp::Ord,
{
fn extract(ob: &'source PyAny) -> PyResult<Self> {
let set: &PySet = ob.downcast()?;
set.iter().map(K::extract).collect()
match ob.downcast::<PySet>() {
Ok(set) => set.iter().map(K::extract).collect(),
Err(err) => {
if let Ok(frozen_set) = ob.downcast::<PyFrozenSet>() {
frozen_set.iter().map(K::extract).collect()
} else {
Err(PyErr::from(err))
}
}
}
}

#[cfg(feature = "experimental-inspect")]
Expand All @@ -96,7 +113,7 @@ where

#[cfg(test)]
mod tests {
use super::PySet;
use super::{PyFrozenSet, PySet};
use crate::{IntoPy, PyObject, Python, ToPyObject};
use std::collections::{BTreeSet, HashSet};

Expand All @@ -106,6 +123,10 @@ mod tests {
let set = PySet::new(py, &[1, 2, 3, 4, 5]).unwrap();
let hash_set: HashSet<usize> = set.extract().unwrap();
assert_eq!(hash_set, [1, 2, 3, 4, 5].iter().copied().collect());

let set = PyFrozenSet::new(py, &[1, 2, 3, 4, 5]).unwrap();
let hash_set: HashSet<usize> = set.extract().unwrap();
assert_eq!(hash_set, [1, 2, 3, 4, 5].iter().copied().collect());
});
}

Expand All @@ -115,6 +136,10 @@ mod tests {
let set = PySet::new(py, &[1, 2, 3, 4, 5]).unwrap();
let hash_set: BTreeSet<usize> = set.extract().unwrap();
assert_eq!(hash_set, [1, 2, 3, 4, 5].iter().copied().collect());

let set = PyFrozenSet::new(py, &[1, 2, 3, 4, 5]).unwrap();
let hash_set: BTreeSet<usize> = set.extract().unwrap();
assert_eq!(hash_set, [1, 2, 3, 4, 5].iter().copied().collect());
});
}

Expand Down

0 comments on commit 24d9113

Please sign in to comment.