Skip to content

Commit

Permalink
Client bug workaround (#11)
Browse files Browse the repository at this point in the history
* Introduce a copy of the client side bug
* Add test that triggers the client bug
* Implement the workaround
  • Loading branch information
karulont authored Jul 15, 2024
1 parent 64f30c8 commit f937776
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Sources/PrivateInformationRetrieval/CuckooTable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public struct CuckooTableConfig: Hashable, Codable, Sendable {
///
/// If enabled, this setting will store only entries using the same hash function, into the same bucket.
/// This can help improve PIR runtime.
let multipleTables: Bool
public let multipleTables: Bool

/// Initializes a ``CuckooTableConfig``.
/// - Parameters:
Expand Down
27 changes: 24 additions & 3 deletions Sources/PrivateInformationRetrieval/KeywordPirProtocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ public final class KeywordPirServer<PirServer: IndexPirServer>: KeywordPirProtoc
with context: Context<Scheme>)
throws -> ProcessedDatabaseWithParameters<Scheme>
{
let cuckooTable = try CuckooTable(config: config.cuckooTableConfig, database: database)
let cuckooTableConfig = config.cuckooTableConfig
let cuckooTable = try CuckooTable(config: cuckooTableConfig, database: database)
let entryTable = try cuckooTable.serializeBuckets()
let maxEntrySize: Int
switch cuckooTable.config.bucketCount {
Expand All @@ -151,15 +152,35 @@ public final class KeywordPirServer<PirServer: IndexPirServer>: KeywordPirProtoc
}
maxEntrySize = foundMaxEntrySize
case .fixedSize:
maxEntrySize = config.cuckooTableConfig.maxSerializedBucketSize
maxEntrySize = cuckooTableConfig.maxSerializedBucketSize
}

// if we would hit the client side bug, reprocess with modified `maxSerializedBucketSize`
if maxEntrySize.isMultiple(of: context.bytesPerPlaintext)
|| context.bytesPerPlaintext.isMultiple(of: maxEntrySize)
{
let newCuckooTableConfig = try CuckooTableConfig(
hashFunctionCount: cuckooTableConfig.hashFunctionCount,
maxEvictionCount: cuckooTableConfig.maxEvictionCount,
maxSerializedBucketSize: maxEntrySize - 1,
bucketCount: cuckooTableConfig.bucketCount,
multipleTables: cuckooTableConfig.multipleTables)

let newConfig = try KeywordPirConfig(
dimensionCount: config.dimensionCount,
cuckooTableConfig: newCuckooTableConfig,
unevenDimensions: config.unevenDimensions)
return try Self.process(database: database, config: newConfig, with: context)
}

let indexPirConfig = try IndexPirConfig(
entryCount: cuckooTable.bucketPerTable,
entrySizeInBytes: maxEntrySize,
dimensionCount: config.dimensionCount,
batchSize: config.cuckooTableConfig.hashFunctionCount,
batchSize: cuckooTableConfig.hashFunctionCount,
unevenDimensions: config.unevenDimensions)
let indexPirParameter = PirServer.generateParameter(config: indexPirConfig, with: context)

let processedDb = try PirServer.Database(plaintexts: stride(
from: 0,
to: entryTable.count,
Expand Down
8 changes: 7 additions & 1 deletion Sources/PrivateInformationRetrieval/MulPir.swift
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public enum MulPir<Scheme: HeScheme>: IndexPirProtocol {

return IndexPirParameter(
entryCount: config.entryCount,
entrySizeInBytes: config.entrySizeInBytes,
entrySizeInBytes: entrySizeInBytes,
dimensions: dimensions, batchSize: config.batchSize)
}

Expand Down Expand Up @@ -223,6 +223,12 @@ extension MulPirClient {
bitsPerCoeff: context.plaintextModulus.log2)
}

// this is a copy of the client side bug
let accessRange = computeResponseRangeInBytes(at: entryIndex)
guard accessRange.upperBound < bytes.count else {
throw PirError.validationError("Client side bug hit!")
}

return Array(bytes[computeResponseRangeInBytes(at: entryIndex)])
}
}
Expand Down
39 changes: 39 additions & 0 deletions Tests/PrivateInformationRetrievalTests/KeywordPirTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,45 @@ class KeywordPirTests: XCTestCase {
MulPirServer<Bfv<UInt64>>.self, client: MulPirClient<Bfv<UInt64>>.self)
}

func testClientBugWorkaround() throws {
func runTest<PirServer: IndexPirServer, PirClient: IndexPirClient>(
rlweParams: PredefinedRlweParameters,
server _: PirServer.Type,
client _: PirClient.Type) throws where PirServer.IndexPir == PirClient.IndexPir
{
let context: Context<PirServer.Scheme> = try Context(encryptionParameters: .init(from: rlweParams))

var testRng = TestRng()
let testDatabase = PirTestUtils.getTestTable(rowCount: 1000, valueSize: 1, using: &testRng)
let config = try KeywordPirConfig(
dimensionCount: 2,
cuckooTableConfig: .defaultKeywordPir(maxSerializedBucketSize: 1024),
unevenDimensions: true)
let processed = try KeywordPirServer<PirServer>.process(
database: testDatabase,
config: config,
with: context)
let server = try KeywordPirServer<PirServer>(
context: context,
processed: processed)
let client = KeywordPirClient<PirClient>(
keywordParameter: config.parameter,
pirParameter: processed.pirParameter,
context: context)
let secretKey = try context.generateSecretKey()
let evaluationKey = try client.generateEvaluationKey(using: secretKey)
let query = try client.generateQuery(at: [], using: secretKey)
let response = try server.computeResponse(to: query, using: evaluationKey)
let result = try client.decrypt(response: response, at: [], using: secretKey)
XCTAssertNil(result)
}
let rlweParams = PredefinedRlweParameters.n_4096_logq_27_28_28_logt_5
try runTest(rlweParams: rlweParams, server:
MulPirServer<Bfv<UInt32>>.self, client: MulPirClient<Bfv<UInt32>>.self)
try runTest(rlweParams: rlweParams, server:
MulPirServer<Bfv<UInt64>>.self, client: MulPirClient<Bfv<UInt64>>.self)
}

func testSharding() throws {
func runTest<PirServer: IndexPirServer, PirClient: IndexPirClient>(
rlweParameters: PredefinedRlweParameters,
Expand Down

0 comments on commit f937776

Please sign in to comment.