Skip to content

Commit

Permalink
hash to curve OK
Browse files Browse the repository at this point in the history
  • Loading branch information
feltroidprime committed Sep 28, 2024
1 parent 00aa8fd commit cc966e4
Show file tree
Hide file tree
Showing 9 changed files with 1,155 additions and 20 deletions.
8 changes: 8 additions & 0 deletions hydra/garaga/precompiled_circuits/all_circuits.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
RHSFinalizeAccCircuit,
SlopeInterceptSamePointCircuit,
)
from garaga.precompiled_circuits.compilable_circuits.isogeny import ApplyIsogenyCircuit
from garaga.starknet.cli.utils import create_directory


Expand Down Expand Up @@ -78,6 +79,7 @@ class CircuitID(Enum):
MP_CHECK_FINALIZE_BLS = int.from_bytes(b"mp_check_finalize_bls", "big")
FP12_MUL_ASSERT_ONE = int.from_bytes(b"fp12_mul_assert_one", "big")
EVAL_E12D = int.from_bytes(b"eval_e12d", "big")
APPLY_ISOGENY = int.from_bytes(b"apply_isogeny", "big")


ALL_CAIRO_CIRCUITS = {
Expand Down Expand Up @@ -217,6 +219,12 @@ class CircuitID(Enum):
"filename": "extf_mul",
"curve_ids": [CurveID.BN254, CurveID.BLS12_381],
},
CircuitID.APPLY_ISOGENY: {
"class": ApplyIsogenyCircuit,
"params": None,
"filename": "isogeny",
"curve_ids": [CurveID.BLS12_381],
},
}


Expand Down
59 changes: 59 additions & 0 deletions hydra/garaga/precompiled_circuits/compilable_circuits/isogeny.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import garaga.modulo_circuit_structs as structs
from garaga.definitions import CurveID
from garaga.precompiled_circuits.compilable_circuits.base import (
BaseModuloCircuit,
ModuloCircuit,
PyFelt,
)
from garaga.signature import get_isogeny_to_g1_map


class ApplyIsogenyCircuit(BaseModuloCircuit):
def __init__(
self, curve_id: int, auto_run: bool = True, compilation_mode: int = 1
) -> None:
super().__init__(
name=f"apply_isogeny_{CurveID(curve_id).name.lower()}",
input_len=2,
curve_id=curve_id,
auto_run=auto_run,
compilation_mode=compilation_mode,
)

def build_input(self) -> list[PyFelt]:
return [self.field(44), self.field(4)]

def _run_circuit_inner(self, input: list[PyFelt]) -> ModuloCircuit:
circuit = ModuloCircuit(
self.name,
self.curve_id,
compilation_mode=self.compilation_mode,
)
px, py = circuit.write_struct(structs.G1PointCircuit(name="pt", elmts=input))
x_rational, y_rational = get_isogeny_to_g1_map(CurveID(self.curve_id))
x_num = [
circuit.set_or_get_constant(c) for c in x_rational.numerator.coefficients
]
x_den = [
circuit.set_or_get_constant(c) for c in x_rational.denominator.coefficients
]

y_num = [
circuit.set_or_get_constant(c) for c in y_rational.numerator.coefficients
]
y_den = [
circuit.set_or_get_constant(c) for c in y_rational.denominator.coefficients
]

x_affine_num = circuit.eval_horner(x_num, px, "x_num")
x_affine_den = circuit.eval_horner(x_den, px, "x_den")
x_affine = circuit.div(x_affine_num, x_affine_den)
y_affine_num = circuit.eval_horner(y_num, px, "y_num")
y_affine_den = circuit.eval_horner(y_den, px, "y_den")
y_affine_eval = circuit.div(y_affine_num, y_affine_den)
y_affine = circuit.mul(y_affine_eval, py)
circuit.extend_struct_output(
structs.G1PointCircuit(name="res", elmts=[x_affine, y_affine])
)

return circuit
16 changes: 9 additions & 7 deletions hydra/garaga/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import hashlib
from typing import Protocol, TypeVar

from garaga.algebra import Polynomial, RationalFunction
from garaga.algebra import Polynomial, PyFelt, RationalFunction
from garaga.definitions import CURVES, CurveID, G1Point, get_base_field
from garaga.hints.io import bytes_to_u32_array

Expand Down Expand Up @@ -163,7 +163,7 @@ def hash_to_field(
print(f"element {element.bit_length()}")
output.append(element)

return [field(x).value for x in output]
return [field(x) for x in output]


def hash_to_curve(
Expand Down Expand Up @@ -192,13 +192,13 @@ def hash_to_curve(
return apply_isogeny(sum).scalar_mul(cofactor)


def map_to_curve(field_element: int, curve_id: CurveID) -> G1Point:
def map_to_curve(field_element: PyFelt, curve_id: CurveID) -> G1Point:
field = get_base_field(curve_id)
a = field(CURVES[curve_id.value].swu_params.A)
b = field(CURVES[curve_id.value].swu_params.B)
z = field(CURVES[curve_id.value].swu_params.Z)

u = field(field_element)
u = field_element
zeta_u2 = z * u**2
ta = zeta_u2**2 + zeta_u2
num_x1 = b * (ta + field.one())
Expand Down Expand Up @@ -351,6 +351,7 @@ def apply_isogeny(pt: G1Point) -> G1Point:
if __name__ == "__main__":
from garaga.hints.io import int_to_u384

field = get_base_field(CurveID.BLS12_381)
message = b"Hello, World!"
sha_message = hashlib.sha256(message).digest()
print(f"sha_message {sha_message.hex()}")
Expand All @@ -371,7 +372,7 @@ def test_hash_to_field(message: bytes):
# assert res == expected, f"Expected {expected}, got {res}"

def test_map_to_curve():
u = 42
u = field(42)
res = map_to_curve(field_element=u, curve_id=CurveID.BLS12_381)
print(f"res {int_to_u384(res.x)} {int_to_u384(res.y)}")

Expand Down Expand Up @@ -403,9 +404,10 @@ def test_hash_to_curve(message: bytes):
message=message, curve_id=CurveID.BLS12_381, hash_name="sha256"
)

assert res == expected, f"Expected {expected}, got {res}"
# assert res == expected, f"Expected {expected}, got {res}"
print(f"res {int_to_u384(res.x)} {int_to_u384(res.y)}")

# test_hash_to_field(message=message)

test_map_to_curve()
# test_hash_to_curve(message=message)
test_hash_to_curve(message=message)
134 changes: 134 additions & 0 deletions hydra/garaga/starknet/tests_and_calldata_generators/map_to_curve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from dataclasses import dataclass

import garaga.modulo_circuit_structs as structs
from garaga.algebra import PyFelt
from garaga.definitions import CURVES, CurveID, G1Point, get_base_field
from garaga.hints.io import bigint_split, int_to_u384
from garaga.signature import apply_isogeny, hash_to_field
from garaga.starknet.tests_and_calldata_generators.msm import MSMCalldataBuilder


@dataclass(slots=True)
class MapToCurveHint:
gx1_is_square: bool
y1: PyFelt
y_flag: bool

def to_cairo(self) -> str:
return f"MapToCurveHint {{ gx1_is_square: {str(self.gx1_is_square).lower()}, y1: {int_to_u384(self.y1.value, as_hex=True)}, y_flag: {str(self.y_flag).lower()} }}"

def to_calldata(self) -> list[int]:
cd = []
cd.append(int(self.gx1_is_square))
cd.extend(bigint_split(self.y1.value))
cd.append(int(self.y_flag))
return cd


@dataclass(slots=True)
class HashToCurveHint:
f0_hint: MapToCurveHint
f1_hint: MapToCurveHint
scalar_mul_hint: structs.Struct
derive_point_from_x_hint: structs.Struct

def to_cairo(self) -> str:
return f"HashToCurveHint {{ f0_hint: {self.f0_hint.to_cairo()}, f1_hint: {self.f1_hint.to_cairo()}, scalar_mul_hint: {self.scalar_mul_hint.serialize(raw=True)}, derive_point_from_x_hint: {self.derive_point_from_x_hint.serialize(raw=True)} }}"

def to_calldata(self) -> list[int]:
cd = []
cd.extend(self.f0_hint.to_calldata())
cd.extend(self.f1_hint.to_calldata())
cd.extend(self.scalar_mul_hint.serialize_to_calldata())
cd.extend(self.derive_point_from_x_hint.serialize_to_calldata())
return cd


def build_map_to_curve_hint(u: PyFelt) -> tuple[G1Point, MapToCurveHint]:
field = get_base_field(CurveID.BLS12_381)
a = field(CURVES[CurveID.BLS12_381.value].swu_params.A)
b = field(CURVES[CurveID.BLS12_381.value].swu_params.B)
z = field(CURVES[CurveID.BLS12_381.value].swu_params.Z)

zeta_u2 = z * u**2
ta = zeta_u2**2 + zeta_u2
num_x1 = b * (ta + field.one())

if ta.value == 0:
div = a * z
else:
div = a * -ta

num2_x1 = num_x1**2
div2 = div**2
div3 = div2 * div
assert div3.value != 0

num_gx1 = (num2_x1 + a * div2) * num_x1 + b * div3
num_x2 = zeta_u2 * num_x1

gx1 = num_gx1 / div3
gx1_square = gx1.is_quad_residue()
if gx1_square:
print(f"square res")
y1 = gx1.sqrt(min_root=False)
assert y1 * y1 == gx1
else:
print(f"not square res")
y1 = (z * gx1).sqrt(min_root=False)
assert y1 * y1 == z * gx1

y2 = zeta_u2 * u * y1
y = y1 if gx1_square else y2
y_flag = y.value % 2 == u.value % 2

num_x = num_x1 if gx1_square else num_x2
x_affine = num_x / div
y_affine = -y if y.value % 2 != u.value % 2 else y

point_on_curve = G1Point(
x_affine.value, y_affine.value, CurveID.BLS12_381, iso_point=True
)
return point_on_curve, MapToCurveHint(
gx1_is_square=gx1_square, y1=y1, y_flag=y_flag
)


def build_hash_to_curve_hint(message: bytes) -> HashToCurveHint:
felt0, felt1 = hash_to_field(message, 2, CurveID.BLS12_381.value, "sha256")
pt0, f0_hint = build_map_to_curve_hint(felt0)
pt1, f1_hint = build_map_to_curve_hint(felt1)
sum_pt = pt0.add(pt1)
print(
f"sum_pt: {int_to_u384(sum_pt.x, as_hex=False)} {int_to_u384(sum_pt.y, as_hex=False)}"
)
sum_pt = apply_isogeny(sum_pt)
print(
f"sum_pt: {int_to_u384(sum_pt.x, as_hex=False)} {int_to_u384(sum_pt.y, as_hex=False)}"
)
x = CURVES[CurveID.BLS12_381.value].x
n = CURVES[CurveID.BLS12_381.value].n
cofactor = (1 - (x % n)) % n
print(f"cofactor: {cofactor}, hex :{hex(cofactor)}")

msm_builder = MSMCalldataBuilder(
curve_id=CurveID.BLS12_381, points=[sum_pt], scalars=[cofactor]
)
msm_hint, derive_point_from_x_hint = msm_builder.build_msm_hints(risc0_mode=True)

return HashToCurveHint(
f0_hint=f0_hint,
f1_hint=f1_hint,
scalar_mul_hint=msm_hint,
derive_point_from_x_hint=derive_point_from_x_hint,
)


if __name__ == "__main__":
field = get_base_field(CurveID.BLS12_381)

import hashlib

hint = build_hash_to_curve_hint(hashlib.sha256(b"Hello, World!").digest())
print(hint.to_cairo())
# print(hint.to_calldata())
1 change: 1 addition & 0 deletions src/src/circuits.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ mod ec;
mod dummy;
mod multi_pairing_check;
mod extf_mul;
mod isogeny;
Loading

0 comments on commit cc966e4

Please sign in to comment.