From 73218595c603cabde45cf8ddbc51e80e70e60e38 Mon Sep 17 00:00:00 2001 From: samkim-crypto Date: Fri, 19 Jan 2024 08:40:42 +0900 Subject: [PATCH] [zk-token-sdk] Restrict Edwards and Ristretto multiscalar multiplication vector length to at most 512 (#34763) * restrict curve25519 multiscalar multiplication vector length to 512 * add syscall tests for msm vector length * add new feature gate `curve25519_restrict_msm_length` * update tests for feature new gate * Update programs/bpf_loader/src/syscalls/mod.rs Co-authored-by: Trent Nelson * remove length guard on the multisicalar mult lib function --------- Co-authored-by: Trent Nelson --- programs/bpf_loader/src/syscalls/mod.rs | 127 ++++++++++++++++++++++++ sdk/src/feature_set.rs | 5 + 2 files changed, 132 insertions(+) diff --git a/programs/bpf_loader/src/syscalls/mod.rs b/programs/bpf_loader/src/syscalls/mod.rs index 3e6562b8ed7b8a..42853d7a6fd503 100644 --- a/programs/bpf_loader/src/syscalls/mod.rs +++ b/programs/bpf_loader/src/syscalls/mod.rs @@ -1139,6 +1139,17 @@ declare_builtin_function!( use solana_zk_token_sdk::curve25519::{ curve_syscall_traits::*, edwards, ristretto, scalar, }; + + let restrict_msm_length = invoke_context + .feature_set + .is_active(&feature_set::curve25519_restrict_msm_length::id()); + #[allow(clippy::collapsible_if)] + if restrict_msm_length { + if points_len > 512 { + return Err(Box::new(SyscallError::InvalidLength)); + } + } + match curve_id { CURVE25519_EDWARDS => { let cost = invoke_context @@ -3146,6 +3157,122 @@ mod tests { assert_eq!(expected_product, result_point); } + #[test] + fn test_syscall_multiscalar_multiplication_maximum_length_exceeded() { + use solana_zk_token_sdk::curve25519::curve_syscall_traits::{ + CURVE25519_EDWARDS, CURVE25519_RISTRETTO, + }; + + let config = Config::default(); + prepare_mockup!(invoke_context, program_id, bpf_loader::id()); + + let scalar: [u8; 32] = [ + 254, 198, 23, 138, 67, 243, 184, 110, 236, 115, 236, 205, 205, 215, 79, 114, 45, 250, + 78, 137, 3, 107, 136, 237, 49, 126, 117, 223, 37, 191, 88, 6, + ]; + let scalars = [scalar; 513]; + let scalars_va = 0x100000000; + + let edwards_point: [u8; 32] = [ + 252, 31, 230, 46, 173, 95, 144, 148, 158, 157, 63, 10, 8, 68, 58, 176, 142, 192, 168, + 53, 61, 105, 194, 166, 43, 56, 246, 236, 28, 146, 114, 133, + ]; + let edwards_points = [edwards_point; 513]; + let edwards_points_va = 0x200000000; + + let ristretto_point: [u8; 32] = [ + 130, 35, 97, 25, 18, 199, 33, 239, 85, 143, 119, 111, 49, 51, 224, 40, 167, 185, 240, + 179, 25, 194, 213, 41, 14, 155, 104, 18, 181, 197, 15, 112, + ]; + let ristretto_points = [ristretto_point; 513]; + let ristretto_points_va = 0x300000000; + + let mut result_point: [u8; 32] = [0; 32]; + let result_point_va = 0x400000000; + + let mut memory_mapping = MemoryMapping::new( + vec![ + MemoryRegion::new_readonly(bytes_of_slice(&scalars), scalars_va), + MemoryRegion::new_readonly(bytes_of_slice(&edwards_points), edwards_points_va), + MemoryRegion::new_readonly(bytes_of_slice(&ristretto_points), ristretto_points_va), + MemoryRegion::new_writable(bytes_of_slice_mut(&mut result_point), result_point_va), + ], + &config, + &SBPFVersion::V2, + ) + .unwrap(); + + // test Edwards + invoke_context.mock_set_remaining(500_000); + let result = SyscallCurveMultiscalarMultiplication::rust( + &mut invoke_context, + CURVE25519_EDWARDS, + scalars_va, + edwards_points_va, + 512, // below maximum vector length + result_point_va, + &mut memory_mapping, + ); + + assert_eq!(0, result.unwrap()); + let expected_product = [ + 20, 146, 226, 37, 22, 61, 86, 249, 208, 40, 38, 11, 126, 101, 10, 82, 81, 77, 88, 209, + 15, 76, 82, 251, 180, 133, 84, 243, 162, 0, 11, 145, + ]; + assert_eq!(expected_product, result_point); + + invoke_context.mock_set_remaining(500_000); + let result = SyscallCurveMultiscalarMultiplication::rust( + &mut invoke_context, + CURVE25519_EDWARDS, + scalars_va, + edwards_points_va, + 513, // above maximum vector length + result_point_va, + &mut memory_mapping, + ) + .unwrap_err() + .downcast::() + .unwrap(); + + assert_eq!(*result, SyscallError::InvalidLength); + + // test Ristretto + invoke_context.mock_set_remaining(500_000); + let result = SyscallCurveMultiscalarMultiplication::rust( + &mut invoke_context, + CURVE25519_RISTRETTO, + scalars_va, + ristretto_points_va, + 512, // below maximum vector length + result_point_va, + &mut memory_mapping, + ); + + assert_eq!(0, result.unwrap()); + let expected_product = [ + 146, 224, 127, 193, 252, 64, 196, 181, 246, 104, 27, 116, 183, 52, 200, 239, 2, 108, + 21, 27, 97, 44, 95, 65, 26, 218, 223, 39, 197, 132, 51, 49, + ]; + assert_eq!(expected_product, result_point); + + invoke_context.mock_set_remaining(500_000); + let result = SyscallCurveMultiscalarMultiplication::rust( + &mut invoke_context, + CURVE25519_RISTRETTO, + scalars_va, + ristretto_points_va, + 513, // above maximum vector length + result_point_va, + &mut memory_mapping, + ) + .unwrap_err() + .downcast::() + .unwrap(); + + assert_eq!(*result, SyscallError::InvalidLength); + } + fn create_filled_type(zero_init: bool) -> T { let mut val = T::default(); let p = &mut val as *mut _ as *mut u8; diff --git a/sdk/src/feature_set.rs b/sdk/src/feature_set.rs index 6c3a2bfb3b4b8f..f2e9c63ff1b2c9 100644 --- a/sdk/src/feature_set.rs +++ b/sdk/src/feature_set.rs @@ -146,6 +146,10 @@ pub mod curve25519_syscall_enabled { solana_sdk::declare_id!("7rcw5UtqgDTBBv2EcynNfYckgdAaH1MAsCjKgXMkN7Ri"); } +pub mod curve25519_restrict_msm_length { + solana_sdk::declare_id!("eca6zf6JJRjQsYYPkBHF3N32MTzur4n2WL4QiiacPCL"); +} + pub mod versioned_tx_message_enabled { solana_sdk::declare_id!("3KZZ6Ks1885aGBQ45fwRcPXVBCtzUvxhUTkwKMR41Tca"); } @@ -950,6 +954,7 @@ lazy_static! { (disable_bpf_loader_instructions::id(), "disable bpf loader management instructions #34194"), (deprecate_executable_meta_update_in_bpf_loader::id(), "deprecate executable meta flag update in bpf loader #34194"), (enable_zk_proof_from_account::id(), "Enable zk token proof program to read proof from accounts instead of instruction data #34750"), + (curve25519_restrict_msm_length::id(), "restrict curve25519 multiscalar multiplication vector lengths #34763"), /*************** ADD NEW FEATURES HERE ***************/ ] .iter()