diff --git a/Sources/PrivateInformationRetrieval/CuckooTable.swift b/Sources/PrivateInformationRetrieval/CuckooTable.swift index 089d5a6d..657c7d56 100644 --- a/Sources/PrivateInformationRetrieval/CuckooTable.swift +++ b/Sources/PrivateInformationRetrieval/CuckooTable.swift @@ -52,7 +52,7 @@ public struct CuckooTableConfig: Hashable, Codable, Sendable { /// /// If enabled, this setting will store only entries using the same hash function, into the same bucket. /// This can help improve PIR runtime. - let multipleTables: Bool + public let multipleTables: Bool /// Initializes a ``CuckooTableConfig``. /// - Parameters: diff --git a/Sources/PrivateInformationRetrieval/KeywordPirProtocol.swift b/Sources/PrivateInformationRetrieval/KeywordPirProtocol.swift index d3329cf0..be21a746 100644 --- a/Sources/PrivateInformationRetrieval/KeywordPirProtocol.swift +++ b/Sources/PrivateInformationRetrieval/KeywordPirProtocol.swift @@ -141,7 +141,8 @@ public final class KeywordPirServer: KeywordPirProtoc with context: Context) throws -> ProcessedDatabaseWithParameters { - let cuckooTable = try CuckooTable(config: config.cuckooTableConfig, database: database) + let cuckooTableConfig = config.cuckooTableConfig + let cuckooTable = try CuckooTable(config: cuckooTableConfig, database: database) let entryTable = try cuckooTable.serializeBuckets() let maxEntrySize: Int switch cuckooTable.config.bucketCount { @@ -151,15 +152,35 @@ public final class KeywordPirServer: KeywordPirProtoc } maxEntrySize = foundMaxEntrySize case .fixedSize: - maxEntrySize = config.cuckooTableConfig.maxSerializedBucketSize + maxEntrySize = cuckooTableConfig.maxSerializedBucketSize } + + // if we would hit the client side bug, reprocess with modified `maxSerializedBucketSize` + if maxEntrySize.isMultiple(of: context.bytesPerPlaintext) + || context.bytesPerPlaintext.isMultiple(of: maxEntrySize) + { + let newCuckooTableConfig = try CuckooTableConfig( + hashFunctionCount: cuckooTableConfig.hashFunctionCount, + maxEvictionCount: cuckooTableConfig.maxEvictionCount, + maxSerializedBucketSize: maxEntrySize - 1, + bucketCount: cuckooTableConfig.bucketCount, + multipleTables: cuckooTableConfig.multipleTables) + + let newConfig = try KeywordPirConfig( + dimensionCount: config.dimensionCount, + cuckooTableConfig: newCuckooTableConfig, + unevenDimensions: config.unevenDimensions) + return try Self.process(database: database, config: newConfig, with: context) + } + let indexPirConfig = try IndexPirConfig( entryCount: cuckooTable.bucketPerTable, entrySizeInBytes: maxEntrySize, dimensionCount: config.dimensionCount, - batchSize: config.cuckooTableConfig.hashFunctionCount, + batchSize: cuckooTableConfig.hashFunctionCount, unevenDimensions: config.unevenDimensions) let indexPirParameter = PirServer.generateParameter(config: indexPirConfig, with: context) + let processedDb = try PirServer.Database(plaintexts: stride( from: 0, to: entryTable.count, diff --git a/Sources/PrivateInformationRetrieval/MulPir.swift b/Sources/PrivateInformationRetrieval/MulPir.swift index 29150885..f293055a 100644 --- a/Sources/PrivateInformationRetrieval/MulPir.swift +++ b/Sources/PrivateInformationRetrieval/MulPir.swift @@ -63,7 +63,7 @@ public enum MulPir: IndexPirProtocol { return IndexPirParameter( entryCount: config.entryCount, - entrySizeInBytes: config.entrySizeInBytes, + entrySizeInBytes: entrySizeInBytes, dimensions: dimensions, batchSize: config.batchSize) } @@ -223,6 +223,12 @@ extension MulPirClient { bitsPerCoeff: context.plaintextModulus.log2) } + // this is a copy of the client side bug + let accessRange = computeResponseRangeInBytes(at: entryIndex) + guard accessRange.upperBound < bytes.count else { + throw PirError.validationError("Client side bug hit!") + } + return Array(bytes[computeResponseRangeInBytes(at: entryIndex)]) } } diff --git a/Tests/PrivateInformationRetrievalTests/CuckooTableTests.swift b/Tests/PrivateInformationRetrievalTests/CuckooTableTests.swift index 69d09bdb..a1fea8ed 100644 --- a/Tests/PrivateInformationRetrievalTests/CuckooTableTests.swift +++ b/Tests/PrivateInformationRetrievalTests/CuckooTableTests.swift @@ -80,9 +80,9 @@ class CuckooTableTests: XCTestCase { let cuckooTable = try CuckooTable(config: config, database: testDatabase, using: rng) let summary = CuckooTable.CuckooTableInformation( entryCount: 100, - bucketCount: 64, - emptyBucketCount: 4, - loadFactor: 0.645) + bucketCount: 80, + emptyBucketCount: 19, + loadFactor: 0.52) XCTAssertEqual(try cuckooTable.summarize(), summary) } diff --git a/Tests/PrivateInformationRetrievalTests/KeywordPirTests.swift b/Tests/PrivateInformationRetrievalTests/KeywordPirTests.swift index 7fedbd66..5996fe2f 100644 --- a/Tests/PrivateInformationRetrievalTests/KeywordPirTests.swift +++ b/Tests/PrivateInformationRetrievalTests/KeywordPirTests.swift @@ -260,7 +260,7 @@ class KeywordPirTests: XCTestCase { let cuckooConfig = try CuckooTableConfig( hashFunctionCount: 2, maxEvictionCount: 100, - maxSerializedBucketSize: valueSize * 4, + maxSerializedBucketSize: HashBucket.serializedSize(singleValueSize: valueSize) * 4, bucketCount: .allowExpansion(expansionFactor: 1.1, targetLoadFactor: 0.7)) let keywordConfig = try KeywordPirConfig( dimensionCount: 2, @@ -291,7 +291,7 @@ class KeywordPirTests: XCTestCase { processed: processed) let client = KeywordPirClient( keywordParameter: keywordConfig.parameter, - pirParameter: pirParameter, + pirParameter: processed.pirParameter, context: testContext) let secretKey = try testContext.generateSecretKey() let evaluationKey = try client.generateEvaluationKey(using: secretKey) @@ -320,6 +320,45 @@ class KeywordPirTests: XCTestCase { MulPirServer>.self, client: MulPirClient>.self) } + func testClientBugWorkaround() throws { + func runTest( + rlweParams: PredefinedRlweParameters, + server _: PirServer.Type, + client _: PirClient.Type) throws where PirServer.IndexPir == PirClient.IndexPir + { + let context: Context = try Context(encryptionParameters: .init(from: rlweParams)) + + var testRng = TestRng() + let testDatabase = PirTestUtils.getTestTable(rowCount: 1000, valueSize: 1, using: &testRng) + let config = try KeywordPirConfig( + dimensionCount: 2, + cuckooTableConfig: .defaultKeywordPir(maxSerializedBucketSize: 1024), + unevenDimensions: true) + let processed = try KeywordPirServer.process( + database: testDatabase, + config: config, + with: context) + let server = try KeywordPirServer( + context: context, + processed: processed) + let client = KeywordPirClient( + keywordParameter: config.parameter, + pirParameter: processed.pirParameter, + context: context) + let secretKey = try context.generateSecretKey() + let evaluationKey = try client.generateEvaluationKey(using: secretKey) + let query = try client.generateQuery(at: [], using: secretKey) + let response = try server.computeResponse(to: query, using: evaluationKey) + let result = try client.decrypt(response: response, at: [], using: secretKey) + XCTAssertNil(result) + } + let rlweParams = PredefinedRlweParameters.n_4096_logq_27_28_28_logt_5 + try runTest(rlweParams: rlweParams, server: + MulPirServer>.self, client: MulPirClient>.self) + try runTest(rlweParams: rlweParams, server: + MulPirServer>.self, client: MulPirClient>.self) + } + func testSharding() throws { func runTest( rlweParameters: PredefinedRlweParameters, diff --git a/Tests/PrivateInformationRetrievalTests/PirTestUtils.swift b/Tests/PrivateInformationRetrievalTests/PirTestUtils.swift index e47e13e0..745bb2e9 100644 --- a/Tests/PrivateInformationRetrievalTests/PirTestUtils.swift +++ b/Tests/PrivateInformationRetrievalTests/PirTestUtils.swift @@ -49,8 +49,10 @@ package enum PirTestUtils { return generateRandomData(size: size, using: &rng) } - static func generateRandomData(size: Int, using rng: inout some RandomNumberGenerator) -> [UInt8] { - (0.. [UInt8] { + var data = [UInt8](repeating: 0, count: size) + rng.fill(&data) + return data } static func getTestTable(rowCount: Int, valueSize: Int) -> [KeywordValuePair] { @@ -61,17 +63,21 @@ package enum PirTestUtils { static func getTestTable( rowCount: Int, valueSize: Int, - using rng: inout some RandomNumberGenerator, + using rng: inout some PseudoRandomNumberGenerator, keywordSize: Int = 30) -> [KeywordValuePair] { - var rows = [KeywordValuePair]() + precondition(rowCount > 0) + var keywords: Set = [] + var rows: [KeywordValuePair] = [] + rows.reserveCapacity(rowCount) repeat { let keyword = PirTestUtils.generateRandomData(size: keywordSize, using: &rng) - if !rows.contains(where: { existingPair in keyword == existingPair.keyword }) { - rows.append(KeywordValuePair( - keyword: keyword, - value: PirTestUtils.generateRandomData(size: valueSize, using: &rng))) + if keywords.contains(keyword) { + continue } + keywords.insert(keyword) + let value = PirTestUtils.generateRandomData(size: valueSize, using: &rng) + rows.append(KeywordValuePair(keyword: keyword, value: value)) } while rows.count < rowCount return rows }