From ad5c52e71db5d3ce5f19ca20c7c05d4976a0cd63 Mon Sep 17 00:00:00 2001 From: armfazh Date: Wed, 16 Mar 2022 20:29:26 -0700 Subject: [PATCH 1/4] Refactor to support draft v09. --- src/client.ts | 52 ++++++++---- src/group.ts | 20 ++++- src/keys.ts | 40 ++++----- src/oprf.ts | 190 +++++++++++++++++++++++++------------------ src/server.ts | 57 +++++++------ src/util.ts | 18 ---- test/keys.test.ts | 110 ++++++++++++------------- test/oprf.test.ts | 83 +++++++++---------- test/server.test.ts | 49 +++++------ test/vectors.test.ts | 41 ++++++---- 10 files changed, 357 insertions(+), 303 deletions(-) diff --git a/src/client.ts b/src/client.ts index 2e9ee72..c67650e 100644 --- a/src/client.ts +++ b/src/client.ts @@ -3,34 +3,54 @@ // Licensed under the BSD-3-Clause license found in the LICENSE file or // at https://opensource.org/licenses/BSD-3-Clause -import { Blind, Blinded, Evaluation, Oprf } from './oprf.js' +import { + Blind, + Blinded, + Evaluation, + EvaluationRequest, + FinalizeData, + ModeID, + Oprf, + SuiteID +} from './oprf.js' import { Group, Scalar } from './group.js' -export class OPRFClient extends Oprf { +class baseClient extends Oprf { + constructor(mode: ModeID, suite: SuiteID) { + super(mode, suite) + } + async randomBlinder(): Promise<{ scalar: Scalar; blind: Blind }> { - const scalar = await this.params.gg.randomScalar() - const blind = new Blind(this.params.gg.serializeScalar(scalar)) + const scalar = await this.gg.randomScalar() + const blind = new Blind(this.gg.serializeScalar(scalar)) return { scalar, blind } } - async blind(input: Uint8Array): Promise<{ blind: Blind; blindedElement: Blinded }> { + async blind(input: Uint8Array): Promise<[FinalizeData, EvaluationRequest]> { const { scalar, blind } = await this.randomBlinder() - const dst = Oprf.getHashToGroupDST(this.params.id) - const P = await this.params.gg.hashToGroup(input, dst) - if (this.params.gg.isIdentity(P)) { + const dst = this.getDST(Oprf.LABELS.HashToGroupDST) + const P = await this.gg.hashToGroup(input, dst) + if (this.gg.isIdentity(P)) { throw new Error('InvalidInputError') } const Q = Group.mul(scalar, P) - const blindedElement = new Blinded(this.params.gg.serialize(Q)) - return { blind, blindedElement } + const evalReq = new EvaluationRequest(new Blinded(this.gg.serialize(Q))) + const finData = new FinalizeData(input, blind, evalReq) + return [finData, evalReq] } - finalize(input: Uint8Array, blind: Blind, evaluation: Evaluation): Promise { - const blindScalar = this.params.gg.deserializeScalar(blind) - const blindScalarInv = this.params.gg.invScalar(blindScalar) - const Z = this.params.gg.deserialize(evaluation) + finalize(finData: FinalizeData, evaluation: Evaluation): Promise { + const blindScalar = this.gg.deserializeScalar(finData.blind) + const blindScalarInv = this.gg.invScalar(blindScalar) + const Z = this.gg.deserialize(evaluation.element) const N = Group.mul(blindScalarInv, Z) - const unblinded = this.params.gg.serialize(N) - return this.coreFinalize(input, unblinded) + const unblinded = this.gg.serialize(N) + return this.coreFinalize(finData.input, unblinded) + } +} + +export class OPRFClient extends baseClient { + constructor(suite: SuiteID) { + super(Oprf.Mode.OPRF, suite) } } diff --git a/src/group.ts b/src/group.ts index 8a6a043..b3cc7e7 100644 --- a/src/group.ts +++ b/src/group.ts @@ -3,7 +3,7 @@ // Licensed under the BSD-3-Clause license found in the LICENSE file or // at https://opensource.org/licenses/BSD-3-Clause -import { hashParams, joinAll, xor } from './util.js' +import { joinAll, xor } from './util.js' import sjcl from './sjcl/index.js' @@ -20,6 +20,24 @@ export type Scalar = sjcl.bn export type Curve = sjcl.ecc.curve export type FieldElt = sjcl.bn +function hashParams(hash: string): { + outLenBytes: number // returns the size in bytes of the output. + blockLenBytes: number // returns the size of the internal block. +} { + switch (hash) { + case 'SHA-1': + return { outLenBytes: 20, blockLenBytes: 64 } + case 'SHA-256': + return { outLenBytes: 32, blockLenBytes: 64 } + case 'SHA-384': + return { outLenBytes: 48, blockLenBytes: 128 } + case 'SHA-512': + return { outLenBytes: 64, blockLenBytes: 128 } + default: + throw new Error(`invalid hash name: ${hash}`) + } +} + async function expandXMD( hash: string, msg: Uint8Array, diff --git a/src/keys.ts b/src/keys.ts index 4399a3d..53c5a6d 100644 --- a/src/keys.ts +++ b/src/keys.ts @@ -3,18 +3,18 @@ // Licensed under the BSD-3-Clause license found in the LICENSE file or // at https://opensource.org/licenses/BSD-3-Clause -import { Oprf, OprfID } from './oprf.js' -import { SerializedElt, SerializedScalar } from './group.js' +import { ModeID, Oprf, SuiteID } from './oprf.js' +import { Scalar, SerializedElt, SerializedScalar } from './group.js' import { joinAll, to16bits } from './util.js' -export function getKeySizes(id: OprfID): { Nsk: number; Npk: number } { - const { gg } = Oprf.params(id) +export function getKeySizes(id: SuiteID): { Nsk: number; Npk: number } { + const gg = Oprf.getGroup(id) return { Nsk: gg.size, Npk: 1 + gg.size } } -export function validatePrivateKey(id: OprfID, privateKey: Uint8Array): boolean { +export function validatePrivateKey(id: SuiteID, privateKey: Uint8Array): boolean { try { - const { gg } = Oprf.params(id) + const gg = Oprf.getGroup(id) const s = gg.deserializeScalar(new SerializedScalar(privateKey)) return !s.equals(0) } catch (_) { @@ -22,9 +22,9 @@ export function validatePrivateKey(id: OprfID, privateKey: Uint8Array): boolean } } -export function validatePublicKey(id: OprfID, publicKey: Uint8Array): boolean { +export function validatePublicKey(id: SuiteID, publicKey: Uint8Array): boolean { try { - const { gg } = Oprf.params(id) + const gg = Oprf.getGroup(id) const P = gg.deserialize(new SerializedElt(publicKey)) return !P.isIdentity } catch (_) { @@ -32,43 +32,44 @@ export function validatePublicKey(id: OprfID, publicKey: Uint8Array): boolean { } } -export async function randomPrivateKey(id: OprfID): Promise { - const { gg } = Oprf.params(id) +export async function randomPrivateKey(id: SuiteID): Promise { + const gg = Oprf.getGroup(id) const priv = await gg.randomScalar() return new Uint8Array(gg.serializeScalar(priv)) } export async function derivePrivateKey( - id: OprfID, + mode: ModeID, + id: SuiteID, seed: Uint8Array, info: Uint8Array ): Promise { - const { gg } = Oprf.params(id) + const gg = Oprf.getGroup(id) const deriveInput = joinAll([seed, to16bits(info.length), info]) let counter = 0 - let priv + let priv: Scalar do { if (counter > 255) { throw new Error('DeriveKeyPairError') } const hashInput = joinAll([deriveInput, Uint8Array.from([counter])]) - priv = await gg.hashToScalar(hashInput, Oprf.getDeriveKeyPairDST(id)) + priv = await gg.hashToScalar(hashInput, Oprf.getDST(mode, id, Oprf.LABELS.DeriveKeyPairDST)) counter++ } while (gg.isScalarZero(priv)) return new Uint8Array(gg.serializeScalar(priv)) } -export function generatePublicKey(id: OprfID, privateKey: Uint8Array): Uint8Array { - const { gg } = Oprf.params(id) +export function generatePublicKey(id: SuiteID, privateKey: Uint8Array): Uint8Array { + const gg = Oprf.getGroup(id) const priv = gg.deserializeScalar(new SerializedScalar(privateKey)) const pub = gg.mulBase(priv) return new Uint8Array(gg.serialize(pub)) } export async function generateKeyPair( - id: OprfID + id: SuiteID ): Promise<{ privateKey: Uint8Array; publicKey: Uint8Array }> { const privateKey = await randomPrivateKey(id) const publicKey = generatePublicKey(id, privateKey) @@ -76,11 +77,12 @@ export async function generateKeyPair( } export async function deriveKeyPair( - id: OprfID, + mode: ModeID, + id: SuiteID, seed: Uint8Array, info: Uint8Array ): Promise<{ privateKey: Uint8Array; publicKey: Uint8Array }> { - const privateKey = await derivePrivateKey(id, seed, info) + const privateKey = await derivePrivateKey(mode, id, seed, info) const publicKey = generatePublicKey(id, privateKey) return { privateKey, publicKey } } diff --git a/src/oprf.ts b/src/oprf.ts index 62299af..ebc922d 100644 --- a/src/oprf.ts +++ b/src/oprf.ts @@ -6,111 +6,145 @@ import { Group, GroupID, SerializedElt, SerializedScalar } from './group.js' import { joinAll, to16bits } from './util.js' -export class Blind extends SerializedScalar { - readonly _BlindBrand = '' -} -export class Blinded extends SerializedElt { - readonly _BlindedBrand = '' -} +import { DLEQProof } from './dleq.js' -export class Evaluation extends SerializedElt { - readonly _EvaluationBrand = '' -} - -export enum OprfID { // eslint-disable-line no-shadow - OPRF_P256_SHA256 = 3, - OPRF_P384_SHA384 = 4, - OPRF_P521_SHA512 = 5 -} +export type ModeID = typeof Oprf.Mode[keyof typeof Oprf.Mode] +export type SuiteID = typeof Oprf.Suite[keyof typeof Oprf.Suite] -export interface OprfParams { - readonly id: OprfID - readonly gg: Group - readonly hash: string - readonly blindedSize: number - readonly evaluationSize: number - readonly blindSize: number +function assertNever(name: string, x: never): never { + throw new Error(`unexpected ${name} identifier: ${x}`) } export abstract class Oprf { - static readonly mode = 0 - - static readonly version = 'VOPRF09-' - - readonly params: OprfParams - - constructor(id: OprfID) { - this.params = Oprf.params(id) - } - - static validateID(id: OprfID): boolean { - switch (id) { - case OprfID.OPRF_P256_SHA256: - case OprfID.OPRF_P384_SHA384: - case OprfID.OPRF_P521_SHA512: - return true + static Mode = { + OPRF: 0, + VOPRF: 1, + POPRF: 2 + } as const + + static Suite = { + P256_SHA256: 3, + P384_SHA384: 4, + P521_SHA512: 5 + } as const + + static LABELS = { + Version: 'VOPRF09-', + FinalizeDST: 'Finalize', + HashToGroupDST: 'HashToGroup-', + HashToScalarDST: 'HashToScalar-', + DeriveKeyPairDST: 'DeriveKeyPair', + InfoLabel: 'Info' + } as const + + private static validateMode(m: ModeID): ModeID { + switch (m) { + case Oprf.Mode.OPRF: + case Oprf.Mode.VOPRF: + case Oprf.Mode.POPRF: + return m default: - throw new Error(`not supported ID: ${id}`) + assertNever('Oprf.Mode', m) } } - - static params(id: OprfID): OprfParams { - Oprf.validateID(id) - let gid = GroupID.P256 - let hash = 'SHA-256' + private static getParams(id: SuiteID): [SuiteID, GroupID, string, number] { switch (id) { - case OprfID.OPRF_P256_SHA256: - break - case OprfID.OPRF_P384_SHA384: - gid = GroupID.P384 - hash = 'SHA-384' - break - case OprfID.OPRF_P521_SHA512: - gid = GroupID.P521 - hash = 'SHA-512' - break + case Oprf.Suite.P256_SHA256: + return [id, GroupID.P256, 'SHA-256', 32] + case Oprf.Suite.P384_SHA384: + return [id, GroupID.P384, 'SHA-384', 48] + case Oprf.Suite.P521_SHA512: + return [id, GroupID.P521, 'SHA-512', 64] default: - throw new Error(`not supported ID: ${id}`) - } - const gg = new Group(gid) - return { - id, - gg, - hash, - blindedSize: 1 + gg.size, - evaluationSize: 1 + gg.size, - blindSize: gg.size + assertNever('Oprf.Suite', id) } } - - static getContextString(id: OprfID): Uint8Array { - Oprf.validateID(id) - return joinAll([new TextEncoder().encode(Oprf.version), new Uint8Array([Oprf.mode, 0, id])]) + static getGroup(suite: SuiteID): Group { + return new Group(Oprf.getParams(suite)[1]) } - - static getHashToGroupDST(id: OprfID): Uint8Array { - return joinAll([new TextEncoder().encode('HashToGroup-'), Oprf.getContextString(id)]) + static getHash(suite: SuiteID): string { + return Oprf.getParams(suite)[2] + } + static getOprfSize(suite: SuiteID): number { + return Oprf.getParams(suite)[3] } + static getDST(mode: ModeID, suite: SuiteID, name: string): Uint8Array { + const m = Oprf.validateMode(mode) + const s = Oprf.getParams(suite)[0] + return joinAll([ + new TextEncoder().encode(name + Oprf.LABELS.Version), + new Uint8Array([m, 0, s]) + ]) + } + + readonly mode: ModeID + readonly ID: SuiteID + readonly gg: Group + readonly hash: string - static getHashToScalarDST(id: OprfID): Uint8Array { - return joinAll([new TextEncoder().encode('HashToScalar-'), Oprf.getContextString(id)]) + constructor(mode: ModeID, suite: SuiteID) { + const [ID, gid, hash] = Oprf.getParams(suite) + this.ID = ID + this.gg = new Group(gid) + this.hash = hash + this.mode = Oprf.validateMode(mode) } - static getDeriveKeyPairDST(id: OprfID): Uint8Array { - return joinAll([new TextEncoder().encode('DeriveKeyPair'), Oprf.getContextString(id)]) + protected getDST(name: string): Uint8Array { + return Oprf.getDST(this.mode, this.ID, name) } protected async coreFinalize( input: Uint8Array, - unblindedElement: Uint8Array + unblindedElement: Uint8Array, + info?: Uint8Array ): Promise { + let hasInfo: Uint8Array[] = [] + if (this.mode === Oprf.Mode.POPRF) { + if (info) { + hasInfo = [to16bits(info.length), info] + } else { + hasInfo = [new Uint8Array(0)] + } + } + const hashInput = joinAll([ to16bits(input.length), input, + ...hasInfo, to16bits(unblindedElement.length), unblindedElement, - new TextEncoder().encode('Finalize') + new TextEncoder().encode(Oprf.LABELS.FinalizeDST) ]) - return new Uint8Array(await crypto.subtle.digest(this.params.hash, hashInput)) + return new Uint8Array(await crypto.subtle.digest(this.hash, hashInput)) } } + +export class Blind extends SerializedScalar { + readonly _BlindBrand = '' +} +export class Blinded extends SerializedElt { + readonly _BlindedBrand = '' +} +export class Evaluated extends SerializedElt { + readonly _EvaluatedBrand = '' +} + +export class Evaluation { + constructor( + public readonly element: Evaluated, + public readonly proof?: DLEQProof + ) { } +} + +export class EvaluationRequest { + constructor(public readonly blinded: Blinded) { } +} + +export class FinalizeData { + constructor( + public readonly input: Uint8Array, + public readonly blind: Blind, + public readonly evalReq: EvaluationRequest + ) { } +} diff --git a/src/server.ts b/src/server.ts index be7ef34..fa55eff 100644 --- a/src/server.ts +++ b/src/server.ts @@ -3,68 +3,75 @@ // Licensed under the BSD-3-Clause license found in the LICENSE file or // at https://opensource.org/licenses/BSD-3-Clause -import { Blinded, Evaluation, Oprf, OprfID } from './oprf.js' +import { Blinded, Evaluated, Evaluation, EvaluationRequest, ModeID, Oprf, SuiteID } from './oprf.js' import { Group, SerializedScalar } from './group.js' import { ctEqual } from './util.js' -export class OPRFServer extends Oprf { +class baseServer extends Oprf { private privateKey: Uint8Array public supportsWebCryptoOPRF = false - constructor(id: OprfID, privateKey: Uint8Array) { - super(id) + constructor(mode: ModeID, suite: SuiteID, privateKey: Uint8Array) { + super(mode, suite) this.privateKey = privateKey } - evaluate(blindedElement: Blinded): Promise { + evaluate(req: EvaluationRequest): Promise { if (this.supportsWebCryptoOPRF) { - return this.evaluateWebCrypto(blindedElement) + return this.evaluateWebCrypto(req) } - return Promise.resolve(this.evaluateSJCL(blindedElement)) + return Promise.resolve(this.evaluateSJCL(req)) } - private async evaluateWebCrypto(blindedElement: Blinded): Promise { + private async evaluateWebCrypto(req: EvaluationRequest): Promise { const key = await crypto.subtle.importKey( 'raw', this.privateKey, { name: 'OPRF', - namedCurve: this.params.gg.id + namedCurve: this.gg.id }, true, ['sign'] ) // webcrypto accepts only compressed points. - let compressed = Uint8Array.from(blindedElement) - if (blindedElement[0] === 0x04) { - const P = this.params.gg.deserialize(blindedElement) - compressed = Uint8Array.from(this.params.gg.serialize(P, true)) + let compressed = Uint8Array.from(req.blinded) + if (req.blinded[0] === 0x04) { + const P = this.gg.deserialize(req.blinded) + compressed = Uint8Array.from(this.gg.serialize(P, true)) } const evaluation = await crypto.subtle.sign('OPRF', key, compressed) - return new Evaluation(evaluation) + return new Evaluation(new Evaluated(evaluation)) } - private evaluateSJCL(blindedElement: Blinded): Evaluation { - const P = this.params.gg.deserialize(blindedElement) + private evaluateSJCL(req: EvaluationRequest): Evaluation { + const P = this.gg.deserialize(req.blinded) const serSk = new SerializedScalar(this.privateKey) - const sk = this.params.gg.deserializeScalar(serSk) + const sk = this.gg.deserializeScalar(serSk) const Z = Group.mul(sk, P) - return new Evaluation(this.params.gg.serialize(Z)) + return new Evaluation(new Evaluated(this.gg.serialize(Z))) } async fullEvaluate(input: Uint8Array): Promise { - const dst = Oprf.getHashToGroupDST(this.params.id) - const T = await this.params.gg.hashToGroup(input, dst) - const issuedElement = new Blinded(this.params.gg.serialize(T)) + const dst = this.getDST(Oprf.LABELS.HashToGroupDST) + const P = await this.gg.hashToGroup(input, dst) + if (this.gg.isIdentity(P)) { + throw new Error('InvalidInputError') + } + const issuedElement = new EvaluationRequest(new Blinded(this.gg.serialize(P))) const evaluation = await this.evaluate(issuedElement) - const digest = await this.coreFinalize(input, evaluation) - return digest + return this.coreFinalize(input, evaluation.element) } async verifyFinalize(input: Uint8Array, output: Uint8Array): Promise { - const digest = await this.fullEvaluate(input) - return ctEqual(output, digest) + return ctEqual(output, await this.fullEvaluate(input)) + } +} + +export class OPRFServer extends baseServer { + constructor(suite: SuiteID, privateKey: Uint8Array) { + super(Oprf.Mode.OPRF, suite, privateKey) } } diff --git a/src/util.ts b/src/util.ts index 10dbe5f..1483a4c 100644 --- a/src/util.ts +++ b/src/util.ts @@ -46,21 +46,3 @@ export function to16bits(n: number): Uint8Array { } return new Uint8Array([(n >> 8) & 0xff, n & 0xff]) } - -export function hashParams(hash: string): { - outLenBytes: number // returns the size in bytes of the output. - blockLenBytes: number // returns the size of the internal block. -} { - switch (hash) { - case 'SHA-1': - return { outLenBytes: 20, blockLenBytes: 64 } - case 'SHA-256': - return { outLenBytes: 32, blockLenBytes: 64 } - case 'SHA-384': - return { outLenBytes: 48, blockLenBytes: 128 } - case 'SHA-512': - return { outLenBytes: 64, blockLenBytes: 128 } - default: - throw new Error(`invalid hash name: ${hash}`) - } -} diff --git a/test/keys.test.ts b/test/keys.test.ts index 2189172..e89dc4f 100644 --- a/test/keys.test.ts +++ b/test/keys.test.ts @@ -5,7 +5,6 @@ import { Oprf, - OprfID, deriveKeyPair, generateKeyPair, getKeySizes, @@ -13,68 +12,65 @@ import { validatePublicKey } from '../src/index.js' -describe.each([OprfID.OPRF_P256_SHA256, OprfID.OPRF_P384_SHA384, OprfID.OPRF_P521_SHA512])( - 'oprf-keys', - (id: OprfID) => { - describe(`${OprfID[id as number]}`, () => { - const { Nsk, Npk } = getKeySizes(id) - const { gg } = Oprf.params(id) +describe.each(Object.entries(Oprf.Suite))('oprf-keys', (name, id) => { + describe(`${name}`, () => { + const { Nsk, Npk } = getKeySizes(id) + const gg = Oprf.getGroup(id) - it('getKeySizes', () => { - expect(Nsk).toBe(Npk - 1) - }) + it('getKeySizes', () => { + expect(Nsk).toBe(Npk - 1) + }) - it('zeroPrivateKey', () => { - const zeroKeyBytes = new Uint8Array(Nsk) - const ret = validatePrivateKey(id, zeroKeyBytes) - expect(ret).toBe(false) - }) + it('zeroPrivateKey', () => { + const zeroKeyBytes = new Uint8Array(Nsk) + const ret = validatePrivateKey(id, zeroKeyBytes) + expect(ret).toBe(false) + }) - it('orderPrivateKey', () => { - const orderPk = gg.serializeScalar(gg.order()) - const ret = validatePrivateKey(id, orderPk) - expect(ret).toBe(false) - }) + it('orderPrivateKey', () => { + const orderPk = gg.serializeScalar(gg.order()) + const ret = validatePrivateKey(id, orderPk) + expect(ret).toBe(false) + }) - it('onesPrivateKey', () => { - const onesKeyBytes = new Uint8Array(Nsk).fill(0xff) - const ret = validatePrivateKey(id, onesKeyBytes) - expect(ret).toBe(false) - }) + it('onesPrivateKey', () => { + const onesKeyBytes = new Uint8Array(Nsk).fill(0xff) + const ret = validatePrivateKey(id, onesKeyBytes) + expect(ret).toBe(false) + }) - it('identityPublicKey', () => { - const identityKeyBytes = gg.serialize(gg.identity()) - const ret = validatePublicKey(id, identityKeyBytes) - expect(ret).toBe(false) - }) + it('identityPublicKey', () => { + const identityKeyBytes = gg.serialize(gg.identity()) + const ret = validatePublicKey(id, identityKeyBytes) + expect(ret).toBe(false) + }) - it('onesPublicKey', () => { - const onesKeyBytes = new Uint8Array(Npk).fill(0xff) - const ret = validatePublicKey(id, onesKeyBytes) - expect(ret).toBe(false) - }) + it('onesPublicKey', () => { + const onesKeyBytes = new Uint8Array(Npk).fill(0xff) + const ret = validatePublicKey(id, onesKeyBytes) + expect(ret).toBe(false) + }) - it('generateKeyPair', async () => { - for (let i = 0; i < 64; i++) { - const keys = await generateKeyPair(id) // eslint-disable-line no-await-in-loop - const sk = validatePrivateKey(id, keys.privateKey) - const pk = validatePublicKey(id, keys.publicKey) - expect(sk).toBe(true) - expect(pk).toBe(true) - } - }) + it('generateKeyPair', async () => { + for (let i = 0; i < 64; i++) { + const keys = await generateKeyPair(id) // eslint-disable-line no-await-in-loop + const sk = validatePrivateKey(id, keys.privateKey) + const pk = validatePublicKey(id, keys.publicKey) + expect(sk).toBe(true) + expect(pk).toBe(true) + } + }) - it('deriveKeyPair', async () => { - const info = new TextEncoder().encode('info used for derivation') - for (let i = 0; i < 64; i++) { - const seed = crypto.getRandomValues(new Uint8Array(Nsk)) - const keys = await deriveKeyPair(id, seed, info) // eslint-disable-line no-await-in-loop - const sk = validatePrivateKey(id, keys.privateKey) - const pk = validatePublicKey(id, keys.publicKey) - expect(sk).toBe(true) - expect(pk).toBe(true) - } - }) + it('deriveKeyPair', async () => { + const info = new TextEncoder().encode('info used for derivation') + for (let i = 0; i < 64; i++) { + const seed = crypto.getRandomValues(new Uint8Array(Nsk)) + const keys = await deriveKeyPair(Oprf.Mode.OPRF, id, seed, info) // eslint-disable-line no-await-in-loop + const sk = validatePrivateKey(id, keys.privateKey) + const pk = validatePublicKey(id, keys.publicKey) + expect(sk).toBe(true) + expect(pk).toBe(true) + } }) - } -) + }) +}) diff --git a/test/oprf.test.ts b/test/oprf.test.ts index fcb0f4f..ae8f8f3 100644 --- a/test/oprf.test.ts +++ b/test/oprf.test.ts @@ -3,48 +3,41 @@ // Licensed under the BSD-3-Clause license found in the LICENSE file or // at https://opensource.org/licenses/BSD-3-Clause -import { OPRFClient, OPRFServer, Oprf, OprfID, randomPrivateKey } from '../src/index.js' - -import { hashParams } from '../src/util.js' - -describe.each([OprfID.OPRF_P256_SHA256, OprfID.OPRF_P384_SHA384, OprfID.OPRF_P521_SHA512])( - 'oprf-workflow', - (id: OprfID) => { - it(`${OprfID[id as number]}`, async () => { - const te = new TextEncoder() - // ///////////////// - // Setup Server - // ///////////////// - const privateKey = await randomPrivateKey(id) - const server = new OPRFServer(id, privateKey) - // ///////////////// - // Setup Client - // ///////////////// - const client = new OPRFClient(id) - const input = te.encode('This is the client input') - // Client - const { blind, blindedElement } = await client.blind(input) - // Client Server - // blindedElement - // ------------------>> - - // Server - const evaluatedElement = await server.evaluate(blindedElement) - // Client Server - // evaluatedElement - // <<------------------ - - // Client - const output = await client.finalize(input, blind, evaluatedElement) - const { outLenBytes } = hashParams(Oprf.params(id).hash) - - expect(output).toHaveLength(outLenBytes) - - const serverOutput = await server.fullEvaluate(input) - expect(output).toStrictEqual(serverOutput) - - const success = await server.verifyFinalize(input, output) - expect(success).toBe(true) - }) - } -) +import { OPRFClient, OPRFServer, Oprf, randomPrivateKey } from '../src/index.js' + +describe.each(Object.entries(Oprf.Suite))('oprf-workflow', (name, id) => { + it(`${name}`, async () => { + const te = new TextEncoder() + // ///////////////// + // Setup Server + // ///////////////// + const privateKey = await randomPrivateKey(id) + const server = new OPRFServer(id, privateKey) + // ///////////////// + // Setup Client + // ///////////////// + const client = new OPRFClient(id) + const input = te.encode('This is the client input') + // Client + const [finData, evalReq] = await client.blind(input) + // Client Server + // evalReq + // ------------------>> + + // Server + const evaluation = await server.evaluate(evalReq) + // Client Server + // evaluation + // <<------------------ + + // Client + const output = await client.finalize(finData, evaluation) + expect(output).toHaveLength(Oprf.getOprfSize(id)) + + const serverOutput = await server.fullEvaluate(input) + expect(output).toStrictEqual(serverOutput) + + const success = await server.verifyFinalize(input, output) + expect(success).toBe(true) + }) +}) diff --git a/test/server.test.ts b/test/server.test.ts index 752cff3..8049907 100644 --- a/test/server.test.ts +++ b/test/server.test.ts @@ -4,12 +4,10 @@ // at https://opensource.org/licenses/BSD-3-Clause import { - Blinded, Group, OPRFClient, OPRFServer, Oprf, - OprfID, SerializedElt, SerializedScalar, randomPrivateKey @@ -53,35 +51,26 @@ function mockSign(...x: Parameters): ReturnType { throw new Error('bad algorithm') } -describe.each([OprfID.OPRF_P256_SHA256, OprfID.OPRF_P384_SHA384, OprfID.OPRF_P521_SHA512])( - 'supportsWebCrypto', - (id: OprfID) => { - beforeAll(() => { - jest.spyOn(crypto.subtle, 'importKey').mockImplementation(mockImportKey) - jest.spyOn(crypto.subtle, 'sign').mockImplementation(mockSign) - }) +describe.each(Object.entries(Oprf.Suite))('supportsWebCrypto', (name, id) => { + beforeAll(() => { + jest.spyOn(crypto.subtle, 'importKey').mockImplementation(mockImportKey) + jest.spyOn(crypto.subtle, 'sign').mockImplementation(mockSign) + }) - it(`${OprfID[id as number]}`, async () => { - const te = new TextEncoder() - const privateKey = await randomPrivateKey(id) - const server = new OPRFServer(id, privateKey) - const client = new OPRFClient(id) - const input = te.encode('This is the client input') - const req = await client.blind(input) - const { gg } = Oprf.params(id) - const bt = gg.deserialize(req.blindedElement) + it(`${name}`, async () => { + const te = new TextEncoder() + const privateKey = await randomPrivateKey(id) + const server = new OPRFServer(id, privateKey) + const client = new OPRFClient(id) + const input = te.encode('This is the client input') + const [, reqEval] = await client.blind(input) - for (const compressed of [true, false]) { - server.supportsWebCryptoOPRF = false - let blinded = new Blinded(gg.serialize(bt, compressed)) - const ev0 = await server.evaluate(blinded) // eslint-disable-line no-await-in-loop + server.supportsWebCryptoOPRF = false + const ev0 = await server.evaluate(reqEval) - server.supportsWebCryptoOPRF = true - blinded = new Blinded(gg.serialize(bt, compressed)) - const ev1 = await server.evaluate(blinded) // eslint-disable-line no-await-in-loop + server.supportsWebCryptoOPRF = true + const ev1 = await server.evaluate(reqEval) - expect(ev0).toEqual(ev1) - } - }) - } -) + expect(ev0).toEqual(ev1) + }) +}) diff --git a/test/vectors.test.ts b/test/vectors.test.ts index f860f01..7f17a7a 100644 --- a/test/vectors.test.ts +++ b/test/vectors.test.ts @@ -3,7 +3,15 @@ // Licensed under the BSD-3-Clause license found in the LICENSE file or // at https://opensource.org/licenses/BSD-3-Clause -import { Blind, OPRFClient, OPRFServer, Oprf, OprfID, derivePrivateKey } from '../src/index.js' +import { + Blind, + ModeID, + OPRFClient, + OPRFServer, + Oprf, + SuiteID, + derivePrivateKey +} from '../src/index.js' import allVectors from './testdata/allVectors_v09.json' import { jest } from '@jest/globals' @@ -19,18 +27,23 @@ function toHex(x: Uint8Array): string { // Test vectors from https://datatracker.ietf.org/doc/draft-irtf-cfrg-voprf // https://tools.ietf.org/html/draft-irtf-cfrg-voprf-06 describe.each(allVectors)('test-vectors', (testVector: typeof allVectors[number]) => { - const oprfID = testVector.suiteID - if (testVector.mode === Oprf.mode && oprfID in OprfID) { - describe(`${testVector.suiteName}/Mode${testVector.mode}`, () => { + const mode = testVector.mode as ModeID + const id = testVector.suiteID as SuiteID + + if (mode === Oprf.Mode.OPRF && Object.values(Oprf.Suite).includes(id)) { + const txtMode = Object.entries(Oprf.Mode)[mode as number][0] + const txtSuite = Object.entries(Oprf.Suite)[Object.values(Oprf.Suite).indexOf(id)][0] + + describe(`${txtMode}, ${txtSuite}`, () => { it('keygen', async () => { const seed = fromHex(testVector.seed) const info = fromHex(testVector.keyInfo) - const skSm = await derivePrivateKey(oprfID, seed, info) + const skSm = await derivePrivateKey(mode, id, seed, info) expect(toHex(skSm)).toBe(testVector.skSm) }) - const server = new OPRFServer(oprfID, fromHex(testVector.skSm)) - const client = new OPRFClient(oprfID) + const server = new OPRFServer(id, fromHex(testVector.skSm)) + const client = new OPRFClient(id) const { vectors } = testVector server.supportsWebCryptoOPRF = false @@ -40,20 +53,20 @@ describe.each(allVectors)('test-vectors', (testVector: typeof allVectors[number] // inject the blind value given by the test vector. jest.spyOn(OPRFClient.prototype, 'randomBlinder').mockImplementationOnce(() => { const blind = new Blind(fromHex(vi.Blind)) - const { gg } = Oprf.params(oprfID) + const gg = Oprf.getGroup(id) const scalar = gg.deserializeScalar(blind) return Promise.resolve({ scalar, blind }) }) const input = fromHex(vi.Input) - const { blind, blindedElement } = await client.blind(input) - expect(toHex(blind)).toEqual(vi.Blind) - expect(toHex(blindedElement)).toEqual(vi.BlindedElement) + const [finData, evalReq] = await client.blind(input) + expect(toHex(finData.blind)).toEqual(vi.Blind) + expect(toHex(evalReq.blinded)).toEqual(vi.BlindedElement) - const evaluation = await server.evaluate(blindedElement) - expect(toHex(evaluation)).toEqual(vi.EvaluationElement) + const evaluation = await server.evaluate(evalReq) + expect(toHex(evaluation.element)).toEqual(vi.EvaluationElement) - const output = await client.finalize(input, blind, evaluation) + const output = await client.finalize(finData, evaluation) expect(toHex(output)).toEqual(vi.Output) const serverCheckOutput = await server.verifyFinalize(input, output) From 198bbd4ae2f1e45603cc5f1aeccae65a99d9497b Mon Sep 17 00:00:00 2001 From: armfazh Date: Thu, 17 Mar 2022 17:32:43 -0700 Subject: [PATCH 2/4] Including verifiable modes. --- src/client.ts | 67 ++++++++++++++++++++++- src/oprf.ts | 34 ++++++------ src/server.ts | 127 ++++++++++++++++++++++++++++++++++--------- test/oprf.test.ts | 90 +++++++++++++++++++----------- test/vectors.test.ts | 117 +++++++++++++++++++++++++++++---------- 5 files changed, 330 insertions(+), 105 deletions(-) diff --git a/src/client.ts b/src/client.ts index c67650e..91a8dce 100644 --- a/src/client.ts +++ b/src/client.ts @@ -13,7 +13,7 @@ import { Oprf, SuiteID } from './oprf.js' -import { Group, Scalar } from './group.js' +import { Elt, Group, Scalar, SerializedElt } from './group.js' class baseClient extends Oprf { constructor(mode: ModeID, suite: SuiteID) { @@ -39,13 +39,17 @@ class baseClient extends Oprf { return [finData, evalReq] } - finalize(finData: FinalizeData, evaluation: Evaluation): Promise { + doFinalize( + finData: FinalizeData, + evaluation: Evaluation, + info = new Uint8Array(0) + ): Promise { const blindScalar = this.gg.deserializeScalar(finData.blind) const blindScalarInv = this.gg.invScalar(blindScalar) const Z = this.gg.deserialize(evaluation.element) const N = Group.mul(blindScalarInv, Z) const unblinded = this.gg.serialize(N) - return this.coreFinalize(finData.input, unblinded) + return this.coreFinalize(finData.input, unblinded, info) } } @@ -53,4 +57,61 @@ export class OPRFClient extends baseClient { constructor(suite: SuiteID) { super(Oprf.Mode.OPRF, suite) } + finalize(finData: FinalizeData, evaluation: Evaluation): Promise { + return super.doFinalize(finData, evaluation) + } +} + +export class VOPRFClient extends baseClient { + constructor(suite: SuiteID, private readonly pubKeyServer: Uint8Array) { + super(Oprf.Mode.VOPRF, suite) + } + + finalize(finData: FinalizeData, evaluation: Evaluation): Promise { + if (!evaluation.proof) { + throw new Error('no proof provided') + } + const pkS = this.gg.deserialize(new SerializedElt(this.pubKeyServer)) + const Q = this.gg.deserialize(finData.evalReq.blinded) + const kQ = this.gg.deserialize(evaluation.element) + if (!evaluation.proof.verify([this.gg.generator(), pkS], [Q, kQ])) { + throw new Error('proof failed') + } + + return super.doFinalize(finData, evaluation) + } +} + +export class POPRFClient extends baseClient { + constructor(suite: SuiteID, private readonly pubKeyServer: Uint8Array) { + super(Oprf.Mode.POPRF, suite) + } + + private async pointFromInfo(info: Uint8Array): Promise { + const m = await this.scalarFromInfo(info) + const T = this.gg.mulBase(m) + const pkS = this.gg.deserialize(new SerializedElt(this.pubKeyServer)) + const tw = Group.add(T, pkS) + if (tw.isIdentity) { + throw new Error('invalid info') + } + return tw + } + + async finalize( + finData: FinalizeData, + evaluation: Evaluation, + info = new Uint8Array(0) + ): Promise { + if (!evaluation.proof) { + throw new Error('no proof provided') + } + const tw = await this.pointFromInfo(info) + const Q = this.gg.deserialize(evaluation.element) + const kQ = this.gg.deserialize(finData.evalReq.blinded) + if (!evaluation.proof.verify([this.gg.generator(), tw], [Q, kQ])) { + throw new Error('proof failed') + } + return super.doFinalize(finData, evaluation, info) + } } diff --git a/src/oprf.ts b/src/oprf.ts index ebc922d..7501589 100644 --- a/src/oprf.ts +++ b/src/oprf.ts @@ -3,7 +3,7 @@ // Licensed under the BSD-3-Clause license found in the LICENSE file or // at https://opensource.org/licenses/BSD-3-Clause -import { Group, GroupID, SerializedElt, SerializedScalar } from './group.js' +import { Group, GroupID, Scalar, SerializedElt, SerializedScalar } from './group.js' import { joinAll, to16bits } from './util.js' import { DLEQProof } from './dleq.js' @@ -96,28 +96,33 @@ export abstract class Oprf { protected async coreFinalize( input: Uint8Array, - unblindedElement: Uint8Array, - info?: Uint8Array + element: Uint8Array, + info: Uint8Array ): Promise { let hasInfo: Uint8Array[] = [] if (this.mode === Oprf.Mode.POPRF) { - if (info) { - hasInfo = [to16bits(info.length), info] - } else { - hasInfo = [new Uint8Array(0)] - } + hasInfo = [to16bits(info.length), info] } const hashInput = joinAll([ to16bits(input.length), input, ...hasInfo, - to16bits(unblindedElement.length), - unblindedElement, + to16bits(element.length), + element, new TextEncoder().encode(Oprf.LABELS.FinalizeDST) ]) return new Uint8Array(await crypto.subtle.digest(this.hash, hashInput)) } + + protected scalarFromInfo(info: Uint8Array): Promise { + if (info.length >= 1 << 16) { + throw new Error('invalid info length') + } + const te = new TextEncoder() + const framedInfo = joinAll([te.encode('Info'), to16bits(info.length), info]) + return this.gg.hashToScalar(framedInfo, this.getDST(Oprf.LABELS.HashToScalarDST)) + } } export class Blind extends SerializedScalar { @@ -131,14 +136,11 @@ export class Evaluated extends SerializedElt { } export class Evaluation { - constructor( - public readonly element: Evaluated, - public readonly proof?: DLEQProof - ) { } + constructor(public readonly element: Evaluated, public readonly proof?: DLEQProof) {} } export class EvaluationRequest { - constructor(public readonly blinded: Blinded) { } + constructor(public readonly blinded: Blinded) {} } export class FinalizeData { @@ -146,5 +148,5 @@ export class FinalizeData { public readonly input: Uint8Array, public readonly blind: Blind, public readonly evalReq: EvaluationRequest - ) { } + ) {} } diff --git a/src/server.ts b/src/server.ts index fa55eff..f548696 100644 --- a/src/server.ts +++ b/src/server.ts @@ -4,12 +4,13 @@ // at https://opensource.org/licenses/BSD-3-Clause import { Blinded, Evaluated, Evaluation, EvaluationRequest, ModeID, Oprf, SuiteID } from './oprf.js' -import { Group, SerializedScalar } from './group.js' +import { Group, Scalar, SerializedScalar } from './group.js' +import { DLEQProver } from './dleq.js' import { ctEqual } from './util.js' class baseServer extends Oprf { - private privateKey: Uint8Array + protected privateKey: Uint8Array public supportsWebCryptoOPRF = false @@ -18,17 +19,17 @@ class baseServer extends Oprf { this.privateKey = privateKey } - evaluate(req: EvaluationRequest): Promise { + protected doEvaluation(bl: Blinded, key: Uint8Array): Promise { if (this.supportsWebCryptoOPRF) { - return this.evaluateWebCrypto(req) + return this.evaluateWebCrypto(bl, key) } - return Promise.resolve(this.evaluateSJCL(req)) + return Promise.resolve(this.evaluateSJCL(bl, key)) } - private async evaluateWebCrypto(req: EvaluationRequest): Promise { - const key = await crypto.subtle.importKey( + private async evaluateWebCrypto(bl: Blinded, key: Uint8Array): Promise { + const crKey = await crypto.subtle.importKey( 'raw', - this.privateKey, + key, { name: 'OPRF', namedCurve: this.gg.id @@ -37,41 +38,113 @@ class baseServer extends Oprf { ['sign'] ) // webcrypto accepts only compressed points. - let compressed = Uint8Array.from(req.blinded) - if (req.blinded[0] === 0x04) { - const P = this.gg.deserialize(req.blinded) + let compressed = Uint8Array.from(bl) + if (bl[0] === 0x04) { + const P = this.gg.deserialize(bl) compressed = Uint8Array.from(this.gg.serialize(P, true)) } - const evaluation = await crypto.subtle.sign('OPRF', key, compressed) - return new Evaluation(new Evaluated(evaluation)) + return new Evaluated(await crypto.subtle.sign('OPRF', crKey, compressed)) } - private evaluateSJCL(req: EvaluationRequest): Evaluation { - const P = this.gg.deserialize(req.blinded) - const serSk = new SerializedScalar(this.privateKey) - const sk = this.gg.deserializeScalar(serSk) + private evaluateSJCL(bl: Blinded, key: Uint8Array): Evaluated { + const P = this.gg.deserialize(bl) + const sk = this.gg.deserializeScalar(new SerializedScalar(key)) const Z = Group.mul(sk, P) - return new Evaluation(new Evaluated(this.gg.serialize(Z))) + return new Evaluated(this.gg.serialize(Z)) } - async fullEvaluate(input: Uint8Array): Promise { - const dst = this.getDST(Oprf.LABELS.HashToGroupDST) - const P = await this.gg.hashToGroup(input, dst) + protected async secretFromInfo(info: Uint8Array): Promise<[Scalar, Scalar]> { + const m = await this.scalarFromInfo(info) + const skS = this.gg.deserializeScalar(new SerializedScalar(this.privateKey)) + const t = this.gg.addScalar(m, skS) + if (this.gg.isScalarZero(t)) { + throw new Error('inverse of zero') + } + const tInv = this.gg.invScalar(t) + return [t, tInv] + } + + protected async doFullEvaluate( + input: Uint8Array, + info = new Uint8Array(0) + ): Promise { + let secret = this.privateKey + if (this.mode === Oprf.Mode.POPRF) { + const [, evalSecret] = await this.secretFromInfo(info) + secret = this.gg.serializeScalar(evalSecret) + } + + const P = await this.gg.hashToGroup(input, this.getDST(Oprf.LABELS.HashToGroupDST)) if (this.gg.isIdentity(P)) { throw new Error('InvalidInputError') } - const issuedElement = new EvaluationRequest(new Blinded(this.gg.serialize(P))) - const evaluation = await this.evaluate(issuedElement) - return this.coreFinalize(input, evaluation.element) + const blinded = new Blinded(this.gg.serialize(P)) + const evaluated = await this.doEvaluation(blinded, secret) + return this.coreFinalize(input, evaluated, info) } +} +export class OPRFServer extends baseServer { + constructor(suite: SuiteID, privateKey: Uint8Array) { + super(Oprf.Mode.OPRF, suite, privateKey) + } + + async evaluate(req: EvaluationRequest): Promise { + return new Evaluation(await this.doEvaluation(req.blinded, this.privateKey)) + } + async fullEvaluate(input: Uint8Array): Promise { + return this.doFullEvaluate(input) + } async verifyFinalize(input: Uint8Array, output: Uint8Array): Promise { - return ctEqual(output, await this.fullEvaluate(input)) + return ctEqual(output, await this.doFullEvaluate(input)) } } -export class OPRFServer extends baseServer { +export class VOPRFServer extends baseServer { constructor(suite: SuiteID, privateKey: Uint8Array) { - super(Oprf.Mode.OPRF, suite, privateKey) + super(Oprf.Mode.VOPRF, suite, privateKey) + } + async evaluate(req: EvaluationRequest): Promise { + const e = await this.doEvaluation(req.blinded, this.privateKey) + const prover = new DLEQProver({ gg: this.gg, hash: this.hash, dst: '' }) + const skS = this.gg.deserializeScalar(new SerializedScalar(this.privateKey)) + const pkS = this.gg.mulBase(skS) + const Q = this.gg.deserialize(req.blinded) + const kQ = this.gg.deserialize(e) + const proof = await prover.prove(skS, [this.gg.generator(), pkS], [Q, kQ]) + return new Evaluation(e, proof) + } + async fullEvaluate(input: Uint8Array): Promise { + return this.doFullEvaluate(input) + } + async verifyFinalize(input: Uint8Array, output: Uint8Array): Promise { + return ctEqual(output, await this.doFullEvaluate(input)) + } +} + +export class POPRFServer extends baseServer { + constructor(suite: SuiteID, privateKey: Uint8Array) { + super(Oprf.Mode.POPRF, suite, privateKey) + } + async evaluate(req: EvaluationRequest, info = new Uint8Array(0)): Promise { + const [keyProof, evalSecret] = await this.secretFromInfo(info) + const secret = this.gg.serializeScalar(evalSecret) + const e = await this.doEvaluation(req.blinded, secret) + const prover = new DLEQProver({ gg: this.gg, hash: this.hash, dst: '' }) + const kG = this.gg.mulBase(keyProof) + const Q = this.gg.deserialize(e) + const kQ = this.gg.deserialize(req.blinded) + const proof = await prover.prove(keyProof, [this.gg.generator(), kG], [Q, kQ]) + return new Evaluation(e, proof) + } + async fullEvaluate(input: Uint8Array, info = new Uint8Array(0)): Promise { + return this.doFullEvaluate(input, info) + } + async verifyFinalize( + input: Uint8Array, + output: Uint8Array, + info = new Uint8Array(0) + ): Promise { + return ctEqual(output, await this.doFullEvaluate(input, info)) } } diff --git a/test/oprf.test.ts b/test/oprf.test.ts index ae8f8f3..1a5dd41 100644 --- a/test/oprf.test.ts +++ b/test/oprf.test.ts @@ -3,41 +3,69 @@ // Licensed under the BSD-3-Clause license found in the LICENSE file or // at https://opensource.org/licenses/BSD-3-Clause -import { OPRFClient, OPRFServer, Oprf, randomPrivateKey } from '../src/index.js' +import { + OPRFClient, + OPRFServer, + Oprf, + POPRFClient, + POPRFServer, + VOPRFClient, + VOPRFServer, + generatePublicKey, + randomPrivateKey +} from '../src/index.js' -describe.each(Object.entries(Oprf.Suite))('oprf-workflow', (name, id) => { - it(`${name}`, async () => { - const te = new TextEncoder() - // ///////////////// - // Setup Server - // ///////////////// - const privateKey = await randomPrivateKey(id) - const server = new OPRFServer(id, privateKey) - // ///////////////// - // Setup Client - // ///////////////// - const client = new OPRFClient(id) - const input = te.encode('This is the client input') - // Client - const [finData, evalReq] = await client.blind(input) - // Client Server - // evalReq - // ------------------>> +describe.each(Object.entries(Oprf.Mode))('protocol', (modeName, mode) => { + describe.each(Object.entries(Oprf.Suite))(`${modeName}`, (suiteName, id) => { + let server: OPRFServer | VOPRFServer | POPRFServer + let client: OPRFClient | VOPRFClient | POPRFClient - // Server - const evaluation = await server.evaluate(evalReq) - // Client Server - // evaluation - // <<------------------ + beforeAll(async () => { + const privateKey = await randomPrivateKey(id) + const publicKey = generatePublicKey(id, privateKey) + switch (mode) { + case Oprf.Mode.OPRF: + server = new OPRFServer(id, privateKey) + client = new OPRFClient(id) + break - // Client - const output = await client.finalize(finData, evaluation) - expect(output).toHaveLength(Oprf.getOprfSize(id)) + case Oprf.Mode.VOPRF: + server = new VOPRFServer(id, privateKey) + client = new VOPRFClient(id, publicKey) + break + case Oprf.Mode.POPRF: + server = new POPRFServer(id, privateKey) + client = new POPRFClient(id, publicKey) + break + } + }) - const serverOutput = await server.fullEvaluate(input) - expect(output).toStrictEqual(serverOutput) + it(`${suiteName}`, async () => { + // Client Server + // ==================================================== + // Client + // blind, blindedElement = Blind(input) + const input = new TextEncoder().encode('This is the client input') + const [finData, evalReq] = await client.blind(input) + // evalReq + // ------------------>> + // Server + // evaluation = Evaluate(evalReq, info*) + const evaluation = await server.evaluate(evalReq) + // evaluation + // <<------------------ + // + // Client + // output = Finalize(finData, evaluation, info*) + // + const output = await client.finalize(finData, evaluation) + expect(output).toHaveLength(Oprf.getOprfSize(id)) - const success = await server.verifyFinalize(input, output) - expect(success).toBe(true) + const serverOutput = await server.fullEvaluate(input) + expect(output).toStrictEqual(serverOutput) + + const success = await server.verifyFinalize(input, output) + expect(success).toBe(true) + }) }) }) diff --git a/test/vectors.test.ts b/test/vectors.test.ts index 7f17a7a..0eaea82 100644 --- a/test/vectors.test.ts +++ b/test/vectors.test.ts @@ -9,8 +9,13 @@ import { OPRFClient, OPRFServer, Oprf, + POPRFClient, + POPRFServer, SuiteID, - derivePrivateKey + VOPRFClient, + VOPRFServer, + derivePrivateKey, + generatePublicKey } from '../src/index.js' import allVectors from './testdata/allVectors_v09.json' @@ -24,53 +29,109 @@ function toHex(x: Uint8Array): string { return Buffer.from(x).toString('hex') } +class wrapPOPRFServer extends POPRFServer { + info!: Uint8Array + + evaluate(...r: Parameters): ReturnType { + return super.evaluate(r[0], this.info) + } + async fullEvaluate(input: Uint8Array): Promise { + return super.fullEvaluate(input, this.info) + } + async verifyFinalize(input: Uint8Array, output: Uint8Array): Promise { + return super.verifyFinalize(input, output, this.info) + } +} + +class wrapPOPRFClient extends POPRFClient { + info!: Uint8Array + + blind(input: Uint8Array): ReturnType { + return super.blind(input) + } + finalize(...r: Parameters): ReturnType { + return super.finalize(...r, this.info) + } +} + // Test vectors from https://datatracker.ietf.org/doc/draft-irtf-cfrg-voprf -// https://tools.ietf.org/html/draft-irtf-cfrg-voprf-06 +// https://tools.ietf.org/html/draft-irtf-cfrg-voprf-09 describe.each(allVectors)('test-vectors', (testVector: typeof allVectors[number]) => { const mode = testVector.mode as ModeID const id = testVector.suiteID as SuiteID - if (mode === Oprf.Mode.OPRF && Object.values(Oprf.Suite).includes(id)) { + if (Object.values(Oprf.Suite).includes(id)) { const txtMode = Object.entries(Oprf.Mode)[mode as number][0] const txtSuite = Object.entries(Oprf.Suite)[Object.values(Oprf.Suite).indexOf(id)][0] describe(`${txtMode}, ${txtSuite}`, () => { - it('keygen', async () => { + let skSm: Uint8Array + let server: OPRFServer | VOPRFServer | wrapPOPRFServer + let client: OPRFClient | VOPRFClient | wrapPOPRFClient + + beforeAll(async () => { const seed = fromHex(testVector.seed) - const info = fromHex(testVector.keyInfo) - const skSm = await derivePrivateKey(mode, id, seed, info) + const keyInfo = fromHex(testVector.keyInfo) + skSm = await derivePrivateKey(mode, id, seed, keyInfo) + const pkSm = generatePublicKey(id, skSm) + switch (mode) { + case Oprf.Mode.OPRF: + server = new OPRFServer(id, skSm) + client = new OPRFClient(id) + break + + case Oprf.Mode.VOPRF: + server = new VOPRFServer(id, skSm) + client = new VOPRFClient(id, pkSm) + break + + case Oprf.Mode.POPRF: + server = new wrapPOPRFServer(id, skSm) + client = new wrapPOPRFClient(id, pkSm) + break + } + }) + + it('keygen', () => { expect(toHex(skSm)).toBe(testVector.skSm) }) - const server = new OPRFServer(id, fromHex(testVector.skSm)) - const client = new OPRFClient(id) const { vectors } = testVector - server.supportsWebCryptoOPRF = false + describe.each(vectors)('vec$#', (vi: typeof vectors[number]) => { + if (vi.Batch === 1) { + it('protocol', async () => { + // Creates a mock for randomBlinder method to + // inject the blind value given by the test vector. + for (const c of [OPRFClient, VOPRFClient, wrapPOPRFClient]) { + jest.spyOn(c.prototype, 'randomBlinder').mockImplementation(() => { + const blind = new Blind(fromHex(vi.Blind)) + const scalar = Oprf.getGroup(id).deserializeScalar(blind) + return Promise.resolve({ scalar, blind }) + }) + } - it.each(vectors)('vec$#', async (vi: typeof vectors[number]) => { - // Creates a mock for OPRFClient.randomBlinder method to - // inject the blind value given by the test vector. - jest.spyOn(OPRFClient.prototype, 'randomBlinder').mockImplementationOnce(() => { - const blind = new Blind(fromHex(vi.Blind)) - const gg = Oprf.getGroup(id) - const scalar = gg.deserializeScalar(blind) - return Promise.resolve({ scalar, blind }) - }) + if (testVector.mode === Oprf.Mode.POPRF) { + const info = fromHex((vi as any).Info as string) // eslint-disable-line @typescript-eslint/no-explicit-any + ;(server as wrapPOPRFServer).info = info + ;(client as wrapPOPRFClient).info = info + } - const input = fromHex(vi.Input) - const [finData, evalReq] = await client.blind(input) - expect(toHex(finData.blind)).toEqual(vi.Blind) - expect(toHex(evalReq.blinded)).toEqual(vi.BlindedElement) + const input = fromHex(vi.Input) + const [finData, evalReq] = await client.blind(input) + expect(toHex(finData.blind)).toEqual(vi.Blind) + expect(toHex(evalReq.blinded)).toEqual(vi.BlindedElement) - const evaluation = await server.evaluate(evalReq) - expect(toHex(evaluation.element)).toEqual(vi.EvaluationElement) + const evaluation = await server.evaluate(evalReq) + expect(toHex(evaluation.element)).toEqual(vi.EvaluationElement) - const output = await client.finalize(finData, evaluation) - expect(toHex(output)).toEqual(vi.Output) + const output = await client.finalize(finData, evaluation) + expect(toHex(output)).toEqual(vi.Output) - const serverCheckOutput = await server.verifyFinalize(input, output) - expect(serverCheckOutput).toBe(true) + const serverCheckOutput = await server.verifyFinalize(input, output) + expect(serverCheckOutput).toBe(true) + }) + } }) }) } From 8119b562d22252ac56fe01a502e9093b33c85687 Mon Sep 17 00:00:00 2001 From: Armando Faz Date: Thu, 24 Mar 2022 08:50:07 -0700 Subject: [PATCH 3/4] Apply Luke's suggestions Co-authored-by: Luke Valenta --- src/client.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/client.ts b/src/client.ts index 91a8dce..3f364c9 100644 --- a/src/client.ts +++ b/src/client.ts @@ -29,7 +29,7 @@ class baseClient extends Oprf { async blind(input: Uint8Array): Promise<[FinalizeData, EvaluationRequest]> { const { scalar, blind } = await this.randomBlinder() const dst = this.getDST(Oprf.LABELS.HashToGroupDST) - const P = await this.gg.hashToGroup(input, dst) + const P = await this.gg.hashToGroup(input, this.getDST(Oprf.LABELS.HashToGroupDST)) if (this.gg.isIdentity(P)) { throw new Error('InvalidInputError') } From bae590a6548f0743ee3f28c39c5e096a2dea31db Mon Sep 17 00:00:00 2001 From: Armando Faz Date: Thu, 24 Mar 2022 09:27:16 -0700 Subject: [PATCH 4/4] Remove dst variable Co-authored-by: Luke Valenta --- src/client.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/src/client.ts b/src/client.ts index 3f364c9..fcaed6b 100644 --- a/src/client.ts +++ b/src/client.ts @@ -28,7 +28,6 @@ class baseClient extends Oprf { async blind(input: Uint8Array): Promise<[FinalizeData, EvaluationRequest]> { const { scalar, blind } = await this.randomBlinder() - const dst = this.getDST(Oprf.LABELS.HashToGroupDST) const P = await this.gg.hashToGroup(input, this.getDST(Oprf.LABELS.HashToGroupDST)) if (this.gg.isIdentity(P)) { throw new Error('InvalidInputError')