From 9e1e876aa120fa7cef3403fdc333d1167b228cdc Mon Sep 17 00:00:00 2001 From: Fabian Boemer Date: Thu, 22 Aug 2024 17:57:29 -0700 Subject: [PATCH] Adds PrivateNearestNeighbhorsSearch Client (#72) --- Sources/HomomorphicEncryption/Array2d.swift | 18 ++ .../Bfv/Bfv+Decrypt.swift | 13 +- .../HomomorphicEncryption/Ciphertext.swift | 2 +- .../HomomorphicEncryption/CrtComposer.swift | 116 +++++++++++ Sources/HomomorphicEncryption/HeScheme.swift | 4 +- .../PolyRq/PolyContext.swift | 2 +- .../RnsBaseConverter.swift | 71 ++----- Sources/HomomorphicEncryption/RnsTool.swift | 13 +- Sources/HomomorphicEncryption/Scalar.swift | 4 +- .../Client.swift | 187 ++++++++++++++++++ .../Config.swift | 14 ++ .../PrivateNearestNeighborsSearch/Error.swift | 3 + .../PnnsProtocol.swift | 41 ++++ .../Array2dTests.swift | 29 +++ .../HomomorphicEncryptionTests/NttTests.swift | 4 +- .../RnsBaseConverterTests.swift | 5 +- .../RnsToolTests.swift | 15 +- .../ClientTests.swift | 164 +++++++++++++++ swift-homomorphic-encryption-protobuf | 2 +- 19 files changed, 615 insertions(+), 92 deletions(-) create mode 100644 Sources/HomomorphicEncryption/CrtComposer.swift create mode 100644 Sources/PrivateNearestNeighborsSearch/Client.swift create mode 100644 Sources/PrivateNearestNeighborsSearch/PnnsProtocol.swift create mode 100644 Tests/PrivateNearestNeighborsSearchTests/ClientTests.swift diff --git a/Sources/HomomorphicEncryption/Array2d.swift b/Sources/HomomorphicEncryption/Array2d.swift index f3f31f7b..806cd2f1 100644 --- a/Sources/HomomorphicEncryption/Array2d.swift +++ b/Sources/HomomorphicEncryption/Array2d.swift @@ -14,6 +14,7 @@ /// Stores values in a 2 dimensional array. public struct Array2d: Equatable, Sendable { + /// Values stored in row-major order. @usableFromInline package var data: [T] @usableFromInline package var rowCount: Int @usableFromInline package var columnCount: Int @@ -21,6 +22,11 @@ public struct Array2d: Equatable, @usableFromInline package var shape: (Int, Int) { (rowCount, columnCount) } @usableFromInline package var count: Int { rowCount * columnCount } + @inlinable + package init(data: [[T]]) { + self.init(data: data.flatMap { $0 }, rowCount: data.count, columnCount: data[0].count) + } + @inlinable package init(data: [T], rowCount: Int, columnCount: Int) { precondition(data.count == rowCount * columnCount) @@ -175,4 +181,16 @@ extension Array2d { HomomorphicEncryption.zeroize(dataPointer.baseAddress!, zeroizeSize) } } + + /// Returns the matrix after transforming each entry with a function. + /// - Parameter transform: A mapping closure. `transform` accepts an element of the array as its parameter and + /// returns a transformed value of the same or of a different type. + /// - Returns: The transformed matrix. + @inlinable + package func map(_ transform: (T) -> (V)) -> Array2d { + Array2d( + data: data.map { value in transform(value) }, + rowCount: rowCount, + columnCount: columnCount) + } } diff --git a/Sources/HomomorphicEncryption/Bfv/Bfv+Decrypt.swift b/Sources/HomomorphicEncryption/Bfv/Bfv+Decrypt.swift index 546dc98b..676de575 100644 --- a/Sources/HomomorphicEncryption/Bfv/Bfv+Decrypt.swift +++ b/Sources/HomomorphicEncryption/Bfv/Bfv+Decrypt.swift @@ -46,16 +46,14 @@ extension Bfv { { // See Definition 1 of // https://www.microsoft.com/en-us/research/wp-content/uploads/2017/06/sealmanual_v2.2.pdf. - precondition(variableTime) var vTimesT = try Self.dotProduct(ciphertext: ciphertext, with: secretKey) vTimesT *= Array(repeating: ciphertext.context.plaintextModulus, count: vTimesT.moduli.count) let rnsTool = ciphertext.context.getRnsTool(moduliCount: vTimesT.moduli.count) - func computeNoiseBudget(of _: PolyRq, _: U.Type) throws -> Double { - let vTimesTComposed: [U] = try rnsTool.crtCompose( - poly: vTimesT, - variableTime: variableTime) - + func computeNoiseBudget(of _: PolyRq, + _: U.Type) throws -> Double + { + let vTimesTComposed: [U] = try rnsTool.crtCompose(poly: vTimesT) let q: U = vTimesT.moduli.product() let qDiv2 = (q &+ 1) &>> 1 let noiseInfinityNorm = Double(vTimesTComposed.map { coeff in @@ -78,10 +76,13 @@ extension Bfv { case 0...self) case tMax...self) default: preconditionFailure("crtMaxIntermediateValue \(crtMaxIntermediateValue) too large") diff --git a/Sources/HomomorphicEncryption/Ciphertext.swift b/Sources/HomomorphicEncryption/Ciphertext.swift index dded0f85..1fc91612 100644 --- a/Sources/HomomorphicEncryption/Ciphertext.swift +++ b/Sources/HomomorphicEncryption/Ciphertext.swift @@ -292,7 +292,7 @@ public struct Ciphertext: Equatable, Senda /// ``HeScheme/minNoiseBudget``, decryption may yield inaccurate plaintexts. /// - Parameters: /// - secretKey: Secret key. - /// - variableTime: Must be `true`, indicating the secret key coefficients are leaked through timing. + /// - variableTime: If `true`, indicates the secret key coefficients may be leaked through timing. /// - Returns: The noise budget. /// - Throws: Error upon failure to compute the noise budget. /// - Warning: Leaks `secretKey` through timing. Should be used for testing only. diff --git a/Sources/HomomorphicEncryption/CrtComposer.swift b/Sources/HomomorphicEncryption/CrtComposer.swift new file mode 100644 index 00000000..a2af850e --- /dev/null +++ b/Sources/HomomorphicEncryption/CrtComposer.swift @@ -0,0 +1,116 @@ +// Copyright 2024 Apple Inc. and the Swift Homomorphic Encryption project authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/// Performs Chinese remainder theorem (CRT) composition of coefficients. +@usableFromInline +package struct CrtComposer: Sendable { + /// Context for the CRT moduli `q_i`. + @usableFromInline let polyContext: PolyContext + + /// i'th entry stores `(q_i / q) % q_i`. + @usableFromInline let inversePuncturedProducts: [MultiplyConstantModulus] + + /// Creates a new ``CrtComposer``. + /// - Parameter polyContext: Context for the CRT moduli. + /// - Throws: Error upon failure to create a new ``CrtComposer``. + @inlinable + package init(polyContext: PolyContext) throws { + self.polyContext = polyContext + self.inversePuncturedProducts = try polyContext.reduceModuli.map { qi in + var puncturedProduct = T(1) + for qj in polyContext.moduli where qj != qi.modulus { + let prod = puncturedProduct.multipliedFullWidth(by: qj) + puncturedProduct = qi.reduce(T.DoubleWidth(prod)) + } + let inversePuncturedProduct = try puncturedProduct.inverseMod( + modulus: qi.modulus, + variableTime: true) + return MultiplyConstantModulus( + multiplicand: inversePuncturedProduct, + modulus: qi.modulus, + variableTime: true) + } + } + + /// Returns an upper bound on the maximum value during a `crtCompose` call. + /// - Parameter moduli: Moduli in the polynomial context + /// - Returns: The upper bound. + @inlinable + package static func composeMaxIntermediateValue(moduli: [T]) -> Double { + let moduli = moduli.map { Double($0) } + if moduli.count == 1 { + return moduli[0] + } + let q = moduli.reduce(1.0, *) + return 2.0 * q + } + + /// Performs Chinese remainder theorem (CRT) composition on a list of + /// coefficients. + /// + /// The composition yields a polynomial with coefficients in `[0, q - 1]`. + /// - Parameter data:Data to compose. Each column must contain a + /// coefficient's residues mod each modulus. + /// - Returns: The composed coefficients. Each coefficient must be able to + /// store values up to + /// `crtComposeMaxIntermediateValue`. + /// - Throws: `HeError` upon failure to compose the polynomial. + /// - Warning: `V`'s operations must be constant time to prevent leaking + /// `poly` through timing. + @inlinable + package func compose(data: Array2d) throws -> [V] + { + precondition(data.rowCount == polyContext.moduli.count) + precondition(Double(V.max) >= Self + .composeMaxIntermediateValue(moduli: polyContext.moduli)) + let q: V = polyContext.moduli.product() + let puncturedProducts = polyContext.moduli.map { qi in q / V(qi) } + + var products: [V] = Array(repeating: 0, count: data.columnCount) + for row in 0..(poly: PolyRq< + T, + Coeff + >) throws -> [V] { + try compose(data: poly.data) + } +} diff --git a/Sources/HomomorphicEncryption/HeScheme.swift b/Sources/HomomorphicEncryption/HeScheme.swift index ec68040d..3384be68 100644 --- a/Sources/HomomorphicEncryption/HeScheme.swift +++ b/Sources/HomomorphicEncryption/HeScheme.swift @@ -636,7 +636,7 @@ public protocol HeScheme { /// - Parameters: /// - ciphertext: Ciphertext whose noise budget to compute. /// - secretKey: Secret key. - /// - variableTime: Must be `true`, indicating the secret key coefficients are leaked through timing. + /// - variableTime: If `true`, indicates the secret key coefficients may be leaked through timing. /// - Returns: The noise budget. /// - Throws: Error upon failure to compute the noise budget. /// - Warning: Leaks `secretKey` through timing. Should be used for testing only. @@ -651,7 +651,7 @@ public protocol HeScheme { /// - Parameters: /// - ciphertext: Ciphertext whose noise budget to compute. /// - secretKey: Secret key. - /// - variableTime: Must be `true`, indicating the secret key coefficients are leaked through timing. + /// - variableTime: If `true`, indicates the secret key coefficients may be leaked through timing. /// - Returns: The noise budget. /// - Throws: Error upon failure to compute the noise budget. /// - Warning: Leaks `secretKey` through timing. Should be used for testing only. diff --git a/Sources/HomomorphicEncryption/PolyRq/PolyContext.swift b/Sources/HomomorphicEncryption/PolyRq/PolyContext.swift index 62d7d5d8..08dc22d4 100644 --- a/Sources/HomomorphicEncryption/PolyRq/PolyContext.swift +++ b/Sources/HomomorphicEncryption/PolyRq/PolyContext.swift @@ -18,7 +18,7 @@ public final class PolyContext: Sendable { /// Number `N` of coefficients in the polynomial, must be a power of two. @usableFromInline let degree: Int /// CRT-representation of the modulus `Q = product_{i=0}^{L-1} q_i`. - @usableFromInline let moduli: [T] + @usableFromInline package let moduli: [T] /// Next context, typically formed by dropping `q_{L-1}`. @usableFromInline let next: PolyContext? /// Operations mod `q_0` up to `q_{L-1}`. diff --git a/Sources/HomomorphicEncryption/RnsBaseConverter.swift b/Sources/HomomorphicEncryption/RnsBaseConverter.swift index d9a7cb51..c0c71f03 100644 --- a/Sources/HomomorphicEncryption/RnsBaseConverter.swift +++ b/Sources/HomomorphicEncryption/RnsBaseConverter.swift @@ -22,8 +22,14 @@ struct RnsBaseConverter: Sendable { @usableFromInline let outputContext: PolyContext /// (i, j)'th entry stores `(q / q_i) % t_j`. @usableFromInline let puncturedProducts: Array2d - /// i'th entry stores `(q_i / q) % q_i``. - @usableFromInline let inversePuncturedProducts: [MultiplyConstantModulus] + + /// Composes polynomials with `inputContext`. + @usableFromInline let crtComposer: CrtComposer + + /// i'th entry stores `(q_i / q) % q_i`. + @usableFromInline var inversePuncturedProducts: [MultiplyConstantModulus] { + crtComposer.inversePuncturedProducts + } @inlinable init(from inputContext: PolyContext, to outputContext: PolyContext) throws { @@ -46,23 +52,12 @@ struct RnsBaseConverter: Sendable { rowCount: outputContext.moduli.count, columnCount: inputContext.moduli.count) - self.inversePuncturedProducts = try inputContext.reduceModuli.map { qi in - var puncturedProduct = T(1) - for qj in inputContext.moduli where qj != qi.modulus { - let prod = puncturedProduct.multipliedFullWidth(by: qj) - puncturedProduct = qi.reduce(T.DoubleWidth(prod)) - } - let inversePuncturedProduct = try puncturedProduct.inverseMod(modulus: qi.modulus, variableTime: true) - return MultiplyConstantModulus( - multiplicand: inversePuncturedProduct, - modulus: qi.modulus, - variableTime: true) - } + self.crtComposer = try CrtComposer(polyContext: inputContext) } /// Performs approximate base conversion. /// - /// Converts input polynomial with coefficients `x_i mod q` to `(x_i + a_x * q) % t` where `a_x \in [0, L-1]`, for + /// Converts input polynomial with coefficients `x_i mod q` to `(x_i + a_x * q) % t` where `a_x \in [0, L - 1]`, for /// `L` the number of moduli in the input basis `q. /// - Parameter poly: Input polynomial with base `q`. /// - Returns: Converted polynomial with base `t`. @@ -76,55 +71,17 @@ struct RnsBaseConverter: Sendable { return convertApproximate(using: poly) } - /// Returns an upper bound on the maximum value during a `crtCompose` call. - @inlinable - func crtComposeMaxIntermediateValue() -> Double { - let moduli = inputContext.moduli.map { Double($0) } - if moduli.count == 1 { - return moduli[0] - } - let q = moduli.reduce(1.0, *) - return 2.0 * q - } - /// Performs Chinese remainder theorem (CRT) composition of each coefficient in `poly`. /// /// The composition yields a polynomial with coefficients in `[0, q - 1]`. - /// - Parameters: - /// - poly: Polynomial to compose. - /// - variableTime: Must be `true`, indicating `poly`'s coefficients are leaked through timing. + /// - Parameter poly: Polynomial to compose. /// - Returns: The coefficients in the composed polynomial. Each coefficient must be able to store values up to /// `crtComposeMaxIntermediateValue`. /// - Throws: `HeError` upon failure to compose the polynomial. - /// - Warning: Leaks `poly` through timing. + /// - Warning: `V`'s operations must be constant time to prevent leaking `poly` through timing. @inlinable - func crtCompose(poly: PolyRq, variableTime: Bool) throws -> [V] { - precondition(variableTime) - precondition(Double(V.max) >= crtComposeMaxIntermediateValue()) - guard poly.context == inputContext else { - throw HeError.invalidPolyContext(poly.context) - } - if inputContext.moduli.count == 1 { - return poly.poly(rnsIndex: 0).map { V($0) } - } - let q: V = inputContext.moduli.product() - let puncturedProducts = inputContext.moduli.map { qi in - q / V(qi) - } - return poly.coeffIndices.map { coeffIndex in - var product: V = 0 - for (rnsCoeff, (puncturedProduct, inversePuncturedProduct)) in zip( - poly.coefficient(coeffIndex: coeffIndex), - zip(puncturedProducts, inversePuncturedProducts)) - { - let tmp = inversePuncturedProduct.multiplyMod(rnsCoeff) - product &+= V(tmp) &* puncturedProduct - if product >= q { - product &-= q - } - } - return product - } + func crtCompose(poly: PolyRq) throws -> [V] { + try crtComposer.compose(poly: poly) } /// Computes approximate products. diff --git a/Sources/HomomorphicEncryption/RnsTool.swift b/Sources/HomomorphicEncryption/RnsTool.swift index ef863b0a..121072f2 100644 --- a/Sources/HomomorphicEncryption/RnsTool.swift +++ b/Sources/HomomorphicEncryption/RnsTool.swift @@ -13,7 +13,7 @@ // limitations under the License. @usableFromInline -struct RnsTool: Sendable { +package struct RnsTool: Sendable { /// `Q = q_0, ..., q_{L-1}`. @usableFromInline let inputContext: PolyContext /// `t_0, ..., t_{M-1}`. @@ -396,17 +396,16 @@ struct RnsTool: Sendable { /// - poly: Polynomial whose coefficients to compose. /// - variableTime: Must be `true`, indicating the coefficients of the polynomial are leaked through timing. /// - Returns: The coefficients of `poly`, each in `[0, Q - 1]`. - /// - Warning: Leaks `poly` through timing. + /// - Warning: `V`'s operations must be constant time to prevent leaking `poly` through timing. @inlinable - func crtCompose(poly: PolyRq, variableTime: Bool) throws -> [V] { - precondition(variableTime) + package func crtCompose(poly: PolyRq) throws -> [V] { // Use arbitrary base converter that has same inputContext - return try rnsConvertQToBSk.crtCompose(poly: poly, variableTime: variableTime) + try rnsConvertQToBSk.crtCompose(poly: poly) } /// Returns an upper bound on the maximum value during a `crtCompose` call. @inlinable - func crtComposeMaxIntermediateValue() -> Double { - rnsConvertQToBSk.crtComposeMaxIntermediateValue() + package func crtComposeMaxIntermediateValue() -> Double { + CrtComposer.composeMaxIntermediateValue(moduli: inputContext.moduli) } } diff --git a/Sources/HomomorphicEncryption/Scalar.swift b/Sources/HomomorphicEncryption/Scalar.swift index 468c7fe0..baf40216 100644 --- a/Sources/HomomorphicEncryption/Scalar.swift +++ b/Sources/HomomorphicEncryption/Scalar.swift @@ -206,7 +206,7 @@ extension FixedWidthInteger { } } -extension ScalarType { +extension UnsignedInteger where Self: FixedWidthInteger { /// Computes the high `Self.bitWidth` bits of `self * rhs`. /// - Parameter rhs: Multiplicand. /// - Returns: the high `Self.bitWidth` bits of `self * rhs`. @@ -269,7 +269,9 @@ extension ScalarType { let sum = self &+ modulus &- rhs return sum.subtractIfExceeds(modulus) } +} +extension ScalarType { /// Computes modular exponentiation. /// /// Computes self raised to the power of `exponent` mod `modulus, i.e., `self^exponent mod modulus`. diff --git a/Sources/PrivateNearestNeighborsSearch/Client.swift b/Sources/PrivateNearestNeighborsSearch/Client.swift new file mode 100644 index 00000000..a0be7cfb --- /dev/null +++ b/Sources/PrivateNearestNeighborsSearch/Client.swift @@ -0,0 +1,187 @@ +// Copyright 2024 Apple Inc. and the Swift Homomorphic Encryption project authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Algorithms +import Foundation +import HomomorphicEncryption + +/// Private nearest neighbors client. +struct Client { + /// Configuration. + let config: ClientConfig + + /// One context per plaintext modulus. + let contexts: [Context] + + /// Performs composition of the plaintext CRT responses. + let crtComposer: CrtComposer + + /// Context for the plaintext CRT moduli. + let plaintextContext: PolyContext + + var evaluationKeyConfiguration: HomomorphicEncryption.EvaluationKeyConfiguration { + config.evaluationKeyConfig + } + + /// Creates a new ``Client``. + /// - Parameter config: Client configuration. + /// - Throws: Error upon failure to create a new client. + @inlinable + init(config: ClientConfig) throws { + guard config.distanceMetric == .cosineSimilarity else { + throw PnnsError.wrongDistanceMetric(got: config.distanceMetric, expected: .cosineSimilarity) + } + self.config = config + let extraEncryptionParams = try config.extraPlaintextModuli.map { plaintextModulus in + try EncryptionParameters( + polyDegree: config.encryptionParams.polyDegree, + plaintextModulus: plaintextModulus, + coefficientModuli: config.encryptionParams.coefficientModuli, + errorStdDev: config.encryptionParams.errorStdDev, + securityLevel: config.encryptionParams.securityLevel) + } + let encryptionParams = [config.encryptionParams] + extraEncryptionParams + self.contexts = try encryptionParams.map { encryptionParams in + try Context(encryptionParameters: encryptionParams) + } + self.plaintextContext = try PolyContext( + degree: config.encryptionParams.polyDegree, + moduli: config.plaintextModuli) + self.crtComposer = try CrtComposer(polyContext: plaintextContext) + } + + /// Generates a nearest neighbor search query. + /// - Parameters: + /// - vectors: Vectors. + /// - secretKey: Secret key to encrypt with. + /// - Returns: The query. + /// - Throws: Error upon failure to generate the query. + @inlinable + func generateQuery(vectors: Array2d, using secretKey: SecretKey) throws -> Query { + let scaledVectors: Array2d = vectors.normalizedRows(norm: Array2d.Norm.Lp(p: 2.0)) + .scaled(by: Float(config.scalingFactor)).rounded() + let dimensions = try MatrixDimensions(rowCount: vectors.rowCount, columnCount: vectors.columnCount) + + let matrices = try contexts.map { context in + // For a single plaintext modulus, reduction isn't necessary + let shouldReduce = contexts.count > 1 + let plaintextMatrix = try PlaintextMatrix( + context: context, + dimensions: dimensions, + packing: config.queryPacking, + signedValues: scaledVectors.data, + reduce: shouldReduce) + return try plaintextMatrix.encrypt(using: secretKey).convertToCoeffFormat() + } + return Query(ciphertextMatrices: matrices) + } + + /// Decrypts a nearest neighbors search response. + /// - Parameters: + /// - response: The response. + /// - secretKey: Secret key to decrypt with. + /// - Returns: The distances from the query vectors to the database rows. + /// - Throws: Error upon failure to decrypt the response. + @inlinable + func decrypt(response: Response, using secretKey: SecretKey) throws -> DatabaseDistances { + guard let dimensions = response.ciphertextMatrices.first?.dimensions else { + throw PnnsError.emptyCiphertextArray + } + let decoded: [[Scheme.Scalar]] = try response.ciphertextMatrices.map { ciphertextMatrix in + try ciphertextMatrix.decrypt(using: secretKey).unpack() + } + // CRT-decomposed scores + let values = Array2d(data: decoded) + // Plaintext CRT modulus must be < `UInt64.max` + let composedDistances: [UInt64] = try crtComposer.compose(data: values) + + let modulus: UInt64 = plaintextContext.moduli.product() + // Encrypted distances are scaled by config.scalingFactor^2, so we undo the scaling here. + let distanceValues = composedDistances.map { unsigned in + let signed = unsigned.remainderToCentered(modulus: modulus) + return Float(signed) / (Float(config.scalingFactor) * Float(config.scalingFactor)) + } + + let distances = Array2d( + data: distanceValues, + rowCount: dimensions.rowCount, + columnCount: dimensions.columnCount) + return DatabaseDistances( + distances: distances, + entryIDs: response.entryIDs, + entryMetadatas: response.entryMetadatas) + } + + /// Generates an ``EvaluationKey`` for use in nearest neighbors search. + /// - Parameter secretKey: Secret key used to generate the evaluation key. + /// - Returns: The evaluation key. + /// - Throws: Error upon failure to generate the evaluation key. + /// - Warning: Uses the first context to generate the evaluation key. So either the HE scheme should generate + /// evaluation keys independent of the plaintext modulus (as in BFV), or there should be just one plaintext modulus. + @inlinable + func generateEvaluationKey(using secretKey: SecretKey) throws -> EvaluationKey { + try contexts[0].generateEvaluationKey(configuration: evaluationKeyConfiguration, using: secretKey) + } +} + +extension Array2d where T == Float { + /// A mapping from vectors to non-negative real numbers. + @usableFromInline + enum Norm { + case Lp(p: Float) // sum_i (|x_i|^p)^{1/p} + } + + /// Normalizes each row in the matrix. + @inlinable + func normalizedRows(norm: Norm) -> Array2d { + switch norm { + case let Norm.Lp(p): + let normalizedValues = data.chunks(ofCount: columnCount).flatMap { row in + let sumOfPowers = row.map { pow($0, p) }.reduce(0, +) + let norm = pow(sumOfPowers, 1 / p) + return row.map { value in + if sumOfPowers.isZero { + Float.zero + } else { + value / norm + } + } + } + return Array2d( + data: normalizedValues, + rowCount: rowCount, + columnCount: columnCount) + } + } + + /// Returns the matrix where each entry is rounded to the closest integer. + @inlinable + func rounded() -> Array2d { + Array2d( + data: data.map { value in V(value.rounded()) }, + rowCount: rowCount, + columnCount: columnCount) + } + + /// Returns the matrix where each entry has been multiplied by a scaling factor. + /// - Parameter scalingFactor: The factor to multiply each entry by. + /// - Returns: The scaled matrix. + @inlinable + func scaled(by scalingFactor: Float) -> Array2d { + Array2d( + data: data.map { value in value * scalingFactor }, + rowCount: rowCount, + columnCount: columnCount) + } +} diff --git a/Sources/PrivateNearestNeighborsSearch/Config.swift b/Sources/PrivateNearestNeighborsSearch/Config.swift index 1055ff51..a16fb5ed 100644 --- a/Sources/PrivateNearestNeighborsSearch/Config.swift +++ b/Sources/PrivateNearestNeighborsSearch/Config.swift @@ -39,6 +39,11 @@ public struct ClientConfig: Codable, Equatable, Hashable, Send /// The first plaintext modulus will be the one in ``ClientConfig/encryptionParams``. public let extraPlaintextModuli: [Scheme.Scalar] + /// The plaintext CRT moduli. + var plaintextModuli: [Scheme.Scalar] { + [encryptionParams.plaintextModulus] + extraPlaintextModuli + } + /// Creates a new ``ClientConfig``. /// - Parameters: /// - encryptionParams: Encryption parameters. @@ -66,6 +71,15 @@ public struct ClientConfig: Codable, Equatable, Hashable, Send self.distanceMetric = distanceMetric self.extraPlaintextModuli = extraPlaintextModuli } + + static func maxScalingFactor(vectorDimension: Int, distanceMetric: DistanceMetric, + plaintextModuli: [Scheme.Scalar]) -> Int + { + precondition(distanceMetric == .cosineSimilarity) + let t = plaintextModuli.map { Float($0) }.reduce(1, *) + let scalingFactor = (((t - 1) / 2).squareRoot() - Float(vectorDimension).squareRoot() / 2).rounded(.down) + return Int(scalingFactor) + } } /// Server configuration. diff --git a/Sources/PrivateNearestNeighborsSearch/Error.swift b/Sources/PrivateNearestNeighborsSearch/Error.swift index 5e3eca0e..8e243509 100644 --- a/Sources/PrivateNearestNeighborsSearch/Error.swift +++ b/Sources/PrivateNearestNeighborsSearch/Error.swift @@ -23,6 +23,7 @@ public enum PnnsError: Error, Equatable { case simdEncodingNotSupported(_ description: String) case wrongCiphertextCount(got: Int, expected: Int) case wrongContext(gotDescription: String, expectedDescription: String) + case wrongDistanceMetric(got: DistanceMetric, expected: DistanceMetric) case wrongEncodingValuesCount(got: Int, expected: Int) case wrongMatrixPacking(got: MatrixPacking, expected: MatrixPacking) case wrongPlaintextCount(got: Int, expected: Int) @@ -55,6 +56,8 @@ extension PnnsError: LocalizedError { "Wrong ciphertext count \(got), expected \(expected)" case let .wrongContext(gotDescription, expectedDescription): "Wrong context: got \(gotDescription), expected \(expectedDescription)" + case let .wrongDistanceMetric(got, expected): + "Wrong distance metric: got \(got), expected \(expected)" case let .wrongEncodingValuesCount(got, expected): "Wrong encoding values count \(got), expected \(expected)" case let .wrongMatrixPacking(got: got, expected: expected): diff --git a/Sources/PrivateNearestNeighborsSearch/PnnsProtocol.swift b/Sources/PrivateNearestNeighborsSearch/PnnsProtocol.swift new file mode 100644 index 00000000..f00e13df --- /dev/null +++ b/Sources/PrivateNearestNeighborsSearch/PnnsProtocol.swift @@ -0,0 +1,41 @@ +// Copyright 2024 Apple Inc. and the Swift Homomorphic Encryption project authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import HomomorphicEncryption + +/// A nearest neighbor search query. +struct Query: Sendable { + // Encrypted query; one matrix per plaintext CRT modulus + let ciphertextMatrices: [CiphertextMatrix] +} + +/// A nearest neighbor search response. +struct Response: Sendable { + // Encrypted response; one matrix per plaintext CRT modulus + let ciphertextMatrices: [CiphertextMatrix] + // A list of entry identifiers the server computed similarities for + let entryIDs: [UInt64] + // Metadata for each entry in the database + let entryMetadatas: [[UInt8]] +} + +/// Distances from one or more query vector to the database rows. +struct DatabaseDistances: Sendable { + /// The distance from each query vector (outer dimension) to each database row (inner dimension). + let distances: Array2d + // Identifier for each entry in the database. + let entryIDs: [UInt64] + // Metadata for each entry in the database. + let entryMetadatas: [[UInt8]] +} diff --git a/Tests/HomomorphicEncryptionTests/Array2dTests.swift b/Tests/HomomorphicEncryptionTests/Array2dTests.swift index 59528a60..7cfadb5b 100644 --- a/Tests/HomomorphicEncryptionTests/Array2dTests.swift +++ b/Tests/HomomorphicEncryptionTests/Array2dTests.swift @@ -16,6 +16,23 @@ import XCTest class Array2dTests: XCTestCase { + func testInit() { + func runTest(_: T.Type) { + let data = [T](1...6) + let array = Array2d(data: data, rowCount: 3, columnCount: 2) + + let data2d: [[T]] = [[1, 2], [3, 4], [5, 6]] + XCTAssertEqual(array, Array2d(data: data2d)) + } + + runTest(Int.self) + runTest(Int32.self) + runTest(Int32.self) + runTest(Int64.self) + runTest(UInt64.self) + runTest(DWUInt128.self) + } + func testZeroAndZeroize() { func runTest(_: T.Type) { let data = [T](1...16) @@ -113,4 +130,16 @@ class Array2dTests: XCTestCase { array.append(rows: Array(40..<56)) XCTAssertEqual(array, Array2d(data: [Int](0..<56), rowCount: 7, columnCount: 8)) } + + func testMap() { + let data = [Int](0..<32) + let array = Array2d(data: data, rowCount: 4, columnCount: 8) + + let arrayPlus1 = array.map { UInt($0) + 1 } + let expected = Array2d(data: [UInt](1..<33), rowCount: 4, columnCount: 8) + XCTAssertEqual(arrayPlus1, expected) + + let roundtripArray = arrayPlus1.map { Int($0 - 1) } + XCTAssertEqual(roundtripArray, array) + } } diff --git a/Tests/HomomorphicEncryptionTests/NttTests.swift b/Tests/HomomorphicEncryptionTests/NttTests.swift index 01b80204..0122fd6c 100644 --- a/Tests/HomomorphicEncryptionTests/NttTests.swift +++ b/Tests/HomomorphicEncryptionTests/NttTests.swift @@ -52,8 +52,8 @@ final class NttTests: XCTestCase { precondition(evalData.count == rowCount) precondition(evalData[0].count == columnCount) - let coeffData = Array2d(data: coeffData.flatMap { $0 }, rowCount: rowCount, columnCount: columnCount) - let evalData = Array2d(data: evalData.flatMap { $0 }, rowCount: rowCount, columnCount: columnCount) + let coeffData = Array2d(data: coeffData) + let evalData = Array2d(data: evalData) let context = try PolyContext(degree: columnCount, moduli: moduli) let polyCoeff = PolyRq(context: context, data: coeffData) diff --git a/Tests/HomomorphicEncryptionTests/RnsBaseConverterTests.swift b/Tests/HomomorphicEncryptionTests/RnsBaseConverterTests.swift index 19c1c017..0ea27e89 100644 --- a/Tests/HomomorphicEncryptionTests/RnsBaseConverterTests.swift +++ b/Tests/HomomorphicEncryptionTests/RnsBaseConverterTests.swift @@ -35,8 +35,7 @@ final class RnsBaseConverterTests: XCTestCase { let referenceX = (0..(from: inputContext, to: outputContext) let data = referenceX.map { bigInt in TestUtils.crtDecompose(value: bigInt, moduli: inputContext.moduli) } - let inputData = Array2d(data: data.flatMap { $0 }, rowCount: degree, columnCount: inputContext.moduli.count) - .transposed() + let inputData = Array2d(data: data).transposed() let input = PolyRq(context: inputContext, data: inputData) let output = try rnsBaseConverter.convertApproximate(poly: input) @@ -80,7 +79,7 @@ final class RnsBaseConverterTests: XCTestCase { let poly: PolyRq = PolyRq.random(context: inputContext) let rnsBaseConverter = try RnsBaseConverter(from: inputContext, to: outputContext) - let composed: [QuadWidth] = try rnsBaseConverter.crtCompose(poly: poly, variableTime: true) + let composed: [QuadWidth] = try rnsBaseConverter.crtCompose(poly: poly) for (coeffIndex, composed) in composed.enumerated() { let roundTripValues = TestUtils.crtDecompose(value: composed, moduli: inputContext.moduli) diff --git a/Tests/HomomorphicEncryptionTests/RnsToolTests.swift b/Tests/HomomorphicEncryptionTests/RnsToolTests.swift index 3f3c7b59..71802057 100644 --- a/Tests/HomomorphicEncryptionTests/RnsToolTests.swift +++ b/Tests/HomomorphicEncryptionTests/RnsToolTests.swift @@ -82,8 +82,7 @@ final class RnsToolTests: XCTestCase { let referenceX = (0...random(in: 0..(context: inputContext, data: inputData) let output = try rnsTool.convertApproximateBskMTilde(poly: input) @@ -177,8 +176,7 @@ final class RnsToolTests: XCTestCase { let referenceX = (0...random(in: 0.. = PolyRq(context: inputContext, data: inputData) let output = try rnsTool.liftQToQBsk(poly: input) @@ -232,11 +230,7 @@ final class RnsToolTests: XCTestCase { value: bigInt, moduli: rnsTool.qBskContext.moduli) } - let inputData = Array2d( - data: data.flatMap { $0 }, - rowCount: degree, - columnCount: rnsTool.qBskContext.moduli.count) - .transposed() + let inputData = Array2d(data: data).transposed() let input: PolyRq = PolyRq(context: rnsTool.qBskContext, data: inputData) let output = try rnsTool.approximateFloor(poly: input) @@ -284,8 +278,7 @@ final class RnsToolTests: XCTestCase { let referenceX = (0...random(in: 0..(context: bskContext, data: inputData) let output = try rnsTool.convertApproximateBskToQ(poly: input) diff --git a/Tests/PrivateNearestNeighborsSearchTests/ClientTests.swift b/Tests/PrivateNearestNeighborsSearchTests/ClientTests.swift new file mode 100644 index 00000000..8350bd5d --- /dev/null +++ b/Tests/PrivateNearestNeighborsSearchTests/ClientTests.swift @@ -0,0 +1,164 @@ +// Copyright 2024 Apple Inc. and the Swift Homomorphic Encryption project authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import HomomorphicEncryption +@testable import PrivateNearestNeighborsSearch +import TestUtilities +import XCTest + +final class ClientTests: XCTestCase { + func testClientConfig() throws { + func runTest(for _: Scheme.Type) throws { + let plaintextModuli = try [ + PredefinedRlweParameters.n_4096_logq_27_28_28_logt_16, + PredefinedRlweParameters.n_4096_logq_27_28_28_logt_17, + ].map { rlweParams in + try EncryptionParameters(from: rlweParams).plaintextModulus + } + // Check scaling factor increases as we add plaintext moduli. + let maxScalingFactor1 = ClientConfig.maxScalingFactor( + vectorDimension: 128, + distanceMetric: .cosineSimilarity, + plaintextModuli: Array(plaintextModuli.prefix(1))) + let maxScalingFactor2 = ClientConfig.maxScalingFactor( + vectorDimension: 128, + distanceMetric: .cosineSimilarity, + plaintextModuli: plaintextModuli) + XCTAssertGreaterThan(maxScalingFactor2, maxScalingFactor1) + + XCTAssertNoThrow( + try ClientConfig( + encryptionParams: EncryptionParameters(from: PredefinedRlweParameters.n_4096_logq_27_28_28_logt_17), + scalingFactor: maxScalingFactor2, + queryPacking: .denseRow, + vectorDimension: 128, + evaluationKeyConfig: EvaluationKeyConfiguration(), + distanceMetric: .cosineSimilarity, + extraPlaintextModuli: [plaintextModuli[1]])) + } + + try runTest(for: Bfv.self) + try runTest(for: Bfv.self) + } + + func testNormalizeRowsAndScale() throws { + struct TestCase { + let scalingFactor: Float + let norm: Array2d.Norm + let input: [[Float]] + let normalized: [[Float]] + let scaled: [[Float]] + let rounded: [[T]] + } + + func runTestCase(testCase: TestCase) throws { + let floatMatrix = Array2d(data: testCase.input) + let normalized = floatMatrix.normalizedRows(norm: testCase.norm) + for (normalized, expected) in zip(normalized.data, testCase.normalized.flatMap { $0 }) { + XCTAssertIsClose(normalized, expected) + } + + let scaled = normalized.scaled(by: testCase.scalingFactor) + for (scaled, expected) in zip(scaled.data, testCase.scaled.flatMap { $0 }) { + XCTAssertIsClose(scaled, expected) + } + let rounded: Array2d = scaled.rounded() + XCTAssertEqual(rounded.data, testCase.rounded.flatMap { $0 }) + } + + let testCases: [TestCase] = [ + TestCase(scalingFactor: 10.0, + norm: Array2d.Norm.Lp(p: 1.0), + input: [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], + normalized: [[1.0 / 3.0, 2.0 / 3.0], [3.0 / 7.0, 4.0 / 7.0], [5.0 / 11.0, 6.0 / 11.0]], + scaled: [[10.0 / 3.0, 20.0 / 3.0], [30.0 / 7.0, 40.0 / 7.0], [50.0 / 11.0, 60.0 / 11.0]], + rounded: [[3, 7], [4, 6], [5, 5]]), + TestCase(scalingFactor: 100.0, + norm: Array2d.Norm.Lp(p: 2.0), + input: [[3.0, 4.0], [-5.0, 12.0]], + normalized: [[3.0 / 5.0, 4.0 / 5.0], [-5.0 / 13.0, 12.0 / 13.0]], + scaled: [[300.0 / 5.0, 400.0 / 5.0], [-500.0 / 13.0, 1200.0 / 13.0]], + rounded: [[60, 80], [-38, 92]]), + ] + for testCase in testCases { + try runTestCase(testCase: testCase) + } + } + + func testQuery() throws { + func runTest(for _: Scheme.Type) throws { + let degree = 512 + let encryptionParams = try EncryptionParameters( + polyDegree: degree, + plaintextModulus: Scheme.Scalar.generatePrimes( + significantBitCounts: [16], + preferringSmall: true, + nttDegree: degree)[0], + coefficientModuli: Scheme.Scalar.generatePrimes( + significantBitCounts: [27, 28, 28], + preferringSmall: false, + nttDegree: degree), + errorStdDev: .stdDev32, + securityLevel: .unchecked) + XCTAssert(encryptionParams.supportsSimdEncoding) + let context = try Context(encryptionParameters: encryptionParams) + let vectorDimension = 32 + let queryDimensions = try MatrixDimensions(rowCount: 1, columnCount: vectorDimension) + + let encodeValues: [[Scheme.Scalar]] = increasingData( + dimensions: queryDimensions, + modulus: context.plaintextModulus) + let queryValues: Array2d = Array2d(data: encodeValues).map { value in Float(value) } + let secretKey = try context.generateSecretKey() + let scalingFactor = 100 + + for extraPlaintextModuli in try [[], Scheme.Scalar.generatePrimes( + significantBitCounts: [17], + preferringSmall: true, nttDegree: degree)] + { + let config = ClientConfig( + encryptionParams: encryptionParams, + scalingFactor: scalingFactor, + queryPacking: .denseRow, + vectorDimension: vectorDimension, + evaluationKeyConfig: EvaluationKeyConfiguration(), + distanceMetric: .cosineSimilarity, + extraPlaintextModuli: extraPlaintextModuli) + let client = try Client(config: config) + let query = try client.generateQuery(vectors: queryValues, using: secretKey) + XCTAssertEqual(query.ciphertextMatrices.count, config.plaintextModuli.count) + + let entryIds = [UInt64(42)] + let entryMetadatas = [42.littleEndianBytes] + let response = Response( + ciphertextMatrices: query.ciphertextMatrices, + entryIDs: entryIds, entryMetadatas: entryMetadatas) + let databaseDistances = try client.decrypt(response: response, using: secretKey) + XCTAssertEqual(databaseDistances.entryIDs, entryIds) + XCTAssertEqual(databaseDistances.entryMetadatas, entryMetadatas) + + let scaledQuery: Array2d = queryValues + .normalizedRows(norm: Array2d.Norm.Lp(p: 2.0)).scaled(by: Float(config.scalingFactor)) + .rounded() + // Cosine similarity response returns result scaled by scalingFactor^2 + let expectedDistances = scaledQuery.map { value in + Float(value) / Float(config.scalingFactor * config.scalingFactor) + } + XCTAssertEqual(databaseDistances.distances, expectedDistances) + } + } + try runTest(for: Bfv.self) + try runTest(for: Bfv.self) + } +} diff --git a/swift-homomorphic-encryption-protobuf b/swift-homomorphic-encryption-protobuf index 25e8161b..9961dae8 160000 --- a/swift-homomorphic-encryption-protobuf +++ b/swift-homomorphic-encryption-protobuf @@ -1 +1 @@ -Subproject commit 25e8161b98c5c2a3532b1338bd65dee618f12f81 +Subproject commit 9961dae813dab98131e0d91f60aac4818460764e