-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding PlaintextMatrix-Vector Multiplication (#74)
- Loading branch information
1 parent
35a9493
commit df7f4f7
Showing
5 changed files
with
338 additions
and
50 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
167 changes: 167 additions & 0 deletions
167
Sources/PrivateNearestNeighborsSearch/MatrixMultiplication.swift
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
// Copyright 2024 Apple Inc. and the Swift Homomorphic Encryption project authors | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
import Foundation | ||
import HomomorphicEncryption | ||
|
||
/// Pre-computed values for matrix-vector multiplication using baby-step, giant-step algorithm. | ||
/// | ||
/// - seealso: Section 6.3 of <https://eprint.iacr.org/2018/244.pdf>. | ||
public struct BabyStepGiantStep: Codable, Equatable, Hashable, Sendable { | ||
/// Dimension of the vector; "D" in the reference. | ||
public let vectorDimension: Int | ||
/// Baby step; "g" in the reference. | ||
public let babyStep: Int | ||
/// Giant step; "h" in the reference. | ||
public let giantStep: Int | ||
|
||
public init(vectorDimension: Int, babyStep: Int, giantStep: Int) { | ||
self.vectorDimension = vectorDimension | ||
self.babyStep = babyStep | ||
self.giantStep = giantStep | ||
} | ||
|
||
public init(vectorDimension: Int) { | ||
let dimension = Int32(vectorDimension).nextPowerOfTwo | ||
let babyStep = Int32(Double(dimension).squareRoot().rounded(.up)) | ||
let giantStep = dimension.dividingCeil(babyStep, variableTime: true) | ||
|
||
self.init(vectorDimension: Int(dimension), babyStep: Int(babyStep), giantStep: Int(giantStep)) | ||
} | ||
} | ||
|
||
/// Helper function to compute evaluation key used in computing multiplication with a vector. | ||
enum MatrixMultiplication { | ||
static func evaluationKeyConfig( | ||
plaintextMatrixDimensions: MatrixDimensions, | ||
encryptionParameters: EncryptionParameters<some HeScheme>) throws -> EvaluationKeyConfiguration | ||
{ | ||
let babyStepGiantStep = BabyStepGiantStep(vectorDimension: plaintextMatrixDimensions.columnCount) | ||
return try EvaluationKeyConfiguration( | ||
galoisElements: [ | ||
GaloisElement.rotatingColumns( | ||
by: -1, | ||
degree: encryptionParameters.polyDegree), | ||
GaloisElement.rotatingColumns( | ||
by: -babyStepGiantStep.babyStep, | ||
degree: encryptionParameters.polyDegree), | ||
], | ||
hasRelinearizationKey: false) | ||
} | ||
} | ||
|
||
extension PlaintextMatrix { | ||
/// Computes dot product of each row in the PlaintextMatrix with vector encrypted in `ciphertextVector`. | ||
/// | ||
/// - Parameters: | ||
/// - ciphertextVector: Encrypted dense-row packed vector. | ||
/// - evaluationKey: Evaluation key to perform BabyStepGiantStep rotations. | ||
/// - Returns: Encrypted dense-column packed vector containing dot products. | ||
/// - Throws: Error upon failure to compute the inner product. | ||
func mul( | ||
ciphertextVector: CiphertextMatrix<Scheme, Scheme.CanonicalCiphertextFormat>, | ||
using evaluationKey: EvaluationKey<Scheme>) throws -> CiphertextMatrix<Scheme, Scheme.CanonicalCiphertextFormat> | ||
{ | ||
guard case .diagonal = packing else { | ||
let expectedBsgs = BabyStepGiantStep(vectorDimension: dimensions.columnCount) | ||
throw PnnsError.wrongMatrixPacking(got: packing, expected: .diagonal(babyStepGiantStep: expectedBsgs)) | ||
} | ||
guard ciphertextVector.packing == .denseRow else { | ||
throw PnnsError.wrongMatrixPacking(got: ciphertextVector.packing, expected: .denseRow) | ||
} | ||
guard ciphertextVector.context == context else { | ||
throw PnnsError.wrongContext(got: ciphertextVector.context, expected: context) | ||
} | ||
guard dimensions.columnCount == ciphertextVector.dimensions.rowCount else { | ||
throw PnnsError.invalidMatrixDimensions(ciphertextVector.dimensions) | ||
} | ||
guard ciphertextVector.columnCount == 1 else { | ||
throw PnnsError.invalidMatrixDimensions(ciphertextVector.dimensions) | ||
} | ||
guard ciphertextVector.ciphertexts.count == 1 else { | ||
throw PnnsError.wrongCiphertextCount(got: ciphertextVector.ciphertexts.count, expected: 1) | ||
} | ||
|
||
// If the plaintext data matrix is | ||
// [[1, 2, 3, 4], | ||
// [5, 6, 7, 8], | ||
// [9, 10, 11, 12], | ||
// [13, 14, 15, 16]] | ||
// it can be packed diagonally as | ||
// [[1, 6, 11, 16], | ||
// [2, 7, 12, 13], | ||
// [3, 8, 9, 14], | ||
// [4, 5, 10, 15]] | ||
// Then, performing a dot product with the encrypted vector [1, 2, 3, 4] | ||
// is done by a series of ciphertxt-plaintext multiplications, ciphertext | ||
// rotations, and ciphertext-ciphertext additions: | ||
// [1, 6, 11, 16] * [1, 2, 3, 4] => [1, 12, 33, 64] | | ||
// [2, 7, 12, 13] * [2, 3, 4, 1] => [4, 21, 48, 13] | | ||
// [3, 8, 9, 14] * [3, 4, 1, 2] => [9, 32, 9, 28] | | ||
// [4, 5, 10, 15] * [4, 1, 2, 3] => [16, 5, 20, 45] | - + -> [30, 70, 110, 150] | ||
// We extend this basic idea using baby-step giant-step logic from Section 6.3 of | ||
// https://eprint.iacr.org/2018/244.pdf. | ||
|
||
let babyStepGiantStep = BabyStepGiantStep(vectorDimension: dimensions.columnCount) | ||
|
||
// 1) Compute v_j = theta^j(v) | ||
var rotatedCiphertexts: [Scheme.EvalCiphertext] = [] | ||
rotatedCiphertexts.reserveCapacity(babyStepGiantStep.babyStep) | ||
var state = ciphertextVector.ciphertexts[0] | ||
for step in 0..<babyStepGiantStep.babyStep { | ||
try rotatedCiphertexts.append(state.convertToEvalFormat()) | ||
if step != babyStepGiantStep.babyStep - 1 { | ||
try state.rotateColumns(by: -1, using: evaluationKey) | ||
} | ||
} | ||
|
||
let resultCiphertextCount = dimensions.rowCount.dividingCeil(context.degree, variableTime: true) | ||
let zeroCiphertext: Scheme.CanonicalCiphertext = try Ciphertext.zero(context: context) | ||
var resultCiphertexts: [Scheme.CanonicalCiphertext] = Array( | ||
repeating: zeroCiphertext, | ||
count: resultCiphertextCount) | ||
|
||
for resultCiphertextIndex in 0..<resultCiphertextCount { | ||
for giantStepIndex in (0..<babyStepGiantStep.giantStep).reversed() { | ||
let plaintextCount = min( | ||
rotatedCiphertexts.count, | ||
babyStepGiantStep.vectorDimension - babyStepGiantStep.babyStep * giantStepIndex) | ||
let plaintextRows = try (0..<plaintextCount).map { j in | ||
j + babyStepGiantStep.babyStep * giantStepIndex | ||
}.map { i in | ||
let index = resultCiphertextCount * i + resultCiphertextIndex | ||
return try plaintexts[index].convertToEvalFormat() | ||
} | ||
let ciphertexts = rotatedCiphertexts[0..<plaintextRows.count] | ||
|
||
// 2) Compute w_k | ||
let innerProduct = try Scheme.innerProduct(ciphertexts: ciphertexts, plaintexts: plaintextRows) | ||
.convertToCanonicalFormat() | ||
|
||
// 3) Compute w incrementally | ||
try resultCiphertexts[resultCiphertextIndex].rotateColumns( | ||
by: -babyStepGiantStep.babyStep, | ||
using: evaluationKey) | ||
try resultCiphertexts[resultCiphertextIndex] += innerProduct | ||
} | ||
} | ||
let ciphertexMatrixDimensions = try MatrixDimensions( | ||
rowCount: resultCiphertextCount * context.encryptionParameters.polyDegree, | ||
columnCount: 1) | ||
return try CiphertextMatrix( | ||
dimensions: ciphertexMatrixDimensions, | ||
packing: .denseColumn, | ||
ciphertexts: resultCiphertexts) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
121 changes: 121 additions & 0 deletions
121
Tests/PrivateNearestNeighborsSearchTests/MatrixMultiplicationTests.swift
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
// Copyright 2024 Apple Inc. and the Swift Homomorphic Encryption project authors | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
import HomomorphicEncryption | ||
@testable import PrivateNearestNeighborsSearch | ||
import TestUtilities | ||
import XCTest | ||
|
||
extension Array where Element: Collection, Element.Element: ScalarType, Element.Index == Int { | ||
typealias BaseElement = Element.Element | ||
|
||
func mul(_ vector: [BaseElement], modulus: BaseElement) throws -> [BaseElement] { | ||
map { row in | ||
precondition(row.count == vector.count) | ||
return zip(row, vector).reduce(0) { sum, multiplicands in | ||
let product = multiplicands.0.multiplyMod(multiplicands.1, modulus: modulus, variableTime: true) | ||
return sum.addMod(product, modulus: modulus) | ||
} | ||
} | ||
} | ||
} | ||
|
||
final class MatrixMultiplicationTests: XCTestCase { | ||
func testMulVector() throws { | ||
func checkProduct<Scheme: HeScheme>( | ||
_: Scheme.Type, | ||
_ plaintextRows: [[Scheme.Scalar]], | ||
_ dimensions: MatrixDimensions, | ||
_ queryValues: [Scheme.Scalar]) throws | ||
{ | ||
let encryptionParameters = try EncryptionParameters<Scheme>(from: .n_4096_logq_27_28_28_logt_16) | ||
let context = try Context(encryptionParameters: encryptionParameters) | ||
let secretKey = try context.generateSecretKey() | ||
|
||
var expected: [Scheme.Scalar] = try plaintextRows.mul( | ||
queryValues, | ||
modulus: encryptionParameters.plaintextModulus) | ||
let n = encryptionParameters.polyDegree | ||
if expected.count % n > 0 { | ||
expected += Array(repeating: 0, count: n - (expected.count % n)) | ||
} | ||
|
||
let babyStepGiantStep = BabyStepGiantStep(vectorDimension: queryValues.count) | ||
let plaintextMatrix = try PlaintextMatrix( | ||
context: context, | ||
dimensions: dimensions, | ||
packing: .diagonal(babyStepGiantStep: babyStepGiantStep), | ||
values: plaintextRows.flatMap { $0 }) | ||
|
||
let evaluationKeyConfig = try MatrixMultiplication.evaluationKeyConfig( | ||
plaintextMatrixDimensions: dimensions, | ||
encryptionParameters: encryptionParameters) | ||
let evaluationKey = try context.generateEvaluationKey( | ||
configuration: evaluationKeyConfig, | ||
using: secretKey) | ||
|
||
// Query ciphertext matrix | ||
let ciphertextDimensions = try MatrixDimensions(rowCount: queryValues.count, columnCount: 1) | ||
let ciphertextVector = try PlaintextMatrix( | ||
context: context, | ||
dimensions: ciphertextDimensions, | ||
packing: .denseRow, | ||
values: queryValues).encrypt(using: secretKey) | ||
|
||
let dotProduct = try plaintextMatrix.mul(ciphertextVector: ciphertextVector, using: evaluationKey) | ||
let expectedCiphertextsCount = dimensions.rowCount.dividingCeil( | ||
encryptionParameters.polyDegree, | ||
variableTime: true) | ||
XCTAssertEqual(dotProduct.ciphertexts.count, expectedCiphertextsCount) | ||
|
||
let resultMatrix = try dotProduct.decrypt(using: secretKey) | ||
let resultValues: [Scheme.Scalar] = try resultMatrix.unpack() | ||
XCTAssertEqual(resultValues, expected) | ||
} | ||
|
||
func runTest<Scheme: HeScheme>(_: Scheme.Type) throws { | ||
// 6x6 | ||
var values: [[Scheme.Scalar]] = [] | ||
for i in Scheme.Scalar(1)...6 { | ||
values.append(Array(repeating: i, count: 6)) | ||
} | ||
var dimensions = try MatrixDimensions(rowCount: 6, columnCount: 6) | ||
var queryValues: [Scheme.Scalar] = Array(repeating: 2, count: 6) | ||
try checkProduct(Scheme.self, values, dimensions, queryValues) | ||
|
||
// Tall - 64x16 | ||
// values = Array(1...1024).map { $0 % 17 } | ||
dimensions = try MatrixDimensions(rowCount: 64, columnCount: 16) | ||
values = increasingData(dimensions: dimensions, modulus: Scheme.Scalar(17)) | ||
queryValues = Array(1...16) | ||
try checkProduct(Scheme.self, values, dimensions, queryValues) | ||
|
||
// Broad - 16x64 | ||
dimensions = try MatrixDimensions(rowCount: 16, columnCount: 64) | ||
values = increasingData(dimensions: dimensions, modulus: Scheme.Scalar(70)) | ||
queryValues = Array(1...64) | ||
queryValues.reverse() | ||
try checkProduct(Scheme.self, values, dimensions, queryValues) | ||
|
||
// Multiple result ciphertexts. 10240x4 | ||
dimensions = try MatrixDimensions(rowCount: 10240, columnCount: 4) | ||
values = increasingData(dimensions: dimensions, modulus: Scheme.Scalar(17)) | ||
queryValues = Array(1...4) | ||
try checkProduct(Scheme.self, values, dimensions, queryValues) | ||
} | ||
|
||
try runTest(Bfv<UInt32>.self) | ||
try runTest(Bfv<UInt64>.self) | ||
} | ||
} |
Oops, something went wrong.