55use std:: error:: Error ;
66use std:: io:: Cursor ;
77
8+ use log:: { debug, warn} ;
9+
810use prio:: vdaf:: prio3:: Prio3Sum ;
911use prio:: vdaf:: prio3:: Prio3SumVec ;
1012use thin_vec:: ThinVec ;
@@ -19,8 +21,6 @@ use types::Time;
1921
2022use prio:: codec:: Encode ;
2123use prio:: codec:: { decode_u16_items, encode_u32_items} ;
22- use prio:: flp:: types:: { Sum , SumVec } ;
23- use prio:: vdaf:: prio3:: Prio3 ;
2424use prio:: vdaf:: Client ;
2525use prio:: vdaf:: VdafError ;
2626
@@ -41,25 +41,24 @@ extern "C" {
4141 ) -> bool ;
4242}
4343
44- pub fn new_prio_u8 ( num_aggregators : u8 , bits : u32 ) -> Result < Prio3Sum , VdafError > {
44+ pub fn new_prio_sum ( num_aggregators : u8 , bits : usize ) -> Result < Prio3Sum , VdafError > {
4545 if bits > 64 {
4646 return Err ( VdafError :: Uncategorized ( format ! (
4747 "bit length ({}) exceeds limit for aggregate type (64)" ,
4848 bits
4949 ) ) ) ;
5050 }
5151
52- Prio3 :: new ( num_aggregators, Sum :: new ( bits as usize ) ? )
52+ Prio3Sum :: new_sum ( num_aggregators, bits)
5353}
5454
55- pub fn new_prio_vecu8 ( num_aggregators : u8 , len : usize ) -> Result < Prio3SumVec , VdafError > {
55+ pub fn new_prio_sumvec (
56+ num_aggregators : u8 ,
57+ len : usize ,
58+ bits : usize ,
59+ ) -> Result < Prio3SumVec , VdafError > {
5660 let chunk_length = prio:: vdaf:: prio3:: optimal_chunk_length ( 8 * len) ;
57- Prio3 :: new ( num_aggregators, SumVec :: new ( 8 , len, chunk_length) ?)
58- }
59-
60- pub fn new_prio_vecu16 ( num_aggregators : u8 , len : usize ) -> Result < Prio3SumVec , VdafError > {
61- let chunk_length = prio:: vdaf:: prio3:: optimal_chunk_length ( 16 * len) ;
62- Prio3 :: new ( num_aggregators, SumVec :: new ( 16 , len, chunk_length) ?)
61+ Prio3SumVec :: new_sum_vec ( num_aggregators, bits, len, chunk_length)
6362}
6463
6564enum Role {
@@ -112,14 +111,17 @@ impl Shardable for u8 {
112111 & self ,
113112 nonce : & [ u8 ; 16 ] ,
114113 ) -> Result < ( Vec < u8 > , Vec < Vec < u8 > > ) , Box < dyn std:: error:: Error > > {
115- let prio = new_prio_u8 ( 2 , 2 ) ?;
114+ let prio = new_prio_sum ( 2 , 8 ) ?;
116115
117116 let ( public_share, input_shares) = prio. shard ( & ( * self as u128 ) , nonce) ?;
118117
119118 debug_assert_eq ! ( input_shares. len( ) , 2 ) ;
120119
121- let encoded_input_shares = input_shares. iter ( ) . map ( |s| s. get_encoded ( ) ) . collect ( ) ;
122- let encoded_public_share = public_share. get_encoded ( ) ;
120+ let encoded_input_shares = input_shares
121+ . iter ( )
122+ . map ( |s| s. get_encoded ( ) )
123+ . collect :: < Result < Vec < _ > , _ > > ( ) ?;
124+ let encoded_public_share = public_share. get_encoded ( ) ?;
123125 Ok ( ( encoded_public_share, encoded_input_shares) )
124126 }
125127}
@@ -129,15 +131,18 @@ impl Shardable for ThinVec<u8> {
129131 & self ,
130132 nonce : & [ u8 ; 16 ] ,
131133 ) -> Result < ( Vec < u8 > , Vec < Vec < u8 > > ) , Box < dyn std:: error:: Error > > {
132- let prio = new_prio_vecu8 ( 2 , self . len ( ) ) ?;
134+ let prio = new_prio_sumvec ( 2 , self . len ( ) , 8 ) ?;
133135
134136 let measurement: Vec < u128 > = self . iter ( ) . map ( |e| ( * e as u128 ) ) . collect ( ) ;
135137 let ( public_share, input_shares) = prio. shard ( & measurement, nonce) ?;
136138
137139 debug_assert_eq ! ( input_shares. len( ) , 2 ) ;
138140
139- let encoded_input_shares = input_shares. iter ( ) . map ( |s| s. get_encoded ( ) ) . collect ( ) ;
140- let encoded_public_share = public_share. get_encoded ( ) ;
141+ let encoded_input_shares = input_shares
142+ . iter ( )
143+ . map ( |s| s. get_encoded ( ) )
144+ . collect :: < Result < Vec < _ > , _ > > ( ) ?;
145+ let encoded_public_share = public_share. get_encoded ( ) ?;
141146 Ok ( ( encoded_public_share, encoded_input_shares) )
142147 }
143148}
@@ -147,23 +152,26 @@ impl Shardable for ThinVec<u16> {
147152 & self ,
148153 nonce : & [ u8 ; 16 ] ,
149154 ) -> Result < ( Vec < u8 > , Vec < Vec < u8 > > ) , Box < dyn std:: error:: Error > > {
150- let prio = new_prio_vecu16 ( 2 , self . len ( ) ) ?;
155+ let prio = new_prio_sumvec ( 2 , self . len ( ) , 16 ) ?;
151156
152157 let measurement: Vec < u128 > = self . iter ( ) . map ( |e| ( * e as u128 ) ) . collect ( ) ;
153158 let ( public_share, input_shares) = prio. shard ( & measurement, nonce) ?;
154159
155160 debug_assert_eq ! ( input_shares. len( ) , 2 ) ;
156161
157- let encoded_input_shares = input_shares. iter ( ) . map ( |s| s. get_encoded ( ) ) . collect ( ) ;
158- let encoded_public_share = public_share. get_encoded ( ) ;
162+ let encoded_input_shares = input_shares
163+ . iter ( )
164+ . map ( |s| s. get_encoded ( ) )
165+ . collect :: < Result < Vec < _ > , _ > > ( ) ?;
166+ let encoded_public_share = public_share. get_encoded ( ) ?;
159167 Ok ( ( encoded_public_share, encoded_input_shares) )
160168 }
161169}
162170
163171/// Pre-fill the info part of the HPKE sealing with the constants from the standard.
164172fn make_base_info ( ) -> Vec < u8 > {
165173 let mut info = Vec :: < u8 > :: new ( ) ;
166- const START : & [ u8 ] = "dap-07 input share" . as_bytes ( ) ;
174+ const START : & [ u8 ] = "dap-09 input share" . as_bytes ( ) ;
167175 info. extend ( START ) ;
168176 const FIXED : u8 = 1 ;
169177 info. push ( FIXED ) ;
@@ -215,7 +223,8 @@ fn get_dap_report_internal<T: Shardable>(
215223 }
216224 . get_encoded ( )
217225 } )
218- . collect ( ) ;
226+ . collect :: < Result < Vec < _ > , _ > > ( ) ?;
227+ debug ! ( "Plaintext input shares computed." ) ;
219228
220229 let metadata = ReportMetadata {
221230 report_id,
@@ -230,18 +239,20 @@ fn get_dap_report_internal<T: Shardable>(
230239 let mut info = make_base_info ( ) ;
231240
232241 let mut aad = Vec :: from ( * task_id) ;
233- metadata. encode ( & mut aad) ;
234- encode_u32_items ( & mut aad, & ( ) , & encoded_public_share) ;
242+ metadata. encode ( & mut aad) ? ;
243+ encode_u32_items ( & mut aad, & ( ) , & encoded_public_share) ? ;
235244
236245 info. push ( Role :: Leader as u8 ) ;
237246
238247 let leader_payload =
239248 hpke_encrypt_wrapper ( & plaintext_input_shares[ 0 ] , & aad, & info, & leader_hpke_config) ?;
249+ debug ! ( "Leader payload encrypted." ) ;
240250
241251 * info. last_mut ( ) . unwrap ( ) = Role :: Helper as u8 ;
242252
243253 let helper_payload =
244254 hpke_encrypt_wrapper ( & plaintext_input_shares[ 1 ] , & aad, & info, & helper_hpke_config) ?;
255+ debug ! ( "Helper payload encrypted." ) ;
245256
246257 Ok ( Report {
247258 metadata,
@@ -264,20 +275,22 @@ pub extern "C" fn dapGetReportU8(
264275) -> bool {
265276 assert_eq ! ( task_id. len( ) , 32 ) ;
266277
267- if let Ok ( report) = get_dap_report_internal :: < u8 > (
278+ let Ok ( report) = get_dap_report_internal :: < u8 > (
268279 leader_hpke_config_encoded,
269280 helper_hpke_config_encoded,
270281 & measurement,
271282 & task_id. as_slice ( ) . try_into ( ) . unwrap ( ) ,
272283 time_precision,
273- ) {
274- let encoded_report = report. get_encoded ( ) ;
275- out_report. extend ( encoded_report) ;
276-
277- true
278- } else {
279- false
280- }
284+ ) else {
285+ warn ! ( "Creating report failed!" ) ;
286+ return false ;
287+ } ;
288+ let Ok ( encoded_report) = report. get_encoded ( ) else {
289+ warn ! ( "Encoding report failed!" ) ;
290+ return false ;
291+ } ;
292+ out_report. extend ( encoded_report) ;
293+ true
281294}
282295
283296#[ no_mangle]
@@ -291,20 +304,22 @@ pub extern "C" fn dapGetReportVecU8(
291304) -> bool {
292305 assert_eq ! ( task_id. len( ) , 32 ) ;
293306
294- if let Ok ( report) = get_dap_report_internal :: < ThinVec < u8 > > (
307+ let Ok ( report) = get_dap_report_internal :: < ThinVec < u8 > > (
295308 leader_hpke_config_encoded,
296309 helper_hpke_config_encoded,
297310 measurement,
298311 & task_id. as_slice ( ) . try_into ( ) . unwrap ( ) ,
299312 time_precision,
300- ) {
301- let encoded_report = report. get_encoded ( ) ;
302- out_report. extend ( encoded_report) ;
303-
304- true
305- } else {
306- false
307- }
313+ ) else {
314+ warn ! ( "Creating report failed!" ) ;
315+ return false ;
316+ } ;
317+ let Ok ( encoded_report) = report. get_encoded ( ) else {
318+ warn ! ( "Encoding report failed!" ) ;
319+ return false ;
320+ } ;
321+ out_report. extend ( encoded_report) ;
322+ true
308323}
309324
310325#[ no_mangle]
@@ -318,18 +333,20 @@ pub extern "C" fn dapGetReportVecU16(
318333) -> bool {
319334 assert_eq ! ( task_id. len( ) , 32 ) ;
320335
321- if let Ok ( report) = get_dap_report_internal :: < ThinVec < u16 > > (
336+ let Ok ( report) = get_dap_report_internal :: < ThinVec < u16 > > (
322337 leader_hpke_config_encoded,
323338 helper_hpke_config_encoded,
324339 measurement,
325340 & task_id. as_slice ( ) . try_into ( ) . unwrap ( ) ,
326341 time_precision,
327- ) {
328- let encoded_report = report. get_encoded ( ) ;
329- out_report. extend ( encoded_report) ;
330-
331- true
332- } else {
333- false
334- }
342+ ) else {
343+ warn ! ( "Creating report failed!" ) ;
344+ return false ;
345+ } ;
346+ let Ok ( encoded_report) = report. get_encoded ( ) else {
347+ warn ! ( "Encoding report failed!" ) ;
348+ return false ;
349+ } ;
350+ out_report. extend ( encoded_report) ;
351+ true
335352}
0 commit comments