-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
00aa8fd
commit cc966e4
Showing
9 changed files
with
1,155 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
59 changes: 59 additions & 0 deletions
59
hydra/garaga/precompiled_circuits/compilable_circuits/isogeny.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import garaga.modulo_circuit_structs as structs | ||
from garaga.definitions import CurveID | ||
from garaga.precompiled_circuits.compilable_circuits.base import ( | ||
BaseModuloCircuit, | ||
ModuloCircuit, | ||
PyFelt, | ||
) | ||
from garaga.signature import get_isogeny_to_g1_map | ||
|
||
|
||
class ApplyIsogenyCircuit(BaseModuloCircuit): | ||
def __init__( | ||
self, curve_id: int, auto_run: bool = True, compilation_mode: int = 1 | ||
) -> None: | ||
super().__init__( | ||
name=f"apply_isogeny_{CurveID(curve_id).name.lower()}", | ||
input_len=2, | ||
curve_id=curve_id, | ||
auto_run=auto_run, | ||
compilation_mode=compilation_mode, | ||
) | ||
|
||
def build_input(self) -> list[PyFelt]: | ||
return [self.field(44), self.field(4)] | ||
|
||
def _run_circuit_inner(self, input: list[PyFelt]) -> ModuloCircuit: | ||
circuit = ModuloCircuit( | ||
self.name, | ||
self.curve_id, | ||
compilation_mode=self.compilation_mode, | ||
) | ||
px, py = circuit.write_struct(structs.G1PointCircuit(name="pt", elmts=input)) | ||
x_rational, y_rational = get_isogeny_to_g1_map(CurveID(self.curve_id)) | ||
x_num = [ | ||
circuit.set_or_get_constant(c) for c in x_rational.numerator.coefficients | ||
] | ||
x_den = [ | ||
circuit.set_or_get_constant(c) for c in x_rational.denominator.coefficients | ||
] | ||
|
||
y_num = [ | ||
circuit.set_or_get_constant(c) for c in y_rational.numerator.coefficients | ||
] | ||
y_den = [ | ||
circuit.set_or_get_constant(c) for c in y_rational.denominator.coefficients | ||
] | ||
|
||
x_affine_num = circuit.eval_horner(x_num, px, "x_num") | ||
x_affine_den = circuit.eval_horner(x_den, px, "x_den") | ||
x_affine = circuit.div(x_affine_num, x_affine_den) | ||
y_affine_num = circuit.eval_horner(y_num, px, "y_num") | ||
y_affine_den = circuit.eval_horner(y_den, px, "y_den") | ||
y_affine_eval = circuit.div(y_affine_num, y_affine_den) | ||
y_affine = circuit.mul(y_affine_eval, py) | ||
circuit.extend_struct_output( | ||
structs.G1PointCircuit(name="res", elmts=[x_affine, y_affine]) | ||
) | ||
|
||
return circuit |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
134 changes: 134 additions & 0 deletions
134
hydra/garaga/starknet/tests_and_calldata_generators/map_to_curve.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
from dataclasses import dataclass | ||
|
||
import garaga.modulo_circuit_structs as structs | ||
from garaga.algebra import PyFelt | ||
from garaga.definitions import CURVES, CurveID, G1Point, get_base_field | ||
from garaga.hints.io import bigint_split, int_to_u384 | ||
from garaga.signature import apply_isogeny, hash_to_field | ||
from garaga.starknet.tests_and_calldata_generators.msm import MSMCalldataBuilder | ||
|
||
|
||
@dataclass(slots=True) | ||
class MapToCurveHint: | ||
gx1_is_square: bool | ||
y1: PyFelt | ||
y_flag: bool | ||
|
||
def to_cairo(self) -> str: | ||
return f"MapToCurveHint {{ gx1_is_square: {str(self.gx1_is_square).lower()}, y1: {int_to_u384(self.y1.value, as_hex=True)}, y_flag: {str(self.y_flag).lower()} }}" | ||
|
||
def to_calldata(self) -> list[int]: | ||
cd = [] | ||
cd.append(int(self.gx1_is_square)) | ||
cd.extend(bigint_split(self.y1.value)) | ||
cd.append(int(self.y_flag)) | ||
return cd | ||
|
||
|
||
@dataclass(slots=True) | ||
class HashToCurveHint: | ||
f0_hint: MapToCurveHint | ||
f1_hint: MapToCurveHint | ||
scalar_mul_hint: structs.Struct | ||
derive_point_from_x_hint: structs.Struct | ||
|
||
def to_cairo(self) -> str: | ||
return f"HashToCurveHint {{ f0_hint: {self.f0_hint.to_cairo()}, f1_hint: {self.f1_hint.to_cairo()}, scalar_mul_hint: {self.scalar_mul_hint.serialize(raw=True)}, derive_point_from_x_hint: {self.derive_point_from_x_hint.serialize(raw=True)} }}" | ||
|
||
def to_calldata(self) -> list[int]: | ||
cd = [] | ||
cd.extend(self.f0_hint.to_calldata()) | ||
cd.extend(self.f1_hint.to_calldata()) | ||
cd.extend(self.scalar_mul_hint.serialize_to_calldata()) | ||
cd.extend(self.derive_point_from_x_hint.serialize_to_calldata()) | ||
return cd | ||
|
||
|
||
def build_map_to_curve_hint(u: PyFelt) -> tuple[G1Point, MapToCurveHint]: | ||
field = get_base_field(CurveID.BLS12_381) | ||
a = field(CURVES[CurveID.BLS12_381.value].swu_params.A) | ||
b = field(CURVES[CurveID.BLS12_381.value].swu_params.B) | ||
z = field(CURVES[CurveID.BLS12_381.value].swu_params.Z) | ||
|
||
zeta_u2 = z * u**2 | ||
ta = zeta_u2**2 + zeta_u2 | ||
num_x1 = b * (ta + field.one()) | ||
|
||
if ta.value == 0: | ||
div = a * z | ||
else: | ||
div = a * -ta | ||
|
||
num2_x1 = num_x1**2 | ||
div2 = div**2 | ||
div3 = div2 * div | ||
assert div3.value != 0 | ||
|
||
num_gx1 = (num2_x1 + a * div2) * num_x1 + b * div3 | ||
num_x2 = zeta_u2 * num_x1 | ||
|
||
gx1 = num_gx1 / div3 | ||
gx1_square = gx1.is_quad_residue() | ||
if gx1_square: | ||
print(f"square res") | ||
y1 = gx1.sqrt(min_root=False) | ||
assert y1 * y1 == gx1 | ||
else: | ||
print(f"not square res") | ||
y1 = (z * gx1).sqrt(min_root=False) | ||
assert y1 * y1 == z * gx1 | ||
|
||
y2 = zeta_u2 * u * y1 | ||
y = y1 if gx1_square else y2 | ||
y_flag = y.value % 2 == u.value % 2 | ||
|
||
num_x = num_x1 if gx1_square else num_x2 | ||
x_affine = num_x / div | ||
y_affine = -y if y.value % 2 != u.value % 2 else y | ||
|
||
point_on_curve = G1Point( | ||
x_affine.value, y_affine.value, CurveID.BLS12_381, iso_point=True | ||
) | ||
return point_on_curve, MapToCurveHint( | ||
gx1_is_square=gx1_square, y1=y1, y_flag=y_flag | ||
) | ||
|
||
|
||
def build_hash_to_curve_hint(message: bytes) -> HashToCurveHint: | ||
felt0, felt1 = hash_to_field(message, 2, CurveID.BLS12_381.value, "sha256") | ||
pt0, f0_hint = build_map_to_curve_hint(felt0) | ||
pt1, f1_hint = build_map_to_curve_hint(felt1) | ||
sum_pt = pt0.add(pt1) | ||
print( | ||
f"sum_pt: {int_to_u384(sum_pt.x, as_hex=False)} {int_to_u384(sum_pt.y, as_hex=False)}" | ||
) | ||
sum_pt = apply_isogeny(sum_pt) | ||
print( | ||
f"sum_pt: {int_to_u384(sum_pt.x, as_hex=False)} {int_to_u384(sum_pt.y, as_hex=False)}" | ||
) | ||
x = CURVES[CurveID.BLS12_381.value].x | ||
n = CURVES[CurveID.BLS12_381.value].n | ||
cofactor = (1 - (x % n)) % n | ||
print(f"cofactor: {cofactor}, hex :{hex(cofactor)}") | ||
|
||
msm_builder = MSMCalldataBuilder( | ||
curve_id=CurveID.BLS12_381, points=[sum_pt], scalars=[cofactor] | ||
) | ||
msm_hint, derive_point_from_x_hint = msm_builder.build_msm_hints(risc0_mode=True) | ||
|
||
return HashToCurveHint( | ||
f0_hint=f0_hint, | ||
f1_hint=f1_hint, | ||
scalar_mul_hint=msm_hint, | ||
derive_point_from_x_hint=derive_point_from_x_hint, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
field = get_base_field(CurveID.BLS12_381) | ||
|
||
import hashlib | ||
|
||
hint = build_hash_to_curve_hint(hashlib.sha256(b"Hello, World!").digest()) | ||
print(hint.to_cairo()) | ||
# print(hint.to_calldata()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,3 +2,4 @@ mod ec; | |
mod dummy; | ||
mod multi_pairing_check; | ||
mod extf_mul; | ||
mod isogeny; |
Oops, something went wrong.