From 011a54a5d18734f4538f4ccc5d2b4e9bc9c92be8 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Thu, 9 Jan 2025 21:41:10 -0500 Subject: [PATCH] and add python exceptions --- crates/chia-datalayer/src/merkle.rs | 115 +++++++++++++++++++++------- tests/test_datalayer.py | 11 ++- wheel/src/api.rs | 6 +- 3 files changed, 101 insertions(+), 31 deletions(-) diff --git a/crates/chia-datalayer/src/merkle.rs b/crates/chia-datalayer/src/merkle.rs index 6e29fc761..b4d57b85d 100644 --- a/crates/chia-datalayer/src/merkle.rs +++ b/crates/chia-datalayer/src/merkle.rs @@ -62,7 +62,6 @@ impl std::fmt::Display for KvId { } } -// ($enum_name:ident, $($variant_name:tt, $variant_string:literal, $($variant_type:tt,*))) => { macro_rules! create_errors { ( $enum:ident, @@ -70,6 +69,7 @@ macro_rules! create_errors { $( ( $name:ident, + $python_name:ident, $string:literal, ( $( @@ -89,6 +89,35 @@ macro_rules! create_errors { $name($($type_,)*), )* } + + #[cfg(feature = "py-bindings")] + pub mod python_exceptions { + use pyo3::prelude::*; + + $( + pyo3::create_exception!(chia_rs.chia_rs.datalayer, $python_name, pyo3::exceptions::PyException); + )* + + pub fn add_to_module(py: Python<'_>, module: &Bound<'_, PyModule>) -> PyResult<()> { + $( + module.add(stringify!($python_name), py.get_type::<$python_name>())?; + )* + + Ok(()) + } + } + + #[cfg(feature = "py-bindings")] + impl From for pyo3::PyErr { + fn from(err: Error) -> pyo3::PyErr { + let message = err.to_string(); + match err { + $( + Error::$name(..) => python_exceptions::$python_name::new_err(message), + )* + } + } + } } } @@ -98,74 +127,112 @@ create_errors!( // TODO: don't use String here ( FailedLoadingMetadata, + FailedLoadingMetadataError, "failed loading metadata: {0}", (String) ), // TODO: don't use String here - (FailedLoadingNode, "failed loading node: {0}", (String)), + ( + FailedLoadingNode, + FailedLoadingNodeError, + "failed loading node: {0}", + (String) + ), ( InvalidBlobLength, + InvalidBlobLengthError, "blob length must be a multiple of block count, found extra bytes: {0}", (usize) ), - (KeyAlreadyPresent, "key already present", ()), + ( + KeyAlreadyPresent, + KeyAlreadyPresentError, + "key already present", + () + ), ( UnableToInsertAsRootOfNonEmptyTree, + UnableToInsertAsRootOfNonEmptyTreeError, "requested insertion at root but tree not empty", () ), - (UnableToFindALeaf, "unable to find a leaf", ()), - (UnknownKey, "unknown key: {0:?}", (KvId)), + ( + UnableToFindALeaf, + UnableToFindALeafError, + "unable to find a leaf", + () + ), + (UnknownKey, UnknownKeyError, "unknown key: {0:?}", (KvId)), ( IntegrityKeyNotInCache, + IntegrityKeyNotInCacheError, "key not in key to index cache: {0:?}", (KvId) ), ( IntegrityKeyToIndexCacheIndex, + IntegrityKeyToIndexCacheIndexError, "key to index cache for {0:?} should be {1:?} got: {2:?}", (KvId, TreeIndex, TreeIndex) ), ( IntegrityParentChildMismatch, + IntegrityParentChildMismatchError, "parent and child relationship mismatched: {0:?}", (TreeIndex) ), ( IntegrityKeyToIndexCacheLength, + IntegrityKeyToIndexCacheLengthError, "found {0:?} leaves but key to index cache length is: {1}", (usize, usize) ), ( IntegrityUnmatchedChildParentRelationships, + IntegrityUnmatchedChildParentRelationshipsError, "unmatched parent -> child references found: {0}", (usize) ), ( IntegrityTotalNodeCount, + IntegrityTotalNodeCountError, "expected total node count {0:?} found: {1:?}", (TreeIndex, usize) ), ( ZeroLengthSeedNotAllowed, + ZeroLengthSeedNotAllowedError, "zero-length seed bytes not allowed", () ), ( BlockIndexOutOfRange, + BlockIndexOutOfRangeError, "block index out of range: {0:?}", (TreeIndex) ), - (NodeNotALeaf, "node not a leaf: {0:?}", (InternalNode)), + ( + NodeNotALeaf, + NodeNotALeafError, + "node not a leaf: {0:?}", + (InternalNode) + ), ( Streaming, + StreamingError, "from streamable: {0:?}", (chia_traits::chia_error::Error) ), - (IndexIsNotAChild, "index not a child: {0}", (TreeIndex)), - (CycleFound, "cycle found", ()), + ( + IndexIsNotAChild, + IndexIsNotAChildError, + "index not a child: {0}", + (TreeIndex) + ), + (CycleFound, CycleFoundError, "cycle found", ()), ( BlockIndexOutOfBounds, + BlockIndexOutOfBoundsError, "block index out of bounds: {0}", (TreeIndex) ) @@ -1267,7 +1334,7 @@ impl MerkleBlob { let slice = unsafe { std::slice::from_raw_parts(blob.buf_ptr() as *const u8, blob.len_bytes()) }; - Self::new(Vec::from(slice)).map_err(|e| PyValueError::new_err(e.to_string())) + Ok(Self::new(Vec::from(slice))?) } #[pyo3(name = "insert", signature = (key, value, hash, reference_kid = None, side = None))] @@ -1286,39 +1353,37 @@ impl MerkleBlob { index: *self .key_to_index .get(&key) + // TODO: use a specific error .ok_or(PyValueError::new_err(format!( "unknown key id passed as insert location reference: {key}" )))?, side: Side::from_bytes(&[side])?, }, _ => { + // TODO: use a specific error return Err(PyValueError::new_err( "must specify neither or both of reference_kid and side", )); } }; - self.insert(key, value, &hash, insert_location) - .map_err(|e| PyValueError::new_err(e.to_string()))?; + self.insert(key, value, &hash, insert_location)?; Ok(()) } #[pyo3(name = "delete")] pub fn py_delete(&mut self, key: KvId) -> PyResult<()> { - self.delete(key) - .map_err(|e| PyValueError::new_err(e.to_string())) + Ok(self.delete(key)?) } #[pyo3(name = "get_raw_node")] pub fn py_get_raw_node(&mut self, index: TreeIndex) -> PyResult { - self.get_node(index) - .map_err(|e| PyValueError::new_err(e.to_string())) + Ok(self.get_node(index)?) } #[pyo3(name = "calculate_lazy_hashes")] pub fn py_calculate_lazy_hashes(&mut self) -> PyResult<()> { - self.calculate_lazy_hashes() - .map_err(|e| PyValueError::new_err(e.to_string())) + Ok(self.calculate_lazy_hashes()?) } #[pyo3(name = "get_lineage_with_indexes")] @@ -1329,10 +1394,7 @@ impl MerkleBlob { ) -> PyResult { let list = pyo3::types::PyList::empty(py); - for (index, node) in self - .get_lineage_with_indexes(index) - .map_err(|e| PyValueError::new_err(e.to_string()))? - { + for (index, node) in self.get_lineage_with_indexes(index)? { use pyo3::types::PyListMethods; list.append((index.into_pyobject(py)?, node.into_pyobject(py)?))?; } @@ -1346,7 +1408,7 @@ impl MerkleBlob { for item in MerkleBlobParentFirstIterator::new(&self.blob) { use pyo3::types::PyListMethods; - let (index, block) = item.map_err(|e| PyValueError::new_err(e.to_string()))?; + let (index, block) = item?; list.append((index.into_pyobject(py)?, block.node.into_pyobject(py)?))?; } @@ -1369,10 +1431,9 @@ impl MerkleBlob { return Ok(None); } - let block = self - .get_block(index) - .map_err(|e| PyValueError::new_err(e.to_string()))?; + let block = self.get_block(index)?; if block.metadata.dirty { + // TODO: use a specific error return Err(PyValueError::new_err("root hash is dirty")); } @@ -1386,13 +1447,13 @@ impl MerkleBlob { hashes: Vec, ) -> PyResult<()> { if keys_values.len() != hashes.len() { + // TODO: use a specific error return Err(PyValueError::new_err( "key/value and hash collection lengths must match", )); } - self.batch_insert(&mut zip(keys_values, hashes)) - .map_err(|e| PyValueError::new_err(e.to_string()))?; + self.batch_insert(&mut zip(keys_values, hashes))?; Ok(()) } diff --git a/tests/test_datalayer.py b/tests/test_datalayer.py index 78de456c1..991d9d530 100644 --- a/tests/test_datalayer.py +++ b/tests/test_datalayer.py @@ -1,8 +1,9 @@ -from chia_rs.datalayer import LeafNode, MerkleBlob +import pytest + +from chia_rs.datalayer import InvalidBlobLengthError, LeafNode, MerkleBlob from chia_rs.sized_bytes import bytes32 from chia_rs.sized_ints import int64, uint8 - def test_merkle_blob(): blob = bytes.fromhex( "000100770a5d50f980316e3a856b2f0447e1c1285064cd301c731e5b16c16d187d0ff90000000400000002000000000000000000000000010001000000060c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b00000000000000010000000000000001010001000000000c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b00000000000000000000000000000000010001000000040c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b0000000000000002000000000000000200010100000000770a5d50f980316e3a856b2f0447e1c1285064cd301c731e5b16c16d187d0ff900000003000000060000000000000000010001000000060c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b0000000000000003000000000000000300000100000004770a5d50f980316e3a856b2f0447e1c1285064cd301c731e5b16c16d187d0ff900000005000000010000000000000000" @@ -52,3 +53,9 @@ def test_checking_coverage() -> None: if isinstance(node, LeafNode) } assert keys == set(range(count)) + + +def test_invalid_blob_length_raised() -> None: + """Mostly verifying that the exceptions are available and raise.""" + with pytest.raises(InvalidBlobLengthError): + MerkleBlob(blob=b"\x00") diff --git a/wheel/src/api.rs b/wheel/src/api.rs index 08eafeaf3..fa079c4f0 100644 --- a/wheel/src/api.rs +++ b/wheel/src/api.rs @@ -78,8 +78,6 @@ use chia_bls::{ Signature, }; -use chia_datalayer::{InternalNode, LeafNode, MerkleBlob}; - #[pyfunction] pub fn compute_merkle_set_root<'p>( py: Python<'p>, @@ -648,6 +646,8 @@ pub fn chia_rs(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { } pub fn add_datalayer_submodule(py: Python<'_>, parent: &Bound<'_, PyModule>) -> PyResult<()> { + use chia_datalayer::*; + let datalayer = PyModule::new(py, "datalayer")?; parent.add_submodule(&datalayer)?; @@ -655,6 +655,8 @@ pub fn add_datalayer_submodule(py: Python<'_>, parent: &Bound<'_, PyModule>) -> datalayer.add_class::()?; datalayer.add_class::()?; + python_exceptions::add_to_module(py, &datalayer)?; + // https://github.com/PyO3/pyo3/issues/1517#issuecomment-808664021 // https://github.com/PyO3/pyo3/issues/759 py.import("sys")?