Skip to content

Commit

Permalink
Perturbation color cleanup (#38)
Browse files Browse the repository at this point in the history
* Upgrade PyO3 and maturin.

* Add missing type annotation and doc string.

* Add `strip_perturbation_data`.
  • Loading branch information
daemontus authored Nov 28, 2024
1 parent 8f6e646 commit fd16dde
Show file tree
Hide file tree
Showing 35 changed files with 231 additions and 159 deletions.
5 changes: 2 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ crate-type = ["cdylib", "rlib"]
static-z3 = ["z3/static-link-z3"]

[dependencies]
pyo3 = { version = "0.22.5", features = ["abi3-py37", "extension-module", "num-bigint", "py-clone"] }
pyo3 = { version = "0.23.2", features = ["abi3-py37", "extension-module", "num-bigint", "py-clone"] }
biodivine-lib-param-bn = { version="0.5.13", features=["solver-z3"] }
biodivine-lib-bdd = "0.5.22"
#biodivine-pbn-control = "0.3.1"
Expand All @@ -29,12 +29,11 @@ zip = "2.1.3"
num-bigint = "0.4.6"
num-traits = "0.2.19"
either = "1.13.0"
itertools = "0.13.0"

# Include Z3 dependencies as strictly as possible, we don't want
# this to change because it might break our release builds.
z3="^0.12.1"
z3-sys = "^0.8.1"

[build-dependencies]
pyo3-build-config = "0.22.1"
pyo3-build-config = "0.23.2"
4 changes: 3 additions & 1 deletion biodivine_aeon/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1789,6 +1789,7 @@ class AsynchronousPerturbationGraph(AsynchronousGraph):
def mk_perturbations(self, perturbations: Union[Mapping[VariableIdType, Optional[bool]], PerturbationModel]) -> PerturbationSet: ...
def mk_perturbations_with_size(self, size: int, up_to: bool) -> PerturbationSet: ...
def colored_robustness(self, set: ColorSet) -> float: ...
def strip_perturbation_data(self, set: ColorSet) -> ColorSet: ...

class Control:
@staticmethod
Expand All @@ -1811,7 +1812,8 @@ class Control:
phenotype: VertexSet,
oscillation_type: Optional[PhenotypeOscillation] = None,
size_limit: Optional[int] = None,
stop_when_found: bool = False) -> ColoredPerturbationSet: ...
stop_when_found: bool = False,
initial_states: VertexSet | None = None) -> ColoredPerturbationSet: ...

BddVariableType = Union[BddVariable, str]
VariableIdType = Union[VariableId, str]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[build-system]
requires = ["maturin>=1.6.0,<1.7.0"]
requires = ["maturin>=1.7.0,<1.8.0"]
build-backend = "maturin"

[project]
Expand Down
6 changes: 3 additions & 3 deletions src/bindings/bn_classifier/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ impl Class {
Ok(Class { items })
}

fn __richcmp__(&self, py: Python, other: &Class, op: CompareOp) -> Py<PyAny> {
fn __richcmp__(&self, py: Python, other: &Class, op: CompareOp) -> PyResult<Py<PyAny>> {
richcmp_eq_by_key(py, op, &self, &other, |x| &x.items)
}

Expand All @@ -97,8 +97,8 @@ impl Class {
format!("Class({})", self.__str__())
}

fn __getnewargs__<'a>(&self, py: Python<'a>) -> Bound<'a, PyTuple> {
PyTuple::new_bound(py, vec![self.feature_list()])
fn __getnewargs__<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyTuple>> {
PyTuple::new(py, vec![self.feature_list()])
}

fn __len__(&self) -> usize {
Expand Down
2 changes: 1 addition & 1 deletion src/bindings/lib_bdd/bdd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ impl Bdd {

/// Convert this `Bdd` into a serialized `bytes` format that can be read using the `Bdd` constructor.
fn data_bytes<'a>(&self, py: Python<'a>) -> Bound<'a, PyBytes> {
PyBytes::new_bound(py, &self.as_native().to_bytes())
PyBytes::new(py, &self.as_native().to_bytes())
}

/// Produce a `graphviz`-compatible `.dot` representation of the underlying graph. If `zero_pruned` is set,
Expand Down
4 changes: 2 additions & 2 deletions src/bindings/lib_bdd/bdd_pointer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ impl BddPointer {
self.0.to_index()
}

fn __getnewargs__<'a>(&self, py: Python<'a>) -> Bound<'a, PyTuple> {
PyTuple::new_bound(py, [self.0.to_index()])
fn __getnewargs__<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyTuple>> {
PyTuple::new(py, [self.0.to_index()])
}

/// Returns the `BddPointer` referencing the `0` terminal node.
Expand Down
11 changes: 8 additions & 3 deletions src/bindings/lib_bdd/bdd_valuation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ impl BddValuation {
let var_count = ctx.get().variable_count();
match values {
None => {
let var_count = u16::try_from(var_count).unwrap();
let var_count = u16::try_from(var_count)?;
let value = biodivine_lib_bdd::BddValuation::all_false(var_count);
Ok(BddValuation { ctx, value })
}
Expand All @@ -101,7 +101,7 @@ impl BddValuation {
}
}

fn __richcmp__(&self, py: Python, other: &BddValuation, op: CompareOp) -> Py<PyAny> {
fn __richcmp__(&self, py: Python, other: &BddValuation, op: CompareOp) -> PyResult<Py<PyAny>> {
richcmp_eq_by_key(py, op, &self, &other, |x| &x.value)
}

Expand Down Expand Up @@ -323,7 +323,12 @@ impl BddPartialValuation {
}
}

fn __richcmp__(&self, py: Python, other: &BddPartialValuation, op: CompareOp) -> Py<PyAny> {
fn __richcmp__(
&self,
py: Python,
other: &BddPartialValuation,
op: CompareOp,
) -> PyResult<Py<PyAny>> {
richcmp_eq_by_key(py, op, &self, &other, |x| &x.value)
}

Expand Down
4 changes: 2 additions & 2 deletions src/bindings/lib_bdd/bdd_variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ impl BddVariable {
self.0.to_index()
}

pub fn __getnewargs__<'a>(&self, py: Python<'a>) -> Bound<'a, PyTuple> {
PyTuple::new_bound(py, [self.0.to_index()])
pub fn __getnewargs__<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyTuple>> {
PyTuple::new(py, [self.0.to_index()])
}
}
6 changes: 3 additions & 3 deletions src/bindings/lib_bdd/bdd_variable_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ impl BddVariableSet {
throw_type_error("Expected `int` or `list[str]`.")
}

fn __richcmp__(&self, py: Python, other: &Self, op: CompareOp) -> Py<PyAny> {
fn __richcmp__(&self, py: Python, other: &Self, op: CompareOp) -> PyResult<Py<PyAny>> {
richcmp_eq_by_key(py, op, &self, &other, |x| x.variable_names())
}

Expand All @@ -81,8 +81,8 @@ impl BddVariableSet {
format!("BddVariableSet({:?})", names)
}

fn __getnewargs__<'a>(&self, py: Python<'a>) -> Bound<'a, PyTuple> {
PyTuple::new_bound(py, [self.variable_names()])
fn __getnewargs__<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyTuple>> {
PyTuple::new(py, [self.variable_names()])
}

/// Return the number of variables managed by this `BddVariableSet`.
Expand Down
2 changes: 1 addition & 1 deletion src/bindings/lib_bdd/bdd_variable_set_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ impl BddVariableSetBuilder {
BddVariableSetBuilder(inner)
}

fn __richcmp__(&self, py: Python, other: &Self, op: CompareOp) -> Py<PyAny> {
fn __richcmp__(&self, py: Python, other: &Self, op: CompareOp) -> PyResult<Py<PyAny>> {
richcmp_eq_by_key(py, op, self, other, |x| x.__getstate__())
}

Expand Down
8 changes: 4 additions & 4 deletions src/bindings/lib_bdd/boolean_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ impl BooleanExpression {
hasher.finish()
}

fn __richcmp__(&self, py: Python, other: &Self, op: CompareOp) -> Py<PyAny> {
fn __richcmp__(&self, py: Python, other: &Self, op: CompareOp) -> PyResult<Py<PyAny>> {
richcmp_eq_by_key(py, op, &self, &other, |it| it.as_native())
}

Expand All @@ -71,11 +71,11 @@ impl BooleanExpression {
format!("BooleanExpression({:?})", self.__str__())
}

fn __getnewargs__<'a>(&self, py: Python<'a>) -> Bound<'a, PyTuple> {
fn __getnewargs__<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyTuple>> {
// Technically, this is a "different" expression because it is created with a completely new `root`,
// but it is much easier (and more transparent) than serializing the root expression and trying to figure
// out how to serialize a pointer into the AST.
PyTuple::new_bound(py, [self.__str__()])
PyTuple::new(py, [self.__str__()])
}

fn __root__(&self) -> BooleanExpression {
Expand All @@ -91,7 +91,7 @@ impl BooleanExpression {
) -> PyResult<bool> {
match (valuation, kwargs) {
(Some(_), Some(_)) => throw_type_error("Cannot use both explicit and named arguments."),
(None, None) => eval(self.as_native(), &PyDict::new_bound(py)),
(None, None) => eval(self.as_native(), &PyDict::new(py)),
(Some(v), None) | (None, Some(v)) => eval(self.as_native(), v),
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/bindings/lib_hctl_model_checker/hctl_formula.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ impl HctlFormula {
hasher.finish()
}

fn __richcmp__(&self, py: Python, other: &Self, op: CompareOp) -> Py<PyAny> {
fn __richcmp__(&self, py: Python, other: &Self, op: CompareOp) -> PyResult<Py<PyAny>> {
richcmp_eq_by_key(py, op, &self, &other, |it| it.as_native())
}

Expand All @@ -221,11 +221,11 @@ impl HctlFormula {
HctlFormula::from_native(self.value.clone())
}

fn __getnewargs__<'a>(&self, py: Python<'a>) -> Bound<'a, PyTuple> {
fn __getnewargs__<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyTuple>> {
// Technically, this is a "different" expression because it is created with a completely new `root`,
// but it is much easier (and more transparent) than serializing the root expression and trying to figure
// out how to serialize a pointer into the AST.
PyTuple::new_bound(py, [self.__str__()])
PyTuple::new(py, [self.__str__()])
}

fn __root__(&self) -> HctlFormula {
Expand Down
2 changes: 1 addition & 1 deletion src/bindings/lib_hctl_model_checker/model_checking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ impl ModelChecking {
)
})
.collect::<PyResult<Vec<Py<ColoredVertexSet>>>>()?;
let result_list = PyList::new_bound(py, result_iter);
let result_list = PyList::new(py, result_iter)?;

Ok(result_list.into_any())
}
Expand Down
2 changes: 1 addition & 1 deletion src/bindings/lib_param_bn/boolean_network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ impl BooleanNetwork {
)
}

pub fn __richcmp__(&self, py: Python, other: &Self, op: CompareOp) -> Py<PyAny> {
pub fn __richcmp__(&self, py: Python, other: &Self, op: CompareOp) -> PyResult<Py<PyAny>> {
// The BN and its underlying RG should be up-to-date, hence it should be ok to just compare the BN.
richcmp_eq_by_key(py, op, &self, &other, |x| x.as_native())
}
Expand Down
18 changes: 12 additions & 6 deletions src/bindings/lib_param_bn/model_annotation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::{throw_runtime_error, AsNative};
use macros::Wrapper;
use pyo3::basic::CompareOp;
use pyo3::prelude::*;
use pyo3::IntoPyObjectExt;

/*
I am sorry for this mess, but this seems to be the best solution at the moment
Expand Down Expand Up @@ -113,28 +114,33 @@ impl ModelAnnotation {
})
}

pub fn __richcmp__(&self, py: Python, other: &ModelAnnotation, op: CompareOp) -> Py<PyAny> {
pub fn __richcmp__(
&self,
py: Python,
other: &ModelAnnotation,
op: CompareOp,
) -> PyResult<Py<PyAny>> {
// First, check the paths.
match op {
CompareOp::Eq => {
if self.path != other.path {
return false.into_py(py);
return false.into_py_any(py);
}
}
CompareOp::Ne => {
if self.path != other.path {
return true.into_py(py);
return true.into_py_any(py);
}
}
_ => return py.NotImplemented(),
_ => return Ok(py.NotImplemented()),
}

// If paths match the operator, do the same thing with the root references.
// Here, we are not doing semantic checking, just pointer equivalence, which makes
// sure both objects reference the same underlying dictionary.
match op {
CompareOp::Eq => (self.root.as_ptr() == other.root.as_ptr()).into_py(py),
CompareOp::Ne => (self.root.as_ptr() != other.root.as_ptr()).into_py(py),
CompareOp::Eq => (self.root.as_ptr() == other.root.as_ptr()).into_py_any(py),
CompareOp::Ne => (self.root.as_ptr() != other.root.as_ptr()).into_py_any(py),
_ => unreachable!(),
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/bindings/lib_param_bn/parameter_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ impl ParameterId {
self.0.to_index()
}

pub fn __getnewargs__<'a>(&self, py: Python<'a>) -> Bound<'a, PyTuple> {
PyTuple::new_bound(py, [self.0.to_index()])
pub fn __getnewargs__<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyTuple>> {
PyTuple::new(py, [self.0.to_index()])
}
}
14 changes: 7 additions & 7 deletions src/bindings/lib_param_bn/regulatory_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ impl RegulatoryGraph {
)
}

fn __richcmp__(&self, py: Python, other: &Self, op: CompareOp) -> Py<PyAny> {
fn __richcmp__(&self, py: Python, other: &Self, op: CompareOp) -> PyResult<Py<PyAny>> {
richcmp_eq_by_key(py, op, &self, &other, |x| x.as_native())
}

Expand Down Expand Up @@ -238,7 +238,7 @@ impl RegulatoryGraph {
/// Return the list of all regulations (represented as `IdRegulation` dictionaries) that are currently
/// managed by this `RegulatoryGraph`.
pub fn regulations<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyList>> {
let result = PyList::empty_bound(py);
let result = PyList::empty(py);
for reg in self.as_native().regulations() {
let reg = Self::encode_regulation(py, reg)?;
result.append(reg)?;
Expand Down Expand Up @@ -762,14 +762,14 @@ impl RegulatoryGraph {
py: Python<'a>,
regulation: &biodivine_lib_param_bn::Regulation,
) -> PyResult<Bound<'a, PyDict>> {
let result = PyDict::new_bound(py);
let result = PyDict::new(py);
let source = VariableId::from(regulation.get_regulator());
let target = VariableId::from(regulation.get_target());
result.set_item("source", source.into_py(py))?;
result.set_item("target", target.into_py(py))?;
result.set_item("essential", regulation.is_observable().into_py(py))?;
result.set_item("source", source)?;
result.set_item("target", target)?;
result.set_item("essential", regulation.is_observable())?;
match regulation.get_monotonicity() {
None => result.set_item("sign", Option::<&str>::None.into_py(py))?,
None => result.set_item("sign", py.None())?,
Some(Monotonicity::Activation) => result.set_item("sign", "+")?,
Some(Monotonicity::Inhibition) => result.set_item("sign", "-")?,
}
Expand Down
10 changes: 6 additions & 4 deletions src/bindings/lib_param_bn/symbolic/asynchronous_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use biodivine_lib_param_bn::symbolic_async_graph::{GraphColors, SymbolicAsyncGra
use either::{Left, Right};
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList};
use pyo3::IntoPyObjectExt;
use std::collections::HashMap;

#[pyclass(module = "biodivine_aeon", frozen, subclass)]
Expand Down Expand Up @@ -420,18 +421,19 @@ impl AsynchronousGraph {
let set = if let Ok(set) = set.extract::<ColorSet>() {
self.as_native()
.transfer_colors_from(set.as_native(), original_ctx.as_native())
.map(|it| ColorSet::mk_native(self.ctx.clone(), it).into_py(py))
.map(|it| ColorSet::mk_native(self.ctx.clone(), it).into_py_any(py))
} else if let Ok(set) = set.extract::<VertexSet>() {
self.as_native()
.transfer_vertices_from(set.as_native(), original_ctx.as_native())
.map(|it| VertexSet::mk_native(self.ctx.clone(), it).into_py(py))
.map(|it| VertexSet::mk_native(self.ctx.clone(), it).into_py_any(py))
} else if let Ok(set) = set.extract::<ColoredVertexSet>() {
self.as_native()
.transfer_from(set.as_native(), original_ctx.as_native())
.map(|it| ColoredVertexSet::mk_native(self.ctx.clone(), it).into_py(py))
.map(|it| ColoredVertexSet::mk_native(self.ctx.clone(), it).into_py_any(py))
} else {
return throw_type_error("Expected `ColorSet`, `VartexSet`, or `ColoredVertexSet`.");
};
}
.transpose()?;
if let Some(set) = set {
Ok(set)
} else {
Expand Down
Loading

0 comments on commit fd16dde

Please sign in to comment.