Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split the shift ALU table into left and right #62

Merged
merged 5 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 27 additions & 45 deletions core/src/alu/shift/mod.rs → core/src/alu/shift/left.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ use crate::disassembler::WORD_SIZE;
use crate::runtime::{Opcode, Segment};
use crate::utils::{pad_to_power_of_two, Chip};

pub const NUM_SHIFT_COLS: usize = size_of::<ShiftCols<u8>>();
pub const NUM_LEFT_SHIFT_COLS: usize = size_of::<LeftShiftCols<u8>>();

pub const BYTE_SIZE: usize = 8;

/// The column layout for the chip.
#[derive(AlignedBorrow, Default, Debug)]
#[repr(C)]
pub struct ShiftCols<T> {
pub struct LeftShiftCols<T> {
/// The output operand.
pub a: Word<T>,

Expand Down Expand Up @@ -83,41 +83,33 @@ pub struct ShiftCols<T> {
/// A boolean array whose `i`th element indicates whether `num_bytes_to_shift = i`.
pub shift_by_n_bytes: [T; WORD_SIZE],

/// Selector flags for the operation to perform.
pub is_sll: T,
pub is_srl: T,
pub is_sra: T,

pub is_real: T,
}

/// A chip that implements bitwise operations for the opcodes SLL, SLLI, SRL, SRLI, SRA, and SRAI.
pub struct ShiftChip;
/// A chip that implements bitwise operations for the opcodes SLL and SLLI.
pub struct LeftShiftChip;

impl ShiftChip {
impl LeftShiftChip {
pub fn new() -> Self {
Self {}
}
}

impl<F: PrimeField> Chip<F> for ShiftChip {
impl<F: PrimeField> Chip<F> for LeftShiftChip {
fn generate_trace(&self, segment: &mut Segment) -> RowMajorMatrix<F> {
// Generate the trace rows for each event.
let rows = segment
.shift_events
.left_shift_events
.par_iter()
.map(|event| {
let mut row = [F::zero(); NUM_SHIFT_COLS];
let cols: &mut ShiftCols<F> = unsafe { transmute(&mut row) };
let mut row = [F::zero(); NUM_LEFT_SHIFT_COLS];
let cols: &mut LeftShiftCols<F> = unsafe { transmute(&mut row) };
let a = event.a.to_le_bytes();
let b = event.b.to_le_bytes();
let c = event.c.to_le_bytes();
cols.a = Word(a.map(F::from_canonical_u8));
cols.b = Word(b.map(F::from_canonical_u8));
cols.c = Word(c.map(F::from_canonical_u8));
cols.is_sll = F::from_bool(event.opcode == Opcode::SLL);
cols.is_srl = F::from_bool(event.opcode == Opcode::SRL);
cols.is_sra = F::from_bool(event.opcode == Opcode::SRA);
cols.is_real = F::one();
for i in 0..BYTE_SIZE {
cols.c_least_sig_byte[i] = F::from_canonical_u32((event.c >> i) & 1);
Expand Down Expand Up @@ -162,45 +154,44 @@ impl<F: PrimeField> Chip<F> for ShiftChip {
// Convert the trace to a row major matrix.
let mut trace = RowMajorMatrix::new(
rows.into_iter().flatten().collect::<Vec<_>>(),
NUM_SHIFT_COLS,
NUM_LEFT_SHIFT_COLS,
);

// Pad the trace to a power of two.
pad_to_power_of_two::<NUM_SHIFT_COLS, F>(&mut trace.values);
pad_to_power_of_two::<NUM_LEFT_SHIFT_COLS, F>(&mut trace.values);

// Create the template for the padded rows. These are fake rows that don't fail on some
// sanity checks.
let padded_row_template = {
let mut row = [F::zero(); NUM_SHIFT_COLS];
let cols: &mut ShiftCols<F> = unsafe { transmute(&mut row) };
cols.is_sll = F::one();
let mut row = [F::zero(); NUM_LEFT_SHIFT_COLS];
let cols: &mut LeftShiftCols<F> = unsafe { transmute(&mut row) };
cols.shift_by_n_bits[0] = F::one();
cols.shift_by_n_bytes[0] = F::one();
cols.bit_shift_multiplier = F::one();
row
};
debug_assert!(padded_row_template.len() == NUM_SHIFT_COLS);
for i in segment.shift_events.len() * NUM_SHIFT_COLS..trace.values.len() {
trace.values[i] = padded_row_template[i % NUM_SHIFT_COLS];
debug_assert!(padded_row_template.len() == NUM_LEFT_SHIFT_COLS);
for i in segment.left_shift_events.len() * NUM_LEFT_SHIFT_COLS..trace.values.len() {
trace.values[i] = padded_row_template[i % NUM_LEFT_SHIFT_COLS];
}

trace
}
}

impl<F> BaseAir<F> for ShiftChip {
impl<F> BaseAir<F> for LeftShiftChip {
fn width(&self) -> usize {
NUM_SHIFT_COLS
NUM_LEFT_SHIFT_COLS
}
}

impl<AB> Air<AB> for ShiftChip
impl<AB> Air<AB> for LeftShiftChip
where
AB: CurtaAirBuilder,
{
fn eval(&self, builder: &mut AB) {
let main = builder.main();
let local: &ShiftCols<AB::Var> = main.row_slice(0).borrow();
let local: &LeftShiftCols<AB::Var> = main.row_slice(0).borrow();

let zero: AB::Expr = AB::F::zero().into();
let one: AB::Expr = AB::F::one().into();
Expand Down Expand Up @@ -318,24 +309,15 @@ where
one.clone(),
);

builder.assert_bool(local.is_sll);
builder.assert_bool(local.is_srl);
builder.assert_bool(local.is_sra);

// Exactly one of them must be true.
builder.assert_eq(local.is_sll + local.is_srl + local.is_sra, one.clone());

builder.assert_bool(local.is_real);

// Receive the arguments.
builder.receive_alu(
local.is_sll * AB::F::from_canonical_u32(Opcode::SLL as u32)
+ local.is_srl * AB::F::from_canonical_u32(Opcode::SRL as u32)
+ local.is_sra * AB::F::from_canonical_u32(Opcode::SRA as u32),
AB::F::from_canonical_u32(Opcode::SLL as u32),
local.a,
local.b,
local.c,
local.is_sll + local.is_srl + local.is_sra,
local.is_real,
);

// A dummy constraint to keep the degree at least 3.
Expand Down Expand Up @@ -371,13 +353,13 @@ mod tests {
};
use p3_commit::ExtensionMmcs;

use super::ShiftChip;
use super::LeftShiftChip;

#[test]
fn generate_trace() {
let mut segment = Segment::default();
segment.shift_events = vec![AluEvent::new(0, Opcode::SLL, 16, 8, 1)];
let chip = ShiftChip::new();
segment.left_shift_events = vec![AluEvent::new(0, Opcode::SLL, 16, 8, 1)];
let chip = LeftShiftChip::new();
let trace: RowMajorMatrix<BabyBear> = chip.generate_trace(&mut segment);
println!("{:?}", trace.values)
}
Expand Down Expand Up @@ -456,8 +438,8 @@ mod tests {
}

let mut segment = Segment::default();
segment.shift_events = shift_events;
let chip = ShiftChip::new();
segment.left_shift_events = shift_events;
let chip = LeftShiftChip::new();
let trace: RowMajorMatrix<BabyBear> = chip.generate_trace(&mut segment);
let proof = prove::<MyConfig, _>(&config, &chip, &mut challenger, trace);

Expand Down
198 changes: 198 additions & 0 deletions core/src/alu/shift/right.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
use core::borrow::{Borrow, BorrowMut};
use core::mem::size_of;
use core::mem::transmute;
use p3_air::{Air, BaseAir};

use p3_field::AbstractField;
use p3_field::PrimeField;
use p3_matrix::dense::RowMajorMatrix;
use p3_matrix::MatrixRowSlices;
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
use valida_derive::AlignedBorrow;

use crate::air::{CurtaAirBuilder, Word};

use crate::runtime::{Opcode, Segment};
use crate::utils::{pad_to_power_of_two, Chip};

pub const RIGHT_NUM_SHIFT_COLS: usize = size_of::<RightShiftCols<u8>>();

/// The column layout for the chip.
#[derive(AlignedBorrow, Default)]
pub struct RightShiftCols<T> {
/// The output operand.
pub a: Word<T>,

/// The first input operand.
pub b: Word<T>,

/// The second input operand.
pub c: Word<T>,

/// Selector flags for the operation to perform.
pub is_srl: T,
pub is_sra: T,
}

/// A chip that implements bitwise operations for the opcodes SRL, SRLI, SRA, and SRAI.
pub struct RightShiftChip;

impl RightShiftChip {
pub fn new() -> Self {
Self {}
}
}

impl<F: PrimeField> Chip<F> for RightShiftChip {
fn generate_trace(&self, segment: &mut Segment) -> RowMajorMatrix<F> {
// Generate the trace rows for each event.
let rows = segment
.right_shift_events
.par_iter()
.map(|event| {
let mut row = [F::zero(); RIGHT_NUM_SHIFT_COLS];
let cols: &mut RightShiftCols<F> = unsafe { transmute(&mut row) };
let a = event.a.to_le_bytes();
let b = event.b.to_le_bytes();
let c = event.c.to_le_bytes();
cols.a = Word(a.map(F::from_canonical_u8));
cols.b = Word(b.map(F::from_canonical_u8));
cols.c = Word(c.map(F::from_canonical_u8));
cols.is_srl = F::from_bool(event.opcode == Opcode::SRL);
cols.is_sra = F::from_bool(event.opcode == Opcode::SRA);
row
})
.collect::<Vec<_>>();

// Convert the trace to a row major matrix.
let mut trace = RowMajorMatrix::new(
rows.into_iter().flatten().collect::<Vec<_>>(),
RIGHT_NUM_SHIFT_COLS,
);

// Pad the trace to a power of two.
pad_to_power_of_two::<RIGHT_NUM_SHIFT_COLS, F>(&mut trace.values);

trace
}
}

impl<F> BaseAir<F> for RightShiftChip {
fn width(&self) -> usize {
RIGHT_NUM_SHIFT_COLS
}
}

impl<AB> Air<AB> for RightShiftChip
where
AB: CurtaAirBuilder,
{
fn eval(&self, builder: &mut AB) {
let main = builder.main();
let local: &RightShiftCols<AB::Var> = main.row_slice(0).borrow();

builder.assert_zero(
local.a[0] * local.b[0] * local.c[0] - local.a[0] * local.b[0] * local.c[0],
);

// Receive the arguments.
builder.receive_alu(
local.is_srl * AB::F::from_canonical_u32(Opcode::SRL as u32)
+ local.is_sra * AB::F::from_canonical_u32(Opcode::SRA as u32),
local.a,
local.b,
local.c,
local.is_srl + local.is_sra,
);
}
}

#[cfg(test)]
mod tests {
use p3_challenger::DuplexChallenger;
use p3_dft::Radix2DitParallel;
use p3_field::Field;

use p3_baby_bear::BabyBear;
use p3_field::extension::BinomialExtensionField;
use p3_fri::{FriBasedPcs, FriConfigImpl, FriLdt};
use p3_keccak::Keccak256Hash;
use p3_ldt::QuotientMmcs;
use p3_matrix::dense::RowMajorMatrix;
use p3_mds::coset_mds::CosetMds;
use p3_merkle_tree::FieldMerkleTreeMmcs;
use p3_poseidon2::{DiffusionMatrixBabybear, Poseidon2};
use p3_symmetric::{CompressionFunctionFromHasher, SerializingHasher32};
use p3_uni_stark::{prove, verify, StarkConfigImpl};
use rand::thread_rng;

use crate::{
alu::AluEvent,
runtime::{Opcode, Segment},
utils::Chip,
};
use p3_commit::ExtensionMmcs;

use super::RightShiftChip;

#[test]
fn generate_trace() {
let mut segment = Segment::default();
segment.right_shift_events = vec![AluEvent::new(0, Opcode::SRL, 6, 12, 1)];
let chip = RightShiftChip::new();
let trace: RowMajorMatrix<BabyBear> = chip.generate_trace(&mut segment);
println!("{:?}", trace.values)
}

#[test]
fn prove_babybear() {
type Val = BabyBear;
type Domain = Val;
type Challenge = BinomialExtensionField<Val, 4>;
type PackedChallenge = BinomialExtensionField<<Domain as Field>::Packing, 4>;

type MyMds = CosetMds<Val, 16>;
let mds = MyMds::default();

type Perm = Poseidon2<Val, MyMds, DiffusionMatrixBabybear, 16, 5>;
let perm = Perm::new_from_rng(8, 22, mds, DiffusionMatrixBabybear, &mut thread_rng());

type MyHash = SerializingHasher32<Keccak256Hash>;
let hash = MyHash::new(Keccak256Hash {});

type MyCompress = CompressionFunctionFromHasher<Val, MyHash, 2, 8>;
let compress = MyCompress::new(hash);

type ValMmcs = FieldMerkleTreeMmcs<Val, MyHash, MyCompress, 8>;
let val_mmcs = ValMmcs::new(hash, compress);

type ChallengeMmcs = ExtensionMmcs<Val, Challenge, ValMmcs>;
let challenge_mmcs = ChallengeMmcs::new(val_mmcs.clone());

type Dft = Radix2DitParallel;
let dft = Dft {};

type Challenger = DuplexChallenger<Val, Perm, 16>;

type Quotient = QuotientMmcs<Domain, Challenge, ValMmcs>;
type MyFriConfig = FriConfigImpl<Val, Challenge, Quotient, ChallengeMmcs, Challenger>;
let fri_config = MyFriConfig::new(40, challenge_mmcs);
let ldt = FriLdt { config: fri_config };

type Pcs = FriBasedPcs<MyFriConfig, ValMmcs, Dft, Challenger>;
type MyConfig = StarkConfigImpl<Val, Challenge, PackedChallenge, Pcs, Challenger>;

let pcs = Pcs::new(dft, val_mmcs, ldt);
let config = StarkConfigImpl::new(pcs);
let mut challenger = Challenger::new(perm.clone());

let mut segment = Segment::default();
segment.right_shift_events = vec![AluEvent::new(0, Opcode::SRL, 6, 12, 1)].repeat(1000);
let chip = RightShiftChip::new();
let trace: RowMajorMatrix<BabyBear> = chip.generate_trace(&mut segment);
let proof = prove::<MyConfig, _>(&config, &chip, &mut challenger, trace);

let mut challenger = Challenger::new(perm);
verify(&config, &chip, &mut challenger, &proof).unwrap();
}
}
Loading