Skip to content

Commit

Permalink
make GTElement also use parse_hex_string (just like PublicKey, Signat…
Browse files Browse the repository at this point in the history
…ure and PrivateKey). Add test cases to python tests
  • Loading branch information
arvidn committed Aug 5, 2024
1 parent 044bc77 commit 7a05d00
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 23 deletions.
32 changes: 9 additions & 23 deletions crates/chia-bls/src/gtelement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,39 +128,25 @@ impl GTElement {
mod pybindings {
use super::*;

use crate::parse_hex::parse_hex_string;
use chia_traits::{FromJsonDict, ToJsonDict};
use pyo3::{exceptions::PyValueError, prelude::*};
use pyo3::prelude::*;

impl ToJsonDict for GTElement {
fn to_json_dict(&self, py: Python<'_>) -> PyResult<PyObject> {
let bytes = self.to_bytes();
Ok(hex::encode(bytes).into_py(py))
Ok(("0x".to_string() + &hex::encode(bytes)).into_py(py))
}
}

impl FromJsonDict for GTElement {
fn from_json_dict(o: &Bound<'_, PyAny>) -> PyResult<Self> {
let s: String = o.extract()?;
if !s.starts_with("0x") {
return Err(PyValueError::new_err(
"bytes object is expected to start with 0x",
));
}
let s = &s[2..];
let buf = match hex::decode(s) {
Err(_) => {
return Err(PyValueError::new_err("invalid hex"));
}
Ok(v) => v,
};
if buf.len() != Self::SIZE {
return Err(PyValueError::new_err(format!(
"GTElement, invalid length {} expected {}",
buf.len(),
Self::SIZE
)));
}
Ok(Self::from_bytes(buf.as_slice().try_into().unwrap()))
Ok(Self::from_bytes(
parse_hex_string(o, Self::SIZE, "GTElement")?
.as_slice()
.try_into()
.unwrap(),
))
}
}
}
12 changes: 12 additions & 0 deletions tests/test_blspy_fidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import chia_rs
from random import getrandbits
import sys
from typing import Any, Type


def randbytes(n: int) -> bytes:
Expand Down Expand Up @@ -185,6 +186,17 @@ def test_bls() -> None:
# get_fingerprint()
assert pk1.get_fingerprint() == pk2.get_fingerprint()

obj: Any
klass: Any
for obj, klass in [(pk2, G1Element), (sig2, G2Element), (sk2, PrivateKey), (pair2, chia_rs.GTElement)]:
print(f"{klass}")
# to_json_dict
expected_json = "0x" + bytes(obj).hex()
assert obj.to_json_dict() == expected_json
# from_json_dict
assert obj == klass.from_json_dict(expected_json)
# binary blobs are also accepted in JSON dicts
assert obj == klass.from_json_dict(bytes(obj))

# ------------------------------------- 8< ----------------------------------
#
Expand Down

0 comments on commit 7a05d00

Please sign in to comment.