Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

delete redudent struct from json.rs #273

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 2 additions & 14 deletions stwo_cairo_prover/crates/prover/src/input/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use cairo_vm::air_public_input::MemorySegmentAddresses;
use mem::Memory;
use serde::{Deserialize, Serialize};
use state_transitions::StateTransitions;

mod decode;
Expand All @@ -19,17 +19,5 @@ pub struct CairoInput {
pub public_mem_addresses: Vec<u32>,

// Builtins.
pub range_check_builtin: SegmentAddrs,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SegmentAddrs {
pub begin_addr: u32,
pub end_addr: u32,
}

impl SegmentAddrs {
pub fn addresses(&self) -> Vec<u32> {
(self.begin_addr..self.end_addr).collect()
}
pub range_check_builtin: MemorySegmentAddresses,
}
7 changes: 4 additions & 3 deletions stwo_cairo_prover/crates/prover/src/input/plain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ use itertools::Itertools;
use super::mem::{MemConfig, MemoryBuilder};
use super::state_transitions::StateTransitions;
use super::vm_import::MemEntry;
use super::{CairoInput, SegmentAddrs};
use super::CairoInput;
use crate::input::MemorySegmentAddresses;

// TODO(Ohad): remove dev_mode after adding the rest of the opcodes.
/// Translates a plain casm into a CairoInput by running the program and extracting the memory and
Expand Down Expand Up @@ -83,9 +84,9 @@ pub fn input_from_finished_runner(runner: CairoRunner, dev_mode: bool) -> CairoI
state_transitions,
mem: mem.build(),
public_mem_addresses,
range_check_builtin: SegmentAddrs {
range_check_builtin: MemorySegmentAddresses {
begin_addr: 24,
end_addr: 64,
stop_ptr: 64,
},
}
}
97 changes: 4 additions & 93 deletions stwo_cairo_prover/crates/prover/src/input/vm_import/json.rs
Original file line number Diff line number Diff line change
@@ -1,99 +1,10 @@
use std::collections::BTreeMap;

use serde::{Deserialize, Deserializer, Serialize, Serializer};

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PublicInput {
pub layout: String,
pub rc_min: u64,
pub rc_max: u64,
pub n_steps: u64,
pub memory_segments: BTreeMap<String, Segment>,
pub public_memory: Vec<PublicMemEntry>,
pub dynamic_params: Option<()>,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Segment {
pub begin_addr: u64,
pub stop_ptr: u64,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PublicMemEntry {
pub address: u64,
pub value: FeltValue,
pub page: u64,
}
use serde::{Deserialize, Serialize};

// TODO(Stav): Replace with original struct once fields are public.
/// Struct to store Cairo private input.
/// Replicated from `cairo_vm::air_private_input::AirPrivateInputSerializable`.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PrivateInput {
pub trace_path: String,
pub memory_path: String,
pub pedersen: Vec<PedersenData>,
pub range_check: Vec<RangeCheckData>,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PedersenData {
pub index: u64,
pub x: FeltValue,
pub y: FeltValue,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RangeCheckData {
pub index: u64,
pub value: FeltValue,
}

#[derive(Clone, Debug)]
pub struct FeltValue([u8; 32]);

impl Serialize for FeltValue {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
// Convert the [u8; 32] to a hexadecimal string
let hex_string = format!("0x{}", hex::encode(self.0));
serializer.serialize_str(&hex_string)
}
}

impl<'de> Deserialize<'de> for FeltValue {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let hex_string = String::deserialize(deserializer)?;

// Remove the "0x" prefix if present
let hex_str = hex_string.strip_prefix("0x").unwrap_or(&hex_string);
let hex_str = format!("{:0>64}", hex_str);

// Convert the hexadecimal string back into a [u8; 32]
let mut bytes = [0u8; 32];
hex::decode_to_slice(hex_str, &mut bytes).map_err(serde::de::Error::custom)?;

Ok(FeltValue(bytes))
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_felt_value_serde() {
let felt_value = FeltValue([0x12; 32]);
let json = sonic_rs::to_string(&felt_value).unwrap();
assert_eq!(
json,
r#""0x1212121212121212121212121212121212121212121212121212121212121212""#
);

let deserialized: FeltValue = sonic_rs::from_str(&json).unwrap();
assert_eq!(felt_value.0, deserialized.0);
}
}
14 changes: 8 additions & 6 deletions stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@ use std::io::Read;
use std::path::Path;

use bytemuck::{bytes_of_mut, Pod, Zeroable};
use cairo_vm::air_public_input::PublicInput;
use cairo_vm::vm::trace::trace_entry::RelocatedTraceEntry;
use json::{PrivateInput, PublicInput};
use json::PrivateInput;
use thiserror::Error;
use tracing::{span, Level};

use super::mem::MemConfig;
use super::state_transitions::StateTransitions;
use super::CairoInput;
use crate::input::mem::MemoryBuilder;
use crate::input::SegmentAddrs;
use crate::input::MemorySegmentAddresses;

#[derive(Debug, Error)]
pub enum VmImportError {
Expand All @@ -32,7 +33,8 @@ pub fn import_from_vm_output(
dev_mod: bool,
) -> Result<CairoInput, VmImportError> {
let _span = span!(Level::INFO, "import_from_vm_output").entered();
let pub_data: PublicInput = sonic_rs::from_str(&std::fs::read_to_string(pub_json)?)?;
let pub_data_string = std::fs::read_to_string(pub_json)?;
let pub_data: PublicInput<'_> = sonic_rs::from_str(&pub_data_string)?;
let priv_data: PrivateInput = sonic_rs::from_str(&std::fs::read_to_string(priv_json)?)?;

let end_addr = pub_data
Expand Down Expand Up @@ -63,9 +65,9 @@ pub fn import_from_vm_output(
state_transitions,
mem: mem.build(),
public_mem_addresses,
range_check_builtin: SegmentAddrs {
begin_addr: pub_data.memory_segments["range_check"].begin_addr as u32,
end_addr: pub_data.memory_segments["range_check"].stop_ptr as u32,
range_check_builtin: MemorySegmentAddresses {
begin_addr: pub_data.memory_segments["range_check"].begin_addr as usize,
stop_ptr: pub_data.memory_segments["range_check"].stop_ptr as usize,
},
})
}
Expand Down
Loading