Skip to content

Commit ac003bb

Browse files
authored
PIR & PNNS benchmarks include {de}serialization. (#256)
1 parent 9e4c6c2 commit ac003bb

File tree

2 files changed

+101
-50
lines changed

2 files changed

+101
-50
lines changed

Sources/_BenchmarkUtilities/PirBenchmarkUtilities.swift

Lines changed: 70 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,7 @@ public struct PirBenchmarkConfig<Scalar: ScalarType> {
115115

116116
extension PrivateInformationRetrieval.Response {
117117
func scaledNoiseBudget(using secretKey: Scheme.SecretKey) throws -> Int {
118-
try Int(
119-
noiseBudget(using: secretKey, variableTime: true) * Double(
120-
noiseBudgetScale))
118+
try Int(noiseBudget(using: secretKey, variableTime: true) * Double(noiseBudgetScale))
121119
}
122120
}
123121

@@ -178,29 +176,26 @@ public func pirProcessBenchmark<PirUtil: PirUtilProtocol>(
178176
struct IndexPirBenchmarkContext<Server: IndexPirServer, Client: IndexPirClient>
179177
where Server.Scheme == Client.Scheme
180178
{
179+
typealias Scheme = Server.Scheme
181180
let processedDatabase: Server.Database
182181
let server: Server
183182
let client: Client
184-
let secretKey: SecretKey<Client.Scheme>
185-
let evaluationKey: Server.Scheme.EvaluationKey
186-
let query: Client.Query
183+
let context: Scheme.Context
187184
let evaluationKeySize: Int
188185
let evaluationKeyCount: Int
189186
let querySize: Int
190187
let queryCiphertextCount: Int
191188
let responseSize: Int
192189
let responseCiphertextCount: Int
193-
let noiseBudget: Int
194190

195191
init(
196192
server _: Server.Type,
197193
client _: Client.Type,
198194
pirConfig: IndexPirConfig,
199195
encryptionConfig: EncryptionParametersConfig) async throws
200196
{
201-
let encryptParameter: EncryptionParameters<Server.Scheme.Scalar> =
202-
try EncryptionParameters(from: encryptionConfig)
203-
let context = try Server.Scheme.Context(encryptionParameters: encryptParameter)
197+
let encryptParameter: EncryptionParameters<Scheme.Scalar> = try EncryptionParameters(from: encryptionConfig)
198+
self.context = try Scheme.Context(encryptionParameters: encryptParameter)
204199
let indexPirParameters = Server.generateParameter(config: pirConfig, with: context)
205200
let database = getDatabaseForTesting(
206201
numberOfEntries: pirConfig.entryCount,
@@ -209,9 +204,8 @@ struct IndexPirBenchmarkContext<Server: IndexPirServer, Client: IndexPirClient>
209204

210205
self.server = try Server(parameter: indexPirParameters, context: context, database: processedDatabase)
211206
self.client = Client(parameter: indexPirParameters, context: context)
212-
self.secretKey = try context.generateSecretKey()
213-
self.evaluationKey = try client.generateEvaluationKey(using: secretKey)
214-
self.query = try client.generateQuery(at: [0], using: secretKey)
207+
let secretKey = try context.generateSecretKey()
208+
let evaluationKey = try client.generateEvaluationKey(using: secretKey)
215209

216210
// Validate correctness
217211
let queryIndex = Int.random(in: 0..<pirConfig.entryCount)
@@ -228,7 +222,6 @@ struct IndexPirBenchmarkContext<Server: IndexPirServer, Client: IndexPirClient>
228222
self.queryCiphertextCount = query.ciphertexts.count
229223
self.responseSize = try response.size()
230224
self.responseCiphertextCount = response.ciphertexts.count
231-
self.noiseBudget = try response.scaledNoiseBudget(using: secretKey)
232225
}
233226
}
234227

@@ -238,6 +231,7 @@ public func indexPirBenchmark<PirUtil: PirUtilProtocol>(
238231
// swiftlint:disable:next force_try
239232
config: PirBenchmarkConfig<PirUtil.Scheme.Scalar> = try! .init()) -> () -> Void
240233
{
234+
// swiftlint:disable:next closure_body_length
241235
{
242236
let benchmarkName = [
243237
"IndexPir",
@@ -251,27 +245,45 @@ public func indexPirBenchmark<PirUtil: PirUtilProtocol>(
251245
Benchmark(benchmarkName, configuration: config.benchmarkConfig) { (
252246
benchmark,
253247
benchmarkContext: IndexPirBenchmarkContext<MulPirServer<PirUtil>, MulPirClient<PirUtil>>) in
248+
let context = benchmarkContext.context
254249
for _ in benchmark.scaledIterations {
255-
try await blackHole(benchmarkContext.server.computeResponse(to: benchmarkContext.query,
256-
using: benchmarkContext
257-
.evaluationKey))
250+
let secretKey = try context.generateSecretKey()
251+
let evaluationKey = try benchmarkContext.client.generateEvaluationKey(using: secretKey)
252+
let queryIndex = Int.random(in: 0..<benchmarkContext.server.parameter.entryCount)
253+
let query = try benchmarkContext.client.generateQuery(at: [queryIndex], using: secretKey)
254+
let serializedQuery = try query.proto()
255+
let serializedEvaluationKey = evaluationKey.serialize().proto()
256+
257+
benchmark.startMeasurement()
258+
259+
let deserializedQuery: Query<PirUtil.Scheme> = try serializedQuery.native(context: context)
260+
let deserializedEvalKey: PirUtil.Scheme.EvaluationKey = try serializedEvaluationKey
261+
.native(context: context)
262+
let response = try await benchmarkContext.server.computeResponse(
263+
to: deserializedQuery,
264+
using: deserializedEvalKey)
265+
try blackHole(response.proto())
266+
267+
benchmark.stopMeasurement()
268+
269+
let noiseBudget = try response.scaledNoiseBudget(using: secretKey)
270+
benchmark.measurement(.noiseBudget, noiseBudget)
258271
}
272+
259273
benchmark.measurement(.evaluationKeySize, benchmarkContext.evaluationKeySize)
260274
benchmark.measurement(.evaluationKeyCount, benchmarkContext.evaluationKeyCount)
261275
benchmark.measurement(.querySize, benchmarkContext.querySize)
262276
benchmark.measurement(.queryCiphertextCount, benchmarkContext.queryCiphertextCount)
263277
benchmark.measurement(.responseSize, benchmarkContext.responseSize)
264278
benchmark.measurement(.responseCiphertextCount, benchmarkContext.responseCiphertextCount)
265-
benchmark.measurement(.noiseBudget, benchmarkContext.noiseBudget)
266-
}
267-
// swiftlint:enable closure_parameter_position
268-
setup: {
279+
} setup: {
269280
try await IndexPirBenchmarkContext(
270281
server: MulPirServer<PirUtil>.self,
271282
client: MulPirClient<PirUtil>.self,
272283
pirConfig: config.indexPirConfig,
273284
encryptionConfig: config.encryptionConfig)
274285
}
286+
// swiftlint:enable closure_parameter_position
275287
}
276288
}
277289

@@ -280,23 +292,21 @@ struct KeywordPirBenchmarkContext<IndexServer: IndexPirServer, IndexClient: Inde
280292
{
281293
typealias Server = KeywordPirServer<IndexServer>
282294
typealias Client = KeywordPirClient<IndexClient>
295+
typealias Scheme = IndexServer.Scheme
283296
let server: Server
284297
let client: Client
285-
let secretKey: SecretKey<Client.Scheme>
286-
let evaluationKey: Server.Scheme.EvaluationKey
287-
let query: Client.Query
298+
let context: Scheme.Context
288299
let evaluationKeySize: Int
289300
let evaluationKeyCount: Int
290301
let querySize: Int
291302
let queryCiphertextCount: Int
292303
let responseSize: Int
293304
let responseCiphertextCount: Int
294-
let noiseBudget: Int
295305

296-
init(config: PirBenchmarkConfig<Server.Scheme.Scalar>) async throws {
297-
let encryptParameter: EncryptionParameters<Server.Scheme.Scalar> =
306+
init(config: PirBenchmarkConfig<Scheme.Scalar>) async throws {
307+
let encryptParameter: EncryptionParameters<Scheme.Scalar> =
298308
try EncryptionParameters(from: config.encryptionConfig)
299-
let context = try Server.Scheme.Context(encryptionParameters: encryptParameter)
309+
self.context = try Server.Scheme.Context(encryptionParameters: encryptParameter)
300310
let rows = (0..<config.databaseConfig.entryCount).map { index in KeywordValuePair(
301311
keyword: [UInt8](String(index).utf8),
302312
value: (0..<config.databaseConfig.entrySizeInBytes).map { _ in UInt8.random(in: 0..<UInt8.max) })
@@ -336,9 +346,8 @@ struct KeywordPirBenchmarkContext<IndexServer: IndexPirServer, IndexClient: Inde
336346
keywordParameter: keywordPirConfig.parameter,
337347
pirParameter: processed.pirParameter,
338348
context: context)
339-
self.secretKey = try context.generateSecretKey()
340-
self.evaluationKey = try client.generateEvaluationKey(using: secretKey)
341-
self.query = try client.generateQuery(at: [UInt8]("0".utf8), using: secretKey)
349+
let secretKey = try context.generateSecretKey()
350+
let evaluationKey = try client.generateEvaluationKey(using: secretKey)
342351

343352
// Validate correctness
344353
let queryIndex = Int.random(in: 0..<config.databaseConfig.entryCount)
@@ -361,7 +370,6 @@ struct KeywordPirBenchmarkContext<IndexServer: IndexPirServer, IndexClient: Inde
361370
self.queryCiphertextCount = query.ciphertexts.count
362371
self.responseSize = try response.size()
363372
self.responseCiphertextCount = response.ciphertexts.count
364-
self.noiseBudget = try response.scaledNoiseBudget(using: secretKey)
365373
}
366374
}
367375

@@ -371,6 +379,7 @@ public func keywordPirBenchmark<PirUtil: PirUtilProtocol>(
371379
// swiftlint:disable:next force_try
372380
config: PirBenchmarkConfig<PirUtil.Scheme.Scalar> = try! .init()) -> () -> Void
373381
{
382+
// swiftlint:disable:next closure_body_length
374383
{
375384
let benchmarkName = [
376385
"KeywordPir",
@@ -380,21 +389,46 @@ public func keywordPirBenchmark<PirUtil: PirUtilProtocol>(
380389
"entrySize=\(config.databaseConfig.entrySizeInBytes)",
381390
"keyCompression=\(config.keywordPirConfig.keyCompression)",
382391
].joined(separator: "/")
383-
Benchmark(benchmarkName, configuration: config.benchmarkConfig) { benchmark, benchmarkContext in
392+
// swiftlint:disable closure_parameter_position
393+
Benchmark(benchmarkName, configuration: config.benchmarkConfig) { (
394+
benchmark,
395+
benchmarkContext: KeywordPirBenchmarkContext<MulPirServer<PirUtil>, MulPirClient<PirUtil>>) in
396+
let context = benchmarkContext.context
384397
for _ in benchmark.scaledIterations {
385-
try await blackHole(benchmarkContext.server.computeResponse(to: benchmarkContext.query,
386-
using: benchmarkContext.evaluationKey))
398+
let secretKey = try context.generateSecretKey()
399+
let evaluationKey = try benchmarkContext.client.generateEvaluationKey(using: secretKey)
400+
let queryIndex = Int.random(in: 0..<config.databaseConfig.entryCount)
401+
let query = try benchmarkContext.client.generateQuery(
402+
at: [UInt8](String(describing: queryIndex).utf8),
403+
using: secretKey)
404+
let serializedQuery = try query.proto()
405+
let serializedEvaluationKey = evaluationKey.serialize().proto()
406+
407+
benchmark.startMeasurement()
408+
409+
let deserializedQuery: Query<PirUtil.Scheme> = try serializedQuery.native(context: context)
410+
let deserializedEvalKey: PirUtil.Scheme.EvaluationKey = try serializedEvaluationKey
411+
.native(context: context)
412+
let response = try await benchmarkContext.server.computeResponse(
413+
to: deserializedQuery,
414+
using: deserializedEvalKey)
415+
try blackHole(response.proto())
416+
417+
benchmark.stopMeasurement()
418+
419+
let noiseBudget = try response.scaledNoiseBudget(using: secretKey)
420+
benchmark.measurement(.noiseBudget, noiseBudget)
387421
}
388422
benchmark.measurement(.evaluationKeySize, benchmarkContext.evaluationKeySize)
389423
benchmark.measurement(.evaluationKeyCount, benchmarkContext.evaluationKeyCount)
390424
benchmark.measurement(.querySize, benchmarkContext.querySize)
391425
benchmark.measurement(.queryCiphertextCount, benchmarkContext.queryCiphertextCount)
392426
benchmark.measurement(.responseSize, benchmarkContext.responseSize)
393427
benchmark.measurement(.responseCiphertextCount, benchmarkContext.responseCiphertextCount)
394-
benchmark.measurement(.noiseBudget, benchmarkContext.noiseBudget)
395428
} setup: {
396429
try await KeywordPirBenchmarkContext<MulPirServer<PirUtil>, MulPirClient<PirUtil>>(
397430
config: config)
398431
}
432+
// swiftlint:enable closure_parameter_position
399433
}
400434
}

Sources/_BenchmarkUtilities/PnnsBenchmarkUtilities.swift

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ public func cosineSimilarityBenchmark<Scheme: HeScheme>(_: Scheme.Type,
131131
config: PnnsBenchmarkConfig = try! .init(),
132132
queryCount: Int = 1) -> () -> Void
133133
{
134+
// swiftlint:disable:next closure_body_length
134135
{
135136
let benchmarkName = [
136137
"CosineSimilarity",
@@ -145,19 +146,39 @@ public func cosineSimilarityBenchmark<Scheme: HeScheme>(_: Scheme.Type,
145146
Benchmark(benchmarkName, configuration: config.benchmarkConfig) { (
146147
benchmark,
147148
benchmarkContext: PnnsBenchmarkContext<Scheme>) in
149+
let context = benchmarkContext.server.contexts[0]
150+
let vectorDimension = benchmarkContext.server.config.vectorDimension
148151
for _ in benchmark.scaledIterations {
149-
try await blackHole(
150-
benchmarkContext.server.computeResponse(
151-
to: benchmarkContext.query,
152-
using: benchmarkContext.evaluationKey))
152+
let secretKey = try context.generateSecretKey()
153+
let evaluationKey = try benchmarkContext.client.generateEvaluationKey(using: secretKey)
154+
let serializedEvaluationKey = evaluationKey.serialize().proto()
155+
let data = getDatabaseForTesting(config: PnnsDatabaseConfig(
156+
rowCount: queryCount,
157+
vectorDimension: vectorDimension))
158+
let queryVectors = Array2d(data: data.rows.map { row in row.vector })
159+
let query = try benchmarkContext.client.generateQuery(for: queryVectors, using: secretKey)
160+
let serializedQuery = try query.proto()
161+
162+
benchmark.startMeasurement()
163+
164+
let deserializedEvalKey: EvaluationKey<Scheme> = try serializedEvaluationKey.native(context: context)
165+
let deserializedQuery: Query<Scheme> = try serializedQuery.native(context: context)
166+
let response = try await benchmarkContext.server.computeResponse(
167+
to: deserializedQuery,
168+
using: deserializedEvalKey)
169+
try blackHole(response.proto())
170+
171+
benchmark.stopMeasurement()
172+
173+
let noiseBudget = try response.scaledNoiseBudget(using: secretKey)
174+
benchmark.measurement(.noiseBudget, noiseBudget)
153175
}
154176
benchmark.measurement(.evaluationKeySize, benchmarkContext.evaluationKeySize)
155177
benchmark.measurement(.evaluationKeyCount, benchmarkContext.evaluationKeyCount)
156178
benchmark.measurement(.querySize, benchmarkContext.querySize)
157179
benchmark.measurement(.queryCiphertextCount, benchmarkContext.queryCiphertextCount)
158180
benchmark.measurement(.responseSize, benchmarkContext.responseSize)
159181
benchmark.measurement(.responseCiphertextCount, benchmarkContext.responseCiphertextCount)
160-
benchmark.measurement(.noiseBudget, benchmarkContext.noiseBudget)
161182
} setup: {
162183
try await PnnsBenchmarkContext<Scheme>(
163184
databaseConfig: config.databaseConfig,
@@ -236,16 +257,13 @@ struct PnnsBenchmarkContext<Scheme: HeScheme> {
236257
let processedDatabase: ProcessedDatabase<Scheme>
237258
let server: Server<Scheme>
238259
let client: Client<Scheme>
239-
let secretKey: SecretKey<Scheme>
240-
let evaluationKey: Scheme.EvaluationKey
260+
let contexts: [Scheme.Context]
241261
let evaluationKeyCount: Int
242-
let query: Query<Scheme>
243262
let evaluationKeySize: Int
244263
let querySize: Int
245264
let queryCiphertextCount: Int
246265
let responseSize: Int
247266
let responseCiphertextCount: Int
248-
let noiseBudget: Int
249267

250268
init(databaseConfig: PnnsDatabaseConfig,
251269
encryptionConfig: EncryptionParametersConfig,
@@ -293,18 +311,18 @@ struct PnnsBenchmarkContext<Scheme: HeScheme> {
293311
databasePacking: .diagonal(babyStepGiantStep: babyStepGiantStep))
294312

295313
let database = getDatabaseForTesting(config: databaseConfig)
296-
let contexts = try clientConfig.encryptionParameters
314+
self.contexts = try clientConfig.encryptionParameters
297315
.map { encryptionParameters in try Scheme.Context(encryptionParameters: encryptionParameters) }
298316
self.processedDatabase = try await database.process(config: serverConfig, contexts: contexts)
299317
self.client = try Client(config: clientConfig, contexts: contexts)
300318
self.server = try Server(database: processedDatabase)
301-
self.secretKey = try client.generateSecretKey()
302-
self.evaluationKey = try client.generateEvaluationKey(using: secretKey)
319+
let secretKey = try client.generateSecretKey()
320+
let evaluationKey = try client.generateEvaluationKey(using: secretKey)
303321

304322
// We query exact matches from rows in the database
305323
let databaseVectors = Array2d(data: database.rows.map { row in row.vector })
306324
let queryVectors = Array2d(data: database.rows.prefix(queryCount).map { row in row.vector })
307-
self.query = try client.generateQuery(for: queryVectors, using: secretKey)
325+
let query = try client.generateQuery(for: queryVectors, using: secretKey)
308326

309327
let response = try await server.computeResponse(to: query, using: evaluationKey)
310328
let decrypted = try client.decrypt(response: response, using: secretKey)
@@ -324,6 +342,5 @@ struct PnnsBenchmarkContext<Scheme: HeScheme> {
324342
self.responseSize = try response.size()
325343
self.responseCiphertextCount = response.ciphertextMatrices
326344
.map { matrix in matrix.ciphertexts.count }.sum()
327-
self.noiseBudget = try response.scaledNoiseBudget(using: secretKey)
328345
}
329346
}

0 commit comments

Comments
 (0)