From 84329d4915aef3043f993aea74b51db39c4aa3b5 Mon Sep 17 00:00:00 2001 From: Ruiyu Zhu Date: Mon, 15 Apr 2024 12:42:21 -0700 Subject: [PATCH] Add support for packing small entries into same plaintext (#90) add support for multiple small entries packed into same plaintext --- Package.swift | 3 +- Sources/Pir/IndexPirProtocol.swift | 23 ++++- Sources/Pir/MulPir.swift | 140 ++++++++++++++++++++++------- Sources/SwiftHe/Ciphertext.swift | 3 +- Tests/PirTests/IndexPirTest.swift | 31 ++++++- Tests/PirTests/MulPirTest.swift | 6 +- Tests/PirTests/PirTestUtils.swift | 21 +++-- 7 files changed, 181 insertions(+), 46 deletions(-) diff --git a/Package.swift b/Package.swift index 86e19b48..85f4fb8b 100644 --- a/Package.swift +++ b/Package.swift @@ -27,7 +27,8 @@ let package = Package( swiftSettings: [SwiftSetting.unsafeFlags(["-cross-module-optimization"])]), .target( name: "Pir", - dependencies: ["SwiftHe"], + dependencies: ["SwiftHe", + .product(name: "Numerics", package: "swift-numerics")], swiftSettings: [SwiftSetting.unsafeFlags(["-cross-module-optimization"])]), .target( name: "TestUtil", diff --git a/Sources/Pir/IndexPirProtocol.swift b/Sources/Pir/IndexPirProtocol.swift index a2a506ba..ff687380 100644 --- a/Sources/Pir/IndexPirProtocol.swift +++ b/Sources/Pir/IndexPirProtocol.swift @@ -1,8 +1,27 @@ import SwiftHe +public struct IndexPirConfig { + let entryCount: Int + let entrySizeInBytes: Int + let dimensionCount: Int + let polyDegree: Int + let plaintextModulusBitWidth: Int + var entrySizeInCoeffs: Int { + entrySizeInBytes * UInt8.bitWidth.divCeil(plaintextModulusBitWidth) + } +} + public struct IndexPirParameter { let entryCount: Int let entrySizeInBytes: Int + var entrySizeInCoeffs: Int { + entrySizeInBytes * UInt8.bitWidth.divCeil(plaintextModulusBitWidth) + } + + let plaintextCount: Int + let polyDegree: Int + let plaintextModulusBitWidth: Int + let dimensions: [Int] var dimensionCount: Int { dimensions.count } var expandedQueryCount: Int { dimensions.sum() } @@ -29,6 +48,8 @@ public protocol IndexPirProtocol { typealias Query = Pir.Query typealias Response = Pir.Response + static func generateParameter(config: IndexPirConfig) -> IndexPirParameter + static func preprocessEncodedDatabase(parameter: IndexPirParameter, with context: Context, database: [[Scheme.Scalar]]) throws -> Database @@ -75,7 +96,7 @@ extension IndexPirProtocol { { try Array(CoefficientPacking.coefficientsToBytes( coeffs: decryptResponse(parameter: parameter, response: response, at: queryIndex, using: secretKey), - bitsPerCoeff: response.ciphertexts[0].plaintextModulus.log2, + bitsPerCoeff: response.ciphertexts[0].context.plaintextModulus.log2, skipLSBs: 0).prefix(parameter.entrySizeInBytes)) } } diff --git a/Sources/Pir/MulPir.swift b/Sources/Pir/MulPir.swift index c369b43f..e7e1d71e 100644 --- a/Sources/Pir/MulPir.swift +++ b/Sources/Pir/MulPir.swift @@ -1,3 +1,5 @@ +import Foundation +import Numerics import SwiftHe class MulPir: IndexPirProtocol { @@ -6,19 +8,65 @@ class MulPir: IndexPirProtocol { typealias Query = Pir.Query typealias Response = Pir.Response typealias CanonicalCiphertext = Scheme.CanonicalCiphertext + + static func generateParameter(config: IndexPirConfig) -> IndexPirParameter { + let perChunkPlaintextCount = if config.entrySizeInCoeffs <= config.polyDegree { + config.entryCount.divCeil(config.polyDegree / config.entrySizeInCoeffs) + } else { + config.entryCount + } + let plaintextCount = perChunkPlaintextCount * config.entrySizeInCoeffs.divCeil(config.polyDegree) + let dimeonsionSize = Int(ceil(Double.root(Double(perChunkPlaintextCount), config.dimensionCount))) + let dimensions = Array(repeating: dimeonsionSize, count: config.dimensionCount) + // TODO: rdar://126490382 (Optimize MulPir's parameter) implement dimension optimizations here + + return IndexPirParameter( + entryCount: config.entryCount, + entrySizeInBytes: config.entrySizeInBytes, + plaintextCount: plaintextCount, + polyDegree: config.polyDegree, + plaintextModulusBitWidth: config.plaintextModulusBitWidth, + dimensions: dimensions) + } + + private static func entryChunksPerPlaintext(_ parameter: IndexPirParameter) -> Int { + let entrySizeInCoeff = parameter.entrySizeInCoeffs + if parameter.polyDegree >= entrySizeInCoeff { + return parameter.polyDegree / entrySizeInCoeff + } + return 1 + } + + private static func plaintextIndex(_ parameter: IndexPirParameter, + entryIndex: Int) -> Int + { + let entryPerPlaintext = entryChunksPerPlaintext(parameter) + return entryIndex / entryPerPlaintext + } + + private static func plaintextCount(_ parameter: IndexPirParameter) -> Int { + if parameter.entrySizeInCoeffs > parameter.polyDegree { + return parameter.entrySizeInCoeffs.divCeil(parameter.polyDegree) * parameter.entryCount + } + return parameter.entryCount.divCeil(entryChunksPerPlaintext(parameter)) + } + + private static func perChunkPlaintextCount(_ parameter: IndexPirParameter) -> Int { + parameter.entryCount.divCeil(entryChunksPerPlaintext(parameter)) + } } // MARK: query generation related function extension MulPir { - static func computeCoordinates(parameter: IndexPirParameter, index: Int) throws -> [Int] { + static func computeCoordinates(parameter: IndexPirParameter, at index: Int) throws -> [Int] { guard index >= 0, index < parameter.entryCount else { throw PirError.invalidIndex(index: index, numberOfEntries: parameter.entryCount) } - var index = index + var plaintextIndex = plaintextIndex(parameter, entryIndex: index) return parameter.dimensions.map { dimensionSize in - let coordinate = index % dimensionSize - index /= dimensionSize + let coordinate = plaintextIndex % dimensionSize + plaintextIndex /= dimensionSize return coordinate } } @@ -29,7 +77,7 @@ extension MulPir { at index: Int, using secretKey: SecretKey) throws -> Query { - let coordinates = try computeCoordinates(parameter: parameter, index: index) + let coordinates = try computeCoordinates(parameter: parameter, at: index) var accumulatedCoordinate = 0 let nonZeroPositions: [Int] = parameter.dimensions.enumerated().map { dimIndex, dimSize in let coordinate = accumulatedCoordinate + coordinates[dimIndex] @@ -62,16 +110,17 @@ extension MulPir { expandedRemainingQuery: ArraySlice, dataChunk: ArraySlice>) throws -> CanonicalCiphertext { - precondition(dataChunk.count == parameter.entryCount) + precondition(dataChunk.count == perChunkPlaintextCount(parameter)) var intermediateResults: [CanonicalCiphertext] = try stride( from: dataChunk.startIndex, to: dataChunk.endIndex, by: parameter.dimensions[0]).map { startIndex in - let size = min(dataChunk.count - startIndex, parameter.dimensions[0]) + let size = min(dataChunk.endIndex - startIndex, parameter.dimensions[0]) return try expandedDim0Query[0..) throws -> Response { - guard database.plaintexts.count.isMultiple(of: parameter.entryCount) else { + guard database.plaintexts.count == plaintextCount(parameter) else { throw PirError.invalidDatabase( description: """ database size, \(database.plaintexts.count), - must be a multiple of \(parameter.entryCount) + should be \(plaintextCount(parameter)) """) } let expandedQueries = try PirUtil.expandCiphertexts( @@ -110,11 +159,12 @@ extension MulPir { let firstDimensionQueries = try expandedQueries[0.., + at index: Int) -> Range + { + if parameter.polyDegree <= parameter.entrySizeInCoeffs { + return 0..) throws -> [Scheme.Scalar] { - let entrySizeInCoeff = parameter.entrySizeInBytes * UInt8.bitWidth - .divCeil(Int(response.ciphertexts[0].plaintextModulus.log2)) + let range = computeResponseRangeInCiphertext( + parameter: parameter, + context: response.ciphertexts[0].context, + at: index) return try response.ciphertexts.flatMap { ciphertext in try Scheme.decode( plaintext: Scheme.decrypt(ciphertext, using: secretKey), - format: .coefficient)[0..]] = try database.map { entry in try stride( - from: 0, - to: entrySizeInCoeff, - by: context.degree).map { startIndex in - let endIndex = min(startIndex + context.degree, entry.count) - let data: [Scheme.Scalar] - if startIndex >= endIndex { - data = [] - } else { - data = Array(entry[startIndex..= context.degree { + let numberOfChunks = entrySizeInCoeff.divCeil(context.degree) + let plaintexts: [[Plaintext]] = try database.map { entry in try stride( + from: 0, + to: entrySizeInCoeff, + by: context.degree).map { startIndex in + let endIndex = min(startIndex + context.degree, entry.count) + let data: [Scheme.Scalar] + if startIndex >= endIndex { + data = [] + } else { + data = Array(entry[startIndex..] = try stride( + from: 0, + to: parameter.entryCount, + by: entriesPerPlaintext).map { startIndex in + let endIndex = min(startIndex + entriesPerPlaintext, database.count) + return try Scheme.encode( + context: context, + values: database[startIndex..: Equatable { - @usableFromInline let context: Context + public let context: Context @usableFromInline var polys: [PolyRq] @usableFromInline var correctionFactor: Scheme.Scalar - public var plaintextModulus: Scheme.Scalar { context.plaintextModulus } @inlinable init(context: Context, polys: [PolyRq], correctionFactor: Scheme.Scalar) { diff --git a/Tests/PirTests/IndexPirTest.swift b/Tests/PirTests/IndexPirTest.swift index 44710695..1020bd0a 100644 --- a/Tests/PirTests/IndexPirTest.swift +++ b/Tests/PirTests/IndexPirTest.swift @@ -13,8 +13,10 @@ class IndexPirTests: XCTestCase { } } - private func indexPirTest(pir _: PIR.Type) throws { - let parameter = PirTestUtils.getTestParameter() + private func indexPirTestForParameter( + pir _: PIR.Type, + for parameter: IndexPirParameter) throws + { let database = getDatabaseForTesting( numberOfEntries: parameter.entryCount, entrySizeInBytes: parameter.entrySizeInBytes) @@ -43,6 +45,31 @@ class IndexPirTests: XCTestCase { } } + private func indexPirTest(pir: PIR.Type) throws { + let config1 = IndexPirConfig( + entryCount: 100, + entrySizeInBytes: 1, + dimensionCount: 2, + polyDegree: TestUtils.testPolyDegree, + plaintextModulusBitWidth: TestUtils.testPlaintextModulus.log2) + let config2 = IndexPirConfig( + entryCount: 100, + entrySizeInBytes: 8, + dimensionCount: 2, + polyDegree: TestUtils.testPolyDegree, + plaintextModulusBitWidth: TestUtils.testPlaintextModulus.log2) + let config3 = IndexPirConfig( + entryCount: 100, + entrySizeInBytes: 24, + dimensionCount: 2, + polyDegree: TestUtils.testPolyDegree, + plaintextModulusBitWidth: TestUtils.testPlaintextModulus.log2) + + try indexPirTestForParameter(pir: pir, for: PIR.generateParameter(config: config1)) + try indexPirTestForParameter(pir: pir, for: PIR.generateParameter(config: config2)) + try indexPirTestForParameter(pir: pir, for: PIR.generateParameter(config: config3)) + } + func testIndexPir() throws { try indexPirTest(pir: MulPir.self) try indexPirTest(pir: MulPir>.self) diff --git a/Tests/PirTests/MulPirTest.swift b/Tests/PirTests/MulPirTest.swift index 704d3a28..a3321e32 100644 --- a/Tests/PirTests/MulPirTest.swift +++ b/Tests/PirTests/MulPirTest.swift @@ -5,7 +5,7 @@ import XCTest class MulPirTests: XCTestCase { private func queryGenerationTest(scheme _: Scheme.Type) throws { - let parameter = PirTestUtils.getTestParameter() + let parameter = PirTestUtils.getTestParameter(pir: MulPir.self) let context: Context = try TestUtils.getTestContext() let secretKey = try Scheme.generateSecretKey(context: context) let galoisElements = MulPir.computeGaloisElements(parameter: parameter, context: context) @@ -67,10 +67,10 @@ class MulPirTests: XCTestCase { private func queryAndResponseTest(scheme _: Scheme.Type) throws { let context: Context = try TestUtils.getTestContext() let database: [[Scheme.Scalar]] = getDatabaseForTesting( - numberOfEntries: PirTestUtils.testNumberOfEntries, + numberOfEntries: PirTestUtils.testEntryCount, entrySizeInCoefficient: PirTestUtils.testEntrySizeInCoefficient, modulus: context.plaintextModulus) - let parameter = PirTestUtils.getTestParameter() + let parameter = PirTestUtils.getTestParameter(pir: MulPir.self) let secretKey = try Scheme.generateSecretKey(context: context) let galoisElements = MulPir.computeGaloisElements(parameter: parameter, context: context) let evaluationKey = try Scheme.generateEvaluationKey( diff --git a/Tests/PirTests/PirTestUtils.swift b/Tests/PirTests/PirTestUtils.swift index 84cb83ab..f73f494c 100644 --- a/Tests/PirTests/PirTestUtils.swift +++ b/Tests/PirTests/PirTestUtils.swift @@ -1,14 +1,21 @@ @testable import Pir +import TestUtil public enum PirTestUtils { - static let testNumberOfEntries = 100 + static let testEntryCount = 100 static let testEntrySizeInCoefficient = 10 - static let testDimensions = [16, 7] - static func getTestParameter() -> IndexPirParameter { - IndexPirParameter( - entryCount: testNumberOfEntries, - entrySizeInBytes: testEntrySizeInCoefficient, - dimensions: testDimensions) + static func getTestParameter( + pir _: PIR.Type, + entrySizeInCoeff: Int = testEntrySizeInCoefficient) -> IndexPirParameter + { + let entrySizeInByte = (entrySizeInCoeff * TestUtils.testPlaintextModulus.log2).divCeil(UInt8.bitWidth) + let config = IndexPirConfig( + entryCount: testEntryCount, + entrySizeInBytes: entrySizeInByte, + dimensionCount: 2, + polyDegree: TestUtils.testPolyDegree, + plaintextModulusBitWidth: TestUtils.testPlaintextModulus.log2) + return PIR.generateParameter(config: config) } }