Skip to content

Commit

Permalink
Adding PlaintextMatrix-Vector Multiplication (#74)
Browse files Browse the repository at this point in the history
  • Loading branch information
akshaywadia authored Aug 23, 2024
1 parent 35a9493 commit df7f4f7
Show file tree
Hide file tree
Showing 5 changed files with 338 additions and 50 deletions.
48 changes: 0 additions & 48 deletions Sources/PrivateNearestNeighborsSearch/DotProduct.swift

This file was deleted.

167 changes: 167 additions & 0 deletions Sources/PrivateNearestNeighborsSearch/MatrixMultiplication.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
// Copyright 2024 Apple Inc. and the Swift Homomorphic Encryption project authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import Foundation
import HomomorphicEncryption

/// Pre-computed values for matrix-vector multiplication using baby-step, giant-step algorithm.
///
/// - seealso: Section 6.3 of <https://eprint.iacr.org/2018/244.pdf>.
public struct BabyStepGiantStep: Codable, Equatable, Hashable, Sendable {
/// Dimension of the vector; "D" in the reference.
public let vectorDimension: Int
/// Baby step; "g" in the reference.
public let babyStep: Int
/// Giant step; "h" in the reference.
public let giantStep: Int

public init(vectorDimension: Int, babyStep: Int, giantStep: Int) {
self.vectorDimension = vectorDimension
self.babyStep = babyStep
self.giantStep = giantStep
}

public init(vectorDimension: Int) {
let dimension = Int32(vectorDimension).nextPowerOfTwo
let babyStep = Int32(Double(dimension).squareRoot().rounded(.up))
let giantStep = dimension.dividingCeil(babyStep, variableTime: true)

self.init(vectorDimension: Int(dimension), babyStep: Int(babyStep), giantStep: Int(giantStep))
}
}

/// Helper function to compute evaluation key used in computing multiplication with a vector.
enum MatrixMultiplication {
static func evaluationKeyConfig(
plaintextMatrixDimensions: MatrixDimensions,
encryptionParameters: EncryptionParameters<some HeScheme>) throws -> EvaluationKeyConfiguration
{
let babyStepGiantStep = BabyStepGiantStep(vectorDimension: plaintextMatrixDimensions.columnCount)
return try EvaluationKeyConfiguration(
galoisElements: [
GaloisElement.rotatingColumns(
by: -1,
degree: encryptionParameters.polyDegree),
GaloisElement.rotatingColumns(
by: -babyStepGiantStep.babyStep,
degree: encryptionParameters.polyDegree),
],
hasRelinearizationKey: false)
}
}

extension PlaintextMatrix {
/// Computes dot product of each row in the PlaintextMatrix with vector encrypted in `ciphertextVector`.
///
/// - Parameters:
/// - ciphertextVector: Encrypted dense-row packed vector.
/// - evaluationKey: Evaluation key to perform BabyStepGiantStep rotations.
/// - Returns: Encrypted dense-column packed vector containing dot products.
/// - Throws: Error upon failure to compute the inner product.
func mul(
ciphertextVector: CiphertextMatrix<Scheme, Scheme.CanonicalCiphertextFormat>,
using evaluationKey: EvaluationKey<Scheme>) throws -> CiphertextMatrix<Scheme, Scheme.CanonicalCiphertextFormat>
{
guard case .diagonal = packing else {
let expectedBsgs = BabyStepGiantStep(vectorDimension: dimensions.columnCount)
throw PnnsError.wrongMatrixPacking(got: packing, expected: .diagonal(babyStepGiantStep: expectedBsgs))
}
guard ciphertextVector.packing == .denseRow else {
throw PnnsError.wrongMatrixPacking(got: ciphertextVector.packing, expected: .denseRow)
}
guard ciphertextVector.context == context else {
throw PnnsError.wrongContext(got: ciphertextVector.context, expected: context)
}
guard dimensions.columnCount == ciphertextVector.dimensions.rowCount else {
throw PnnsError.invalidMatrixDimensions(ciphertextVector.dimensions)
}
guard ciphertextVector.columnCount == 1 else {
throw PnnsError.invalidMatrixDimensions(ciphertextVector.dimensions)
}
guard ciphertextVector.ciphertexts.count == 1 else {
throw PnnsError.wrongCiphertextCount(got: ciphertextVector.ciphertexts.count, expected: 1)
}

// If the plaintext data matrix is
// [[1, 2, 3, 4],
// [5, 6, 7, 8],
// [9, 10, 11, 12],
// [13, 14, 15, 16]]
// it can be packed diagonally as
// [[1, 6, 11, 16],
// [2, 7, 12, 13],
// [3, 8, 9, 14],
// [4, 5, 10, 15]]
// Then, performing a dot product with the encrypted vector [1, 2, 3, 4]
// is done by a series of ciphertxt-plaintext multiplications, ciphertext
// rotations, and ciphertext-ciphertext additions:
// [1, 6, 11, 16] * [1, 2, 3, 4] => [1, 12, 33, 64] |
// [2, 7, 12, 13] * [2, 3, 4, 1] => [4, 21, 48, 13] |
// [3, 8, 9, 14] * [3, 4, 1, 2] => [9, 32, 9, 28] |
// [4, 5, 10, 15] * [4, 1, 2, 3] => [16, 5, 20, 45] | - + -> [30, 70, 110, 150]
// We extend this basic idea using baby-step giant-step logic from Section 6.3 of
// https://eprint.iacr.org/2018/244.pdf.

let babyStepGiantStep = BabyStepGiantStep(vectorDimension: dimensions.columnCount)

// 1) Compute v_j = theta^j(v)
var rotatedCiphertexts: [Scheme.EvalCiphertext] = []
rotatedCiphertexts.reserveCapacity(babyStepGiantStep.babyStep)
var state = ciphertextVector.ciphertexts[0]
for step in 0..<babyStepGiantStep.babyStep {
try rotatedCiphertexts.append(state.convertToEvalFormat())
if step != babyStepGiantStep.babyStep - 1 {
try state.rotateColumns(by: -1, using: evaluationKey)
}
}

let resultCiphertextCount = dimensions.rowCount.dividingCeil(context.degree, variableTime: true)
let zeroCiphertext: Scheme.CanonicalCiphertext = try Ciphertext.zero(context: context)
var resultCiphertexts: [Scheme.CanonicalCiphertext] = Array(
repeating: zeroCiphertext,
count: resultCiphertextCount)

for resultCiphertextIndex in 0..<resultCiphertextCount {
for giantStepIndex in (0..<babyStepGiantStep.giantStep).reversed() {
let plaintextCount = min(
rotatedCiphertexts.count,
babyStepGiantStep.vectorDimension - babyStepGiantStep.babyStep * giantStepIndex)
let plaintextRows = try (0..<plaintextCount).map { j in
j + babyStepGiantStep.babyStep * giantStepIndex
}.map { i in
let index = resultCiphertextCount * i + resultCiphertextIndex
return try plaintexts[index].convertToEvalFormat()
}
let ciphertexts = rotatedCiphertexts[0..<plaintextRows.count]

// 2) Compute w_k
let innerProduct = try Scheme.innerProduct(ciphertexts: ciphertexts, plaintexts: plaintextRows)
.convertToCanonicalFormat()

// 3) Compute w incrementally
try resultCiphertexts[resultCiphertextIndex].rotateColumns(
by: -babyStepGiantStep.babyStep,
using: evaluationKey)
try resultCiphertexts[resultCiphertextIndex] += innerProduct
}
}
let ciphertexMatrixDimensions = try MatrixDimensions(
rowCount: resultCiphertextCount * context.encryptionParameters.polyDegree,
columnCount: 1)
return try CiphertextMatrix(
dimensions: ciphertexMatrixDimensions,
packing: .denseColumn,
ciphertexts: resultCiphertexts)
}
}
4 changes: 2 additions & 2 deletions Sources/PrivateNearestNeighborsSearch/PlaintextMatrix.swift
Original file line number Diff line number Diff line change
Expand Up @@ -447,8 +447,8 @@ public struct PlaintextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable,
let i = (plaintexts.count - chunkIndex) / plaintextsPerColumn
let rotationStep = i.previousMultiple(of: bsgs.babyStep, variableTime: true)
let middle = min(chunk.endIndex, chunk.startIndex + n / 2)
chunk[chunk.startIndex..<middle].rotate(toStartAt: chunk.startIndex + rotationStep)
chunk[middle...].rotate(toStartAt: middle + rotationStep)
chunk[chunk.startIndex..<middle].rotate(toStartAt: middle - rotationStep)
chunk[middle...].rotate(toStartAt: chunk.endIndex - rotationStep)

let plaintext = try context.encode(values: Array(chunk), format: .simd)
plaintexts.append(plaintext)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// Copyright 2024 Apple Inc. and the Swift Homomorphic Encryption project authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import HomomorphicEncryption
@testable import PrivateNearestNeighborsSearch
import TestUtilities
import XCTest

extension Array where Element: Collection, Element.Element: ScalarType, Element.Index == Int {
typealias BaseElement = Element.Element

func mul(_ vector: [BaseElement], modulus: BaseElement) throws -> [BaseElement] {
map { row in
precondition(row.count == vector.count)
return zip(row, vector).reduce(0) { sum, multiplicands in
let product = multiplicands.0.multiplyMod(multiplicands.1, modulus: modulus, variableTime: true)
return sum.addMod(product, modulus: modulus)
}
}
}
}

final class MatrixMultiplicationTests: XCTestCase {
func testMulVector() throws {
func checkProduct<Scheme: HeScheme>(
_: Scheme.Type,
_ plaintextRows: [[Scheme.Scalar]],
_ dimensions: MatrixDimensions,
_ queryValues: [Scheme.Scalar]) throws
{
let encryptionParameters = try EncryptionParameters<Scheme>(from: .n_4096_logq_27_28_28_logt_16)
let context = try Context(encryptionParameters: encryptionParameters)
let secretKey = try context.generateSecretKey()

var expected: [Scheme.Scalar] = try plaintextRows.mul(
queryValues,
modulus: encryptionParameters.plaintextModulus)
let n = encryptionParameters.polyDegree
if expected.count % n > 0 {
expected += Array(repeating: 0, count: n - (expected.count % n))
}

let babyStepGiantStep = BabyStepGiantStep(vectorDimension: queryValues.count)
let plaintextMatrix = try PlaintextMatrix(
context: context,
dimensions: dimensions,
packing: .diagonal(babyStepGiantStep: babyStepGiantStep),
values: plaintextRows.flatMap { $0 })

let evaluationKeyConfig = try MatrixMultiplication.evaluationKeyConfig(
plaintextMatrixDimensions: dimensions,
encryptionParameters: encryptionParameters)
let evaluationKey = try context.generateEvaluationKey(
configuration: evaluationKeyConfig,
using: secretKey)

// Query ciphertext matrix
let ciphertextDimensions = try MatrixDimensions(rowCount: queryValues.count, columnCount: 1)
let ciphertextVector = try PlaintextMatrix(
context: context,
dimensions: ciphertextDimensions,
packing: .denseRow,
values: queryValues).encrypt(using: secretKey)

let dotProduct = try plaintextMatrix.mul(ciphertextVector: ciphertextVector, using: evaluationKey)
let expectedCiphertextsCount = dimensions.rowCount.dividingCeil(
encryptionParameters.polyDegree,
variableTime: true)
XCTAssertEqual(dotProduct.ciphertexts.count, expectedCiphertextsCount)

let resultMatrix = try dotProduct.decrypt(using: secretKey)
let resultValues: [Scheme.Scalar] = try resultMatrix.unpack()
XCTAssertEqual(resultValues, expected)
}

func runTest<Scheme: HeScheme>(_: Scheme.Type) throws {
// 6x6
var values: [[Scheme.Scalar]] = []
for i in Scheme.Scalar(1)...6 {
values.append(Array(repeating: i, count: 6))
}
var dimensions = try MatrixDimensions(rowCount: 6, columnCount: 6)
var queryValues: [Scheme.Scalar] = Array(repeating: 2, count: 6)
try checkProduct(Scheme.self, values, dimensions, queryValues)

// Tall - 64x16
// values = Array(1...1024).map { $0 % 17 }
dimensions = try MatrixDimensions(rowCount: 64, columnCount: 16)
values = increasingData(dimensions: dimensions, modulus: Scheme.Scalar(17))
queryValues = Array(1...16)
try checkProduct(Scheme.self, values, dimensions, queryValues)

// Broad - 16x64
dimensions = try MatrixDimensions(rowCount: 16, columnCount: 64)
values = increasingData(dimensions: dimensions, modulus: Scheme.Scalar(70))
queryValues = Array(1...64)
queryValues.reverse()
try checkProduct(Scheme.self, values, dimensions, queryValues)

// Multiple result ciphertexts. 10240x4
dimensions = try MatrixDimensions(rowCount: 10240, columnCount: 4)
values = increasingData(dimensions: dimensions, modulus: Scheme.Scalar(17))
queryValues = Array(1...4)
try checkProduct(Scheme.self, values, dimensions, queryValues)
}

try runTest(Bfv<UInt32>.self)
try runTest(Bfv<UInt64>.self)
}
}
Loading

0 comments on commit df7f4f7

Please sign in to comment.