Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit ae042f6

Browse files
T2T Teamcopybara-github
authored andcommitted
Internal
PiperOrigin-RevId: 380716688
1 parent 3f12173 commit ae042f6

File tree

2 files changed

+62
-0
lines changed

2 files changed

+62
-0
lines changed

tensor2tensor/data_generators/text_problems.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,32 @@ def text2real_txt_iterator(source_txt_path, target_txt_path):
799799
yield {"inputs": inputs, "targets": targets}
800800

801801

802+
def txt_line_sharded_iterator(txt_pattern):
803+
"""Iterate through lines of sharded file."""
804+
all_files = tf.gfile.Glob(txt_pattern)
805+
for txt_path in all_files:
806+
with tf.gfile.Open(txt_path) as f:
807+
for line in f:
808+
yield line.strip()
809+
810+
811+
def text2text_txt_sharded_iterator(source_txt_pattern, target_txt_pattern):
812+
"""Yield dicts for Text2TextProblem.generate_samples from lines of files.
813+
814+
Args:
815+
source_txt_pattern: path to the sharded source file
816+
target_txt_pattern: path to the sharded target file
817+
818+
Yields:
819+
{"inputs": inputs, "targets": targets}
820+
821+
"""
822+
for inputs, targets in zip(
823+
txt_line_sharded_iterator(source_txt_pattern),
824+
txt_line_sharded_iterator(target_txt_pattern)):
825+
yield {"inputs": inputs, "targets": targets}
826+
827+
802828
def text2text_txt_tab_iterator(txt_path):
803829
"""Yield dicts for Text2TextProblem.generate_samples from lines of txt_path.
804830
@@ -848,6 +874,7 @@ class Text2textTmpdir(Text2TextProblem):
848874
TRAIN_FILES = ("inputs.train.txt", "targets.train.txt")
849875
EVAL_FILES = ("inputs.eval.txt", "targets.eval.txt")
850876

877+
@property
851878
def is_generate_per_split(self):
852879
return True
853880

tensor2tensor/data_generators/translate.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ def compile_data(tmp_dir, datasets, filename, datatypes_to_clean=None):
266266
class TranslateDistillProblem(TranslateProblem):
267267
"""Base class for translation problems."""
268268

269+
@property
269270
def is_generate_per_split(self):
270271
return True
271272

@@ -311,3 +312,37 @@ def generate_samples(self, data_dir, tmp_dir, dataset_split):
311312
return text_problems.text2text_distill_iterator(data_path + "inputs",
312313
data_path + "gold",
313314
data_path + "prediction")
315+
316+
317+
class TranslateWmt20Problem(TranslateProblem):
318+
"""Base class for WMT20 Datasets."""
319+
320+
@property
321+
def is_generate_per_split(self):
322+
return True
323+
324+
def generate_encoded_samples(self, data_dir, tmp_dir, dataset_split):
325+
generator = self.generate_samples(data_dir, tmp_dir, dataset_split)
326+
vocab = self.get_or_create_vocab(data_dir, tmp_dir)
327+
# For each example, encode the text and append EOS ID.
328+
for sample in generator:
329+
if self.has_inputs:
330+
sample["inputs"] = vocab.encode(sample["inputs"])
331+
sample["inputs"].append(text_encoder.EOS_ID)
332+
sample["targets"] = vocab.encode(sample["targets"])
333+
sample["targets"].append(text_encoder.EOS_ID)
334+
yield sample
335+
336+
def generate_text_for_vocab(self, data_dir, tmp_dir):
337+
for i, sample in enumerate(
338+
self.generate_samples(data_dir, tmp_dir, problem.DatasetSplit.TRAIN)):
339+
if self.has_inputs:
340+
yield sample["inputs"]
341+
yield sample["targets"]
342+
if self.max_samples_for_vocab and (i + 1) >= self.max_samples_for_vocab:
343+
break
344+
345+
def generate_samples(self, data_dir, tmp_dir, dataset_split):
346+
data_path = self.source_data_files(dataset_split)[0]
347+
assert tf.gfile.Exists(data_path)
348+
return text_problems.text2text_txt_tab_iterator(data_path)

0 commit comments

Comments
 (0)