Skip to content

Commit

Permalink
test: add tests for op and computation, and refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
mhchia committed Jan 24, 2024
1 parent ad3666c commit 9cc9c62
Show file tree
Hide file tree
Showing 10 changed files with 293 additions and 143 deletions.
2 changes: 1 addition & 1 deletion examples/computation/computation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@
" out_1 = state.median(x_0)\n",
" return state.mean(torch.tensor([out_0, out_1]).reshape(1,-1,1))\n",
"\n",
"prover_model = create_model(computation)\n",
"_, prover_model = create_model(computation)\n",
"prover_gen_settings([data_path], comb_data_path, prover_model, prover_model_path, \"default\", \"resources\", settings_path)\n"
]
},
Expand Down
Empty file added tests/__init__.py
Empty file.
18 changes: 18 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest
import torch


@pytest.fixture
def error() -> float:
return 0.01


@pytest.fixture
def column_0():
return torch.tensor([3.0, 4.5, 1.0, 2.0, 7.5, 6.4, 5.5])


@pytest.fixture
def column_1():
return torch.tensor([2.7, 3.3, 1.1, 2.2, 3.8, 8.2, 4.4])

57 changes: 57 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import json
from typing import Type
from pathlib import Path

import torch

from zkstats.core import prover_gen_settings, verifier_setup, prover_gen_proof, verifier_verify
from zkstats.computation import IModel, IsResultPrecise


def compute(basepath: Path, data: list[torch.Tensor], model: Type[IModel]) -> IsResultPrecise:
comb_data_path = basepath / "comb_data.json"
model_path = basepath / "model.onnx"
settings_path = basepath / "settings.json"
witness_path = basepath / "witness.json"
compiled_model_path = basepath / "model.compiled"
proof_path = basepath / "model.proof"
pk_path = basepath / "model.pk"
vk_path = basepath / "model.vk"
data_paths = [basepath / f"data_{i}.json" for i in range(len(data))]

for i, d in enumerate(data):
filename = data_paths[i]
data_json = {"input_data": [d.tolist()]}
with open(filename, "w") as f:
f.write(json.dumps(data_json))

prover_gen_settings(
data_path_array=[str(i) for i in data_paths],
comb_data_path=str(comb_data_path),
prover_model=model,
prover_model_path=str(model_path),
scale="default",
mode="resources",
settings_path=str(settings_path),
)
verifier_setup(
str(model_path),
str(compiled_model_path),
str(settings_path),
str(vk_path),
str(pk_path),
)
prover_gen_proof(
str(model_path),
str(comb_data_path),
str(witness_path),
str(compiled_model_path),
str(settings_path),
str(proof_path),
str(pk_path),
)
verifier_verify(
str(proof_path),
str(settings_path),
str(vk_path),
)
31 changes: 31 additions & 0 deletions tests/test_computation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import statistics
import torch
import torch

from zkstats.computation import State, create_model
from zkstats.ops import Mean, Median

from .helpers import compute


def computation(state: State, x: list[torch.Tensor]):
out_0 = state.median(x[0])
out_1 = state.median(x[1])
return state.mean(torch.tensor([out_0, out_1]).reshape(1,-1,1))


def test_computation(tmp_path, column_0: torch.Tensor, column_1: torch.Tensor, error: float):
state, model = create_model(computation, error)
compute(tmp_path, [column_0, column_1], model)
assert state.current_op_index == 3

ops = state.ops
op0 = ops[0]
assert isinstance(op0, Median)
assert op0.result == statistics.median(column_0)
op1 = ops[1]
assert isinstance(op1, Median)
assert op1.result == statistics.median(column_1)
op2 = ops[2]
assert isinstance(op2, Mean)
assert op2.result == statistics.mean([op0.result.tolist(), op1.result.tolist()])
36 changes: 36 additions & 0 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import json
from typing import Type, Callable
from dataclasses import dataclass
from pathlib import Path
import statistics

import pytest

import torch
from zkstats.computation import Operation, Mean, Median, IModel, IsResultPrecise

from .helpers import compute


@pytest.mark.parametrize(
"op_type, expected_func",
[
(Mean, statistics.mean),
(Median, statistics.median),
]
)
def test_1d(tmp_path, column_0: torch.Tensor, error: float, op_type: Type[Operation], expected_func: Callable[[list[float]], float]):
op = op_type.create(column_0, error)
expected_res = expected_func(column_0.tolist())
assert expected_res == op.result
model = op_to_model(op)
compute(tmp_path, [column_0], model)


def op_to_model(op: Operation) -> Type[IModel]:
class Model(IModel):
def forward(self, x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
return op.ezkl(x), op.result
return Model


35 changes: 32 additions & 3 deletions zkstats/cli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import os
import sys
from typing import Type
import importlib.util

import click
import torch

from .core import prover_gen_proof, prover_setup, load_model, verifier_verify, gen_data_commitment
from .core import prover_gen_proof, prover_gen_settings, verifier_setup, verifier_verify, gen_data_commitment

cwd = os.getcwd()
# TODO: Should make this configurable
Expand Down Expand Up @@ -29,15 +34,19 @@ def cli():
def prove(model_path: str, data_path: str):
model = load_model(model_path)
print("Loaded model:", model)
prover_setup(
prover_gen_settings(
[data_path],
comb_data_path,
model,
model_onnx_path,
compiled_model_path,
"default",
"resources",
settings_path,
)
verifier_setup(
model_path,
compiled_model_path,
settings_path,
vk_path,
pk_path,
)
Expand Down Expand Up @@ -80,6 +89,26 @@ def main():
cli()


def load_model(module_path: str) -> Type[torch.nn.Module]:
"""
Load a model from a Python module.
"""
# FIXME: This is unsafe since malicious code can be executed

model_name = "Model"
module_name = os.path.splitext(os.path.basename(module_path))[0]
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)

try:
cls = getattr(module, model_name)
except AttributeError:
raise ImportError(f"class {model_name} does not exist in {module_name}")
return cls


# Register commands
cli.add_command(prove)
cli.add_command(verify)
Expand Down
4 changes: 2 additions & 2 deletions zkstats/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor
TComputation = Callable[[State, list[torch.Tensor]], tuple[IsResultPrecise, torch.Tensor]]


def create_model(computation: TComputation, error: float = DEFAULT_ERROR) -> Type[IModel]:
def create_model(computation: TComputation, error: float = DEFAULT_ERROR) -> tuple[State, Type[IModel]]:
"""
Create a torch model from a `computation` function defined by user
"""
Expand All @@ -118,4 +118,4 @@ def preprocess(self, x: list[torch.Tensor]) -> None:
def forward(self, *x: list[torch.Tensor]) -> tuple[IsResultPrecise, torch.Tensor]:
return computation(state, x)

return Model
return state, Model
Loading

0 comments on commit 9cc9c62

Please sign in to comment.