Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[zk-token-sdk] Restrict Edwards and Ristretto multiscalar multiplication vector length to at most 512 #34763

Merged
merged 6 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions programs/bpf_loader/src/syscalls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1139,6 +1139,17 @@ declare_builtin_function!(
use solana_zk_token_sdk::curve25519::{
curve_syscall_traits::*, edwards, ristretto, scalar,
};

let restrict_msm_length = invoke_context
.feature_set
.is_active(&feature_set::curve25519_restrict_msm_length::id());
#[allow(clippy::collapsible_if)]
if restrict_msm_length {
if points_len > 512 {
return Err(Box::new(SyscallError::InvalidLength));
}
}

match curve_id {
CURVE25519_EDWARDS => {
let cost = invoke_context
Expand Down Expand Up @@ -3146,6 +3157,122 @@ mod tests {
assert_eq!(expected_product, result_point);
}

#[test]
fn test_syscall_multiscalar_multiplication_maximum_length_exceeded() {
use solana_zk_token_sdk::curve25519::curve_syscall_traits::{
CURVE25519_EDWARDS, CURVE25519_RISTRETTO,
};

let config = Config::default();
prepare_mockup!(invoke_context, program_id, bpf_loader::id());

let scalar: [u8; 32] = [
254, 198, 23, 138, 67, 243, 184, 110, 236, 115, 236, 205, 205, 215, 79, 114, 45, 250,
78, 137, 3, 107, 136, 237, 49, 126, 117, 223, 37, 191, 88, 6,
];
let scalars = [scalar; 513];
let scalars_va = 0x100000000;

let edwards_point: [u8; 32] = [
252, 31, 230, 46, 173, 95, 144, 148, 158, 157, 63, 10, 8, 68, 58, 176, 142, 192, 168,
53, 61, 105, 194, 166, 43, 56, 246, 236, 28, 146, 114, 133,
];
let edwards_points = [edwards_point; 513];
let edwards_points_va = 0x200000000;

let ristretto_point: [u8; 32] = [
130, 35, 97, 25, 18, 199, 33, 239, 85, 143, 119, 111, 49, 51, 224, 40, 167, 185, 240,
179, 25, 194, 213, 41, 14, 155, 104, 18, 181, 197, 15, 112,
];
let ristretto_points = [ristretto_point; 513];
let ristretto_points_va = 0x300000000;

let mut result_point: [u8; 32] = [0; 32];
let result_point_va = 0x400000000;

let mut memory_mapping = MemoryMapping::new(
vec![
MemoryRegion::new_readonly(bytes_of_slice(&scalars), scalars_va),
MemoryRegion::new_readonly(bytes_of_slice(&edwards_points), edwards_points_va),
MemoryRegion::new_readonly(bytes_of_slice(&ristretto_points), ristretto_points_va),
MemoryRegion::new_writable(bytes_of_slice_mut(&mut result_point), result_point_va),
],
&config,
&SBPFVersion::V2,
)
.unwrap();

// test Edwards
invoke_context.mock_set_remaining(500_000);
let result = SyscallCurveMultiscalarMultiplication::rust(
&mut invoke_context,
CURVE25519_EDWARDS,
scalars_va,
edwards_points_va,
512, // below maximum vector length
result_point_va,
&mut memory_mapping,
);

assert_eq!(0, result.unwrap());
let expected_product = [
20, 146, 226, 37, 22, 61, 86, 249, 208, 40, 38, 11, 126, 101, 10, 82, 81, 77, 88, 209,
15, 76, 82, 251, 180, 133, 84, 243, 162, 0, 11, 145,
];
assert_eq!(expected_product, result_point);

invoke_context.mock_set_remaining(500_000);
let result = SyscallCurveMultiscalarMultiplication::rust(
&mut invoke_context,
CURVE25519_EDWARDS,
scalars_va,
edwards_points_va,
513, // above maximum vector length
result_point_va,
&mut memory_mapping,
)
.unwrap_err()
.downcast::<SyscallError>()
.unwrap();

assert_eq!(*result, SyscallError::InvalidLength);

// test Ristretto
invoke_context.mock_set_remaining(500_000);
let result = SyscallCurveMultiscalarMultiplication::rust(
&mut invoke_context,
CURVE25519_RISTRETTO,
scalars_va,
ristretto_points_va,
512, // below maximum vector length
result_point_va,
&mut memory_mapping,
);

assert_eq!(0, result.unwrap());
let expected_product = [
146, 224, 127, 193, 252, 64, 196, 181, 246, 104, 27, 116, 183, 52, 200, 239, 2, 108,
21, 27, 97, 44, 95, 65, 26, 218, 223, 39, 197, 132, 51, 49,
];
assert_eq!(expected_product, result_point);

invoke_context.mock_set_remaining(500_000);
let result = SyscallCurveMultiscalarMultiplication::rust(
&mut invoke_context,
CURVE25519_RISTRETTO,
scalars_va,
ristretto_points_va,
513, // above maximum vector length
result_point_va,
&mut memory_mapping,
)
.unwrap_err()
.downcast::<SyscallError>()
.unwrap();

assert_eq!(*result, SyscallError::InvalidLength);
}

fn create_filled_type<T: Default>(zero_init: bool) -> T {
let mut val = T::default();
let p = &mut val as *mut _ as *mut u8;
Expand Down
5 changes: 5 additions & 0 deletions sdk/src/feature_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ pub mod curve25519_syscall_enabled {
solana_sdk::declare_id!("7rcw5UtqgDTBBv2EcynNfYckgdAaH1MAsCjKgXMkN7Ri");
}

pub mod curve25519_restrict_msm_length {
solana_sdk::declare_id!("eca6zf6JJRjQsYYPkBHF3N32MTzur4n2WL4QiiacPCL");
t-nelson marked this conversation as resolved.
Show resolved Hide resolved
}

pub mod versioned_tx_message_enabled {
solana_sdk::declare_id!("3KZZ6Ks1885aGBQ45fwRcPXVBCtzUvxhUTkwKMR41Tca");
}
Expand Down Expand Up @@ -950,6 +954,7 @@ lazy_static! {
(disable_bpf_loader_instructions::id(), "disable bpf loader management instructions #34194"),
(deprecate_executable_meta_update_in_bpf_loader::id(), "deprecate executable meta flag update in bpf loader #34194"),
(enable_zk_proof_from_account::id(), "Enable zk token proof program to read proof from accounts instead of instruction data #34750"),
(curve25519_restrict_msm_length::id(), "restrict curve25519 multiscalar multiplication vector lengths #34763"),
/*************** ADD NEW FEATURES HERE ***************/
]
.iter()
Expand Down