diff --git a/src/client.ts b/src/client.ts index 2e9ee72..fcaed6b 100644 --- a/src/client.ts +++ b/src/client.ts @@ -3,34 +3,114 @@ // Licensed under the BSD-3-Clause license found in the LICENSE file or // at https://opensource.org/licenses/BSD-3-Clause -import { Blind, Blinded, Evaluation, Oprf } from './oprf.js' -import { Group, Scalar } from './group.js' +import { + Blind, + Blinded, + Evaluation, + EvaluationRequest, + FinalizeData, + ModeID, + Oprf, + SuiteID +} from './oprf.js' +import { Elt, Group, Scalar, SerializedElt } from './group.js' + +class baseClient extends Oprf { + constructor(mode: ModeID, suite: SuiteID) { + super(mode, suite) + } -export class OPRFClient extends Oprf { async randomBlinder(): Promise<{ scalar: Scalar; blind: Blind }> { - const scalar = await this.params.gg.randomScalar() - const blind = new Blind(this.params.gg.serializeScalar(scalar)) + const scalar = await this.gg.randomScalar() + const blind = new Blind(this.gg.serializeScalar(scalar)) return { scalar, blind } } - async blind(input: Uint8Array): Promise<{ blind: Blind; blindedElement: Blinded }> { + async blind(input: Uint8Array): Promise<[FinalizeData, EvaluationRequest]> { const { scalar, blind } = await this.randomBlinder() - const dst = Oprf.getHashToGroupDST(this.params.id) - const P = await this.params.gg.hashToGroup(input, dst) - if (this.params.gg.isIdentity(P)) { + const P = await this.gg.hashToGroup(input, this.getDST(Oprf.LABELS.HashToGroupDST)) + if (this.gg.isIdentity(P)) { throw new Error('InvalidInputError') } const Q = Group.mul(scalar, P) - const blindedElement = new Blinded(this.params.gg.serialize(Q)) - return { blind, blindedElement } + const evalReq = new EvaluationRequest(new Blinded(this.gg.serialize(Q))) + const finData = new FinalizeData(input, blind, evalReq) + return [finData, evalReq] } - finalize(input: Uint8Array, blind: Blind, evaluation: Evaluation): Promise { - const blindScalar = this.params.gg.deserializeScalar(blind) - const blindScalarInv = this.params.gg.invScalar(blindScalar) - const Z = this.params.gg.deserialize(evaluation) + doFinalize( + finData: FinalizeData, + evaluation: Evaluation, + info = new Uint8Array(0) + ): Promise { + const blindScalar = this.gg.deserializeScalar(finData.blind) + const blindScalarInv = this.gg.invScalar(blindScalar) + const Z = this.gg.deserialize(evaluation.element) const N = Group.mul(blindScalarInv, Z) - const unblinded = this.params.gg.serialize(N) - return this.coreFinalize(input, unblinded) + const unblinded = this.gg.serialize(N) + return this.coreFinalize(finData.input, unblinded, info) + } +} + +export class OPRFClient extends baseClient { + constructor(suite: SuiteID) { + super(Oprf.Mode.OPRF, suite) + } + finalize(finData: FinalizeData, evaluation: Evaluation): Promise { + return super.doFinalize(finData, evaluation) + } +} + +export class VOPRFClient extends baseClient { + constructor(suite: SuiteID, private readonly pubKeyServer: Uint8Array) { + super(Oprf.Mode.VOPRF, suite) + } + + finalize(finData: FinalizeData, evaluation: Evaluation): Promise { + if (!evaluation.proof) { + throw new Error('no proof provided') + } + const pkS = this.gg.deserialize(new SerializedElt(this.pubKeyServer)) + const Q = this.gg.deserialize(finData.evalReq.blinded) + const kQ = this.gg.deserialize(evaluation.element) + if (!evaluation.proof.verify([this.gg.generator(), pkS], [Q, kQ])) { + throw new Error('proof failed') + } + + return super.doFinalize(finData, evaluation) + } +} + +export class POPRFClient extends baseClient { + constructor(suite: SuiteID, private readonly pubKeyServer: Uint8Array) { + super(Oprf.Mode.POPRF, suite) + } + + private async pointFromInfo(info: Uint8Array): Promise { + const m = await this.scalarFromInfo(info) + const T = this.gg.mulBase(m) + const pkS = this.gg.deserialize(new SerializedElt(this.pubKeyServer)) + const tw = Group.add(T, pkS) + if (tw.isIdentity) { + throw new Error('invalid info') + } + return tw + } + + async finalize( + finData: FinalizeData, + evaluation: Evaluation, + info = new Uint8Array(0) + ): Promise { + if (!evaluation.proof) { + throw new Error('no proof provided') + } + const tw = await this.pointFromInfo(info) + const Q = this.gg.deserialize(evaluation.element) + const kQ = this.gg.deserialize(finData.evalReq.blinded) + if (!evaluation.proof.verify([this.gg.generator(), tw], [Q, kQ])) { + throw new Error('proof failed') + } + return super.doFinalize(finData, evaluation, info) } } diff --git a/src/group.ts b/src/group.ts index 8a6a043..b3cc7e7 100644 --- a/src/group.ts +++ b/src/group.ts @@ -3,7 +3,7 @@ // Licensed under the BSD-3-Clause license found in the LICENSE file or // at https://opensource.org/licenses/BSD-3-Clause -import { hashParams, joinAll, xor } from './util.js' +import { joinAll, xor } from './util.js' import sjcl from './sjcl/index.js' @@ -20,6 +20,24 @@ export type Scalar = sjcl.bn export type Curve = sjcl.ecc.curve export type FieldElt = sjcl.bn +function hashParams(hash: string): { + outLenBytes: number // returns the size in bytes of the output. + blockLenBytes: number // returns the size of the internal block. +} { + switch (hash) { + case 'SHA-1': + return { outLenBytes: 20, blockLenBytes: 64 } + case 'SHA-256': + return { outLenBytes: 32, blockLenBytes: 64 } + case 'SHA-384': + return { outLenBytes: 48, blockLenBytes: 128 } + case 'SHA-512': + return { outLenBytes: 64, blockLenBytes: 128 } + default: + throw new Error(`invalid hash name: ${hash}`) + } +} + async function expandXMD( hash: string, msg: Uint8Array, diff --git a/src/keys.ts b/src/keys.ts index 4399a3d..53c5a6d 100644 --- a/src/keys.ts +++ b/src/keys.ts @@ -3,18 +3,18 @@ // Licensed under the BSD-3-Clause license found in the LICENSE file or // at https://opensource.org/licenses/BSD-3-Clause -import { Oprf, OprfID } from './oprf.js' -import { SerializedElt, SerializedScalar } from './group.js' +import { ModeID, Oprf, SuiteID } from './oprf.js' +import { Scalar, SerializedElt, SerializedScalar } from './group.js' import { joinAll, to16bits } from './util.js' -export function getKeySizes(id: OprfID): { Nsk: number; Npk: number } { - const { gg } = Oprf.params(id) +export function getKeySizes(id: SuiteID): { Nsk: number; Npk: number } { + const gg = Oprf.getGroup(id) return { Nsk: gg.size, Npk: 1 + gg.size } } -export function validatePrivateKey(id: OprfID, privateKey: Uint8Array): boolean { +export function validatePrivateKey(id: SuiteID, privateKey: Uint8Array): boolean { try { - const { gg } = Oprf.params(id) + const gg = Oprf.getGroup(id) const s = gg.deserializeScalar(new SerializedScalar(privateKey)) return !s.equals(0) } catch (_) { @@ -22,9 +22,9 @@ export function validatePrivateKey(id: OprfID, privateKey: Uint8Array): boolean } } -export function validatePublicKey(id: OprfID, publicKey: Uint8Array): boolean { +export function validatePublicKey(id: SuiteID, publicKey: Uint8Array): boolean { try { - const { gg } = Oprf.params(id) + const gg = Oprf.getGroup(id) const P = gg.deserialize(new SerializedElt(publicKey)) return !P.isIdentity } catch (_) { @@ -32,43 +32,44 @@ export function validatePublicKey(id: OprfID, publicKey: Uint8Array): boolean { } } -export async function randomPrivateKey(id: OprfID): Promise { - const { gg } = Oprf.params(id) +export async function randomPrivateKey(id: SuiteID): Promise { + const gg = Oprf.getGroup(id) const priv = await gg.randomScalar() return new Uint8Array(gg.serializeScalar(priv)) } export async function derivePrivateKey( - id: OprfID, + mode: ModeID, + id: SuiteID, seed: Uint8Array, info: Uint8Array ): Promise { - const { gg } = Oprf.params(id) + const gg = Oprf.getGroup(id) const deriveInput = joinAll([seed, to16bits(info.length), info]) let counter = 0 - let priv + let priv: Scalar do { if (counter > 255) { throw new Error('DeriveKeyPairError') } const hashInput = joinAll([deriveInput, Uint8Array.from([counter])]) - priv = await gg.hashToScalar(hashInput, Oprf.getDeriveKeyPairDST(id)) + priv = await gg.hashToScalar(hashInput, Oprf.getDST(mode, id, Oprf.LABELS.DeriveKeyPairDST)) counter++ } while (gg.isScalarZero(priv)) return new Uint8Array(gg.serializeScalar(priv)) } -export function generatePublicKey(id: OprfID, privateKey: Uint8Array): Uint8Array { - const { gg } = Oprf.params(id) +export function generatePublicKey(id: SuiteID, privateKey: Uint8Array): Uint8Array { + const gg = Oprf.getGroup(id) const priv = gg.deserializeScalar(new SerializedScalar(privateKey)) const pub = gg.mulBase(priv) return new Uint8Array(gg.serialize(pub)) } export async function generateKeyPair( - id: OprfID + id: SuiteID ): Promise<{ privateKey: Uint8Array; publicKey: Uint8Array }> { const privateKey = await randomPrivateKey(id) const publicKey = generatePublicKey(id, privateKey) @@ -76,11 +77,12 @@ export async function generateKeyPair( } export async function deriveKeyPair( - id: OprfID, + mode: ModeID, + id: SuiteID, seed: Uint8Array, info: Uint8Array ): Promise<{ privateKey: Uint8Array; publicKey: Uint8Array }> { - const privateKey = await derivePrivateKey(id, seed, info) + const privateKey = await derivePrivateKey(mode, id, seed, info) const publicKey = generatePublicKey(id, privateKey) return { privateKey, publicKey } } diff --git a/src/oprf.ts b/src/oprf.ts index 62299af..7501589 100644 --- a/src/oprf.ts +++ b/src/oprf.ts @@ -3,114 +3,150 @@ // Licensed under the BSD-3-Clause license found in the LICENSE file or // at https://opensource.org/licenses/BSD-3-Clause -import { Group, GroupID, SerializedElt, SerializedScalar } from './group.js' +import { Group, GroupID, Scalar, SerializedElt, SerializedScalar } from './group.js' import { joinAll, to16bits } from './util.js' -export class Blind extends SerializedScalar { - readonly _BlindBrand = '' -} -export class Blinded extends SerializedElt { - readonly _BlindedBrand = '' -} +import { DLEQProof } from './dleq.js' -export class Evaluation extends SerializedElt { - readonly _EvaluationBrand = '' -} - -export enum OprfID { // eslint-disable-line no-shadow - OPRF_P256_SHA256 = 3, - OPRF_P384_SHA384 = 4, - OPRF_P521_SHA512 = 5 -} +export type ModeID = typeof Oprf.Mode[keyof typeof Oprf.Mode] +export type SuiteID = typeof Oprf.Suite[keyof typeof Oprf.Suite] -export interface OprfParams { - readonly id: OprfID - readonly gg: Group - readonly hash: string - readonly blindedSize: number - readonly evaluationSize: number - readonly blindSize: number +function assertNever(name: string, x: never): never { + throw new Error(`unexpected ${name} identifier: ${x}`) } export abstract class Oprf { - static readonly mode = 0 - - static readonly version = 'VOPRF09-' - - readonly params: OprfParams - - constructor(id: OprfID) { - this.params = Oprf.params(id) - } - - static validateID(id: OprfID): boolean { - switch (id) { - case OprfID.OPRF_P256_SHA256: - case OprfID.OPRF_P384_SHA384: - case OprfID.OPRF_P521_SHA512: - return true + static Mode = { + OPRF: 0, + VOPRF: 1, + POPRF: 2 + } as const + + static Suite = { + P256_SHA256: 3, + P384_SHA384: 4, + P521_SHA512: 5 + } as const + + static LABELS = { + Version: 'VOPRF09-', + FinalizeDST: 'Finalize', + HashToGroupDST: 'HashToGroup-', + HashToScalarDST: 'HashToScalar-', + DeriveKeyPairDST: 'DeriveKeyPair', + InfoLabel: 'Info' + } as const + + private static validateMode(m: ModeID): ModeID { + switch (m) { + case Oprf.Mode.OPRF: + case Oprf.Mode.VOPRF: + case Oprf.Mode.POPRF: + return m default: - throw new Error(`not supported ID: ${id}`) + assertNever('Oprf.Mode', m) } } - - static params(id: OprfID): OprfParams { - Oprf.validateID(id) - let gid = GroupID.P256 - let hash = 'SHA-256' + private static getParams(id: SuiteID): [SuiteID, GroupID, string, number] { switch (id) { - case OprfID.OPRF_P256_SHA256: - break - case OprfID.OPRF_P384_SHA384: - gid = GroupID.P384 - hash = 'SHA-384' - break - case OprfID.OPRF_P521_SHA512: - gid = GroupID.P521 - hash = 'SHA-512' - break + case Oprf.Suite.P256_SHA256: + return [id, GroupID.P256, 'SHA-256', 32] + case Oprf.Suite.P384_SHA384: + return [id, GroupID.P384, 'SHA-384', 48] + case Oprf.Suite.P521_SHA512: + return [id, GroupID.P521, 'SHA-512', 64] default: - throw new Error(`not supported ID: ${id}`) - } - const gg = new Group(gid) - return { - id, - gg, - hash, - blindedSize: 1 + gg.size, - evaluationSize: 1 + gg.size, - blindSize: gg.size + assertNever('Oprf.Suite', id) } } - - static getContextString(id: OprfID): Uint8Array { - Oprf.validateID(id) - return joinAll([new TextEncoder().encode(Oprf.version), new Uint8Array([Oprf.mode, 0, id])]) + static getGroup(suite: SuiteID): Group { + return new Group(Oprf.getParams(suite)[1]) } - - static getHashToGroupDST(id: OprfID): Uint8Array { - return joinAll([new TextEncoder().encode('HashToGroup-'), Oprf.getContextString(id)]) + static getHash(suite: SuiteID): string { + return Oprf.getParams(suite)[2] + } + static getOprfSize(suite: SuiteID): number { + return Oprf.getParams(suite)[3] + } + static getDST(mode: ModeID, suite: SuiteID, name: string): Uint8Array { + const m = Oprf.validateMode(mode) + const s = Oprf.getParams(suite)[0] + return joinAll([ + new TextEncoder().encode(name + Oprf.LABELS.Version), + new Uint8Array([m, 0, s]) + ]) } - static getHashToScalarDST(id: OprfID): Uint8Array { - return joinAll([new TextEncoder().encode('HashToScalar-'), Oprf.getContextString(id)]) + readonly mode: ModeID + readonly ID: SuiteID + readonly gg: Group + readonly hash: string + + constructor(mode: ModeID, suite: SuiteID) { + const [ID, gid, hash] = Oprf.getParams(suite) + this.ID = ID + this.gg = new Group(gid) + this.hash = hash + this.mode = Oprf.validateMode(mode) } - static getDeriveKeyPairDST(id: OprfID): Uint8Array { - return joinAll([new TextEncoder().encode('DeriveKeyPair'), Oprf.getContextString(id)]) + protected getDST(name: string): Uint8Array { + return Oprf.getDST(this.mode, this.ID, name) } protected async coreFinalize( input: Uint8Array, - unblindedElement: Uint8Array + element: Uint8Array, + info: Uint8Array ): Promise { + let hasInfo: Uint8Array[] = [] + if (this.mode === Oprf.Mode.POPRF) { + hasInfo = [to16bits(info.length), info] + } + const hashInput = joinAll([ to16bits(input.length), input, - to16bits(unblindedElement.length), - unblindedElement, - new TextEncoder().encode('Finalize') + ...hasInfo, + to16bits(element.length), + element, + new TextEncoder().encode(Oprf.LABELS.FinalizeDST) ]) - return new Uint8Array(await crypto.subtle.digest(this.params.hash, hashInput)) + return new Uint8Array(await crypto.subtle.digest(this.hash, hashInput)) } + + protected scalarFromInfo(info: Uint8Array): Promise { + if (info.length >= 1 << 16) { + throw new Error('invalid info length') + } + const te = new TextEncoder() + const framedInfo = joinAll([te.encode('Info'), to16bits(info.length), info]) + return this.gg.hashToScalar(framedInfo, this.getDST(Oprf.LABELS.HashToScalarDST)) + } +} + +export class Blind extends SerializedScalar { + readonly _BlindBrand = '' +} +export class Blinded extends SerializedElt { + readonly _BlindedBrand = '' +} +export class Evaluated extends SerializedElt { + readonly _EvaluatedBrand = '' +} + +export class Evaluation { + constructor(public readonly element: Evaluated, public readonly proof?: DLEQProof) {} +} + +export class EvaluationRequest { + constructor(public readonly blinded: Blinded) {} +} + +export class FinalizeData { + constructor( + public readonly input: Uint8Array, + public readonly blind: Blind, + public readonly evalReq: EvaluationRequest + ) {} } diff --git a/src/server.ts b/src/server.ts index be7ef34..f548696 100644 --- a/src/server.ts +++ b/src/server.ts @@ -3,68 +3,148 @@ // Licensed under the BSD-3-Clause license found in the LICENSE file or // at https://opensource.org/licenses/BSD-3-Clause -import { Blinded, Evaluation, Oprf, OprfID } from './oprf.js' -import { Group, SerializedScalar } from './group.js' +import { Blinded, Evaluated, Evaluation, EvaluationRequest, ModeID, Oprf, SuiteID } from './oprf.js' +import { Group, Scalar, SerializedScalar } from './group.js' +import { DLEQProver } from './dleq.js' import { ctEqual } from './util.js' -export class OPRFServer extends Oprf { - private privateKey: Uint8Array +class baseServer extends Oprf { + protected privateKey: Uint8Array public supportsWebCryptoOPRF = false - constructor(id: OprfID, privateKey: Uint8Array) { - super(id) + constructor(mode: ModeID, suite: SuiteID, privateKey: Uint8Array) { + super(mode, suite) this.privateKey = privateKey } - evaluate(blindedElement: Blinded): Promise { + protected doEvaluation(bl: Blinded, key: Uint8Array): Promise { if (this.supportsWebCryptoOPRF) { - return this.evaluateWebCrypto(blindedElement) + return this.evaluateWebCrypto(bl, key) } - return Promise.resolve(this.evaluateSJCL(blindedElement)) + return Promise.resolve(this.evaluateSJCL(bl, key)) } - private async evaluateWebCrypto(blindedElement: Blinded): Promise { - const key = await crypto.subtle.importKey( + private async evaluateWebCrypto(bl: Blinded, key: Uint8Array): Promise { + const crKey = await crypto.subtle.importKey( 'raw', - this.privateKey, + key, { name: 'OPRF', - namedCurve: this.params.gg.id + namedCurve: this.gg.id }, true, ['sign'] ) // webcrypto accepts only compressed points. - let compressed = Uint8Array.from(blindedElement) - if (blindedElement[0] === 0x04) { - const P = this.params.gg.deserialize(blindedElement) - compressed = Uint8Array.from(this.params.gg.serialize(P, true)) + let compressed = Uint8Array.from(bl) + if (bl[0] === 0x04) { + const P = this.gg.deserialize(bl) + compressed = Uint8Array.from(this.gg.serialize(P, true)) } - const evaluation = await crypto.subtle.sign('OPRF', key, compressed) - return new Evaluation(evaluation) + return new Evaluated(await crypto.subtle.sign('OPRF', crKey, compressed)) } - private evaluateSJCL(blindedElement: Blinded): Evaluation { - const P = this.params.gg.deserialize(blindedElement) - const serSk = new SerializedScalar(this.privateKey) - const sk = this.params.gg.deserializeScalar(serSk) + private evaluateSJCL(bl: Blinded, key: Uint8Array): Evaluated { + const P = this.gg.deserialize(bl) + const sk = this.gg.deserializeScalar(new SerializedScalar(key)) const Z = Group.mul(sk, P) - return new Evaluation(this.params.gg.serialize(Z)) + return new Evaluated(this.gg.serialize(Z)) } + protected async secretFromInfo(info: Uint8Array): Promise<[Scalar, Scalar]> { + const m = await this.scalarFromInfo(info) + const skS = this.gg.deserializeScalar(new SerializedScalar(this.privateKey)) + const t = this.gg.addScalar(m, skS) + if (this.gg.isScalarZero(t)) { + throw new Error('inverse of zero') + } + const tInv = this.gg.invScalar(t) + return [t, tInv] + } + + protected async doFullEvaluate( + input: Uint8Array, + info = new Uint8Array(0) + ): Promise { + let secret = this.privateKey + if (this.mode === Oprf.Mode.POPRF) { + const [, evalSecret] = await this.secretFromInfo(info) + secret = this.gg.serializeScalar(evalSecret) + } + + const P = await this.gg.hashToGroup(input, this.getDST(Oprf.LABELS.HashToGroupDST)) + if (this.gg.isIdentity(P)) { + throw new Error('InvalidInputError') + } + const blinded = new Blinded(this.gg.serialize(P)) + const evaluated = await this.doEvaluation(blinded, secret) + return this.coreFinalize(input, evaluated, info) + } +} + +export class OPRFServer extends baseServer { + constructor(suite: SuiteID, privateKey: Uint8Array) { + super(Oprf.Mode.OPRF, suite, privateKey) + } + + async evaluate(req: EvaluationRequest): Promise { + return new Evaluation(await this.doEvaluation(req.blinded, this.privateKey)) + } async fullEvaluate(input: Uint8Array): Promise { - const dst = Oprf.getHashToGroupDST(this.params.id) - const T = await this.params.gg.hashToGroup(input, dst) - const issuedElement = new Blinded(this.params.gg.serialize(T)) - const evaluation = await this.evaluate(issuedElement) - const digest = await this.coreFinalize(input, evaluation) - return digest + return this.doFullEvaluate(input) } + async verifyFinalize(input: Uint8Array, output: Uint8Array): Promise { + return ctEqual(output, await this.doFullEvaluate(input)) + } +} +export class VOPRFServer extends baseServer { + constructor(suite: SuiteID, privateKey: Uint8Array) { + super(Oprf.Mode.VOPRF, suite, privateKey) + } + async evaluate(req: EvaluationRequest): Promise { + const e = await this.doEvaluation(req.blinded, this.privateKey) + const prover = new DLEQProver({ gg: this.gg, hash: this.hash, dst: '' }) + const skS = this.gg.deserializeScalar(new SerializedScalar(this.privateKey)) + const pkS = this.gg.mulBase(skS) + const Q = this.gg.deserialize(req.blinded) + const kQ = this.gg.deserialize(e) + const proof = await prover.prove(skS, [this.gg.generator(), pkS], [Q, kQ]) + return new Evaluation(e, proof) + } + async fullEvaluate(input: Uint8Array): Promise { + return this.doFullEvaluate(input) + } async verifyFinalize(input: Uint8Array, output: Uint8Array): Promise { - const digest = await this.fullEvaluate(input) - return ctEqual(output, digest) + return ctEqual(output, await this.doFullEvaluate(input)) + } +} + +export class POPRFServer extends baseServer { + constructor(suite: SuiteID, privateKey: Uint8Array) { + super(Oprf.Mode.POPRF, suite, privateKey) + } + async evaluate(req: EvaluationRequest, info = new Uint8Array(0)): Promise { + const [keyProof, evalSecret] = await this.secretFromInfo(info) + const secret = this.gg.serializeScalar(evalSecret) + const e = await this.doEvaluation(req.blinded, secret) + const prover = new DLEQProver({ gg: this.gg, hash: this.hash, dst: '' }) + const kG = this.gg.mulBase(keyProof) + const Q = this.gg.deserialize(e) + const kQ = this.gg.deserialize(req.blinded) + const proof = await prover.prove(keyProof, [this.gg.generator(), kG], [Q, kQ]) + return new Evaluation(e, proof) + } + async fullEvaluate(input: Uint8Array, info = new Uint8Array(0)): Promise { + return this.doFullEvaluate(input, info) + } + async verifyFinalize( + input: Uint8Array, + output: Uint8Array, + info = new Uint8Array(0) + ): Promise { + return ctEqual(output, await this.doFullEvaluate(input, info)) } } diff --git a/src/util.ts b/src/util.ts index 10dbe5f..1483a4c 100644 --- a/src/util.ts +++ b/src/util.ts @@ -46,21 +46,3 @@ export function to16bits(n: number): Uint8Array { } return new Uint8Array([(n >> 8) & 0xff, n & 0xff]) } - -export function hashParams(hash: string): { - outLenBytes: number // returns the size in bytes of the output. - blockLenBytes: number // returns the size of the internal block. -} { - switch (hash) { - case 'SHA-1': - return { outLenBytes: 20, blockLenBytes: 64 } - case 'SHA-256': - return { outLenBytes: 32, blockLenBytes: 64 } - case 'SHA-384': - return { outLenBytes: 48, blockLenBytes: 128 } - case 'SHA-512': - return { outLenBytes: 64, blockLenBytes: 128 } - default: - throw new Error(`invalid hash name: ${hash}`) - } -} diff --git a/test/keys.test.ts b/test/keys.test.ts index 2189172..e89dc4f 100644 --- a/test/keys.test.ts +++ b/test/keys.test.ts @@ -5,7 +5,6 @@ import { Oprf, - OprfID, deriveKeyPair, generateKeyPair, getKeySizes, @@ -13,68 +12,65 @@ import { validatePublicKey } from '../src/index.js' -describe.each([OprfID.OPRF_P256_SHA256, OprfID.OPRF_P384_SHA384, OprfID.OPRF_P521_SHA512])( - 'oprf-keys', - (id: OprfID) => { - describe(`${OprfID[id as number]}`, () => { - const { Nsk, Npk } = getKeySizes(id) - const { gg } = Oprf.params(id) +describe.each(Object.entries(Oprf.Suite))('oprf-keys', (name, id) => { + describe(`${name}`, () => { + const { Nsk, Npk } = getKeySizes(id) + const gg = Oprf.getGroup(id) - it('getKeySizes', () => { - expect(Nsk).toBe(Npk - 1) - }) + it('getKeySizes', () => { + expect(Nsk).toBe(Npk - 1) + }) - it('zeroPrivateKey', () => { - const zeroKeyBytes = new Uint8Array(Nsk) - const ret = validatePrivateKey(id, zeroKeyBytes) - expect(ret).toBe(false) - }) + it('zeroPrivateKey', () => { + const zeroKeyBytes = new Uint8Array(Nsk) + const ret = validatePrivateKey(id, zeroKeyBytes) + expect(ret).toBe(false) + }) - it('orderPrivateKey', () => { - const orderPk = gg.serializeScalar(gg.order()) - const ret = validatePrivateKey(id, orderPk) - expect(ret).toBe(false) - }) + it('orderPrivateKey', () => { + const orderPk = gg.serializeScalar(gg.order()) + const ret = validatePrivateKey(id, orderPk) + expect(ret).toBe(false) + }) - it('onesPrivateKey', () => { - const onesKeyBytes = new Uint8Array(Nsk).fill(0xff) - const ret = validatePrivateKey(id, onesKeyBytes) - expect(ret).toBe(false) - }) + it('onesPrivateKey', () => { + const onesKeyBytes = new Uint8Array(Nsk).fill(0xff) + const ret = validatePrivateKey(id, onesKeyBytes) + expect(ret).toBe(false) + }) - it('identityPublicKey', () => { - const identityKeyBytes = gg.serialize(gg.identity()) - const ret = validatePublicKey(id, identityKeyBytes) - expect(ret).toBe(false) - }) + it('identityPublicKey', () => { + const identityKeyBytes = gg.serialize(gg.identity()) + const ret = validatePublicKey(id, identityKeyBytes) + expect(ret).toBe(false) + }) - it('onesPublicKey', () => { - const onesKeyBytes = new Uint8Array(Npk).fill(0xff) - const ret = validatePublicKey(id, onesKeyBytes) - expect(ret).toBe(false) - }) + it('onesPublicKey', () => { + const onesKeyBytes = new Uint8Array(Npk).fill(0xff) + const ret = validatePublicKey(id, onesKeyBytes) + expect(ret).toBe(false) + }) - it('generateKeyPair', async () => { - for (let i = 0; i < 64; i++) { - const keys = await generateKeyPair(id) // eslint-disable-line no-await-in-loop - const sk = validatePrivateKey(id, keys.privateKey) - const pk = validatePublicKey(id, keys.publicKey) - expect(sk).toBe(true) - expect(pk).toBe(true) - } - }) + it('generateKeyPair', async () => { + for (let i = 0; i < 64; i++) { + const keys = await generateKeyPair(id) // eslint-disable-line no-await-in-loop + const sk = validatePrivateKey(id, keys.privateKey) + const pk = validatePublicKey(id, keys.publicKey) + expect(sk).toBe(true) + expect(pk).toBe(true) + } + }) - it('deriveKeyPair', async () => { - const info = new TextEncoder().encode('info used for derivation') - for (let i = 0; i < 64; i++) { - const seed = crypto.getRandomValues(new Uint8Array(Nsk)) - const keys = await deriveKeyPair(id, seed, info) // eslint-disable-line no-await-in-loop - const sk = validatePrivateKey(id, keys.privateKey) - const pk = validatePublicKey(id, keys.publicKey) - expect(sk).toBe(true) - expect(pk).toBe(true) - } - }) + it('deriveKeyPair', async () => { + const info = new TextEncoder().encode('info used for derivation') + for (let i = 0; i < 64; i++) { + const seed = crypto.getRandomValues(new Uint8Array(Nsk)) + const keys = await deriveKeyPair(Oprf.Mode.OPRF, id, seed, info) // eslint-disable-line no-await-in-loop + const sk = validatePrivateKey(id, keys.privateKey) + const pk = validatePublicKey(id, keys.publicKey) + expect(sk).toBe(true) + expect(pk).toBe(true) + } }) - } -) + }) +}) diff --git a/test/oprf.test.ts b/test/oprf.test.ts index fcb0f4f..1a5dd41 100644 --- a/test/oprf.test.ts +++ b/test/oprf.test.ts @@ -3,42 +3,63 @@ // Licensed under the BSD-3-Clause license found in the LICENSE file or // at https://opensource.org/licenses/BSD-3-Clause -import { OPRFClient, OPRFServer, Oprf, OprfID, randomPrivateKey } from '../src/index.js' +import { + OPRFClient, + OPRFServer, + Oprf, + POPRFClient, + POPRFServer, + VOPRFClient, + VOPRFServer, + generatePublicKey, + randomPrivateKey +} from '../src/index.js' -import { hashParams } from '../src/util.js' +describe.each(Object.entries(Oprf.Mode))('protocol', (modeName, mode) => { + describe.each(Object.entries(Oprf.Suite))(`${modeName}`, (suiteName, id) => { + let server: OPRFServer | VOPRFServer | POPRFServer + let client: OPRFClient | VOPRFClient | POPRFClient -describe.each([OprfID.OPRF_P256_SHA256, OprfID.OPRF_P384_SHA384, OprfID.OPRF_P521_SHA512])( - 'oprf-workflow', - (id: OprfID) => { - it(`${OprfID[id as number]}`, async () => { - const te = new TextEncoder() - // ///////////////// - // Setup Server - // ///////////////// + beforeAll(async () => { const privateKey = await randomPrivateKey(id) - const server = new OPRFServer(id, privateKey) - // ///////////////// - // Setup Client - // ///////////////// - const client = new OPRFClient(id) - const input = te.encode('This is the client input') + const publicKey = generatePublicKey(id, privateKey) + switch (mode) { + case Oprf.Mode.OPRF: + server = new OPRFServer(id, privateKey) + client = new OPRFClient(id) + break + + case Oprf.Mode.VOPRF: + server = new VOPRFServer(id, privateKey) + client = new VOPRFClient(id, publicKey) + break + case Oprf.Mode.POPRF: + server = new POPRFServer(id, privateKey) + client = new POPRFClient(id, publicKey) + break + } + }) + + it(`${suiteName}`, async () => { + // Client Server + // ==================================================== // Client - const { blind, blindedElement } = await client.blind(input) - // Client Server - // blindedElement + // blind, blindedElement = Blind(input) + const input = new TextEncoder().encode('This is the client input') + const [finData, evalReq] = await client.blind(input) + // evalReq // ------------------>> - - // Server - const evaluatedElement = await server.evaluate(blindedElement) - // Client Server - // evaluatedElement + // Server + // evaluation = Evaluate(evalReq, info*) + const evaluation = await server.evaluate(evalReq) + // evaluation // <<------------------ - + // // Client - const output = await client.finalize(input, blind, evaluatedElement) - const { outLenBytes } = hashParams(Oprf.params(id).hash) - - expect(output).toHaveLength(outLenBytes) + // output = Finalize(finData, evaluation, info*) + // + const output = await client.finalize(finData, evaluation) + expect(output).toHaveLength(Oprf.getOprfSize(id)) const serverOutput = await server.fullEvaluate(input) expect(output).toStrictEqual(serverOutput) @@ -46,5 +67,5 @@ describe.each([OprfID.OPRF_P256_SHA256, OprfID.OPRF_P384_SHA384, OprfID.OPRF_P52 const success = await server.verifyFinalize(input, output) expect(success).toBe(true) }) - } -) + }) +}) diff --git a/test/server.test.ts b/test/server.test.ts index 752cff3..8049907 100644 --- a/test/server.test.ts +++ b/test/server.test.ts @@ -4,12 +4,10 @@ // at https://opensource.org/licenses/BSD-3-Clause import { - Blinded, Group, OPRFClient, OPRFServer, Oprf, - OprfID, SerializedElt, SerializedScalar, randomPrivateKey @@ -53,35 +51,26 @@ function mockSign(...x: Parameters): ReturnType { throw new Error('bad algorithm') } -describe.each([OprfID.OPRF_P256_SHA256, OprfID.OPRF_P384_SHA384, OprfID.OPRF_P521_SHA512])( - 'supportsWebCrypto', - (id: OprfID) => { - beforeAll(() => { - jest.spyOn(crypto.subtle, 'importKey').mockImplementation(mockImportKey) - jest.spyOn(crypto.subtle, 'sign').mockImplementation(mockSign) - }) +describe.each(Object.entries(Oprf.Suite))('supportsWebCrypto', (name, id) => { + beforeAll(() => { + jest.spyOn(crypto.subtle, 'importKey').mockImplementation(mockImportKey) + jest.spyOn(crypto.subtle, 'sign').mockImplementation(mockSign) + }) - it(`${OprfID[id as number]}`, async () => { - const te = new TextEncoder() - const privateKey = await randomPrivateKey(id) - const server = new OPRFServer(id, privateKey) - const client = new OPRFClient(id) - const input = te.encode('This is the client input') - const req = await client.blind(input) - const { gg } = Oprf.params(id) - const bt = gg.deserialize(req.blindedElement) + it(`${name}`, async () => { + const te = new TextEncoder() + const privateKey = await randomPrivateKey(id) + const server = new OPRFServer(id, privateKey) + const client = new OPRFClient(id) + const input = te.encode('This is the client input') + const [, reqEval] = await client.blind(input) - for (const compressed of [true, false]) { - server.supportsWebCryptoOPRF = false - let blinded = new Blinded(gg.serialize(bt, compressed)) - const ev0 = await server.evaluate(blinded) // eslint-disable-line no-await-in-loop + server.supportsWebCryptoOPRF = false + const ev0 = await server.evaluate(reqEval) - server.supportsWebCryptoOPRF = true - blinded = new Blinded(gg.serialize(bt, compressed)) - const ev1 = await server.evaluate(blinded) // eslint-disable-line no-await-in-loop + server.supportsWebCryptoOPRF = true + const ev1 = await server.evaluate(reqEval) - expect(ev0).toEqual(ev1) - } - }) - } -) + expect(ev0).toEqual(ev1) + }) +}) diff --git a/test/vectors.test.ts b/test/vectors.test.ts index f860f01..0eaea82 100644 --- a/test/vectors.test.ts +++ b/test/vectors.test.ts @@ -3,7 +3,20 @@ // Licensed under the BSD-3-Clause license found in the LICENSE file or // at https://opensource.org/licenses/BSD-3-Clause -import { Blind, OPRFClient, OPRFServer, Oprf, OprfID, derivePrivateKey } from '../src/index.js' +import { + Blind, + ModeID, + OPRFClient, + OPRFServer, + Oprf, + POPRFClient, + POPRFServer, + SuiteID, + VOPRFClient, + VOPRFServer, + derivePrivateKey, + generatePublicKey +} from '../src/index.js' import allVectors from './testdata/allVectors_v09.json' import { jest } from '@jest/globals' @@ -16,48 +29,109 @@ function toHex(x: Uint8Array): string { return Buffer.from(x).toString('hex') } +class wrapPOPRFServer extends POPRFServer { + info!: Uint8Array + + evaluate(...r: Parameters): ReturnType { + return super.evaluate(r[0], this.info) + } + async fullEvaluate(input: Uint8Array): Promise { + return super.fullEvaluate(input, this.info) + } + async verifyFinalize(input: Uint8Array, output: Uint8Array): Promise { + return super.verifyFinalize(input, output, this.info) + } +} + +class wrapPOPRFClient extends POPRFClient { + info!: Uint8Array + + blind(input: Uint8Array): ReturnType { + return super.blind(input) + } + finalize(...r: Parameters): ReturnType { + return super.finalize(...r, this.info) + } +} + // Test vectors from https://datatracker.ietf.org/doc/draft-irtf-cfrg-voprf -// https://tools.ietf.org/html/draft-irtf-cfrg-voprf-06 +// https://tools.ietf.org/html/draft-irtf-cfrg-voprf-09 describe.each(allVectors)('test-vectors', (testVector: typeof allVectors[number]) => { - const oprfID = testVector.suiteID - if (testVector.mode === Oprf.mode && oprfID in OprfID) { - describe(`${testVector.suiteName}/Mode${testVector.mode}`, () => { - it('keygen', async () => { + const mode = testVector.mode as ModeID + const id = testVector.suiteID as SuiteID + + if (Object.values(Oprf.Suite).includes(id)) { + const txtMode = Object.entries(Oprf.Mode)[mode as number][0] + const txtSuite = Object.entries(Oprf.Suite)[Object.values(Oprf.Suite).indexOf(id)][0] + + describe(`${txtMode}, ${txtSuite}`, () => { + let skSm: Uint8Array + let server: OPRFServer | VOPRFServer | wrapPOPRFServer + let client: OPRFClient | VOPRFClient | wrapPOPRFClient + + beforeAll(async () => { const seed = fromHex(testVector.seed) - const info = fromHex(testVector.keyInfo) - const skSm = await derivePrivateKey(oprfID, seed, info) + const keyInfo = fromHex(testVector.keyInfo) + skSm = await derivePrivateKey(mode, id, seed, keyInfo) + const pkSm = generatePublicKey(id, skSm) + switch (mode) { + case Oprf.Mode.OPRF: + server = new OPRFServer(id, skSm) + client = new OPRFClient(id) + break + + case Oprf.Mode.VOPRF: + server = new VOPRFServer(id, skSm) + client = new VOPRFClient(id, pkSm) + break + + case Oprf.Mode.POPRF: + server = new wrapPOPRFServer(id, skSm) + client = new wrapPOPRFClient(id, pkSm) + break + } + }) + + it('keygen', () => { expect(toHex(skSm)).toBe(testVector.skSm) }) - const server = new OPRFServer(oprfID, fromHex(testVector.skSm)) - const client = new OPRFClient(oprfID) const { vectors } = testVector - server.supportsWebCryptoOPRF = false + describe.each(vectors)('vec$#', (vi: typeof vectors[number]) => { + if (vi.Batch === 1) { + it('protocol', async () => { + // Creates a mock for randomBlinder method to + // inject the blind value given by the test vector. + for (const c of [OPRFClient, VOPRFClient, wrapPOPRFClient]) { + jest.spyOn(c.prototype, 'randomBlinder').mockImplementation(() => { + const blind = new Blind(fromHex(vi.Blind)) + const scalar = Oprf.getGroup(id).deserializeScalar(blind) + return Promise.resolve({ scalar, blind }) + }) + } - it.each(vectors)('vec$#', async (vi: typeof vectors[number]) => { - // Creates a mock for OPRFClient.randomBlinder method to - // inject the blind value given by the test vector. - jest.spyOn(OPRFClient.prototype, 'randomBlinder').mockImplementationOnce(() => { - const blind = new Blind(fromHex(vi.Blind)) - const { gg } = Oprf.params(oprfID) - const scalar = gg.deserializeScalar(blind) - return Promise.resolve({ scalar, blind }) - }) + if (testVector.mode === Oprf.Mode.POPRF) { + const info = fromHex((vi as any).Info as string) // eslint-disable-line @typescript-eslint/no-explicit-any + ;(server as wrapPOPRFServer).info = info + ;(client as wrapPOPRFClient).info = info + } - const input = fromHex(vi.Input) - const { blind, blindedElement } = await client.blind(input) - expect(toHex(blind)).toEqual(vi.Blind) - expect(toHex(blindedElement)).toEqual(vi.BlindedElement) + const input = fromHex(vi.Input) + const [finData, evalReq] = await client.blind(input) + expect(toHex(finData.blind)).toEqual(vi.Blind) + expect(toHex(evalReq.blinded)).toEqual(vi.BlindedElement) - const evaluation = await server.evaluate(blindedElement) - expect(toHex(evaluation)).toEqual(vi.EvaluationElement) + const evaluation = await server.evaluate(evalReq) + expect(toHex(evaluation.element)).toEqual(vi.EvaluationElement) - const output = await client.finalize(input, blind, evaluation) - expect(toHex(output)).toEqual(vi.Output) + const output = await client.finalize(finData, evaluation) + expect(toHex(output)).toEqual(vi.Output) - const serverCheckOutput = await server.verifyFinalize(input, output) - expect(serverCheckOutput).toBe(true) + const serverCheckOutput = await server.verifyFinalize(input, output) + expect(serverCheckOutput).toBe(true) + }) + } }) }) }