@@ -266,6 +266,7 @@ def compile_data(tmp_dir, datasets, filename, datatypes_to_clean=None):
266266class TranslateDistillProblem (TranslateProblem ):
267267 """Base class for translation problems."""
268268
269+ @property
269270 def is_generate_per_split (self ):
270271 return True
271272
@@ -311,3 +312,37 @@ def generate_samples(self, data_dir, tmp_dir, dataset_split):
311312 return text_problems .text2text_distill_iterator (data_path + "inputs" ,
312313 data_path + "gold" ,
313314 data_path + "prediction" )
315+
316+
317+ class TranslateWmt20Problem (TranslateProblem ):
318+ """Base class for WMT20 Datasets."""
319+
320+ @property
321+ def is_generate_per_split (self ):
322+ return True
323+
324+ def generate_encoded_samples (self , data_dir , tmp_dir , dataset_split ):
325+ generator = self .generate_samples (data_dir , tmp_dir , dataset_split )
326+ vocab = self .get_or_create_vocab (data_dir , tmp_dir )
327+ # For each example, encode the text and append EOS ID.
328+ for sample in generator :
329+ if self .has_inputs :
330+ sample ["inputs" ] = vocab .encode (sample ["inputs" ])
331+ sample ["inputs" ].append (text_encoder .EOS_ID )
332+ sample ["targets" ] = vocab .encode (sample ["targets" ])
333+ sample ["targets" ].append (text_encoder .EOS_ID )
334+ yield sample
335+
336+ def generate_text_for_vocab (self , data_dir , tmp_dir ):
337+ for i , sample in enumerate (
338+ self .generate_samples (data_dir , tmp_dir , problem .DatasetSplit .TRAIN )):
339+ if self .has_inputs :
340+ yield sample ["inputs" ]
341+ yield sample ["targets" ]
342+ if self .max_samples_for_vocab and (i + 1 ) >= self .max_samples_for_vocab :
343+ break
344+
345+ def generate_samples (self , data_dir , tmp_dir , dataset_split ):
346+ data_path = self .source_data_files (dataset_split )[0 ]
347+ assert tf .gfile .Exists (data_path )
348+ return text_problems .text2text_txt_tab_iterator (data_path )
0 commit comments