Skip to content

Commit

Permalink
Use PyIterator type for more idiomatic shim code.
Browse files Browse the repository at this point in the history
  • Loading branch information
obi1kenobi committed Sep 18, 2024
1 parent 9282b20 commit 73322ca
Showing 1 changed file with 21 additions and 24 deletions.
45 changes: 21 additions & 24 deletions pytrustfall/src/shim.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{collections::BTreeMap, sync::Arc};

use pyo3::{exceptions::PyStopIteration, prelude::*, wrap_pyfunction};
use pyo3::{exceptions::PyStopIteration, prelude::*, types::PyIterator, wrap_pyfunction};
use trustfall_core::{
frontend::{error::FrontendError, parse},
interpreter::{
Expand Down Expand Up @@ -120,8 +120,8 @@ impl AdapterShim {
}
}

fn make_iterator<'py>(value: &Bound<'py, PyAny>) -> PyResult<Bound<'py, PyAny>> {
value.call_method0("__iter__")
fn make_iterator<'py>(value: &Bound<'py, PyAny>, origin: &'static str) -> Bound<'py, PyIterator> {
value.iter().unwrap_or_else(|e| panic!("{origin} is not an iterable: {e}"))
}

#[pyclass(unsendable)]
Expand Down Expand Up @@ -196,17 +196,17 @@ impl Adapter<'static> for AdapterShim {
.map(|(k, v)| (k.to_string(), FieldValue::from(v.clone()).into_py(py)))
.collect();

// TODO: use `intern!()` macro to intern the fixed method names for efficiency
let py_iterable = self
.adapter
.call_method_bound(
py,
"resolve_starting_vertices",
pyo3::intern!(py, "resolve_starting_vertices"),
(edge_name.as_ref(), parameter_data),
None,
)
.unwrap();
let iter = make_iterator(py_iterable.bind(py)).unwrap();

let iter = make_iterator(py_iterable.bind(py), "resolve_starting_vertices()");
Box::new(PythonVertexIterator::new(iter.unbind()))
})
}
Expand All @@ -231,9 +231,7 @@ impl Adapter<'static> for AdapterShim {
.unwrap();

let iter = PythonResolvePropertyIterator::new(
make_iterator(py_iterable.bind(py))
.expect("failed to use py_iterable as an iterator")
.unbind(),
make_iterator(py_iterable.bind(py), "resolve_property()").unbind(),
);

Box::new(iter.map(|(opaque, value)| {
Expand Down Expand Up @@ -272,9 +270,7 @@ impl Adapter<'static> for AdapterShim {
.unwrap();

let iter = PythonResolveNeighborsIterator::new(
make_iterator(py_iterable.bind(py))
.expect("failed to use py_iterable as an iterator")
.unbind(),
make_iterator(py_iterable.bind(py), "resolve_neighbors()").unbind(),
);
Box::new(iter.map(|(opaque, neighbors)| {
// SAFETY: This `Opaque` was constructed just a few lines ago
Expand Down Expand Up @@ -306,9 +302,7 @@ impl Adapter<'static> for AdapterShim {
.unwrap();

let iter = PythonResolveCoercionIterator::new(
make_iterator(py_iterable.bind(py))
.expect("failed to use py_iterable as an iterator")
.unbind(),
make_iterator(py_iterable.bind(py), "resolve_coercion()").unbind(),
);
Box::new(iter.map(|(opaque, value)| {
// SAFETY: This `Opaque` was constructed just a few lines ago
Expand All @@ -322,11 +316,11 @@ impl Adapter<'static> for AdapterShim {
}

struct PythonVertexIterator {
underlying: Py<PyAny>,
underlying: Py<PyIterator>,
}

impl PythonVertexIterator {
fn new(underlying: Py<PyAny>) -> Self {
fn new(underlying: Py<PyIterator>) -> Self {
Self { underlying }
}
}
Expand All @@ -351,11 +345,11 @@ impl Iterator for PythonVertexIterator {
}

struct PythonResolvePropertyIterator {
underlying: Py<PyAny>,
underlying: Py<PyIterator>,
}

impl PythonResolvePropertyIterator {
fn new(underlying: Py<PyAny>) -> Self {
fn new(underlying: Py<PyIterator>) -> Self {
Self { underlying }
}
}
Expand Down Expand Up @@ -399,11 +393,11 @@ impl Iterator for PythonResolvePropertyIterator {
}

struct PythonResolveNeighborsIterator {
underlying: Py<PyAny>,
underlying: Py<PyIterator>,
}

impl PythonResolveNeighborsIterator {
fn new(underlying: Py<PyAny>) -> Self {
fn new(underlying: Py<PyIterator>) -> Self {
Self { underlying }
}
}
Expand All @@ -426,7 +420,10 @@ impl Iterator for PythonResolveNeighborsIterator {

// Allow returning iterables (e.g. []), not just iterators.
// Iterators return self when __iter__() is called.
let neighbors_iter = make_iterator(neighbors_iterable.bind(py)).unwrap();
let neighbors_iter = make_iterator(
neighbors_iterable.bind(py),
"resolve_neighbors() yielded tuple's second element",
);

let neighbors: VertexIterator<'static, Arc<Py<PyAny>>> =
Box::new(PythonVertexIterator::new(neighbors_iter.unbind()));
Expand All @@ -447,11 +444,11 @@ impl Iterator for PythonResolveNeighborsIterator {
}

struct PythonResolveCoercionIterator {
underlying: Py<PyAny>,
underlying: Py<PyIterator>,
}

impl PythonResolveCoercionIterator {
fn new(underlying: Py<PyAny>) -> Self {
fn new(underlying: Py<PyIterator>) -> Self {
Self { underlying }
}
}
Expand Down

0 comments on commit 73322ca

Please sign in to comment.