Skip to content

Commit

Permalink
Add support for packing small entries into same plaintext (#90)
Browse files Browse the repository at this point in the history
add support for multiple small entries packed into same plaintext
  • Loading branch information
RuiyuZhu authored and GitHub Enterprise committed Apr 15, 2024
1 parent f665db1 commit 84329d4
Show file tree
Hide file tree
Showing 7 changed files with 181 additions and 46 deletions.
3 changes: 2 additions & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ let package = Package(
swiftSettings: [SwiftSetting.unsafeFlags(["-cross-module-optimization"])]),
.target(
name: "Pir",
dependencies: ["SwiftHe"],
dependencies: ["SwiftHe",
.product(name: "Numerics", package: "swift-numerics")],
swiftSettings: [SwiftSetting.unsafeFlags(["-cross-module-optimization"])]),
.target(
name: "TestUtil",
Expand Down
23 changes: 22 additions & 1 deletion Sources/Pir/IndexPirProtocol.swift
Original file line number Diff line number Diff line change
@@ -1,8 +1,27 @@
import SwiftHe

public struct IndexPirConfig {
let entryCount: Int
let entrySizeInBytes: Int
let dimensionCount: Int
let polyDegree: Int
let plaintextModulusBitWidth: Int
var entrySizeInCoeffs: Int {
entrySizeInBytes * UInt8.bitWidth.divCeil(plaintextModulusBitWidth)
}
}

public struct IndexPirParameter {
let entryCount: Int
let entrySizeInBytes: Int
var entrySizeInCoeffs: Int {
entrySizeInBytes * UInt8.bitWidth.divCeil(plaintextModulusBitWidth)
}

let plaintextCount: Int
let polyDegree: Int
let plaintextModulusBitWidth: Int

let dimensions: [Int]
var dimensionCount: Int { dimensions.count }
var expandedQueryCount: Int { dimensions.sum() }
Expand All @@ -29,6 +48,8 @@ public protocol IndexPirProtocol {
typealias Query = Pir.Query<Scheme>
typealias Response = Pir.Response<Scheme>

static func generateParameter(config: IndexPirConfig) -> IndexPirParameter

static func preprocessEncodedDatabase(parameter: IndexPirParameter,
with context: Context<Scheme>,
database: [[Scheme.Scalar]]) throws -> Database
Expand Down Expand Up @@ -75,7 +96,7 @@ extension IndexPirProtocol {
{
try Array(CoefficientPacking.coefficientsToBytes(
coeffs: decryptResponse(parameter: parameter, response: response, at: queryIndex, using: secretKey),
bitsPerCoeff: response.ciphertexts[0].plaintextModulus.log2,
bitsPerCoeff: response.ciphertexts[0].context.plaintextModulus.log2,
skipLSBs: 0).prefix(parameter.entrySizeInBytes))
}
}
140 changes: 110 additions & 30 deletions Sources/Pir/MulPir.swift
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import Foundation
import Numerics
import SwiftHe

class MulPir<Scheme: HeScheme>: IndexPirProtocol {
Expand All @@ -6,19 +8,65 @@ class MulPir<Scheme: HeScheme>: IndexPirProtocol {
typealias Query = Pir.Query<Scheme>
typealias Response = Pir.Response<Scheme>
typealias CanonicalCiphertext = Scheme.CanonicalCiphertext

static func generateParameter(config: IndexPirConfig) -> IndexPirParameter {
let perChunkPlaintextCount = if config.entrySizeInCoeffs <= config.polyDegree {
config.entryCount.divCeil(config.polyDegree / config.entrySizeInCoeffs)
} else {
config.entryCount
}
let plaintextCount = perChunkPlaintextCount * config.entrySizeInCoeffs.divCeil(config.polyDegree)
let dimeonsionSize = Int(ceil(Double.root(Double(perChunkPlaintextCount), config.dimensionCount)))
let dimensions = Array(repeating: dimeonsionSize, count: config.dimensionCount)
// TODO: rdar://126490382 (Optimize MulPir's parameter) implement dimension optimizations here

return IndexPirParameter(
entryCount: config.entryCount,
entrySizeInBytes: config.entrySizeInBytes,
plaintextCount: plaintextCount,
polyDegree: config.polyDegree,
plaintextModulusBitWidth: config.plaintextModulusBitWidth,
dimensions: dimensions)
}

private static func entryChunksPerPlaintext(_ parameter: IndexPirParameter) -> Int {
let entrySizeInCoeff = parameter.entrySizeInCoeffs
if parameter.polyDegree >= entrySizeInCoeff {
return parameter.polyDegree / entrySizeInCoeff
}
return 1
}

private static func plaintextIndex(_ parameter: IndexPirParameter,
entryIndex: Int) -> Int
{
let entryPerPlaintext = entryChunksPerPlaintext(parameter)
return entryIndex / entryPerPlaintext
}

private static func plaintextCount(_ parameter: IndexPirParameter) -> Int {
if parameter.entrySizeInCoeffs > parameter.polyDegree {
return parameter.entrySizeInCoeffs.divCeil(parameter.polyDegree) * parameter.entryCount
}
return parameter.entryCount.divCeil(entryChunksPerPlaintext(parameter))
}

private static func perChunkPlaintextCount(_ parameter: IndexPirParameter) -> Int {
parameter.entryCount.divCeil(entryChunksPerPlaintext(parameter))
}
}

// MARK: query generation related function

extension MulPir {
static func computeCoordinates(parameter: IndexPirParameter, index: Int) throws -> [Int] {
static func computeCoordinates(parameter: IndexPirParameter, at index: Int) throws -> [Int] {
guard index >= 0, index < parameter.entryCount else {
throw PirError.invalidIndex(index: index, numberOfEntries: parameter.entryCount)
}
var index = index
var plaintextIndex = plaintextIndex(parameter, entryIndex: index)
return parameter.dimensions.map { dimensionSize in
let coordinate = index % dimensionSize
index /= dimensionSize
let coordinate = plaintextIndex % dimensionSize
plaintextIndex /= dimensionSize
return coordinate
}
}
Expand All @@ -29,7 +77,7 @@ extension MulPir {
at index: Int,
using secretKey: SecretKey<Scheme>) throws -> Query
{
let coordinates = try computeCoordinates(parameter: parameter, index: index)
let coordinates = try computeCoordinates(parameter: parameter, at: index)
var accumulatedCoordinate = 0
let nonZeroPositions: [Int] = parameter.dimensions.enumerated().map { dimIndex, dimSize in
let coordinate = accumulatedCoordinate + coordinates[dimIndex]
Expand Down Expand Up @@ -62,16 +110,17 @@ extension MulPir {
expandedRemainingQuery: ArraySlice<CanonicalCiphertext>,
dataChunk: ArraySlice<Plaintext<Scheme, Eval>>) throws -> CanonicalCiphertext
{
precondition(dataChunk.count == parameter.entryCount)
precondition(dataChunk.count == perChunkPlaintextCount(parameter))
var intermediateResults: [CanonicalCiphertext] = try stride(
from: dataChunk.startIndex,
to: dataChunk.endIndex,
by: parameter.dimensions[0]).map { startIndex in
let size = min(dataChunk.count - startIndex, parameter.dimensions[0])
let size = min(dataChunk.endIndex - startIndex, parameter.dimensions[0])
return try expandedDim0Query[0..<size]
.innerProduct(plaintexts: dataChunk[startIndex..<startIndex + size])
.convertToCanonicalFormat()
}

var queryStartingIndex = expandedRemainingQuery.startIndex
for dimensionIndex in 1..<parameter.dimensionCount {
let dimensionSize = parameter.dimensions[dimensionIndex]
Expand All @@ -96,11 +145,11 @@ extension MulPir {
query: Query,
using evaluationKey: EvaluationKey<Scheme>) throws -> Response
{
guard database.plaintexts.count.isMultiple(of: parameter.entryCount) else {
guard database.plaintexts.count == plaintextCount(parameter) else {
throw PirError.invalidDatabase(
description: """
database size, \(database.plaintexts.count),
must be a multiple of \(parameter.entryCount)
should be \(plaintextCount(parameter))
""")
}
let expandedQueries = try PirUtil.expandCiphertexts(
Expand All @@ -110,29 +159,46 @@ extension MulPir {
let firstDimensionQueries = try expandedQueries[0..<parameter.dimensions[0]]
.map { ciphertext in try ciphertext.convertToEvalFormat() }
let remainingQueries = expandedQueries[parameter.dimensions[0]..<parameter.expandedQueryCount]
return try Response(ciphertexts: stride(from: 0, to: database.count, by: parameter.entryCount)
.map { startingIndex in try computeResponseForOneChunk(
let step = perChunkPlaintextCount(parameter)
return try Response(ciphertexts: stride(from: 0, to: database.count, by: step)
.map { startIndex in try computeResponseForOneChunk(
parameter: parameter, expandedDim0Query: firstDimensionQueries,
expandedRemainingQuery: remainingQueries,
dataChunk: database.plaintexts[startingIndex..<parameter.entryCount])
dataChunk: database.plaintexts[startIndex..<startIndex + step])
})
}
}

// MARK: query decrypt function

extension MulPir {
private static func computeResponseRangeInCiphertext(
parameter: IndexPirParameter,
context: Context<Scheme>,
at index: Int) -> Range<Int>
{
if parameter.polyDegree <= parameter.entrySizeInCoeffs {
return 0..<parameter.polyDegree
}
let entrySizeInCoeff = parameter.entrySizeInCoeffs
let entriesInPlaintext = context.degree / entrySizeInCoeff
let postion = index % entriesInPlaintext
return postion * entrySizeInCoeff..<(postion + 1) * entrySizeInCoeff
}

static func decryptResponse(
parameter: IndexPirParameter,
response: Response,
at _: Int,
at index: Int,
using secretKey: SecretKey<Scheme>) throws -> [Scheme.Scalar]
{
let entrySizeInCoeff = parameter.entrySizeInBytes * UInt8.bitWidth
.divCeil(Int(response.ciphertexts[0].plaintextModulus.log2))
let range = computeResponseRangeInCiphertext(
parameter: parameter,
context: response.ciphertexts[0].context,
at: index)
return try response.ciphertexts.flatMap { ciphertext in try Scheme.decode(
plaintext: Scheme.decrypt(ciphertext, using: secretKey),
format: .coefficient)[0..<entrySizeInCoeff]
format: .coefficient)[range]
}
}
}
Expand All @@ -146,22 +212,36 @@ extension MulPir {
{
let entrySizeInCoeff = parameter.entrySizeInBytes * UInt8.bitWidth
.divCeil(Int(context.plaintextModulus.log2))
let numberOfChunks = entrySizeInCoeff.divCeil(context.degree)
let plaintexts: [[Plaintext<Scheme, Eval>]] = try database.map { entry in try stride(
from: 0,
to: entrySizeInCoeff,
by: context.degree).map { startIndex in
let endIndex = min(startIndex + context.degree, entry.count)
let data: [Scheme.Scalar]
if startIndex >= endIndex {
data = []
} else {
data = Array(entry[startIndex..<endIndex])
if entrySizeInCoeff >= context.degree {
let numberOfChunks = entrySizeInCoeff.divCeil(context.degree)
let plaintexts: [[Plaintext<Scheme, Eval>]] = try database.map { entry in try stride(
from: 0,
to: entrySizeInCoeff,
by: context.degree).map { startIndex in
let endIndex = min(startIndex + context.degree, entry.count)
let data: [Scheme.Scalar]
if startIndex >= endIndex {
data = []
} else {
data = Array(entry[startIndex..<endIndex])
}
return try Scheme.encode(context: context, values: data, format: .coefficient)
}
return try Scheme.encode(context: context, values: data, format: .coefficient)
}
return Database(plaintexts: (0..<numberOfChunks).flatMap { index in plaintexts.map { $0[index] }
})
}
let entriesPerPlaintext = context.degree / entrySizeInCoeff
let plaintexts: [Plaintext<Scheme, Eval>] = try stride(
from: 0,
to: parameter.entryCount,
by: entriesPerPlaintext).map { startIndex in
let endIndex = min(startIndex + entriesPerPlaintext, database.count)
return try Scheme.encode(
context: context,
values: database[startIndex..<endIndex].flatMap { $0 },
format: .coefficient)
}
return Database(plaintexts: (0..<numberOfChunks).flatMap { index in plaintexts.map { $0[index] }
})
return Database(plaintexts: plaintexts)
}
}
3 changes: 1 addition & 2 deletions Sources/SwiftHe/Ciphertext.swift
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import Foundation

public struct Ciphertext<Scheme: HeScheme, Format: HeFormat>: Equatable {
@usableFromInline let context: Context<Scheme>
public let context: Context<Scheme>
@usableFromInline var polys: [PolyRq<Scheme.Scalar, Format>]
@usableFromInline var correctionFactor: Scheme.Scalar
public var plaintextModulus: Scheme.Scalar { context.plaintextModulus }

@inlinable
init(context: Context<Scheme>, polys: [PolyRq<Scheme.Scalar, Format>], correctionFactor: Scheme.Scalar) {
Expand Down
31 changes: 29 additions & 2 deletions Tests/PirTests/IndexPirTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ class IndexPirTests: XCTestCase {
}
}

private func indexPirTest<PIR: IndexPirProtocol>(pir _: PIR.Type) throws {
let parameter = PirTestUtils.getTestParameter()
private func indexPirTestForParameter<PIR: IndexPirProtocol>(
pir _: PIR.Type,
for parameter: IndexPirParameter) throws
{
let database = getDatabaseForTesting(
numberOfEntries: parameter.entryCount,
entrySizeInBytes: parameter.entrySizeInBytes)
Expand Down Expand Up @@ -43,6 +45,31 @@ class IndexPirTests: XCTestCase {
}
}

private func indexPirTest<PIR: IndexPirProtocol>(pir: PIR.Type) throws {
let config1 = IndexPirConfig(
entryCount: 100,
entrySizeInBytes: 1,
dimensionCount: 2,
polyDegree: TestUtils.testPolyDegree,
plaintextModulusBitWidth: TestUtils.testPlaintextModulus.log2)
let config2 = IndexPirConfig(
entryCount: 100,
entrySizeInBytes: 8,
dimensionCount: 2,
polyDegree: TestUtils.testPolyDegree,
plaintextModulusBitWidth: TestUtils.testPlaintextModulus.log2)
let config3 = IndexPirConfig(
entryCount: 100,
entrySizeInBytes: 24,
dimensionCount: 2,
polyDegree: TestUtils.testPolyDegree,
plaintextModulusBitWidth: TestUtils.testPlaintextModulus.log2)

try indexPirTestForParameter(pir: pir, for: PIR.generateParameter(config: config1))
try indexPirTestForParameter(pir: pir, for: PIR.generateParameter(config: config2))
try indexPirTestForParameter(pir: pir, for: PIR.generateParameter(config: config3))
}

func testIndexPir() throws {
try indexPirTest(pir: MulPir<NoOpScheme>.self)
try indexPirTest(pir: MulPir<Bfv<UInt32>>.self)
Expand Down
6 changes: 3 additions & 3 deletions Tests/PirTests/MulPirTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import XCTest

class MulPirTests: XCTestCase {
private func queryGenerationTest<Scheme: HeScheme>(scheme _: Scheme.Type) throws {
let parameter = PirTestUtils.getTestParameter()
let parameter = PirTestUtils.getTestParameter(pir: MulPir<NoOpScheme>.self)
let context: Context<Scheme> = try TestUtils.getTestContext()
let secretKey = try Scheme.generateSecretKey(context: context)
let galoisElements = MulPir<Scheme>.computeGaloisElements(parameter: parameter, context: context)
Expand Down Expand Up @@ -67,10 +67,10 @@ class MulPirTests: XCTestCase {
private func queryAndResponseTest<Scheme: HeScheme>(scheme _: Scheme.Type) throws {
let context: Context<Scheme> = try TestUtils.getTestContext()
let database: [[Scheme.Scalar]] = getDatabaseForTesting(
numberOfEntries: PirTestUtils.testNumberOfEntries,
numberOfEntries: PirTestUtils.testEntryCount,
entrySizeInCoefficient: PirTestUtils.testEntrySizeInCoefficient,
modulus: context.plaintextModulus)
let parameter = PirTestUtils.getTestParameter()
let parameter = PirTestUtils.getTestParameter(pir: MulPir<NoOpScheme>.self)
let secretKey = try Scheme.generateSecretKey(context: context)
let galoisElements = MulPir<Scheme>.computeGaloisElements(parameter: parameter, context: context)
let evaluationKey = try Scheme.generateEvaluationKey(
Expand Down
21 changes: 14 additions & 7 deletions Tests/PirTests/PirTestUtils.swift
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
@testable import Pir
import TestUtil

public enum PirTestUtils {
static let testNumberOfEntries = 100
static let testEntryCount = 100
static let testEntrySizeInCoefficient = 10
static let testDimensions = [16, 7]

static func getTestParameter() -> IndexPirParameter {
IndexPirParameter(
entryCount: testNumberOfEntries,
entrySizeInBytes: testEntrySizeInCoefficient,
dimensions: testDimensions)
static func getTestParameter<PIR: IndexPirProtocol>(
pir _: PIR.Type,
entrySizeInCoeff: Int = testEntrySizeInCoefficient) -> IndexPirParameter
{
let entrySizeInByte = (entrySizeInCoeff * TestUtils.testPlaintextModulus.log2).divCeil(UInt8.bitWidth)
let config = IndexPirConfig(
entryCount: testEntryCount,
entrySizeInBytes: entrySizeInByte,
dimensionCount: 2,
polyDegree: TestUtils.testPolyDegree,
plaintextModulusBitWidth: TestUtils.testPlaintextModulus.log2)
return PIR.generateParameter(config: config)
}
}

0 comments on commit 84329d4

Please sign in to comment.