Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 79b920f

Browse files
0x0539copybara-github
authored andcommitted
Parity tests for pack_sequences_2 and pack_sequences_k.
PiperOrigin-RevId: 404790223
1 parent c22a226 commit 79b920f

File tree

1 file changed

+273
-46
lines changed

1 file changed

+273
-46
lines changed

tensor2tensor/data_generators/ops/pack_sequences_ops_test.py

Lines changed: 273 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,117 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121

22+
import numpy as np
2223
from tensor2tensor.data_generators.ops import pack_sequences_ops
2324
import tensorflow.compat.v1 as tf
2425

2526

27+
def _pack_sequences_k(inputs, targets, input_max_length, target_max_length):
28+
"""Wrapper for pack_sequences_k with same interface as pack_sequences_2."""
29+
inputs = tf.convert_to_tensor(inputs, tf.int32)
30+
targets = tf.convert_to_tensor(targets, tf.int32)
31+
input_max_length = tf.convert_to_tensor(input_max_length, dtype=tf.int32)
32+
target_max_length = tf.convert_to_tensor(target_max_length, dtype=tf.int32)
33+
(packed, segmentation, position) = pack_sequences_ops.pack_sequences_k(
34+
[inputs, targets], [input_max_length, target_max_length])
35+
(inputs_packed, targets_packed) = packed
36+
(inputs_segmentation, targets_segmentation) = segmentation
37+
(inputs_position, targets_position) = position
38+
return (inputs_packed, inputs_segmentation, inputs_position, targets_packed,
39+
targets_segmentation, targets_position)
40+
41+
2642
class PackSequencesOpsTest(tf.test.TestCase):
2743

28-
def test_pack_sequences2(self):
44+
def do_test_pack_sequences_length3(self, pack_fn):
45+
inputs = [
46+
[1, 2, 3],
47+
[4, 5, 0],
48+
[6, 0, 0],
49+
]
50+
targets = [
51+
[10, 0, 0],
52+
[20, 30, 40],
53+
[50, 60, 0],
54+
]
55+
inputs_max_length = 3
56+
targets_max_length = 3
57+
(inputs_packed, inputs_segmentation, inputs_position, targets_packed,
58+
targets_segmentation, targets_position) = (
59+
pack_fn(inputs, targets, inputs_max_length, targets_max_length))
60+
self.assertAllEqual(inputs_packed, [
61+
[1, 2, 3],
62+
[4, 5, 0],
63+
[6, 0, 0],
64+
])
65+
self.assertAllEqual(inputs_segmentation, [
66+
[1, 1, 1],
67+
[1, 1, 0],
68+
[1, 0, 0],
69+
])
70+
self.assertAllEqual(inputs_position, [
71+
[0, 1, 2],
72+
[0, 1, 0],
73+
[0, 0, 0],
74+
])
75+
self.assertAllEqual(targets_packed, [
76+
[10, 0, 0],
77+
[20, 30, 40],
78+
[50, 60, 0],
79+
])
80+
self.assertAllEqual(targets_segmentation, [
81+
[1, 0, 0],
82+
[1, 1, 1],
83+
[1, 1, 0],
84+
])
85+
self.assertAllEqual(targets_position, [
86+
[0, 0, 0],
87+
[0, 1, 2],
88+
[0, 1, 0],
89+
])
90+
91+
def do_test_pack_sequences_length4(self, pack_fn):
92+
inputs = [
93+
[1, 2, 3],
94+
[4, 5, 0],
95+
[6, 0, 0],
96+
]
97+
targets = [
98+
[10, 0, 0],
99+
[20, 30, 40],
100+
[50, 60, 0],
101+
]
102+
inputs_max_length = 4
103+
targets_max_length = 4
104+
(inputs_packed, inputs_segmentation, inputs_position, targets_packed,
105+
targets_segmentation, targets_position) = (
106+
pack_fn(inputs, targets, inputs_max_length, targets_max_length))
107+
self.assertAllEqual(inputs_packed, [
108+
[1, 2, 3, 6],
109+
[4, 5, 0, 0],
110+
])
111+
self.assertAllEqual(inputs_segmentation, [
112+
[1, 1, 1, 2],
113+
[1, 1, 0, 0],
114+
])
115+
self.assertAllEqual(inputs_position, [
116+
[0, 1, 2, 0],
117+
[0, 1, 0, 0],
118+
])
119+
self.assertAllEqual(targets_packed, [
120+
[10, 50, 60, 0],
121+
[20, 30, 40, 0],
122+
])
123+
self.assertAllEqual(targets_segmentation, [
124+
[1, 2, 2, 0],
125+
[1, 1, 1, 0],
126+
])
127+
self.assertAllEqual(targets_position, [
128+
[0, 0, 1, 0],
129+
[0, 1, 2, 0],
130+
])
131+
132+
def do_test_pack_sequences_length5(self, pack_fn):
29133
inputs = [
30134
[1, 2, 3],
31135
[4, 5, 0],
@@ -37,10 +141,9 @@ def test_pack_sequences2(self):
37141
[50, 60, 0],
38142
]
39143
max_length = 5
40-
(inputs_packed, inputs_segmentation, inputs_position,
41-
targets_packed, targets_segmentation, targets_position) = (
42-
pack_sequences_ops.pack_sequences2(
43-
inputs, targets, max_length, max_length))
144+
(inputs_packed, inputs_segmentation, inputs_position, targets_packed,
145+
targets_segmentation, targets_position) = (
146+
pack_fn(inputs, targets, max_length, max_length))
44147
self.assertAllEqual(
45148
inputs_packed, [
46149
[1, 2, 3, 4, 5],
@@ -72,53 +175,177 @@ def test_pack_sequences2(self):
72175
[0, 1, 0, 0, 0],
73176
])
74177

75-
def test_pack_sequences_k(self):
76-
inputs = tf.convert_to_tensor([
178+
def do_test_pack_sequences_length6(self, pack_fn):
179+
inputs = [
77180
[1, 2, 3],
78181
[4, 5, 0],
79182
[6, 0, 0],
80-
], dtype=tf.int32)
81-
targets = tf.convert_to_tensor([
183+
]
184+
targets = [
82185
[10, 0, 0],
83186
[20, 30, 40],
84187
[50, 60, 0],
85-
], dtype=tf.int32)
86-
max_length = tf.convert_to_tensor(5, dtype=tf.int32)
87-
(packed, segmentation, position) = pack_sequences_ops.pack_sequences_k(
88-
[inputs, targets], [max_length, max_length])
89-
(inputs_packed, targets_packed) = packed
90-
(inputs_segmentation, targets_segmentation) = segmentation
91-
(inputs_position, targets_position) = position
92-
self.assertAllEqual(
93-
inputs_packed, [
94-
[1, 2, 3, 4, 5],
95-
[6, 0, 0, 0, 0],
96-
])
97-
self.assertAllEqual(
98-
inputs_segmentation, [
99-
[1, 1, 1, 2, 2],
100-
[1, 0, 0, 0, 0],
101-
])
102-
self.assertAllEqual(
103-
inputs_position, [
104-
[0, 1, 2, 0, 1],
105-
[0, 0, 0, 0, 0],
106-
])
107-
self.assertAllEqual(
108-
targets_packed, [
109-
[10, 20, 30, 40, 0],
110-
[50, 60, 0, 0, 0],
111-
])
112-
self.assertAllEqual(
113-
targets_segmentation, [
114-
[1, 2, 2, 2, 0],
115-
[1, 1, 0, 0, 0],
116-
])
117-
self.assertAllEqual(
118-
targets_position, [
119-
[0, 0, 1, 2, 0],
120-
[0, 1, 0, 0, 0],
121-
])
188+
]
189+
max_length = 6
190+
(inputs_packed, inputs_segmentation, inputs_position, targets_packed,
191+
targets_segmentation, targets_position) = (
192+
pack_fn(inputs, targets, max_length, max_length))
193+
self.assertAllEqual(inputs_packed, [
194+
[1, 2, 3, 4, 5, 6],
195+
])
196+
self.assertAllEqual(inputs_segmentation, [
197+
[1, 1, 1, 2, 2, 3],
198+
])
199+
self.assertAllEqual(inputs_position, [
200+
[0, 1, 2, 0, 1, 0],
201+
])
202+
self.assertAllEqual(targets_packed, [
203+
[10, 20, 30, 40, 50, 60],
204+
])
205+
self.assertAllEqual(targets_segmentation, [
206+
[1, 2, 2, 2, 3, 3],
207+
])
208+
self.assertAllEqual(targets_position, [
209+
[0, 0, 1, 2, 0, 1],
210+
])
211+
212+
def do_test_pack_sequences_length7(self, pack_fn):
213+
inputs = [
214+
[1, 2, 3],
215+
[4, 5, 0],
216+
[6, 0, 0],
217+
]
218+
targets = [
219+
[10, 0, 0],
220+
[20, 30, 40],
221+
[50, 60, 0],
222+
]
223+
max_length = 7
224+
(inputs_packed, inputs_segmentation, inputs_position, targets_packed,
225+
targets_segmentation, targets_position) = (
226+
pack_fn(inputs, targets, max_length, max_length))
227+
self.assertAllEqual(inputs_packed, [
228+
[1, 2, 3, 4, 5, 6, 0],
229+
])
230+
self.assertAllEqual(inputs_segmentation, [
231+
[1, 1, 1, 2, 2, 3, 0],
232+
])
233+
self.assertAllEqual(inputs_position, [
234+
[0, 1, 2, 0, 1, 0, 0],
235+
])
236+
self.assertAllEqual(targets_packed, [
237+
[10, 20, 30, 40, 50, 60, 0],
238+
])
239+
self.assertAllEqual(targets_segmentation, [
240+
[1, 2, 2, 2, 3, 3, 0],
241+
])
242+
self.assertAllEqual(targets_position, [
243+
[0, 0, 1, 2, 0, 1, 0],
244+
])
245+
246+
def do_test_pack_sequences_length_different_lengths(self, pack_fn):
247+
inputs = [
248+
[1, 2, 3],
249+
[4, 5, 0],
250+
[6, 0, 0],
251+
]
252+
targets = [
253+
[10, 0, 0],
254+
[20, 30, 40],
255+
[50, 60, 0],
256+
]
257+
input_max_length = 3
258+
target_max_length = 4
259+
(inputs_packed, inputs_segmentation, inputs_position, targets_packed,
260+
targets_segmentation, targets_position) = (
261+
pack_fn(inputs, targets, input_max_length, target_max_length))
262+
self.assertAllEqual(inputs_packed, [
263+
[1, 2, 3],
264+
[4, 5, 0],
265+
[6, 0, 0],
266+
])
267+
self.assertAllEqual(inputs_segmentation, [
268+
[1, 1, 1],
269+
[1, 1, 0],
270+
[1, 0, 0],
271+
])
272+
self.assertAllEqual(inputs_position, [
273+
[0, 1, 2],
274+
[0, 1, 0],
275+
[0, 0, 0],
276+
])
277+
self.assertAllEqual(targets_packed, [
278+
[10, 0, 0, 0],
279+
[20, 30, 40, 0],
280+
[50, 60, 0, 0],
281+
])
282+
self.assertAllEqual(targets_segmentation, [
283+
[1, 0, 0, 0],
284+
[1, 1, 1, 0],
285+
[1, 1, 0, 0],
286+
])
287+
self.assertAllEqual(targets_position, [
288+
[0, 0, 0, 0],
289+
[0, 1, 2, 0],
290+
[0, 1, 0, 0],
291+
])
292+
293+
def test_pack_sequences2(self):
294+
self.do_test_pack_sequences_length3(pack_sequences_ops.pack_sequences2)
295+
self.do_test_pack_sequences_length4(pack_sequences_ops.pack_sequences2)
296+
self.do_test_pack_sequences_length5(pack_sequences_ops.pack_sequences2)
297+
self.do_test_pack_sequences_length6(pack_sequences_ops.pack_sequences2)
298+
self.do_test_pack_sequences_length7(pack_sequences_ops.pack_sequences2)
299+
self.do_test_pack_sequences_length_different_lengths(
300+
pack_sequences_ops.pack_sequences2)
301+
302+
def test_pack_sequences_k(self):
303+
self.do_test_pack_sequences_length3(_pack_sequences_k)
304+
self.do_test_pack_sequences_length4(_pack_sequences_k)
305+
self.do_test_pack_sequences_length5(_pack_sequences_k)
306+
self.do_test_pack_sequences_length6(_pack_sequences_k)
307+
self.do_test_pack_sequences_length7(_pack_sequences_k)
308+
self.do_test_pack_sequences_length_different_lengths(_pack_sequences_k)
309+
310+
def test_random_inputs(self):
311+
for _ in range(10):
312+
batch_size = np.random.randint(900, 1100, size=[])
313+
input_seqlen = np.random.randint(1, 10, size=[])
314+
target_seqlen = np.random.randint(1, 10, size=[])
315+
inputs_list = []
316+
targets_list = []
317+
for _ in range(batch_size):
318+
input_num_pads = np.random.randint(0, input_seqlen, size=[])
319+
input_pads = np.full([input_num_pads], 0, dtype=np.int32)
320+
inputs = np.random.randint(1, 10, size=[input_seqlen - input_num_pads])
321+
inputs = np.concatenate([inputs, input_pads], axis=0)
322+
323+
target_num_pads = np.random.randint(0, target_seqlen, size=[])
324+
target_pads = np.full([target_num_pads], 0, dtype=np.int32)
325+
targets = np.random.randint(
326+
1, 10, size=[target_seqlen - target_num_pads])
327+
targets = np.concatenate([targets, target_pads], axis=0)
328+
329+
inputs_list.append(inputs)
330+
targets_list.append(targets)
331+
input_maxlen = np.random.randint(input_seqlen, input_seqlen + 10, size=[])
332+
target_maxlen = np.random.randint(
333+
target_seqlen, target_seqlen + 10, size=[])
334+
(inputs_packed2, inputs_segmentation2, inputs_positions2, targets_packed2,
335+
targets_segmentation2, targets_positions2) = (
336+
pack_sequences_ops.pack_sequences2(inputs_list, targets_list,
337+
input_maxlen, target_maxlen))
338+
(inputs_packed_k, inputs_segmentation_k, inputs_positions_k,
339+
targets_packed_k, targets_segmentation_k, targets_positions_k) = (
340+
_pack_sequences_k(inputs_list, targets_list, input_maxlen,
341+
target_maxlen))
342+
343+
self.assertAllEqual(inputs_packed2, inputs_packed_k)
344+
self.assertAllEqual(inputs_segmentation2, inputs_segmentation_k)
345+
self.assertAllEqual(inputs_positions2, inputs_positions_k)
346+
self.assertAllEqual(targets_packed2, targets_packed_k)
347+
self.assertAllEqual(targets_segmentation2, targets_segmentation_k)
348+
self.assertAllEqual(targets_positions2, targets_positions_k)
122349

123350
def test_pack_sequences_k_multi_input(self):
124351
input_tokens = tf.convert_to_tensor([

0 commit comments

Comments
 (0)