1919from __future__ import division
2020from __future__ import print_function
2121
22+ import numpy as np
2223from tensor2tensor .data_generators .ops import pack_sequences_ops
2324import tensorflow .compat .v1 as tf
2425
2526
27+ def _pack_sequences_k (inputs , targets , input_max_length , target_max_length ):
28+ """Wrapper for pack_sequences_k with same interface as pack_sequences_2."""
29+ inputs = tf .convert_to_tensor (inputs , tf .int32 )
30+ targets = tf .convert_to_tensor (targets , tf .int32 )
31+ input_max_length = tf .convert_to_tensor (input_max_length , dtype = tf .int32 )
32+ target_max_length = tf .convert_to_tensor (target_max_length , dtype = tf .int32 )
33+ (packed , segmentation , position ) = pack_sequences_ops .pack_sequences_k (
34+ [inputs , targets ], [input_max_length , target_max_length ])
35+ (inputs_packed , targets_packed ) = packed
36+ (inputs_segmentation , targets_segmentation ) = segmentation
37+ (inputs_position , targets_position ) = position
38+ return (inputs_packed , inputs_segmentation , inputs_position , targets_packed ,
39+ targets_segmentation , targets_position )
40+
41+
2642class PackSequencesOpsTest (tf .test .TestCase ):
2743
28- def test_pack_sequences2 (self ):
44+ def do_test_pack_sequences_length3 (self , pack_fn ):
45+ inputs = [
46+ [1 , 2 , 3 ],
47+ [4 , 5 , 0 ],
48+ [6 , 0 , 0 ],
49+ ]
50+ targets = [
51+ [10 , 0 , 0 ],
52+ [20 , 30 , 40 ],
53+ [50 , 60 , 0 ],
54+ ]
55+ inputs_max_length = 3
56+ targets_max_length = 3
57+ (inputs_packed , inputs_segmentation , inputs_position , targets_packed ,
58+ targets_segmentation , targets_position ) = (
59+ pack_fn (inputs , targets , inputs_max_length , targets_max_length ))
60+ self .assertAllEqual (inputs_packed , [
61+ [1 , 2 , 3 ],
62+ [4 , 5 , 0 ],
63+ [6 , 0 , 0 ],
64+ ])
65+ self .assertAllEqual (inputs_segmentation , [
66+ [1 , 1 , 1 ],
67+ [1 , 1 , 0 ],
68+ [1 , 0 , 0 ],
69+ ])
70+ self .assertAllEqual (inputs_position , [
71+ [0 , 1 , 2 ],
72+ [0 , 1 , 0 ],
73+ [0 , 0 , 0 ],
74+ ])
75+ self .assertAllEqual (targets_packed , [
76+ [10 , 0 , 0 ],
77+ [20 , 30 , 40 ],
78+ [50 , 60 , 0 ],
79+ ])
80+ self .assertAllEqual (targets_segmentation , [
81+ [1 , 0 , 0 ],
82+ [1 , 1 , 1 ],
83+ [1 , 1 , 0 ],
84+ ])
85+ self .assertAllEqual (targets_position , [
86+ [0 , 0 , 0 ],
87+ [0 , 1 , 2 ],
88+ [0 , 1 , 0 ],
89+ ])
90+
91+ def do_test_pack_sequences_length4 (self , pack_fn ):
92+ inputs = [
93+ [1 , 2 , 3 ],
94+ [4 , 5 , 0 ],
95+ [6 , 0 , 0 ],
96+ ]
97+ targets = [
98+ [10 , 0 , 0 ],
99+ [20 , 30 , 40 ],
100+ [50 , 60 , 0 ],
101+ ]
102+ inputs_max_length = 4
103+ targets_max_length = 4
104+ (inputs_packed , inputs_segmentation , inputs_position , targets_packed ,
105+ targets_segmentation , targets_position ) = (
106+ pack_fn (inputs , targets , inputs_max_length , targets_max_length ))
107+ self .assertAllEqual (inputs_packed , [
108+ [1 , 2 , 3 , 6 ],
109+ [4 , 5 , 0 , 0 ],
110+ ])
111+ self .assertAllEqual (inputs_segmentation , [
112+ [1 , 1 , 1 , 2 ],
113+ [1 , 1 , 0 , 0 ],
114+ ])
115+ self .assertAllEqual (inputs_position , [
116+ [0 , 1 , 2 , 0 ],
117+ [0 , 1 , 0 , 0 ],
118+ ])
119+ self .assertAllEqual (targets_packed , [
120+ [10 , 50 , 60 , 0 ],
121+ [20 , 30 , 40 , 0 ],
122+ ])
123+ self .assertAllEqual (targets_segmentation , [
124+ [1 , 2 , 2 , 0 ],
125+ [1 , 1 , 1 , 0 ],
126+ ])
127+ self .assertAllEqual (targets_position , [
128+ [0 , 0 , 1 , 0 ],
129+ [0 , 1 , 2 , 0 ],
130+ ])
131+
132+ def do_test_pack_sequences_length5 (self , pack_fn ):
29133 inputs = [
30134 [1 , 2 , 3 ],
31135 [4 , 5 , 0 ],
@@ -37,10 +141,9 @@ def test_pack_sequences2(self):
37141 [50 , 60 , 0 ],
38142 ]
39143 max_length = 5
40- (inputs_packed , inputs_segmentation , inputs_position ,
41- targets_packed , targets_segmentation , targets_position ) = (
42- pack_sequences_ops .pack_sequences2 (
43- inputs , targets , max_length , max_length ))
144+ (inputs_packed , inputs_segmentation , inputs_position , targets_packed ,
145+ targets_segmentation , targets_position ) = (
146+ pack_fn (inputs , targets , max_length , max_length ))
44147 self .assertAllEqual (
45148 inputs_packed , [
46149 [1 , 2 , 3 , 4 , 5 ],
@@ -72,53 +175,177 @@ def test_pack_sequences2(self):
72175 [0 , 1 , 0 , 0 , 0 ],
73176 ])
74177
75- def test_pack_sequences_k (self ):
76- inputs = tf . convert_to_tensor ( [
178+ def do_test_pack_sequences_length6 (self , pack_fn ):
179+ inputs = [
77180 [1 , 2 , 3 ],
78181 [4 , 5 , 0 ],
79182 [6 , 0 , 0 ],
80- ], dtype = tf . int32 )
81- targets = tf . convert_to_tensor ( [
183+ ]
184+ targets = [
82185 [10 , 0 , 0 ],
83186 [20 , 30 , 40 ],
84187 [50 , 60 , 0 ],
85- ], dtype = tf .int32 )
86- max_length = tf .convert_to_tensor (5 , dtype = tf .int32 )
87- (packed , segmentation , position ) = pack_sequences_ops .pack_sequences_k (
88- [inputs , targets ], [max_length , max_length ])
89- (inputs_packed , targets_packed ) = packed
90- (inputs_segmentation , targets_segmentation ) = segmentation
91- (inputs_position , targets_position ) = position
92- self .assertAllEqual (
93- inputs_packed , [
94- [1 , 2 , 3 , 4 , 5 ],
95- [6 , 0 , 0 , 0 , 0 ],
96- ])
97- self .assertAllEqual (
98- inputs_segmentation , [
99- [1 , 1 , 1 , 2 , 2 ],
100- [1 , 0 , 0 , 0 , 0 ],
101- ])
102- self .assertAllEqual (
103- inputs_position , [
104- [0 , 1 , 2 , 0 , 1 ],
105- [0 , 0 , 0 , 0 , 0 ],
106- ])
107- self .assertAllEqual (
108- targets_packed , [
109- [10 , 20 , 30 , 40 , 0 ],
110- [50 , 60 , 0 , 0 , 0 ],
111- ])
112- self .assertAllEqual (
113- targets_segmentation , [
114- [1 , 2 , 2 , 2 , 0 ],
115- [1 , 1 , 0 , 0 , 0 ],
116- ])
117- self .assertAllEqual (
118- targets_position , [
119- [0 , 0 , 1 , 2 , 0 ],
120- [0 , 1 , 0 , 0 , 0 ],
121- ])
188+ ]
189+ max_length = 6
190+ (inputs_packed , inputs_segmentation , inputs_position , targets_packed ,
191+ targets_segmentation , targets_position ) = (
192+ pack_fn (inputs , targets , max_length , max_length ))
193+ self .assertAllEqual (inputs_packed , [
194+ [1 , 2 , 3 , 4 , 5 , 6 ],
195+ ])
196+ self .assertAllEqual (inputs_segmentation , [
197+ [1 , 1 , 1 , 2 , 2 , 3 ],
198+ ])
199+ self .assertAllEqual (inputs_position , [
200+ [0 , 1 , 2 , 0 , 1 , 0 ],
201+ ])
202+ self .assertAllEqual (targets_packed , [
203+ [10 , 20 , 30 , 40 , 50 , 60 ],
204+ ])
205+ self .assertAllEqual (targets_segmentation , [
206+ [1 , 2 , 2 , 2 , 3 , 3 ],
207+ ])
208+ self .assertAllEqual (targets_position , [
209+ [0 , 0 , 1 , 2 , 0 , 1 ],
210+ ])
211+
212+ def do_test_pack_sequences_length7 (self , pack_fn ):
213+ inputs = [
214+ [1 , 2 , 3 ],
215+ [4 , 5 , 0 ],
216+ [6 , 0 , 0 ],
217+ ]
218+ targets = [
219+ [10 , 0 , 0 ],
220+ [20 , 30 , 40 ],
221+ [50 , 60 , 0 ],
222+ ]
223+ max_length = 7
224+ (inputs_packed , inputs_segmentation , inputs_position , targets_packed ,
225+ targets_segmentation , targets_position ) = (
226+ pack_fn (inputs , targets , max_length , max_length ))
227+ self .assertAllEqual (inputs_packed , [
228+ [1 , 2 , 3 , 4 , 5 , 6 , 0 ],
229+ ])
230+ self .assertAllEqual (inputs_segmentation , [
231+ [1 , 1 , 1 , 2 , 2 , 3 , 0 ],
232+ ])
233+ self .assertAllEqual (inputs_position , [
234+ [0 , 1 , 2 , 0 , 1 , 0 , 0 ],
235+ ])
236+ self .assertAllEqual (targets_packed , [
237+ [10 , 20 , 30 , 40 , 50 , 60 , 0 ],
238+ ])
239+ self .assertAllEqual (targets_segmentation , [
240+ [1 , 2 , 2 , 2 , 3 , 3 , 0 ],
241+ ])
242+ self .assertAllEqual (targets_position , [
243+ [0 , 0 , 1 , 2 , 0 , 1 , 0 ],
244+ ])
245+
246+ def do_test_pack_sequences_length_different_lengths (self , pack_fn ):
247+ inputs = [
248+ [1 , 2 , 3 ],
249+ [4 , 5 , 0 ],
250+ [6 , 0 , 0 ],
251+ ]
252+ targets = [
253+ [10 , 0 , 0 ],
254+ [20 , 30 , 40 ],
255+ [50 , 60 , 0 ],
256+ ]
257+ input_max_length = 3
258+ target_max_length = 4
259+ (inputs_packed , inputs_segmentation , inputs_position , targets_packed ,
260+ targets_segmentation , targets_position ) = (
261+ pack_fn (inputs , targets , input_max_length , target_max_length ))
262+ self .assertAllEqual (inputs_packed , [
263+ [1 , 2 , 3 ],
264+ [4 , 5 , 0 ],
265+ [6 , 0 , 0 ],
266+ ])
267+ self .assertAllEqual (inputs_segmentation , [
268+ [1 , 1 , 1 ],
269+ [1 , 1 , 0 ],
270+ [1 , 0 , 0 ],
271+ ])
272+ self .assertAllEqual (inputs_position , [
273+ [0 , 1 , 2 ],
274+ [0 , 1 , 0 ],
275+ [0 , 0 , 0 ],
276+ ])
277+ self .assertAllEqual (targets_packed , [
278+ [10 , 0 , 0 , 0 ],
279+ [20 , 30 , 40 , 0 ],
280+ [50 , 60 , 0 , 0 ],
281+ ])
282+ self .assertAllEqual (targets_segmentation , [
283+ [1 , 0 , 0 , 0 ],
284+ [1 , 1 , 1 , 0 ],
285+ [1 , 1 , 0 , 0 ],
286+ ])
287+ self .assertAllEqual (targets_position , [
288+ [0 , 0 , 0 , 0 ],
289+ [0 , 1 , 2 , 0 ],
290+ [0 , 1 , 0 , 0 ],
291+ ])
292+
293+ def test_pack_sequences2 (self ):
294+ self .do_test_pack_sequences_length3 (pack_sequences_ops .pack_sequences2 )
295+ self .do_test_pack_sequences_length4 (pack_sequences_ops .pack_sequences2 )
296+ self .do_test_pack_sequences_length5 (pack_sequences_ops .pack_sequences2 )
297+ self .do_test_pack_sequences_length6 (pack_sequences_ops .pack_sequences2 )
298+ self .do_test_pack_sequences_length7 (pack_sequences_ops .pack_sequences2 )
299+ self .do_test_pack_sequences_length_different_lengths (
300+ pack_sequences_ops .pack_sequences2 )
301+
302+ def test_pack_sequences_k (self ):
303+ self .do_test_pack_sequences_length3 (_pack_sequences_k )
304+ self .do_test_pack_sequences_length4 (_pack_sequences_k )
305+ self .do_test_pack_sequences_length5 (_pack_sequences_k )
306+ self .do_test_pack_sequences_length6 (_pack_sequences_k )
307+ self .do_test_pack_sequences_length7 (_pack_sequences_k )
308+ self .do_test_pack_sequences_length_different_lengths (_pack_sequences_k )
309+
310+ def test_random_inputs (self ):
311+ for _ in range (10 ):
312+ batch_size = np .random .randint (900 , 1100 , size = [])
313+ input_seqlen = np .random .randint (1 , 10 , size = [])
314+ target_seqlen = np .random .randint (1 , 10 , size = [])
315+ inputs_list = []
316+ targets_list = []
317+ for _ in range (batch_size ):
318+ input_num_pads = np .random .randint (0 , input_seqlen , size = [])
319+ input_pads = np .full ([input_num_pads ], 0 , dtype = np .int32 )
320+ inputs = np .random .randint (1 , 10 , size = [input_seqlen - input_num_pads ])
321+ inputs = np .concatenate ([inputs , input_pads ], axis = 0 )
322+
323+ target_num_pads = np .random .randint (0 , target_seqlen , size = [])
324+ target_pads = np .full ([target_num_pads ], 0 , dtype = np .int32 )
325+ targets = np .random .randint (
326+ 1 , 10 , size = [target_seqlen - target_num_pads ])
327+ targets = np .concatenate ([targets , target_pads ], axis = 0 )
328+
329+ inputs_list .append (inputs )
330+ targets_list .append (targets )
331+ input_maxlen = np .random .randint (input_seqlen , input_seqlen + 10 , size = [])
332+ target_maxlen = np .random .randint (
333+ target_seqlen , target_seqlen + 10 , size = [])
334+ (inputs_packed2 , inputs_segmentation2 , inputs_positions2 , targets_packed2 ,
335+ targets_segmentation2 , targets_positions2 ) = (
336+ pack_sequences_ops .pack_sequences2 (inputs_list , targets_list ,
337+ input_maxlen , target_maxlen ))
338+ (inputs_packed_k , inputs_segmentation_k , inputs_positions_k ,
339+ targets_packed_k , targets_segmentation_k , targets_positions_k ) = (
340+ _pack_sequences_k (inputs_list , targets_list , input_maxlen ,
341+ target_maxlen ))
342+
343+ self .assertAllEqual (inputs_packed2 , inputs_packed_k )
344+ self .assertAllEqual (inputs_segmentation2 , inputs_segmentation_k )
345+ self .assertAllEqual (inputs_positions2 , inputs_positions_k )
346+ self .assertAllEqual (targets_packed2 , targets_packed_k )
347+ self .assertAllEqual (targets_segmentation2 , targets_segmentation_k )
348+ self .assertAllEqual (targets_positions2 , targets_positions_k )
122349
123350 def test_pack_sequences_k_multi_input (self ):
124351 input_tokens = tf .convert_to_tensor ([
0 commit comments