Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(avm): cpp msm changes #7056

Merged
merged 1 commit into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,13 @@ std::vector<Row> Execution::gen_trace(std::vector<Instruction> const& instructio
std::get<uint32_t>(inst.operands.at(6)),
std::get<uint32_t>(inst.operands.at(7)));
break;
case OpCode::MSM:
trace_builder.op_variable_msm(std::get<uint8_t>(inst.operands.at(0)),
std::get<uint32_t>(inst.operands.at(1)),
std::get<uint32_t>(inst.operands.at(2)),
std::get<uint32_t>(inst.operands.at(3)),
std::get<uint32_t>(inst.operands.at(4)));
break;
case OpCode::REVERT:
trace_builder.op_revert(std::get<uint8_t>(inst.operands.at(0)),
std::get<uint32_t>(inst.operands.at(1)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ std::string to_string(OpCode opcode)
return "SHA256";
case OpCode::PEDERSEN:
return "PEDERSEN";
case OpCode::ECADD:
return "ECADD";
case OpCode::MSM:
return "MSM";
case OpCode::TORADIXLE:
return "TORADIXLE";
case OpCode::SHA256COMPRESSION:
Expand Down
251 changes: 251 additions & 0 deletions barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_trace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
#include <vector>

#include "barretenberg/common/throw_or_abort.hpp"
#include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp"
#include "barretenberg/numeric/uint256/uint256.hpp"
#include "barretenberg/polynomials/univariate.hpp"
#include "barretenberg/vm/avm_trace/avm_common.hpp"
#include "barretenberg/vm/avm_trace/avm_helper.hpp"
#include "barretenberg/vm/avm_trace/avm_opcode.hpp"
Expand Down Expand Up @@ -3632,6 +3634,255 @@ void AvmTraceBuilder::op_ec_add(uint8_t indirect,
FF(internal_return_ptr),
{ result.is_point_at_infinity() });
}

// This function is a bit overloaded with logic around reconstructing points and scalars that could probably be moved to
// the gadget at some stage (although this is another temporary gadget..)
void AvmTraceBuilder::op_variable_msm(uint8_t indirect,
uint32_t points_offset,
uint32_t scalars_offset,
uint32_t output_offset,
uint32_t point_length_offset)
{
auto clk = static_cast<uint32_t>(main_trace.size()) + 1;
// This will all get refactored as part of the indirection refactor
bool tag_match = true;
uint32_t direct_points_offset = points_offset;
uint32_t direct_scalars_offset = scalars_offset;
uint32_t direct_output_offset = output_offset;
// Resolve the indirects
bool indirect_points_flag = is_operand_indirect(indirect, 0);
bool indirect_scalars_flag = is_operand_indirect(indirect, 1);
bool indirect_output_flag = is_operand_indirect(indirect, 2);

// Read in the points first
if (indirect_points_flag) {
auto read_ind_a =
mem_trace_builder.indirect_read_and_load_from_memory(call_ptr, clk, IndirectRegister::IND_A, points_offset);
direct_points_offset = uint32_t(read_ind_a.val);
tag_match = tag_match && read_ind_a.tag_match;
}

auto read_points = mem_trace_builder.read_and_load_from_memory(
call_ptr, clk, IntermRegister::IA, direct_points_offset, AvmMemoryTag::FF, AvmMemoryTag::U0);

// Read in the scalars
if (indirect_scalars_flag) {
auto read_ind_b = mem_trace_builder.indirect_read_and_load_from_memory(
call_ptr, clk, IndirectRegister::IND_B, scalars_offset);
direct_scalars_offset = uint32_t(read_ind_b.val);
tag_match = tag_match && read_ind_b.tag_match;
}
auto read_scalars = mem_trace_builder.read_and_load_from_memory(
call_ptr, clk, IntermRegister::IB, direct_scalars_offset, AvmMemoryTag::FF, AvmMemoryTag::U0);

// In the refactor we will have the read_slice function handle indirects as well
main_trace.push_back(Row{
.avm_main_clk = clk,
.avm_main_ia = read_points.val,
.avm_main_ib = read_scalars.val,
.avm_main_ind_a = indirect_points_flag ? FF(points_offset) : FF(0),
.avm_main_ind_b = indirect_scalars_flag ? FF(scalars_offset) : FF(0),
.avm_main_ind_op_a = FF(static_cast<uint32_t>(indirect_points_flag)),
.avm_main_ind_op_b = FF(static_cast<uint32_t>(indirect_scalars_flag)),
.avm_main_internal_return_ptr = FF(internal_return_ptr),
.avm_main_mem_idx_a = FF(direct_points_offset),
.avm_main_mem_idx_b = FF(direct_scalars_offset),
.avm_main_mem_op_a = FF(1),
.avm_main_mem_op_b = FF(1),
.avm_main_pc = FF(pc++),
.avm_main_r_in_tag = FF(static_cast<uint32_t>(AvmMemoryTag::FF)),
.avm_main_tag_err = FF(static_cast<uint32_t>(!tag_match)),
});
clk++;

// Read the points length (different row since it has a different memory tag)
auto points_length_read = mem_trace_builder.read_and_load_from_memory(
call_ptr, clk, IntermRegister::IA, point_length_offset, AvmMemoryTag::U32, AvmMemoryTag::U0);
main_trace.push_back(Row{
.avm_main_clk = clk,
.avm_main_ia = points_length_read.val,
.avm_main_internal_return_ptr = FF(internal_return_ptr),
.avm_main_mem_idx_a = FF(point_length_offset),
.avm_main_mem_op_a = FF(1),
.avm_main_pc = FF(pc),
.avm_main_r_in_tag = FF(static_cast<uint32_t>(AvmMemoryTag::U32)),
.avm_main_tag_err = FF(static_cast<uint32_t>(!points_length_read.tag_match)),
});
clk++;

// Points are stored as [x1, y1, inf1, x2, y2, inf2, ...] with the types [FF, FF, U8, FF, FF, U8, ...]
uint32_t num_points = uint32_t(points_length_read.val) / 3; // 3 elements per point
// We need to split up the reads due to the memory tags,
std::vector<FF> points_coords_vec;
std::vector<FF> points_inf_vec;
std::vector<FF> scalars_vec;
// Read the coordinates first, +2 since we read 2 points per row
for (uint32_t i = 0; i < num_points; i += 2) {
// We can read up to 4 coordinates per row (x1,y1,x2,y2)
// Each pair of coordinates are separated by 3 memory addressess
auto point_x1_read = mem_trace_builder.read_and_load_from_memory(
call_ptr, clk, IntermRegister::IA, direct_points_offset + i * 3, AvmMemoryTag::FF, AvmMemoryTag::U0);
auto point_y1_read = mem_trace_builder.read_and_load_from_memory(
call_ptr, clk, IntermRegister::IB, direct_points_offset + i * 3 + 1, AvmMemoryTag::FF, AvmMemoryTag::U0);
auto point_x2_read = mem_trace_builder.read_and_load_from_memory(
call_ptr, clk, IntermRegister::IC, direct_points_offset + (i + 1) * 3, AvmMemoryTag::FF, AvmMemoryTag::U0);
auto point_y2_read = mem_trace_builder.read_and_load_from_memory(call_ptr,
clk,
IntermRegister::ID,
direct_points_offset + (i + 1) * 3 + 1,
AvmMemoryTag::FF,
AvmMemoryTag::U0);
bool tag_match =
point_x1_read.tag_match && point_y1_read.tag_match && point_x2_read.tag_match && point_y2_read.tag_match;
points_coords_vec.insert(points_coords_vec.end(),
{ point_x1_read.val, point_y1_read.val, point_x2_read.val, point_y2_read.val });
main_trace.push_back(Row{
.avm_main_clk = clk,
.avm_main_ia = point_x1_read.val,
.avm_main_ib = point_y1_read.val,
.avm_main_ic = point_x2_read.val,
.avm_main_id = point_y2_read.val,
.avm_main_internal_return_ptr = FF(internal_return_ptr),
.avm_main_mem_idx_a = FF(direct_points_offset + i * 3),
.avm_main_mem_idx_b = FF(direct_points_offset + i * 3 + 1),
.avm_main_mem_idx_c = FF(direct_points_offset + (i + 1) * 3),
.avm_main_mem_idx_d = FF(direct_points_offset + (i + 1) * 3 + 1),
.avm_main_mem_op_a = FF(1),
.avm_main_mem_op_b = FF(1),
.avm_main_mem_op_c = FF(1),
.avm_main_mem_op_d = FF(1),
.avm_main_pc = FF(pc),
.avm_main_r_in_tag = FF(static_cast<uint32_t>(AvmMemoryTag::FF)),
.avm_main_tag_err = FF(static_cast<uint32_t>(!tag_match)),
});
clk++;
}
// Read the Infinities flags, +4 since we read 4 points row
for (uint32_t i = 0; i < num_points; i += 4) {
// We can read up to 4 infinities per row
// Each infinity flag is separated by 3 memory addressess
uint32_t offset = direct_points_offset + i * 3 + 2;
auto point_inf1_read = mem_trace_builder.read_and_load_from_memory(
call_ptr, clk, IntermRegister::IA, offset, AvmMemoryTag::U8, AvmMemoryTag::U0);
offset += 3;

auto point_inf2_read = mem_trace_builder.read_and_load_from_memory(
call_ptr, clk, IntermRegister::IB, offset, AvmMemoryTag::U8, AvmMemoryTag::U0);
offset += 3;

auto point_inf3_read = mem_trace_builder.read_and_load_from_memory(
call_ptr, clk, IntermRegister::IC, offset, AvmMemoryTag::U8, AvmMemoryTag::U0);
offset += 3;

auto point_inf4_read = mem_trace_builder.read_and_load_from_memory(
call_ptr, clk, IntermRegister::ID, offset, AvmMemoryTag::U8, AvmMemoryTag::U0);

points_inf_vec.insert(points_inf_vec.end(),
{ point_inf1_read.val, point_inf2_read.val, point_inf3_read.val, point_inf4_read.val });
bool tag_match = point_inf1_read.tag_match && point_inf2_read.tag_match && point_inf3_read.tag_match &&
point_inf4_read.tag_match;
main_trace.push_back(Row{
.avm_main_clk = clk,
.avm_main_ia = point_inf1_read.val,
.avm_main_ib = point_inf2_read.val,
.avm_main_ic = point_inf3_read.val,
.avm_main_id = point_inf4_read.val,
.avm_main_internal_return_ptr = FF(internal_return_ptr),
.avm_main_mem_idx_a = FF(direct_points_offset + i * 3 + 2),
.avm_main_mem_idx_b = FF(direct_points_offset + (i + 1) * 3 + 2),
.avm_main_mem_idx_c = FF(direct_points_offset + (i + 2) * 3 + 2),
.avm_main_mem_idx_d = FF(direct_points_offset + (i + 3) * 3 + 2),
.avm_main_mem_op_a = FF(1),
.avm_main_mem_op_b = FF(1),
.avm_main_mem_op_c = FF(1),
.avm_main_mem_op_d = FF(1),
.avm_main_pc = FF(pc),
.avm_main_r_in_tag = FF(static_cast<uint32_t>(AvmMemoryTag::U8)),
.avm_main_tag_err = FF(static_cast<uint32_t>(!tag_match)),
});
clk++;
}
// Scalar read length is num_points* 2 since scalars are stored as lo and hi limbs
uint32_t scalar_read_length = num_points * 2;
auto num_scalar_rows = read_slice_to_memory(call_ptr,
clk,
direct_scalars_offset,
AvmMemoryTag::FF,
AvmMemoryTag::U0,
FF(internal_return_ptr),
scalar_read_length,
scalars_vec);
clk += num_scalar_rows;
// Reconstruct Grumpkin points
std::vector<grumpkin::g1::affine_element> points;
for (size_t i = 0; i < num_points; i++) {
grumpkin::g1::Fq x = points_coords_vec[i * 2];
grumpkin::g1::Fq y = points_coords_vec[i * 2 + 1];
bool is_inf = points_inf_vec[i] == 1;
if (is_inf) {
points.emplace_back(grumpkin::g1::affine_element::infinity());
} else {
points.emplace_back(x, y);
}
}
// Reconstruct Grumpkin scalars
// Scalars are stored as [lo1, hi1, lo2, hi2, ...] with the types [FF, FF, FF, FF, ...]
std::vector<grumpkin::fr> scalars;
for (size_t i = 0; i < num_points; i++) {
FF lo = scalars_vec[i * 2];
FF hi = scalars_vec[i * 2 + 1];
// hi is shifted 128 bits
uint256_t scalar = (uint256_t(hi) << 128) + uint256_t(lo);
scalars.emplace_back(scalar);
}
// Perform the variable MSM - could just put the logic in here since there are no constraints.
auto result = ecc_trace_builder.variable_msm(points, scalars, clk);
// Write the result back to memory [x, y, inf] with tags [FF, FF, U8]
if (indirect_output_flag) {
auto read_ind_a =
mem_trace_builder.indirect_read_and_load_from_memory(call_ptr, clk, IndirectRegister::IND_A, output_offset);
direct_output_offset = uint32_t(read_ind_a.val);
}
mem_trace_builder.write_into_memory(
call_ptr, clk, IntermRegister::IA, direct_output_offset, result.x, AvmMemoryTag::U0, AvmMemoryTag::FF);
mem_trace_builder.write_into_memory(
call_ptr, clk, IntermRegister::IB, direct_output_offset + 1, result.y, AvmMemoryTag::U0, AvmMemoryTag::FF);
main_trace.push_back(Row{
.avm_main_clk = clk,
.avm_main_ia = result.x,
.avm_main_ib = result.y,
.avm_main_ind_a = indirect_output_flag ? FF(output_offset) : FF(0),
.avm_main_ind_op_a = FF(static_cast<uint32_t>(indirect_output_flag)),
.avm_main_internal_return_ptr = FF(internal_return_ptr),
.avm_main_mem_idx_a = FF(direct_output_offset),
.avm_main_mem_idx_b = FF(direct_output_offset + 1),
.avm_main_mem_op_a = FF(1),
.avm_main_mem_op_b = FF(1),
.avm_main_pc = FF(pc),
.avm_main_rwa = FF(1),
.avm_main_rwb = FF(1),
.avm_main_w_in_tag = FF(static_cast<uint32_t>(AvmMemoryTag::FF)),
});
clk++;
// Write the infinity
mem_trace_builder.write_into_memory(call_ptr,
clk,
IntermRegister::IA,
direct_output_offset + 2,
result.is_point_at_infinity(),
AvmMemoryTag::U0,
AvmMemoryTag::U8);
main_trace.push_back(Row{
.avm_main_clk = clk,
.avm_main_ia = static_cast<uint8_t>(result.is_point_at_infinity()),
.avm_main_internal_return_ptr = FF(internal_return_ptr),
.avm_main_mem_idx_a = FF(direct_output_offset + 2),
.avm_main_mem_op_a = FF(1),
.avm_main_pc = FF(pc),
.avm_main_rwa = FF(1),
.avm_main_w_in_tag = FF(static_cast<uint32_t>(AvmMemoryTag::U8)),
});
}
// Finalise Lookup Counts
//
// For log derivative lookups, we require a column that contains the number of times each lookup is consumed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,11 @@ class AvmTraceBuilder {
uint32_t rhs_y_offset,
uint32_t rhs_is_inf_offset,
uint32_t output_offset);
void op_variable_msm(uint8_t indirect,
uint32_t points_offset,
uint32_t scalars_offset,
uint32_t output_offset,
uint32_t point_length_offset);

private:
// Used for the standard indirect address resolution of three operands opcode.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,21 @@ element AvmEccTraceBuilder::embedded_curve_add(element lhs, element rhs, uint32_
return result;
}

element AvmEccTraceBuilder::variable_msm(const std::vector<element>& points,
const std::vector<grumpkin::fr>& scalars,
uint32_t clk)
{
// Replace this with pippenger if/when we have the time
auto result = grumpkin::g1::affine_point_at_infinity;
for (size_t i = 0; i < points.size(); ++i) {
result = result + points[i] * scalars[i];
}

std::tuple<FF, FF, bool> result_tuple = { result.x, result.y, result.is_point_at_infinity() };

ecc_trace.push_back({ .clk = clk, .result = result_tuple });

return result;
}

} // namespace bb::avm_trace
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ class AvmEccTraceBuilder {
public:
struct EccTraceEntry {
uint32_t clk = 0;
std::tuple<FF, FF, bool> p1; // x, y, is_infinity
std::tuple<FF, FF, bool> p2;
std::tuple<FF, FF, bool> result;
std::tuple<FF, FF, bool> p1 = { FF(0), FF(0), true }; // x, y, is_infinity
std::tuple<FF, FF, bool> p2 = { FF(0), FF(0), true };
std::tuple<FF, FF, bool> result = { FF(0), FF(0), true };
};

AvmEccTraceBuilder();
Expand All @@ -21,6 +21,9 @@ class AvmEccTraceBuilder {
grumpkin::g1::affine_element embedded_curve_add(grumpkin::g1::affine_element lhs,
grumpkin::g1::affine_element rhs,
uint32_t clk);
grumpkin::g1::affine_element variable_msm(const std::vector<grumpkin::g1::affine_element>& points,
const std::vector<grumpkin::fr>& scalars,
uint32_t clk);

private:
std::vector<EccTraceEntry> ecc_trace;
Expand Down
Loading
Loading