Skip to content

Commit

Permalink
Add Python binding to scalar_multiply
Browse files Browse the repository at this point in the history
  • Loading branch information
Rigidity committed Jul 8, 2024
1 parent 4e6bfd8 commit 6702403
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 2 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions crates/chia-bls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ repository = "https://github.com/Chia-Network/chia_rs"
workspace = true

[features]
py-bindings = ["dep:pyo3", "chia_py_streamable_macro", "chia-traits/py-bindings"]
py-bindings = ["dep:pyo3", "dep:num-bigint", "chia_py_streamable_macro", "chia-traits/py-bindings"]
arbitrary = ["dep:arbitrary"]

[dependencies]
Expand All @@ -24,9 +24,10 @@ hkdf = { workspace = true }
blst = { workspace = true }
hex = { workspace = true }
thiserror = { workspace = true }
pyo3 = { workspace = true, features = ["multiple-pymethods"], optional = true }
pyo3 = { workspace = true, features = ["multiple-pymethods", "num-bigint"], optional = true }
arbitrary = { workspace = true, optional = true }
lru = { workspace = true }
num-bigint = { workspace = true, optional = true }

[dev-dependencies]
rand = { workspace = true }
Expand Down
13 changes: 13 additions & 0 deletions crates/chia-bls/src/public_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,9 @@ pub fn hash_to_g1_with_dst(msg: &[u8], dst: &[u8]) -> PublicKey {
PublicKey(p1)
}

#[cfg(feature = "py-bindings")]
use num_bigint::BigInt;

#[cfg(feature = "py-bindings")]
#[pyo3::pymethods]
impl PublicKey {
Expand Down Expand Up @@ -340,6 +343,16 @@ impl PublicKey {
pub fn __iadd__(&mut self, rhs: &Self) {
*self += rhs;
}

#[must_use]
#[pyo3(name = "scalar_multiply")]
#[allow(clippy::needless_pass_by_value)]
pub fn py_scalar_multiply(&self, scalar: BigInt) -> Self {
let mut clone = *self;
let bytes = scalar.to_signed_bytes_be();
clone.scalar_multiply(&bytes);
clone
}
}

#[cfg(feature = "py-bindings")]
Expand Down
34 changes: 34 additions & 0 deletions crates/chia-bls/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,9 @@ pub fn sign<Msg: AsRef<[u8]>>(sk: &SecretKey, msg: Msg) -> Signature {
sign_raw(sk, aug_msg)
}

#[cfg(feature = "py-bindings")]
use num_bigint::BigInt;

#[cfg(feature = "py-bindings")]
#[pyo3::pymethods]
impl Signature {
Expand Down Expand Up @@ -509,6 +512,16 @@ impl Signature {
pub fn __iadd__(&mut self, rhs: &Self) {
*self += rhs;
}

#[must_use]
#[pyo3(name = "scalar_multiply")]
#[allow(clippy::needless_pass_by_value)]
pub fn py_scalar_multiply(&self, scalar: BigInt) -> Self {
let mut clone = self.clone();
let bytes = scalar.to_signed_bytes_be();
clone.scalar_multiply(&bytes);
clone
}
}

#[cfg(feature = "py-bindings")]
Expand Down Expand Up @@ -1203,6 +1216,27 @@ mod tests {
}
}

#[test]
fn test_scalar_multiply_large() {
let mut rng = StdRng::seed_from_u64(1337);
let mut data = [0; 4198];
rng.fill(data.as_mut_slice());
let seed: [u8; 32] = rng.gen();
let msg: [u8; 32] = rng.gen();
let sk = SecretKey::from_seed(&seed);
let mut g2 = sign(&sk, msg);
g2.scalar_multiply(&data);
assert_eq!(
hex::encode(g2.to_bytes()),
"
ae4d384a25c51283b8be8c6546d23e1555995c87fbcc0fe12169b63a052c4a1f
c9d3a020e8e010d4be619e3c0980a1f213b951fe75375012c5df6a690548a637
ef25f8da1e4f9c8f4d2062531ce688c040258a76543831abde774872e00af74b
"
.replace([' ', '\n'], "")
);
}

#[test]
fn test_hash_to_g2_different_dst() {
const DEFAULT_DST: &[u8] = b"BLS_SIG_BLS12381G2_XMD:SHA-256_SSWU_RO_AUG_";
Expand Down
2 changes: 2 additions & 0 deletions wheel/generate_type_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ def __init__(
"def __add__(self, other: G1Element) -> G1Element: ...",
"def __iadd__(self, other: G1Element) -> G1Element: ...",
"def derive_unhardened(self, int) -> G1Element: ...",
"def scalar_multiply(self, value: int) -> G1Element: ...",
],
)
print_class(
Expand All @@ -404,6 +405,7 @@ def __init__(
"def __str__(self) -> str: ...",
"def __add__(self, other: G2Element) -> G2Element: ...",
"def __iadd__(self, other: G2Element) -> G2Element: ...",
"def scalar_multiply(self, value: int) -> G2Element: ...",
],
)
print_class(
Expand Down
1 change: 1 addition & 0 deletions wheel/python/chia_rs/chia_rs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ class G2Element:
def __str__(self) -> str: ...
def __add__(self, other: G2Element) -> G2Element: ...
def __iadd__(self, other: G2Element) -> G2Element: ...
def scalar_multiply(self, value: int) -> G2Element: ...
def __init__(
self
) -> None: ...
Expand Down

0 comments on commit 6702403

Please sign in to comment.