Skip to content

Commit

Permalink
Merge pull request #1545 from o1-labs/feature/safe-foreign-add
Browse files Browse the repository at this point in the history
Fix foreign EC add soundness
  • Loading branch information
mitschabaude authored Apr 9, 2024
2 parents 6731ad3 + 770c7b6 commit 88822f9
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 32 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm
_Security_ in case of vulnerabilities.
-->

## [Unreleased](https://github.com/o1-labs/o1js/compare/1b6fd8b8e...HEAD)

### Breaking changes

- Add assertion to the foreign EC addition gadget that prevents degenerate cases https://github.com/o1-labs/o1js/pull/1545
- Fixes soundness of ECDSA; slightly increases its constraints from ~28k to 29k
- Breaks circuits that used EC addition, like ECDSA

## [0.18.0](https://github.com/o1-labs/o1js/compare/74948acac...1b6fd8b8e) - 2024-04-09

### Breaking changes
Expand Down
23 changes: 19 additions & 4 deletions src/lib/provable/field.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import {
assertSquare,
assertBoolean,
} from './gadgets/compatible.js';
import { toLinearCombination } from './gadgets/basic.js';
import { assertBilinear, toLinearCombination } from './gadgets/basic.js';
import {
FieldType,
FieldVar,
Expand All @@ -29,6 +29,7 @@ import {
lessThanFull,
lessThanOrEqualFull,
} from './gadgets/comparison.js';
import { toVar } from './gadgets/common.js';

// external API
export { Field };
Expand Down Expand Up @@ -768,6 +769,16 @@ class Field {
}
// inv() proves that a field element is non-zero, using 1 constraint.
// so this takes 1-2 generic gates, while x.equals(y).assertTrue() takes 3-5
if (isConstant(y)) {
// custom single generic gate for (x - y) * z = 1
// TODO remove once assertMul() handles these cases
let x = toVar(this);
let y0 = toFp(y);
let z = existsOne(() => Fp.inverse(this.toBigInt() - y0) ?? 0n);
// 1*x*z + 0*x + (-y)*z + (-1) = 0
assertBilinear(x, z, [1n, 0n, -y0, -1n]);
return;
}
this.sub(y).inv();
} catch (err) {
throw withMessage(err, message);
Expand Down Expand Up @@ -874,12 +885,12 @@ class Field {
*
* @return A {@link Field} element that is equal to the result of AST that was previously on this {@link Field} element.
*/
seal() {
seal(): VarField | ConstantField {
let { constant, terms } = toLinearCombination(this.value);
if (terms.length === 0) return new Field(constant);
if (terms.length === 0) return ConstantField(constant);
if (terms.length === 1 && constant === 0n) {
let [c, x] = terms[0];
if (c === 1n) return new Field(x);
if (c === 1n) return VarField(x);
}
let x = existsOne(() => this.toBigInt());
this.assertEquals(x);
Expand Down Expand Up @@ -1226,3 +1237,7 @@ Warning: whatever happens inside asProver() will not be part of the zk proof.
function VarField(x: VarFieldVar): VarField {
return new Field(x) as VarField;
}

function ConstantField(x: ConstantFieldVar | bigint): ConstantField {
return new Field(x) as ConstantField;
}
55 changes: 53 additions & 2 deletions src/lib/provable/gadgets/basic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,17 @@ import {
import { toVar } from './common.js';
import { Gates, fieldVar } from '../gates.js';
import { TupleN } from '../../util/types.js';
import { existsOne } from '../core/exists.js';
import { exists, existsOne } from '../core/exists.js';
import { createField } from '../core/field-constructor.js';
import { assert } from '../../util/assert.js';

export { assertMul, arrayGet, assertOneOf };
export {
assertMul,
assertBilinear,
arrayGet,
assertOneOf,
assertNotVectorEquals,
};

// internal
export {
Expand Down Expand Up @@ -133,6 +140,50 @@ function assertOneOf(x: Field, allowed: [bigint, bigint, ...bigint[]]) {
}
}

/**
* Assert that x does not equal a constant vector c:
*
* `(x[0],...,x[n-1]) !== (c[0],...,c[n-1])`
*
* We prove this by witnessing a vector z such that:
*
* `sum_i (x[i] - c[i])*z[i] === 1`
*
* If we had `x[i] === c[i]` for all i, the left-hand side would be 0 regardless of z.
*/
function assertNotVectorEquals(x: Field[], c: [bigint, bigint, ...bigint[]]) {
let xv = x.map(toVar);
let n = c.length;
assert(n > 1 && x.length === n, 'vector lengths must match');

// witness vector z
let z = exists(n, () => {
let z = Array(n).fill(0n);

// find index where x[i] !== c[i]
let i = x.findIndex((xi, i) => xi.toBigInt() !== c[i]);
if (i === -1) return z;

// z[i] = (x[i] - c[i])^-1
z[i] = Fp.inverse(Fp.sub(x[i].toBigInt(), c[i])) ?? 0n;
return z;
});

let products = xv.map((xi, i) => {
// (xi - ci)*zi = xi*zi + 0*xi - ci*zi + 0
return bilinear(xi, z[i], [1n, 0n, -c[i], 0n]);
});

// sum_i (xi - ci)*zi = 1
let sum = products[0];
for (let i = 1; i < n - 1; i++) {
// sum = sum + products[i]
sum = bilinear(sum, products[i], [0n, 1n, 1n, 0n]);
}
// sum + products[n - 1] - 1 === 0
assertBilinear(sum, products[n - 1], [0n, 1n, 1n, -1n]);
}

// low-level helpers to create generic gates

/**
Expand Down
33 changes: 22 additions & 11 deletions src/lib/provable/gadgets/elliptic-curve.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { Field } from '../field.js';
import { Provable } from '../provable.js';
import { assert } from './common.js';
import { Field3, ForeignField, split, weakBound } from './foreign-field.js';
import { l2, multiRangeCheck } from './range-check.js';
import { l, l2, multiRangeCheck } from './range-check.js';
import { sha256 } from 'js-sha256';
import {
bigIntToBytes,
Expand All @@ -17,7 +17,7 @@ import {
import { Bool } from '../bool.js';
import { provable } from '../types/struct.js';
import { assertPositiveInteger } from '../../../bindings/crypto/non-negative.js';
import { arrayGet } from './basic.js';
import { arrayGet, assertNotVectorEquals } from './basic.js';
import { sliceField3 } from './bit-slices.js';
import { Hashed } from '../packed.js';
import { exists } from '../core/exists.js';
Expand Down Expand Up @@ -56,6 +56,8 @@ function add(p1: Point, p2: Point, Curve: { modulus: bigint }) {
let { x: x1, y: y1 } = p1;
let { x: x2, y: y2 } = p2;
let f = Curve.modulus;
let [f0, f1, f2] = split(f);
let [, , fx22] = split(f * 2n);

// constant case
if (Point.isConstant(p1) && Point.isConstant(p2)) {
Expand All @@ -66,11 +68,10 @@ function add(p1: Point, p2: Point, Curve: { modulus: bigint }) {
// witness and range-check slope, x3, y3
let witnesses = exists(9, () => {
let [x1_, x2_, y1_, y2_] = Field3.toBigints(x1, x2, y1, y2);
let denom = inverse(mod(x1_ - x2_, f), f);
let denom = inverse(mod(x1_ - x2_, f), f) ?? 0n;

let m = denom !== undefined ? mod((y1_ - y2_) * denom, f) : 0n;
let m2 = mod(m * m, f);
let x3 = mod(m2 - x1_ - x2_, f);
let m = mod((y1_ - y2_) * denom, f);
let x3 = mod(m * m - x1_ - x2_, f);
let y3 = mod(m * (x1_ - x3) - y1_, f);

return [...split(m), ...split(x3), ...split(y3)];
Expand All @@ -81,8 +82,16 @@ function add(p1: Point, p2: Point, Curve: { modulus: bigint }) {
let y3: Field3 = [y30, y31, y32];
ForeignField.assertAlmostReduced([m, x3, y3], f);

// check that x1 != x2
// we assume x1, x2 are almost reduced, so deltaX <= x1 - x2 + f < 3f
// which means we need to check that deltaX != 0, f, 2f
let deltaX = ForeignField.sub(x1, x2, f);
let deltaX01 = deltaX[0].add(deltaX[1].mul(1n << l)).seal();
assertNotVectorEquals([deltaX01, deltaX[2]], [0n, 0n]); // != 0
assertNotVectorEquals([deltaX01, deltaX[2]], [f0 + (f1 << l), f2]); // != f
deltaX[2].assertNotEquals(fx22); // != 2f (stronger check bc assuming deltaX < f doesn't harm completeness)

// (x1 - x2)*m = y1 - y2
let deltaX = ForeignField.Sum(x1).sub(x2);
let deltaY = ForeignField.Sum(y1).sub(y2);
ForeignField.assertMul(deltaX, m, deltaY, f);

Expand Down Expand Up @@ -111,11 +120,10 @@ function double(p1: Point, Curve: { modulus: bigint; a: bigint }) {
// witness and range-check slope, x3, y3
let witnesses = exists(9, () => {
let [x1_, y1_] = Field3.toBigints(x1, y1);
let denom = inverse(mod(2n * y1_, f), f);
let denom = inverse(mod(2n * y1_, f), f) ?? 0n;

let m = denom !== undefined ? mod(3n * mod(x1_ ** 2n, f) * denom, f) : 0n;
let m2 = mod(m * m, f);
let x3 = mod(m2 - 2n * x1_, f);
let m = mod(3n * mod(x1_ ** 2n, f) * denom, f);
let x3 = mod(m * m - 2n * x1_, f);
let y3 = mod(m * (x1_ - x3) - y1_, f);

return [...split(m), ...split(x3), ...split(y3)];
Expand Down Expand Up @@ -423,6 +431,9 @@ function multiScalarMul(
table.map((point) => HashedPoint.hash(point))
);

// initialize sum to the initial aggregator, which is expected to be unrelated to any point that this gadget is used with
// note: this is a trick to ensure _completeness_ of the gadget
// soundness follows because add() and double() are sound, on all inputs that are valid non-zero curve points
ia ??= initialAggregator(Curve);
let sum = Point.from(ia);

Expand Down
8 changes: 8 additions & 0 deletions src/lib/testing/equivalent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,14 @@ function record<Specs extends { [k in string]: Spec<any, any> }>(
};
}

function map<T1, T2, S1, S2>(
{ from, to }: { from: ProvableSpec<T1, T2>; to: ProvableSpec<S1, S2> },
there: (t: T1) => S1
): ProvableSpec<S1, S2>;
function map<T1, T2, S1, S2>(
{ from, to }: { from: FromSpec<T1, T2>; to: Spec<S1, S2> },
there: (t: T1) => S1
): Spec<S1, S2>;
function map<T1, T2, S1, S2>(
{ from, to }: { from: FromSpec<T1, T2>; to: Spec<S1, S2> },
there: (t: T1) => S1
Expand Down
Loading

0 comments on commit 88822f9

Please sign in to comment.