Skip to content

Commit

Permalink
Merge pull request #3714 from kevaundray/kw/use-optimized-bls-msm
Browse files Browse the repository at this point in the history
chore: use py-arkworks's multi-exp method inside of `g1_lincomb` and `g2_lincomb`
  • Loading branch information
ralexstokes authored Apr 23, 2024
2 parents 830b255 + a526cdf commit b13e03e
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 8 deletions.
1 change: 1 addition & 0 deletions pysetup/spec_builders/deneb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def imports(cls, preset_name: str):
def preparations(cls):
return '''
T = TypeVar('T') # For generic function
TPoint = TypeVar('TPoint') # For generic function. G1 or G2 point.
'''

@classmethod
Expand Down
14 changes: 10 additions & 4 deletions specs/_features/eip7594/polynomial-commitments-sampling.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,18 @@ def coset_evals_to_cell(coset_evals: CosetEvals) -> Cell:
```python
def g2_lincomb(points: Sequence[G2Point], scalars: Sequence[BLSFieldElement]) -> Bytes96:
"""
BLS multiscalar multiplication in G2. This function can be optimized using Pippenger's algorithm and variants.
BLS multiscalar multiplication in G2. This can be naively implemented using double-and-add.
"""
assert len(points) == len(scalars)
result = bls.Z2()
for x, a in zip(points, scalars):
result = bls.add(result, bls.multiply(bls.bytes96_to_G2(x), a))

if len(points) == 0:
return bls.G2_to_bytes96(bls.Z2())

points_g2 = []
for point in points:
points_g2.append(bls.bytes96_to_G2(point))

result = bls.multi_exp(points_g2, scalars)
return Bytes96(bls.G2_to_bytes96(result))
```

Expand Down
27 changes: 23 additions & 4 deletions specs/deneb/polynomial-commitments.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
- [`reverse_bits`](#reverse_bits)
- [`bit_reversal_permutation`](#bit_reversal_permutation)
- [BLS12-381 helpers](#bls12-381-helpers)
- [`multi_exp`](#multi_exp)
- [`hash_to_bls_field`](#hash_to_bls_field)
- [`bytes_to_bls_field`](#bytes_to_bls_field)
- [`bls_field_to_bytes`](#bls_field_to_bytes)
Expand Down Expand Up @@ -146,6 +147,18 @@ def bit_reversal_permutation(sequence: Sequence[T]) -> Sequence[T]:

### BLS12-381 helpers


#### `multi_exp`

This function performs a multi-scalar multiplication between `points` and `integers`. `points` can either be in G1 or G2.

```python
def multi_exp(points: Sequence[TPoint],
integers: Sequence[uint64]) -> Sequence[TPoint]:
# pylint: disable=unused-argument
...
```

#### `hash_to_bls_field`

```python
Expand Down Expand Up @@ -274,12 +287,18 @@ def div(x: BLSFieldElement, y: BLSFieldElement) -> BLSFieldElement:
```python
def g1_lincomb(points: Sequence[KZGCommitment], scalars: Sequence[BLSFieldElement]) -> KZGCommitment:
"""
BLS multiscalar multiplication. This function can be optimized using Pippenger's algorithm and variants.
BLS multiscalar multiplication in G1. This can be naively implemented using double-and-add.
"""
assert len(points) == len(scalars)
result = bls.Z1()
for x, a in zip(points, scalars):
result = bls.add(result, bls.multiply(bls.bytes48_to_G1(x), a))

if len(points) == 0:
return bls.G1_to_bytes48(bls.Z1())

points_g1 = []
for point in points:
points_g1.append(bls.bytes48_to_G1(point))

result = bls.multi_exp(points_g1, scalars)
return KZGCommitment(bls.G1_to_bytes48(result))
```

Expand Down
39 changes: 39 additions & 0 deletions tests/core/pyspec/eth2spec/utils/bls.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,45 @@ def multiply(point, scalar):
return py_ecc_mul(point, scalar)


def multi_exp(points, integers):
"""
Performs a multi-scalar multiplication between
`points` and `integers`.
`points` can either be in G1 or G2.
"""
# Since this method accepts either G1 or G2, we need to know
# the type of the point to return. Hence, we need at least one point.
if not points or not integers:
raise Exception("Cannot call multi_exp with zero points or zero integers")

if bls == arkworks_bls or bls == fastest_bls:
# Convert integers into arkworks Scalars
scalars = []
for integer in integers:
int_as_bytes = integer.to_bytes(32, 'little')
scalars.append(arkworks_Scalar.from_le_bytes(int_as_bytes))

# Check if we need to perform a G1 or G2 multiexp
if isinstance(points[0], arkworks_G1):
return arkworks_G1.multiexp_unchecked(points, scalars)
elif isinstance(points[0], arkworks_G2):
return arkworks_G2.multiexp_unchecked(points, scalars)
else:
raise Exception("Invalid point type")

result = None
if isinstance(points[0], py_ecc_G1):
result = Z1()
elif isinstance(points[0], py_ecc_G2):
result = Z2()
else:
raise Exception("Invalid point type")

for point, scalar in points.zip(integers):
result = add(result, multiply(point, scalar))
return result


def neg(point):
"""
Returns the point negation of `point`
Expand Down

0 comments on commit b13e03e

Please sign in to comment.