From 9da1624160055ac39596a0c274f8463bce1671df Mon Sep 17 00:00:00 2001 From: akokoshn Date: Sat, 18 Nov 2023 19:51:40 +0200 Subject: [PATCH] Keep lookup constant/selector columns --- .../blueprint/blueprint/plonk/assignment.hpp | 45 +++++++ .../blueprint/plonk/assignment_proxy.hpp | 125 ++++++++++++++---- .../blueprint/plonk/circuit_proxy.hpp | 53 +++++++- test/proxy.cpp | 71 ++-------- test/test_plonk_component.hpp | 2 +- 5 files changed, 204 insertions(+), 92 deletions(-) diff --git a/include/nil/blueprint/blueprint/plonk/assignment.hpp b/include/nil/blueprint/blueprint/plonk/assignment.hpp index e42b1168d..54593de1c 100644 --- a/include/nil/blueprint/blueprint/plonk/assignment.hpp +++ b/include/nil/blueprint/blueprint/plonk/assignment.hpp @@ -28,6 +28,7 @@ #include #include +#include #include #include @@ -69,6 +70,8 @@ namespace nil { std::uint32_t assignment_allocated_rows = 0; std::vector assignment_private_storage; shared_container_type shared_storage; // results of the previously prover + std::set lookup_constant_cols; + std::set lookup_selector_cols; public: static constexpr const std::size_t private_storage_index = std::numeric_limits::max(); @@ -131,6 +134,19 @@ namespace nil { } } + void fill_selector(std::uint32_t index, const column_type& column) override { + lookup_selector_cols.insert(index); + zk_type::fill_selector(index, column); + } + + virtual const std::set& get_lookup_selector_cols() const { + return lookup_selector_cols; + } + + virtual std::uint32_t get_lookup_selector_amount() const { + return lookup_selector_cols.size(); + } + virtual value_type &shared(std::uint32_t shared_index, std::uint32_t row_index) { if (shared_storage[shared_index].size() <= row_index) { shared_storage[shared_index].resize(row_index + 1); @@ -151,6 +167,10 @@ namespace nil { return shared_storage.size(); } + virtual const column_type& shared(std::uint32_t index) const { + return shared_storage[index]; + } + virtual value_type &witness(std::uint32_t witness_index, std::uint32_t row_index) { BLUEPRINT_ASSERT(witness_index < ArithmetizationParams::WitnessColumns); @@ -176,6 +196,10 @@ namespace nil { return zk_type::witnesses_amount(); } + virtual const column_type& witness(std::uint32_t index) const { + return zk_type::witness(index); + } + virtual value_type &public_input( std::uint32_t public_input_index, std::uint32_t row_index) { @@ -204,6 +228,10 @@ namespace nil { return zk_type::public_inputs_amount(); } + virtual const column_type& public_input(std::uint32_t index) const { + return zk_type::public_input(index); + } + virtual value_type &constant( std::uint32_t constant_index, std::uint32_t row_index) { @@ -225,6 +253,23 @@ namespace nil { return zk_type::constant(constant_index)[row_index]; } + virtual const column_type& constant(std::uint32_t index) const { + return zk_type::constant(index); + } + + void fill_constant(std::uint32_t index, const column_type& column) override { + lookup_constant_cols.insert(index); + zk_type::fill_constant(index, column); + } + + virtual const std::set& get_lookup_constant_cols() const { + return lookup_constant_cols; + } + + virtual std::uint32_t get_lookup_constant_amount() const { + return lookup_constant_cols.size(); + } + virtual std::uint32_t constant_column_size(std::uint32_t col_idx) const { return this->_public_table.constant_column_size(col_idx); } diff --git a/include/nil/blueprint/blueprint/plonk/assignment_proxy.hpp b/include/nil/blueprint/blueprint/plonk/assignment_proxy.hpp index 12dede8ca..224dbda7a 100644 --- a/include/nil/blueprint/blueprint/plonk/assignment_proxy.hpp +++ b/include/nil/blueprint/blueprint/plonk/assignment_proxy.hpp @@ -66,6 +66,10 @@ namespace nil { return *assignment_ptr; } + assignment& get() { + return *assignment_ptr; + } + std::uint32_t get_id() const { return id; } @@ -92,9 +96,13 @@ namespace nil { } value_type selector(std::size_t selector_index, std::uint32_t row_index) const override { - if (check && used_rows.find(row_index) == used_rows.end()) { - std::cout << id << ": Not found selector " << selector_index << " on row " << row_index << std::endl; - BLUEPRINT_ASSERT(false); + if (check) { + const auto lookup_selector_cols = assignment_ptr->get_lookup_selector_cols(); + if (lookup_selector_cols.find(selector_index) == lookup_selector_cols.end() && + used_rows.find(row_index) == used_rows.end()) { + std::cout << id << ": Not found selector " << selector_index << " on row " << row_index << std::endl; + BLUEPRINT_ASSERT(false); + } } return std::const_pointer_cast>(assignment_ptr)->selector(selector_index, row_index); } @@ -126,12 +134,17 @@ namespace nil { } void fill_selector(std::uint32_t index, const column_type& column) override { - for (std::uint32_t i = 0; i < column.size(); i++) { - used_rows.insert(i); - } assignment_ptr->fill_selector(index, column); } + const std::set& get_lookup_selector_cols() const override { + return assignment_ptr->get_lookup_selector_cols(); + } + + std::uint32_t get_lookup_selector_amount() const override { + return assignment_ptr->get_lookup_selector_amount(); + } + value_type &shared(std::uint32_t shared_index, std::uint32_t row_index) override { return assignment_ptr->shared(shared_index, row_index); } @@ -148,6 +161,10 @@ namespace nil { return assignment_ptr->shared_column_size(index); } + const column_type& shared(std::uint32_t index) const override { + return assignment_ptr->shared(index); + } + std::uint32_t shareds_amount() const override { return assignment_ptr->shareds_amount(); } @@ -165,6 +182,10 @@ namespace nil { return std::const_pointer_cast>(assignment_ptr)->witness(witness_index, row_index); } + const column_type& witness(std::uint32_t index) const override { + return assignment_ptr->witness(index); + } + std::uint32_t witnesses_amount() const override { return assignment_ptr->witnesses_amount(); } @@ -183,6 +204,10 @@ namespace nil { return std::const_pointer_cast>(assignment_ptr)->public_input(public_input_index, row_index); } + const column_type& public_input(std::uint32_t index) const override { + return assignment_ptr->public_input(index); + } + std::uint32_t public_inputs_amount() const override { return assignment_ptr->public_inputs_amount(); } @@ -198,20 +223,17 @@ namespace nil { } value_type constant(std::uint32_t constant_index, std::uint32_t row_index) const override { - if (check && used_rows.find(row_index) == used_rows.end()) { - std::cout << id << ": Not found constant " << constant_index << " on row " << row_index << std::endl; - BLUEPRINT_ASSERT(false); + if (check) { + const auto lookup_constant_cols = assignment_ptr->get_lookup_constant_cols(); + if (lookup_constant_cols.find(constant_index) == lookup_constant_cols.end() && + used_rows.find(row_index) == used_rows.end()) { + std::cout << id << ": Not found constant " << constant_index << " on row " << row_index << std::endl; + BLUEPRINT_ASSERT(false); + } } return std::const_pointer_cast>(assignment_ptr)->constant(constant_index, row_index); } - void fill_constant(std::uint32_t index, const column_type& column) override { - for (std::uint32_t i = 0; i < column.size(); i++) { - used_rows.insert(i); - } - assignment_ptr->fill_constant(index, column); - } - std::uint32_t constants_amount() const override { return assignment_ptr->constants_amount(); } @@ -220,6 +242,22 @@ namespace nil { return assignment_ptr->constant_column_size(index); } + void fill_constant(std::uint32_t index, const column_type& column) override { + assignment_ptr->fill_constant(index, column); + } + + const column_type& constant(std::uint32_t index) const override { + return assignment_ptr->constant(index); + } + + const std::set& get_lookup_constant_cols() const override { + return assignment_ptr->get_lookup_constant_cols(); + } + + std::uint32_t get_lookup_constant_amount() const override { + return assignment_ptr->get_lookup_constant_amount(); + } + value_type private_storage(std::uint32_t storage_index) const override { return assignment_ptr->private_storage(storage_index); } @@ -286,11 +324,23 @@ namespace nil { << "max_size: " << max_size << " " << "internal used rows size: " << used_rows.size() << "\n"; - std::cout << "internal used rows: "; + os << "internal used rows: "; for (const auto& it : used_rows) { - std::cout << it << " "; + os << it << " "; + } + os << "\n"; + + os << "lookup constants: "; + for (const auto &it : assignment_ptr->get_lookup_constant_cols()) { + os << it << " "; + } + os << "\n"; + + os << "lookup selectors: "; + for (const auto &it : assignment_ptr->get_lookup_selector_cols()) { + os << it << " "; } - std::cout << "\n"; + os << "\n"; os << std::dec; os << std::hex << std::setfill('0'); @@ -319,13 +369,13 @@ namespace nil { os << "| "; for (std::uint32_t j = 0; j < constants_size; j++) { os << std::setw(width) - << (i < assignment_ptr->constant_column_size(j) && is_used_row ? + << (i < assignment_ptr->constant_column_size(j) ? assignment_ptr->constant(j, i) : 0).data << " "; } os << "| "; // Selectors only need a single bit, so we do not renew the size here for (std::uint32_t j = 0; j < selectors_size - 1; j++) { - os << (i < assignment_ptr->selector_column_size(j) && is_used_row ? + os << (i < assignment_ptr->selector_column_size(j) ? assignment_ptr->selector(j, i) : 0).data << " "; } os << "\n"; @@ -369,15 +419,19 @@ namespace nil { ArithmetizationParams>> &assignments){ using variable_type = crypto3::zk::snark::plonk_variable; - const auto private_rows = assignments.get_used_rows(); + const auto& private_rows = assignments.get_used_rows(); + const auto& lookup_selector_cols = assignments.get_lookup_selector_cols(); const std::vector>> &gates = bp.gates(); - const std::set& used_gates = bp.get_used_gates(); + const auto& used_gates = bp.get_used_gates(); const std::vector> ©_constraints = bp.copy_constraints(); - const std::set& used_copy_constraints = bp.get_used_copy_constraints(); + const auto& used_copy_constraints = bp.get_used_copy_constraints(); + + const auto& lookup_gates = bp.lookup_gates(); + const auto& used_lookup_gates = bp.get_used_lookup_gates(); std::uint32_t row_index = 0; auto check_var = [&assignments, &row_index](const variable_type& var) { @@ -434,6 +488,29 @@ namespace nil { check_var(copy_constraints[i].second); } + for (const auto& i : used_lookup_gates) { + if (i >= lookup_gates.size()) { + std::cout << "No lookup gate " << i << "\n"; + return false; + } + row_index = 0; + crypto3::math::expression_for_each_variable_visitor visitor(check_var); + + crypto3::zk::snark::plonk_column selector = + assignments.selector(lookup_gates[i].tag_index); + + for (std::size_t selector_row = 0; + selector_row < selector.size(); selector_row++) { + if (!selector[selector_row].is_zero()) { + row_index = selector_row; + for (const auto &lookup_constraint: lookup_gates[i].constraints) { + for (const auto &constraint : lookup_constraint.lookup_input) { + visitor.visit(constraint); + } + } + } + } + } return true; } diff --git a/include/nil/blueprint/blueprint/plonk/circuit_proxy.hpp b/include/nil/blueprint/blueprint/plonk/circuit_proxy.hpp index 58ad77413..d0ec191a3 100644 --- a/include/nil/blueprint/blueprint/plonk/circuit_proxy.hpp +++ b/include/nil/blueprint/blueprint/plonk/circuit_proxy.hpp @@ -70,6 +70,10 @@ namespace nil { return *circuit_ptr; } + circuit& get() { + return *circuit_ptr; + } + std::uint32_t get_id() const { return id; } @@ -116,37 +120,37 @@ namespace nil { std::size_t add_gate(const std::vector &args) override { const auto selector_index = circuit_ptr->add_gate(args); - used_gates.insert(selector_index); + used_gates.insert(circuit_ptr->num_gates() - 1); return selector_index; } std::size_t add_gate(const constraint_type &args) override { const auto selector_index = circuit_ptr->add_gate(args); - used_gates.insert(selector_index); + used_gates.insert(circuit_ptr->num_gates() - 1); return selector_index; } std::size_t add_gate(const std::initializer_list &&args) override { const auto selector_index = circuit_ptr->add_gate(args); - used_gates.insert(selector_index); + used_gates.insert(circuit_ptr->num_gates() - 1); return selector_index; } std::size_t add_lookup_gate(const std::vector &args) override { const auto selector_index = circuit_ptr->add_lookup_gate(args); - used_lookup_gates.insert(selector_index); + used_lookup_gates.insert(circuit_ptr->num_lookup_gates() - 1); return selector_index; } std::size_t add_lookup_gate(const lookup_constraint_type &args) override { const auto selector_index = circuit_ptr->add_lookup_gate(args); - used_lookup_gates.insert(selector_index); + used_lookup_gates.insert(circuit_ptr->num_lookup_gates() - 1); return selector_index; } std::size_t add_lookup_gate(const std::initializer_list &&args) override { const auto selector_index = circuit_ptr->add_lookup_gate(args); - used_lookup_gates.insert(selector_index); + used_lookup_gates.insert(circuit_ptr->num_lookup_gates() - 1); return selector_index; } @@ -155,7 +159,6 @@ namespace nil { } void add_lookup_table(const typename ArithmetizationType::lookup_table_type &table) override { - used_lookup_tables.insert(circuit_ptr->lookup_tables().size()); circuit_ptr->add_lookup_table(table); } @@ -165,6 +168,8 @@ namespace nil { void reserve_table(std::string name) override { circuit_ptr->reserve_table(name); + const auto idx = circuit_ptr->get_reserved_indices().at(name) - 1; + used_lookup_tables.insert(idx); } const typename lookup_library::left_reserved_type @@ -197,6 +202,9 @@ namespace nil { const auto gates = circuit_ptr->gates(); const auto copy_constraints = circuit_ptr->copy_constraints(); + const auto lookup_gates = circuit_ptr->lookup_gates(); + const auto lookup_tables = circuit_ptr->lookup_tables(); + const auto lookup_table_indexes = circuit_ptr->get_reserved_indices(); os << "used_gates_size: " << used_gates.size() << " " << "gates_size: " << gates.size() << " " @@ -216,6 +224,37 @@ namespace nil { os << i << ": " << copy_constraints[i].first << " " << copy_constraints[i].second << "\n"; } + + std::cout << "\nlookup gates:\n"; + for (const auto& i : used_lookup_gates) { + os << i << ": selector: " << lookup_gates[i].tag_index + << " lookup constraints size: " << lookup_gates[i].constraints.size() << "\n"; + for (std::size_t j = 0; j < lookup_gates[i].constraints.size(); j++) { + os << "constraints size: " << lookup_gates[i].constraints[j].lookup_input.size() << "\n"; + os << "table index: " << lookup_gates[i].constraints[j].table_id << "\n"; + for (std::size_t k = 0; k < lookup_gates[i].constraints[j].lookup_input.size(); k++) { + os << lookup_gates[i].constraints[j].lookup_input[k] << "\n"; + } + std::cout << "\n"; + } + std::cout << "\n"; + } + + std::cout << "\nlookup tables:\n"; + for (const auto& i : used_lookup_tables) { + bool found = false; + for (const auto it : lookup_table_indexes) { + if (it.second == (i + 1)) { + os << i << ": " << it.first << "\n"; + found = true; + break; + } + } + if (!found) { + os << i << ": not found\n"; + } + } + os.flush(); os.flags(os_flags); } diff --git a/test/proxy.cpp b/test/proxy.cpp index 0b839dd6f..7d1c8daca 100644 --- a/test/proxy.cpp +++ b/test/proxy.cpp @@ -149,16 +149,13 @@ BOOST_AUTO_TEST_CASE(blueprint_circuit_proxy_lookup_tables_test) { circuits.emplace_back(bp_ptr, 0); circuits.emplace_back(bp_ptr, 1); - const auto lookup_tbale_0 = ArithmetizationType::lookup_table_type(); - const auto lookup_tbale_1 = ArithmetizationType::lookup_table_type(); - const auto lookup_tbale_2 = ArithmetizationType::lookup_table_type(); + const std::string lookup_tbale_name_0 = "sha256_sparse_base4/full"; + const std::string lookup_tbale_name_1 = "sha256_sparse_base4/first_column"; + const std::string lookup_tbale_name_2 = "sha256_reverse_sparse_base4/full"; - circuits[0].add_lookup_table(lookup_tbale_0); - circuits[0].add_lookup_table(lookup_tbale_1); - circuits[1].add_lookup_table(lookup_tbale_2); - - const auto &tables = circuits[0].lookup_tables(); - BOOST_ASSERT(tables.size() == 3); + circuits[0].reserve_table(lookup_tbale_name_0); + circuits[0].reserve_table(lookup_tbale_name_1); + circuits[1].reserve_table(lookup_tbale_name_2); std::set used_lookup_tables_0 = {0, 1}; BOOST_ASSERT(circuits[0].get_used_lookup_tables() == used_lookup_tables_0); @@ -391,54 +388,6 @@ BOOST_AUTO_TEST_CASE(blueprint_assignment_proxy_save_shared_test) { BOOST_ASSERT(assignment.shared_column_size(0) == 2); } -BOOST_AUTO_TEST_CASE(blueprint_assignment_proxy_fill_constant_test) { - using BlueprintFieldType = typename nil::crypto3::algebra::curves::pallas::base_field_type; - constexpr std::size_t WitnessColumns = 15; - constexpr std::size_t PublicInputColumns = 1; - constexpr std::size_t ConstantColumns = 5; - constexpr std::size_t SelectorColumns = 35; - - using ArithmetizationParams = - nil::crypto3::zk::snark::plonk_arithmetization_params; - using ArithmetizationType = nil::crypto3::zk::snark::plonk_constraint_system; - using var = nil::crypto3::zk::snark::plonk_variable; - using column_type = typename nil::crypto3::zk::snark::plonk_column; - - auto assignment_ptr = std::make_shared>(); - assignment_proxy assignment(assignment_ptr, 0); - - const column_type constant_col = {1, 2, 3, 4, 5}; - assignment.fill_constant(1, constant_col); - - BOOST_ASSERT(assignment.constant_column_size(1) == 5); - std::set used_rows = {0, 1, 2, 3, 4}; - BOOST_ASSERT(assignment.get_used_rows() == used_rows); -} - -BOOST_AUTO_TEST_CASE(blueprint_assignment_proxy_fill_selector_test) { - using BlueprintFieldType = typename nil::crypto3::algebra::curves::pallas::base_field_type; - constexpr std::size_t WitnessColumns = 15; - constexpr std::size_t PublicInputColumns = 1; - constexpr std::size_t ConstantColumns = 5; - constexpr std::size_t SelectorColumns = 35; - - using ArithmetizationParams = - nil::crypto3::zk::snark::plonk_arithmetization_params; - using ArithmetizationType = nil::crypto3::zk::snark::plonk_constraint_system; - using var = nil::crypto3::zk::snark::plonk_variable; - using column_type = typename nil::crypto3::zk::snark::plonk_column; - - auto assignment_ptr = std::make_shared>(); - assignment_proxy assignment(assignment_ptr, 0); - - const column_type selector_col = {1, 2, 3, 4, 5}; - assignment.fill_selector(1, selector_col); - - BOOST_ASSERT(assignment.selector_column_size(1) == 5); - std::set used_rows = {0, 1, 2, 3, 4}; - BOOST_ASSERT(assignment.get_used_rows() == used_rows); -} - BOOST_AUTO_TEST_CASE(blueprint_proxy_call_pack_lookup_tables_test) { using BlueprintFieldType = typename nil::crypto3::algebra::curves::pallas::base_field_type; constexpr std::size_t WitnessColumns = 15; @@ -465,9 +414,11 @@ BOOST_AUTO_TEST_CASE(blueprint_proxy_call_pack_lookup_tables_test) { nil::crypto3::zk::snark::pack_lookup_tables( bp.get_reserved_indices(), bp.get_reserved_tables(), - bp, assignment, lookup_columns_indices, + bp.get(), assignment.get(), lookup_columns_indices, usable_rows_amount); - std::set used_rows = {0, 1, 2, 3, 4}; - BOOST_ASSERT(assignment.get_used_rows() == used_rows); + std::set lookup_constant_cols = {0, 1, 2, 3, 4}; + BOOST_ASSERT(assignment.get_lookup_constant_cols() == lookup_constant_cols); + std::set lookup_selector_cols = {1}; + BOOST_ASSERT(assignment.get_lookup_selector_cols() == lookup_selector_cols); } \ No newline at end of file diff --git a/test/test_plonk_component.hpp b/test/test_plonk_component.hpp index bb31588bb..a1e5c8c51 100644 --- a/test/test_plonk_component.hpp +++ b/test/test_plonk_component.hpp @@ -287,7 +287,7 @@ namespace nil { desc.usable_rows_amount = zk::snark::pack_lookup_tables_horizontal( bp.get_reserved_indices(), bp.get_reserved_tables(), - bp, assignment, lookup_columns_indices, + bp, assignment, lookup_columns_indices, 0, desc.usable_rows_amount, 500000 );