From 6702403c4f38160eba00fe79290cc6844ab6d550 Mon Sep 17 00:00:00 2001 From: Rigidity Date: Mon, 8 Jul 2024 11:23:07 -0400 Subject: [PATCH] Add Python binding to scalar_multiply --- Cargo.lock | 1 + crates/chia-bls/Cargo.toml | 5 +++-- crates/chia-bls/src/public_key.rs | 13 ++++++++++++ crates/chia-bls/src/signature.rs | 34 +++++++++++++++++++++++++++++++ wheel/generate_type_stubs.py | 2 ++ wheel/python/chia_rs/chia_rs.pyi | 1 + 6 files changed, 54 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9a4c64dea..8c64b4bfa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -320,6 +320,7 @@ dependencies = [ "hex", "hkdf", "lru", + "num-bigint", "pyo3", "rand", "rstest", diff --git a/crates/chia-bls/Cargo.toml b/crates/chia-bls/Cargo.toml index 8bb42ce23..b4ffa35ef 100644 --- a/crates/chia-bls/Cargo.toml +++ b/crates/chia-bls/Cargo.toml @@ -12,7 +12,7 @@ repository = "https://github.com/Chia-Network/chia_rs" workspace = true [features] -py-bindings = ["dep:pyo3", "chia_py_streamable_macro", "chia-traits/py-bindings"] +py-bindings = ["dep:pyo3", "dep:num-bigint", "chia_py_streamable_macro", "chia-traits/py-bindings"] arbitrary = ["dep:arbitrary"] [dependencies] @@ -24,9 +24,10 @@ hkdf = { workspace = true } blst = { workspace = true } hex = { workspace = true } thiserror = { workspace = true } -pyo3 = { workspace = true, features = ["multiple-pymethods"], optional = true } +pyo3 = { workspace = true, features = ["multiple-pymethods", "num-bigint"], optional = true } arbitrary = { workspace = true, optional = true } lru = { workspace = true } +num-bigint = { workspace = true, optional = true } [dev-dependencies] rand = { workspace = true } diff --git a/crates/chia-bls/src/public_key.rs b/crates/chia-bls/src/public_key.rs index 867e1203d..1543dc9a5 100644 --- a/crates/chia-bls/src/public_key.rs +++ b/crates/chia-bls/src/public_key.rs @@ -296,6 +296,9 @@ pub fn hash_to_g1_with_dst(msg: &[u8], dst: &[u8]) -> PublicKey { PublicKey(p1) } +#[cfg(feature = "py-bindings")] +use num_bigint::BigInt; + #[cfg(feature = "py-bindings")] #[pyo3::pymethods] impl PublicKey { @@ -340,6 +343,16 @@ impl PublicKey { pub fn __iadd__(&mut self, rhs: &Self) { *self += rhs; } + + #[must_use] + #[pyo3(name = "scalar_multiply")] + #[allow(clippy::needless_pass_by_value)] + pub fn py_scalar_multiply(&self, scalar: BigInt) -> Self { + let mut clone = *self; + let bytes = scalar.to_signed_bytes_be(); + clone.scalar_multiply(&bytes); + clone + } } #[cfg(feature = "py-bindings")] diff --git a/crates/chia-bls/src/signature.rs b/crates/chia-bls/src/signature.rs index 2031757a0..83815f4bd 100644 --- a/crates/chia-bls/src/signature.rs +++ b/crates/chia-bls/src/signature.rs @@ -475,6 +475,9 @@ pub fn sign>(sk: &SecretKey, msg: Msg) -> Signature { sign_raw(sk, aug_msg) } +#[cfg(feature = "py-bindings")] +use num_bigint::BigInt; + #[cfg(feature = "py-bindings")] #[pyo3::pymethods] impl Signature { @@ -509,6 +512,16 @@ impl Signature { pub fn __iadd__(&mut self, rhs: &Self) { *self += rhs; } + + #[must_use] + #[pyo3(name = "scalar_multiply")] + #[allow(clippy::needless_pass_by_value)] + pub fn py_scalar_multiply(&self, scalar: BigInt) -> Self { + let mut clone = self.clone(); + let bytes = scalar.to_signed_bytes_be(); + clone.scalar_multiply(&bytes); + clone + } } #[cfg(feature = "py-bindings")] @@ -1203,6 +1216,27 @@ mod tests { } } + #[test] + fn test_scalar_multiply_large() { + let mut rng = StdRng::seed_from_u64(1337); + let mut data = [0; 4198]; + rng.fill(data.as_mut_slice()); + let seed: [u8; 32] = rng.gen(); + let msg: [u8; 32] = rng.gen(); + let sk = SecretKey::from_seed(&seed); + let mut g2 = sign(&sk, msg); + g2.scalar_multiply(&data); + assert_eq!( + hex::encode(g2.to_bytes()), + " + ae4d384a25c51283b8be8c6546d23e1555995c87fbcc0fe12169b63a052c4a1f + c9d3a020e8e010d4be619e3c0980a1f213b951fe75375012c5df6a690548a637 + ef25f8da1e4f9c8f4d2062531ce688c040258a76543831abde774872e00af74b + " + .replace([' ', '\n'], "") + ); + } + #[test] fn test_hash_to_g2_different_dst() { const DEFAULT_DST: &[u8] = b"BLS_SIG_BLS12381G2_XMD:SHA-256_SSWU_RO_AUG_"; diff --git a/wheel/generate_type_stubs.py b/wheel/generate_type_stubs.py index 8ae0938de..313931ae7 100644 --- a/wheel/generate_type_stubs.py +++ b/wheel/generate_type_stubs.py @@ -389,6 +389,7 @@ def __init__( "def __add__(self, other: G1Element) -> G1Element: ...", "def __iadd__(self, other: G1Element) -> G1Element: ...", "def derive_unhardened(self, int) -> G1Element: ...", + "def scalar_multiply(self, value: int) -> G1Element: ...", ], ) print_class( @@ -404,6 +405,7 @@ def __init__( "def __str__(self) -> str: ...", "def __add__(self, other: G2Element) -> G2Element: ...", "def __iadd__(self, other: G2Element) -> G2Element: ...", + "def scalar_multiply(self, value: int) -> G2Element: ...", ], ) print_class( diff --git a/wheel/python/chia_rs/chia_rs.pyi b/wheel/python/chia_rs/chia_rs.pyi index f1c815408..f55607e05 100644 --- a/wheel/python/chia_rs/chia_rs.pyi +++ b/wheel/python/chia_rs/chia_rs.pyi @@ -158,6 +158,7 @@ class G2Element: def __str__(self) -> str: ... def __add__(self, other: G2Element) -> G2Element: ... def __iadd__(self, other: G2Element) -> G2Element: ... + def scalar_multiply(self, value: int) -> G2Element: ... def __init__( self ) -> None: ...