Skip to content

Commit

Permalink
Keep lookup constant/selector columns
Browse files Browse the repository at this point in the history
  • Loading branch information
akokoshn committed Nov 30, 2023
1 parent ff5c51b commit 6f0f090
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 110 deletions.
45 changes: 45 additions & 0 deletions include/nil/blueprint/blueprint/plonk/assignment.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

#include <algorithm>
#include <limits>
#include <unordered_set>

#include <nil/crypto3/zk/snark/arithmetization/plonk/table_description.hpp>
#include <nil/crypto3/zk/snark/arithmetization/plonk/constraint_system.hpp>
Expand Down Expand Up @@ -69,6 +70,8 @@ namespace nil {
std::uint32_t assignment_allocated_rows = 0;
std::vector<value_type> assignment_private_storage;
shared_container_type shared_storage; // results of the previously prover
std::set<std::uint32_t> lookup_constant_cols;
std::set<std::uint32_t> lookup_selector_cols;
public:
static constexpr const std::size_t private_storage_index = std::numeric_limits<std::size_t>::max();

Expand Down Expand Up @@ -131,6 +134,19 @@ namespace nil {
}
}

void fill_selector(std::uint32_t index, const column_type& column) override {
lookup_selector_cols.insert(index);
zk_type::fill_selector(index, column);
}

virtual const std::set<std::uint32_t>& get_lookup_selector_cols() const {
return lookup_selector_cols;
}

virtual std::uint32_t get_lookup_selector_amount() const {
return lookup_selector_cols.size();
}

virtual value_type &shared(std::uint32_t shared_index, std::uint32_t row_index) {
if (shared_storage[shared_index].size() <= row_index) {
shared_storage[shared_index].resize(row_index + 1);
Expand All @@ -151,6 +167,10 @@ namespace nil {
return shared_storage.size();
}

virtual const column_type& shared(std::uint32_t index) const {
return shared_storage[index];
}

virtual value_type &witness(std::uint32_t witness_index, std::uint32_t row_index) {
BLUEPRINT_ASSERT(witness_index < ArithmetizationParams::WitnessColumns);

Expand All @@ -176,6 +196,10 @@ namespace nil {
return zk_type::witnesses_amount();
}

virtual const column_type& witness(std::uint32_t index) const {
return zk_type::witness(index);
}

virtual value_type &public_input(
std::uint32_t public_input_index, std::uint32_t row_index) {

Expand Down Expand Up @@ -204,6 +228,10 @@ namespace nil {
return zk_type::public_inputs_amount();
}

virtual const column_type& public_input(std::uint32_t index) const {
return zk_type::public_input(index);
}

virtual value_type &constant(
std::uint32_t constant_index, std::uint32_t row_index) {

Expand All @@ -225,6 +253,23 @@ namespace nil {
return zk_type::constant(constant_index)[row_index];
}

virtual const column_type& constant(std::uint32_t index) const {
return zk_type::constant(index);
}

void fill_constant(std::uint32_t index, const column_type& column) override {
lookup_constant_cols.insert(index);
zk_type::fill_constant(index, column);
}

virtual const std::set<std::uint32_t>& get_lookup_constant_cols() const {
return lookup_constant_cols;
}

virtual std::uint32_t get_lookup_constant_amount() const {
return lookup_constant_cols.size();
}

virtual std::uint32_t constant_column_size(std::uint32_t col_idx) const {
return this->_public_table.constant_column_size(col_idx);
}
Expand Down
125 changes: 101 additions & 24 deletions include/nil/blueprint/blueprint/plonk/assignment_proxy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ namespace nil {
return *assignment_ptr;
}

assignment<ArithmetizationType>& get() {
return *assignment_ptr;
}

std::uint32_t get_id() const {
return id;
}
Expand All @@ -92,9 +96,13 @@ namespace nil {
}

value_type selector(std::size_t selector_index, std::uint32_t row_index) const override {
if (check && used_rows.find(row_index) == used_rows.end()) {
std::cout << id << ": Not found selector " << selector_index << " on row " << row_index << std::endl;
BLUEPRINT_ASSERT(false);
if (check) {
const auto lookup_selector_cols = assignment_ptr->get_lookup_selector_cols();
if (lookup_selector_cols.find(selector_index) == lookup_selector_cols.end() &&
used_rows.find(row_index) == used_rows.end()) {
std::cout << id << ": Not found selector " << selector_index << " on row " << row_index << std::endl;
BLUEPRINT_ASSERT(false);
}
}
return std::const_pointer_cast<const assignment<ArithmetizationType>>(assignment_ptr)->selector(selector_index, row_index);
}
Expand Down Expand Up @@ -126,12 +134,17 @@ namespace nil {
}

void fill_selector(std::uint32_t index, const column_type& column) override {
for (std::uint32_t i = 0; i < column.size(); i++) {
used_rows.insert(i);
}
assignment_ptr->fill_selector(index, column);
}

const std::set<std::uint32_t>& get_lookup_selector_cols() const override {
return assignment_ptr->get_lookup_selector_cols();
}

std::uint32_t get_lookup_selector_amount() const override {
return assignment_ptr->get_lookup_selector_amount();
}

value_type &shared(std::uint32_t shared_index, std::uint32_t row_index) override {
return assignment_ptr->shared(shared_index, row_index);
}
Expand All @@ -148,6 +161,10 @@ namespace nil {
return assignment_ptr->shared_column_size(index);
}

const column_type& shared(std::uint32_t index) const override {
return assignment_ptr->shared(index);
}

std::uint32_t shareds_amount() const override {
return assignment_ptr->shareds_amount();
}
Expand All @@ -165,6 +182,10 @@ namespace nil {
return std::const_pointer_cast<const assignment<ArithmetizationType>>(assignment_ptr)->witness(witness_index, row_index);
}

const column_type& witness(std::uint32_t index) const override {
return assignment_ptr->witness(index);
}

std::uint32_t witnesses_amount() const override {
return assignment_ptr->witnesses_amount();
}
Expand All @@ -183,6 +204,10 @@ namespace nil {
return std::const_pointer_cast<const assignment<ArithmetizationType>>(assignment_ptr)->public_input(public_input_index, row_index);
}

const column_type& public_input(std::uint32_t index) const override {
return assignment_ptr->public_input(index);
}

std::uint32_t public_inputs_amount() const override {
return assignment_ptr->public_inputs_amount();
}
Expand All @@ -198,20 +223,17 @@ namespace nil {
}

value_type constant(std::uint32_t constant_index, std::uint32_t row_index) const override {
if (check && used_rows.find(row_index) == used_rows.end()) {
std::cout << id << ": Not found constant " << constant_index << " on row " << row_index << std::endl;
BLUEPRINT_ASSERT(false);
if (check) {
const auto lookup_constant_cols = assignment_ptr->get_lookup_constant_cols();
if (lookup_constant_cols.find(constant_index) == lookup_constant_cols.end() &&
used_rows.find(row_index) == used_rows.end()) {
std::cout << id << ": Not found constant " << constant_index << " on row " << row_index << std::endl;
BLUEPRINT_ASSERT(false);
}
}
return std::const_pointer_cast<const assignment<ArithmetizationType>>(assignment_ptr)->constant(constant_index, row_index);
}

void fill_constant(std::uint32_t index, const column_type& column) override {
for (std::uint32_t i = 0; i < column.size(); i++) {
used_rows.insert(i);
}
assignment_ptr->fill_constant(index, column);
}

std::uint32_t constants_amount() const override {
return assignment_ptr->constants_amount();
}
Expand All @@ -220,6 +242,22 @@ namespace nil {
return assignment_ptr->constant_column_size(index);
}

void fill_constant(std::uint32_t index, const column_type& column) override {
assignment_ptr->fill_constant(index, column);
}

const column_type& constant(std::uint32_t index) const override {
return assignment_ptr->constant(index);
}

const std::set<std::uint32_t>& get_lookup_constant_cols() const override {
return assignment_ptr->get_lookup_constant_cols();
}

std::uint32_t get_lookup_constant_amount() const override {
return assignment_ptr->get_lookup_constant_amount();
}

value_type private_storage(std::uint32_t storage_index) const override {
return assignment_ptr->private_storage(storage_index);
}
Expand Down Expand Up @@ -286,11 +324,23 @@ namespace nil {
<< "max_size: " << max_size << " "
<< "internal used rows size: " << used_rows.size() << "\n";

std::cout << "internal used rows: ";
os << "internal used rows: ";
for (const auto& it : used_rows) {
std::cout << it << " ";
os << it << " ";
}
os << "\n";

os << "lookup constants: ";
for (const auto &it : assignment_ptr->get_lookup_constant_cols()) {
os << it << " ";
}
os << "\n";

os << "lookup selectors: ";
for (const auto &it : assignment_ptr->get_lookup_selector_cols()) {
os << it << " ";
}
std::cout << "\n";
os << "\n";

os << std::dec;
os << std::hex << std::setfill('0');
Expand Down Expand Up @@ -319,13 +369,13 @@ namespace nil {
os << "| ";
for (std::uint32_t j = 0; j < constants_size; j++) {
os << std::setw(width)
<< (i < assignment_ptr->constant_column_size(j) && is_used_row ?
<< (i < assignment_ptr->constant_column_size(j) ?
assignment_ptr->constant(j, i) : 0).data << " ";
}
os << "| ";
// Selectors only need a single bit, so we do not renew the size here
for (std::uint32_t j = 0; j < selectors_size - 1; j++) {
os << (i < assignment_ptr->selector_column_size(j) && is_used_row ?
os << (i < assignment_ptr->selector_column_size(j) ?
assignment_ptr->selector(j, i) : 0).data << " ";
}
os << "\n";
Expand Down Expand Up @@ -369,15 +419,19 @@ namespace nil {
ArithmetizationParams>> &assignments){

using variable_type = crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>;
const auto private_rows = assignments.get_used_rows();
const auto& private_rows = assignments.get_used_rows();
const auto& lookup_selector_cols = assignments.get_lookup_selector_cols();

const std::vector<crypto3::zk::snark::plonk_gate<BlueprintFieldType,
crypto3::zk::snark::plonk_constraint<BlueprintFieldType>>> &gates = bp.gates();
const std::set<std::uint32_t>& used_gates = bp.get_used_gates();
const auto& used_gates = bp.get_used_gates();

const std::vector<crypto3::zk::snark::plonk_copy_constraint<BlueprintFieldType>> &copy_constraints =
bp.copy_constraints();
const std::set<std::uint32_t>& used_copy_constraints = bp.get_used_copy_constraints();
const auto& used_copy_constraints = bp.get_used_copy_constraints();

const auto& lookup_gates = bp.lookup_gates();
const auto& used_lookup_gates = bp.get_used_lookup_gates();

std::uint32_t row_index = 0;
auto check_var = [&assignments, &row_index](const variable_type& var) {
Expand Down Expand Up @@ -434,6 +488,29 @@ namespace nil {
check_var(copy_constraints[i].second);
}

for (const auto& i : used_lookup_gates) {
if (i >= lookup_gates.size()) {
std::cout << "No lookup gate " << i << "\n";
return false;
}
row_index = 0;
crypto3::math::expression_for_each_variable_visitor<variable_type> visitor(check_var);

crypto3::zk::snark::plonk_column <BlueprintFieldType> selector =
assignments.selector(lookup_gates[i].tag_index);

for (std::size_t selector_row = 0;
selector_row < selector.size(); selector_row++) {
if (!selector[selector_row].is_zero()) {
row_index = selector_row;
for (const auto &lookup_constraint: lookup_gates[i].constraints) {
for (const auto &constraint : lookup_constraint.lookup_input) {
visitor.visit(constraint);
}
}
}
}
}
return true;
}

Expand Down
Loading

0 comments on commit 6f0f090

Please sign in to comment.