diff --git a/wrappers/rust/icicle-core/src/poseidon/mod.rs b/wrappers/rust/icicle-core/src/poseidon/mod.rs index c4b4e7757..d5831f013 100644 --- a/wrappers/rust/icicle-core/src/poseidon/mod.rs +++ b/wrappers/rust/icicle-core/src/poseidon/mod.rs @@ -102,6 +102,18 @@ macro_rules! impl_poseidon_tests { check_poseidon_hash_sponge::<$field>(); } + #[test] + fn test_poseidon_hash_multi_device() { + initialize(); + test_utilities::test_set_main_device(); + let nof_devices = icicle_runtime::get_device_count().unwrap(); + if nof_devices > 1 { + check_poseidon_hash_multi_device::<$field>(); + } else { + println!("Skipping test_poseidon_hash_multi_device due to single device in the machine"); + } + } + #[test] fn test_poseidon_tree() { initialize(); diff --git a/wrappers/rust/icicle-core/src/poseidon/tests.rs b/wrappers/rust/icicle-core/src/poseidon/tests.rs index 440ebad76..83d190408 100644 --- a/wrappers/rust/icicle-core/src/poseidon/tests.rs +++ b/wrappers/rust/icicle-core/src/poseidon/tests.rs @@ -84,6 +84,55 @@ where } } +pub fn check_poseidon_hash_multi_device() +where + ::Config: PoseidonHasher + GenerateRandom, +{ + let t = 9; // t=9 is for Poseidon9 hash (t is the paper's terminology) + let inputs: Vec = F::Config::generate_random(t); + let mut outputs_main_0 = vec![F::zero(); 1]; + let mut outputs_main_1 = vec![F::zero(); 1]; + let mut outputs_ref = vec![F::zero(); 1]; + + test_utilities::test_set_ref_device(); + let poseidon_hasher_ref = Poseidon::new::(t as u32, None /*domain_tag*/).unwrap(); + + poseidon_hasher_ref + .hash( + HostSlice::from_slice(&inputs), + &HashConfig::default(), + HostSlice::from_mut_slice(&mut outputs_ref), + ) + .unwrap(); + + // initialize hasher on 2 devices + test_utilities::test_set_main_device_with_id(0); + let poseidon_hasher_main_dev_0 = Poseidon::new::(t as u32, None /*domain_tag*/).unwrap(); + test_utilities::test_set_main_device_with_id(1); + let poseidon_hasher_main_dev_1 = Poseidon::new::(t as u32, None /*domain_tag*/).unwrap(); + + // test device 1 + poseidon_hasher_main_dev_1 + .hash( + HostSlice::from_slice(&inputs), + &HashConfig::default(), + HostSlice::from_mut_slice(&mut outputs_main_1), + ) + .unwrap(); + assert_eq!(outputs_ref, outputs_main_1); + + // test device 0 + test_utilities::test_set_main_device_with_id(0); + poseidon_hasher_main_dev_0 + .hash( + HostSlice::from_slice(&inputs), + &HashConfig::default(), + HostSlice::from_mut_slice(&mut outputs_main_0), + ) + .unwrap(); + assert_eq!(outputs_ref, outputs_main_0); +} + pub fn check_poseidon_tree() where ::Config: PoseidonHasher,