diff --git a/crates/chia-protocol/src/bytes.rs b/crates/chia-protocol/src/bytes.rs index e1d1bf00b..87426e71b 100644 --- a/crates/chia-protocol/src/bytes.rs +++ b/crates/chia-protocol/src/bytes.rs @@ -410,6 +410,10 @@ impl ChiaToPython for BytesImpl { let bytes_module = PyModule::import_bound(py, "chia_rs.sized_bytes")?; let ty = bytes_module.getattr("bytes32")?; ty.call1((self.0.into_py(py),)) + } else if N == 48 { + let bytes_module = PyModule::import_bound(py, "chia_rs.sized_bytes")?; + let ty = bytes_module.getattr("bytes48")?; + ty.call1((self.0.into_py(py),)) } else { Ok(PyBytes::new_bound(py, &self.0).into_any()) } diff --git a/crates/chia_py_streamable_macro/src/lib.rs b/crates/chia_py_streamable_macro/src/lib.rs index c95bd8e4e..9ab532b77 100644 --- a/crates/chia_py_streamable_macro/src/lib.rs +++ b/crates/chia_py_streamable_macro/src/lib.rs @@ -216,10 +216,15 @@ pub fn py_streamable_macro(input: proc_macro::TokenStream) -> proc_macro::TokenS } } - pub fn get_hash<'p>(&self, py: pyo3::Python<'p>) -> pyo3::PyResult> { + pub fn get_hash<'p>(&self, py: pyo3::Python<'p>) -> pyo3::PyResult> { + use pyo3::IntoPy; + use pyo3::types::PyModule; + use pyo3::prelude::PyAnyMethods; let mut ctx = clvmr::sha2::Sha256::new(); #crate_name::Streamable::update_digest(self, &mut ctx); - Ok(pyo3::types::PyBytes::new_bound(py, &ctx.finalize())) + let bytes_module = PyModule::import_bound(py, "chia_rs.sized_bytes")?; + let ty = bytes_module.getattr("bytes32")?; + ty.call1((&ctx.finalize().into_py(py),)) } #[pyo3(name = "to_bytes")] pub fn py_to_bytes<'p>(&self, py: pyo3::Python<'p>) -> pyo3::PyResult> { diff --git a/tests/test_streamable.py b/tests/test_streamable.py index d13e06ea1..a19fbb94f 100644 --- a/tests/test_streamable.py +++ b/tests/test_streamable.py @@ -11,8 +11,10 @@ from chia_rs.sized_bytes import bytes32 import pytest import copy +import random -sk = AugSchemeMPL.key_gen(bytes32.random()) +rng = random.Random(1337) +sk = AugSchemeMPL.key_gen(bytes32.random(rng)) pk = sk.get_g1() coin = b"bcbcbcbcbcbcbcbcbcbcbcbcbcbcbcbc" @@ -71,6 +73,9 @@ def test_hash_spend() -> None: assert type(c) is int assert b != c + assert a1.get_hash() == bytes32.fromhex("2b72a6614da0368147fa6cb785445d6569603e38f2de230e5f30692bf6410245") + assert str(a1.get_hash()) == "2b72a6614da0368147fa6cb785445d6569603e38f2de230e5f30692bf6410245" + def test_hash_spend_bundle_conditions() -> None: