Skip to content

Commit

Permalink
Merge pull request #311 from o1-labs/feature/transpose-party-leaf-types
Browse files Browse the repository at this point in the history
Transpose party leaf types
  • Loading branch information
mitschabaude authored Aug 1, 2022
2 parents d180ddb + 66d60d3 commit cce3fee
Show file tree
Hide file tree
Showing 13 changed files with 272 additions and 367 deletions.
2 changes: 1 addition & 1 deletion src/examples/party-witness.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ let fields = Types.Party.toFields(party);
let aux = Types.Party.toAuxiliary(party);

let partyRaw = Types.Party.fromFields(fields, aux);
let json = Types.Party.toJson(partyRaw);
let json = Types.Party.toJSON(partyRaw);

if (address.toBase58() !== json.body.publicKey) throw Error('fail');

Expand Down
2 changes: 1 addition & 1 deletion src/examples/to-hash-input.ts
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ function testInput<T>(
toInputOcaml: (json: string) => InputOcaml,
value: T
) {
let json = Module.toJson(value);
let json = Module.toJSON(value);
// console.log(json);
let input1 = inputFromOcaml(toInputOcaml(JSON.stringify(json)));
let input2 = Module.toInput(value);
Expand Down
10 changes: 4 additions & 6 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
export {
Field,
Bool,
Group,
Scalar,
AsFieldElements,
Ledger,
isReady,
shutdown,
} from './snarky';
export { Field, Bool } from './lib/core';
export type { VerificationKey, Keypair } from './snarky';
export * from './snarky/addons';
export { Poseidon } from './lib/hash';
Expand Down Expand Up @@ -58,12 +57,11 @@ export { Character, CircuitString } from './lib/string';
// experimental APIs
import { Reducer } from './lib/zkapp';
import { createChildParty } from './lib/party';
import { memoizeWitness } from './lib/circuit_value';
import {
jsLayout,
asFieldsAndAux,
memoizeWitness,
AsFieldsAndAux as AsFieldsAndAux_,
} from './snarky/types';
} from './lib/circuit_value';
import { jsLayout, asFieldsAndAux } from './snarky/types';
import { packToFields } from './lib/hash';
export { Experimental };

Expand Down
117 changes: 111 additions & 6 deletions src/lib/circuit_value.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import 'reflect-metadata';
import { Circuit, Field, Bool, JSONValue, AsFieldElements } from '../snarky';
import { Circuit, JSONValue, AsFieldElements } from '../snarky';
import { Field, Bool } from './core';
import { Context } from './global-context';
import { HashInput } from './hash';
import { snarkContext } from './proof_system';

// external API
Expand All @@ -17,6 +19,8 @@ export {

// internal API
export {
AsFieldsExtended,
AsFieldsAndAux,
cloneCircuitValue,
circuitValueEquals,
circuitArray,
Expand Down Expand Up @@ -57,11 +61,10 @@ abstract class CircuitValue {
v: InstanceType<T>
): Field[] {
const res: Field[] = [];
const fields = (this as any).prototype._fields;
const fields = this.prototype._fields;
if (fields === undefined || fields === null) {
return res;
}

for (let i = 0, n = fields.length; i < n; ++i) {
const [key, propType] = fields[i];
const subElts: Field[] = propType.toFields((v as any)[key]);
Expand All @@ -70,6 +73,28 @@ abstract class CircuitValue {
return res;
}

static toInput<T extends AnyConstructor>(
this: T,
v: InstanceType<T>
): HashInput {
let input: HashInput = { fields: [], packed: [] };
let fields = this.prototype._fields;
if (fields === undefined) return input;
for (let i = 0, n = fields.length; i < n; ++i) {
let [key, type] = fields[i];
if ('toInput' in type) {
HashInput.append(input, type.toInput(v[key]));
continue;
}
// as a fallback, use toFields on the type
// TODO: this is problematic -- ignores if there's a toInput on a nested type
// so, remove this? should every circuit value define toInput?
let xs: Field[] = type.toFields(v[key]);
input.fields!.push(...xs);
}
return input;
}

toFields(): Field[] {
return (this.constructor as any).toFields(this);
}
Expand Down Expand Up @@ -116,7 +141,6 @@ abstract class CircuitValue {
if (fields === undefined || fields === null) {
return;
}

for (let i = 0; i < fields.length; ++i) {
const [key, propType] = fields[i];
const value = (v as any)[key];
Expand Down Expand Up @@ -339,13 +363,18 @@ function circuitMain(
let primitives = new Set(['Field', 'Bool', 'Scalar', 'Group']);
let complexTypes = new Set(['object', 'function']);

type AsFieldsExtended<T> = AsFieldElements<T> & {
toInput: (x: T) => { fields?: Field[]; packed?: [Field, number][] };
toJSON: (x: T) => JSONValue;
};

// TODO properly type this at the interface
// create recursive type that describes JSON-like structures of circuit types
// TODO unit-test this
function circuitValue<T>(
typeObj: any,
options?: { customObjectKeys: string[] }
): AsFieldElements<T> {
): AsFieldsExtended<T> {
let objectKeys =
typeof typeObj === 'object' && typeObj !== null
? options?.customObjectKeys ?? Object.keys(typeObj).sort()
Expand All @@ -367,6 +396,30 @@ function circuitValue<T>(
if ('toFields' in typeObj) return typeObj.toFields(obj);
return objectKeys.map((k) => toFields(typeObj[k], obj[k])).flat();
}
function toInput(typeObj: any, obj: any): HashInput {
if (!complexTypes.has(typeof typeObj) || typeObj === null) return {};
if (Array.isArray(typeObj)) {
return typeObj
.map((t, i) => toInput(t, obj[i]))
.reduce(HashInput.append, {});
}
if ('toInput' in typeObj) return typeObj.toInput(obj) as HashInput;
if ('toFields' in typeObj) {
return { fields: typeObj.toFields(obj) };
}
return objectKeys
.map((k) => toInput(typeObj[k], obj[k]))
.reduce(HashInput.append, {});
}
function toJSON(typeObj: any, obj: any): JSONValue {
if (!complexTypes.has(typeof typeObj) || typeObj === null)
return obj ?? null;
if (Array.isArray(typeObj)) return typeObj.map((t, i) => toJSON(t, obj[i]));
if ('toJSON' in typeObj) return typeObj.toJSON(obj);
return Object.fromEntries(
objectKeys.map((k) => [k, toJSON(typeObj[k], obj[k])])
);
}
function ofFields(typeObj: any, fields: Field[]): any {
if (!complexTypes.has(typeof typeObj) || typeObj === null) return null;
if (Array.isArray(typeObj)) {
Expand All @@ -387,7 +440,7 @@ function circuitValue<T>(
return Object.fromEntries(objectKeys.map((k, i) => [k, values[i]]));
}
function check(typeObj: any, obj: any): void {
if (typeof typeObj !== 'object' || typeObj === null) return;
if (!complexTypes.has(typeof typeObj) || typeObj === null) return;
if (Array.isArray(typeObj))
return typeObj.forEach((t, i) => check(t, obj[i]));
if ('check' in typeObj) return typeObj.check(obj);
Expand All @@ -396,6 +449,8 @@ function circuitValue<T>(
return {
sizeInFields: () => sizeInFields(typeObj),
toFields: (obj: T) => toFields(typeObj, obj),
toInput: (obj: T) => toInput(typeObj, obj),
toJSON: (obj: T) => toJSON(typeObj, obj),
ofFields: (fields: Field[]) => ofFields(typeObj, fields) as T,
check: (obj: T) => check(typeObj, obj),
};
Expand Down Expand Up @@ -581,3 +636,53 @@ function getBlindingValue() {
}
return context.blindingValue;
}

// "complex" circuit values which have auxiliary data, and have to be hashed

type AsFieldsAndAux<T, TJson> = {
sizeInFields(): number;
toFields(value: T): Field[];
toAuxiliary(value?: T): any[];
fromFields(fields: Field[], aux: any[]): T;
toJSON(value: T): TJson;
check(value: T): void;
toInput(value: T): HashInput;
};

// convert from circuit values
function fromCircuitValue<T, A extends AsFieldsExtended<T>, TJson = JSONValue>(
type: A
): AsFieldsAndAux<T, TJson> {
return {
sizeInFields() {
return type.sizeInFields();
},
toFields(value) {
return type.toFields(value);
},
toAuxiliary(_) {
return [];
},
fromFields(fields) {
let myFields: Field[] = [];
let size = type.sizeInFields();
for (let i = 0; i < size; i++) {
myFields.push(fields.pop()!);
}
return type.ofFields(myFields);
},
check(value) {
type.check(value);
},
toInput(value) {
return type.toInput(value);
},
toJSON(value) {
return type.toJSON(value) as any;
},
};
}

const AsFieldsAndAux = {
fromCircuitValue,
};
11 changes: 11 additions & 0 deletions src/lib/core.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import { Bool, Field } from '../snarky';

export { Field, Bool };

Field.toInput = function (x) {
return { fields: [x] };
};

Bool.toInput = function (x) {
return { packed: [[x.toField(), 1] as [Field, number]] };
};
20 changes: 16 additions & 4 deletions src/lib/hash.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { AsFieldsAndAux } from '../snarky/parties-helpers';
import { AsFieldsAndAux } from './circuit_value';
import { Poseidon as Poseidon_, Field } from '../snarky';
import { inCheckedComputation } from './proof_system';

Expand All @@ -7,6 +7,7 @@ export { Poseidon };

// internal API
export {
HashInput,
prefixes,
emptyHashWithPrefix,
hashWithPrefix,
Expand Down Expand Up @@ -91,7 +92,7 @@ function prefixToField(prefix: string) {
* Convert the {fields, packed} hash input representation to a list of field elements
* Random_oracle_input.Chunked.pack_to_fields
*/
function packToFields({ fields = [], packed = [] }: Input) {
function packToFields({ fields = [], packed = [] }: HashInput) {
if (packed.length === 0) return fields;
let packedBits = [];
let currentPackedField = Field.zero;
Expand All @@ -112,7 +113,18 @@ function packToFields({ fields = [], packed = [] }: Input) {
return fields.concat(packedBits);
}

type Input = { fields?: Field[]; packed?: [Field, number][] };
type HashInput = { fields?: Field[]; packed?: [Field, number][] };
const HashInput = {
append(input1: HashInput, input2: HashInput) {
if (input2.fields !== undefined) {
(input1.fields ??= []).push(...input2.fields);
}
if (input2.packed !== undefined) {
(input1.packed ??= []).push(...input2.packed);
}
return input1;
},
};

type TokenSymbol = { symbol: string; field: Field };

Expand All @@ -135,7 +147,7 @@ const TokenSymbolPure: AsFieldsAndAux<TokenSymbol, string> = {
let actual = field.rangeCheckHelper(48);
actual.assertEquals(field);
},
toJson({ symbol }) {
toJSON({ symbol }) {
return symbol;
},
toInput({ field }) {
Expand Down
21 changes: 21 additions & 0 deletions src/lib/int.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { Circuit, Field, Bool } from '../snarky';
import { CircuitValue, prop } from './circuit_value';
import { Types } from '../snarky/types';
import { HashInput } from './hash';

// external API
export { UInt32, UInt64, Int64, Sign };
Expand All @@ -25,6 +26,12 @@ class UInt64 extends CircuitValue {
let actual = x.value.rangeCheckHelper(64);
actual.assertEquals(x.value);
}
static toInput(x: UInt64): HashInput {
return { packed: [[x.value, 64]] };
}
static toJSON(x: UInt64) {
return x.value.toString();
}

private static checkConstant(x: Field) {
if (!x.isConstant()) return x;
Expand Down Expand Up @@ -204,6 +211,12 @@ class UInt32 extends CircuitValue {
let actual = x.value.rangeCheckHelper(32);
actual.assertEquals(x.value);
}
static toInput(x: UInt32): HashInput {
return { packed: [[x.value, 32]] };
}
static toJSON(x: UInt32) {
return x.value.toString();
}

private static checkConstant(x: Field) {
if (!x.isConstant()) return x;
Expand Down Expand Up @@ -349,6 +362,14 @@ class Sign extends CircuitValue {
// x^2 == 1 <=> x == 1 or x == -1
x.value.square().assertEquals(Field.one);
}
static toInput(x: Sign): HashInput {
return { packed: [[x.isPositive().toField(), 1]] };
}
static toJSON(x: Sign) {
if (x.toString() === '1') return 'Positive';
if (x.neg().toString() === '1') return 'Negative';
throw Error(`Invalid Sign: ${x}`);
}
neg() {
return new Sign(this.value.neg());
}
Expand Down
6 changes: 3 additions & 3 deletions src/lib/party.ts
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,7 @@ class Party implements Types.Party {
}

toJSON() {
return Types.Party.toJson(this);
return Types.Party.toJSON(this);
}

hash() {
Expand All @@ -821,7 +821,7 @@ class Party implements Types.Party {
let input = Types.Party.toInput(this);
return hashWithPrefix(prefixes.body, packToFields(input));
} else {
let json = Types.Party.toJson(this);
let json = Types.Party.toJSON(this);
return Ledger.hashPartyFromJson(JSON.stringify(json));
}
}
Expand Down Expand Up @@ -1034,7 +1034,7 @@ type PartiesProved = {

function partiesToJson({ feePayer, otherParties, memo }: Parties) {
memo = Ledger.memoToBase58(memo);
return Types.Parties.toJson({ feePayer, otherParties, memo });
return Types.Parties.toJSON({ feePayer, otherParties, memo });
}

const Authorization = {
Expand Down
2 changes: 1 addition & 1 deletion src/lib/zkapp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ function wrapMethod(
// then we're already in a witness block, and shouldn't open another one
let { party, result } =
methodCallDepth === 0
? Party.witness(
? Party.witness<any>(
returnType ?? circuitValue<null>(null),
runCalledContract,
true
Expand Down
Loading

0 comments on commit cce3fee

Please sign in to comment.