Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 21dba2c

Browse files
T2T Teamcopybara-github
authored andcommitted
Add Seq2Edits (go/seq2edits-paper) to T2T.
PiperOrigin-RevId: 342622759
1 parent 5f9dd2d commit 21dba2c

File tree

4 files changed

+811
-0
lines changed

4 files changed

+811
-0
lines changed

tensor2tensor/data_generators/all_problems.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
"tensor2tensor.data_generators.quora_qpairs",
7474
"tensor2tensor.data_generators.rte",
7575
"tensor2tensor.data_generators.scitail",
76+
"tensor2tensor.data_generators.seq2edits",
7677
"tensor2tensor.data_generators.snli",
7778
"tensor2tensor.data_generators.stanford_nli",
7879
"tensor2tensor.data_generators.style_transfer",
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
# coding=utf-8
2+
# Copyright 2020 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Problems for Seq2Edits (see models/research/transformer_seq2edits.py)."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
import os
23+
24+
from tensor2tensor.data_generators import text_encoder
25+
from tensor2tensor.data_generators import text_problems
26+
from tensor2tensor.layers import modalities
27+
from tensor2tensor.utils import registry
28+
29+
import tensorflow.compat.v1 as tf
30+
31+
32+
@modalities.is_pointwise
33+
def pointer_top(body_output, targets, model_hparams, vocab_size):
34+
"""Like identity_top() with is_pointwise annotation."""
35+
del targets, model_hparams, vocab_size # unused arg
36+
return body_output
37+
38+
39+
def pointer_bottom(x, model_hparams, vocab_size):
40+
"""Like identity_bottom() without converting to float."""
41+
del model_hparams, vocab_size # unused arg
42+
return x
43+
44+
45+
@registry.register_problem
46+
class Seq2editsGec(text_problems.Text2TextProblem):
47+
"""Seq2Edits for grammatical error correction."""
48+
49+
def dataset_filename(self):
50+
return "edit_ops_gec"
51+
52+
@property
53+
def vocab_file(self):
54+
return "vocab.subwords"
55+
56+
@property
57+
def vocab_filename(self):
58+
return "vocab.subwords"
59+
60+
@property
61+
def error_tag_vocab_file(self):
62+
return "vocab.error_tags"
63+
64+
def feature_encoders(self, data_dir):
65+
subword_encoder = text_encoder.SubwordTextEncoder(
66+
os.path.join(data_dir, self.vocab_file))
67+
error_tag_encoder = text_encoder.TokenTextEncoder(
68+
os.path.join(data_dir, self.error_tag_vocab_file))
69+
return {
70+
"inputs": subword_encoder,
71+
"targets": subword_encoder,
72+
"targets_error_tag": error_tag_encoder
73+
}
74+
75+
def hparams(self, defaults, model_hparams):
76+
super(Seq2editsGec, self).hparams(defaults, model_hparams)
77+
78+
for pointer_feat in ["targets_start_token", "targets_end_token"]:
79+
defaults.modality[pointer_feat] = modalities.ModalityType.IDENTITY
80+
defaults.vocab_size[pointer_feat] = None
81+
model_hparams.bottom[pointer_feat] = pointer_bottom
82+
model_hparams.top[pointer_feat] = pointer_top
83+
# Whether to use tags.
84+
if "use_error_tags" not in model_hparams:
85+
model_hparams.add_hparam("use_error_tags", True)
86+
# If true, span and tag prediction is in the middle of the decoder layer
87+
# stack. Otherwise, they are at the end of the decoder layer stack.
88+
if "middle_prediction" not in model_hparams:
89+
model_hparams.add_hparam("middle_prediction", True)
90+
# If middle_prediction=True, divide num_decoder_layers by this to get the
91+
# number of layers before and after the middle prediction.
92+
if "middle_prediction_layer_factor" not in model_hparams:
93+
model_hparams.add_hparam("middle_prediction_layer_factor", 2)
94+
# Whether to predict the targets_start_token feature. If this is false, use
95+
# the previous end token as implicit start token.
96+
if "use_start_token" not in model_hparams:
97+
model_hparams.add_hparam("use_start_token", False)
98+
# Whether to feed back targets_end_token to the next time step. If false,
99+
# only feed back targets_start_token.
100+
if "feedback_end_token" not in model_hparams:
101+
model_hparams.add_hparam("feedback_end_token", False)
102+
# Number of feedforward layers between prediction layers in the cascade.
103+
if "ffn_in_prediction_cascade" not in model_hparams:
104+
model_hparams.add_hparam("ffn_in_prediction_cascade", 1)
105+
# Embedding size for error tags.
106+
if "error_tag_embed_size" not in model_hparams:
107+
model_hparams.add_hparam("error_tag_embed_size", 6)
108+
if model_hparams.use_error_tags:
109+
defaults.modality["targets_error_tag"] = modalities.ModalityType.SYMBOL
110+
error_tag_vocab_size = self._encoders["targets_error_tag"].vocab_size
111+
defaults.vocab_size["targets_error_tag"] = error_tag_vocab_size
112+
model_hparams.top["targets_error_tag"] = pointer_top
113+
114+
def example_reading_spec(self):
115+
data_fields, _ = super(Seq2editsGec, self).example_reading_spec()
116+
data_fields["targets_start_token"] = tf.VarLenFeature(tf.int64)
117+
data_fields["targets_end_token"] = tf.VarLenFeature(tf.int64)
118+
data_fields["targets_error_tag"] = tf.VarLenFeature(tf.int64)
119+
return data_fields, None
120+
121+
122+
@registry.register_problem
123+
class Seq2editsGecPacked256(Seq2editsGec):
124+
"""Packed version for TPU."""
125+
126+
def dataset_filename(self):
127+
return "edit_ops_gec_packed256"
128+
129+
@property
130+
def packed_length(self):
131+
return 256
132+
133+
@property
134+
def max_segment_length(self):
135+
return 256
136+
137+
138+
@registry.register_problem
139+
class Seq2editsGecNoTags(Seq2editsGec):
140+
"""Seq2Edits for grammatical error correction without tags."""
141+
142+
def dataset_filename(self):
143+
return "edit_ops_gec"
144+
145+
def hparams(self, defaults, model_hparams):
146+
super(Seq2editsGecNoTags, self).hparams(defaults, model_hparams)
147+
model_hparams.use_error_tags = False
148+
149+
150+
@registry.register_problem
151+
class Seq2editsGecNoTagsPacked256(Seq2editsGecPacked256):
152+
"""Packed version for TPU."""
153+
154+
def dataset_filename(self):
155+
return "edit_ops_gec_packed256"
156+
157+
def hparams(self, defaults, model_hparams):
158+
super(Seq2editsGecNoTagsPacked256, self).hparams(defaults, model_hparams)
159+
model_hparams.use_error_tags = False
160+
161+
162+
@registry.register_problem
163+
class Seq2editsGecDeep(Seq2editsGec):
164+
"""Seq2Edits for grammatical error correction with deeper decoder."""
165+
166+
def hparams(self, defaults, model_hparams):
167+
super(Seq2editsGecDeep, self).hparams(defaults, model_hparams)
168+
model_hparams.middle_prediction_layer_factor = 1.5
169+
170+
171+
@registry.register_problem
172+
class Seq2editsGecDeepPacked256(Seq2editsGecPacked256):
173+
"""Packed version for TPU."""
174+
175+
def hparams(self, defaults, model_hparams):
176+
super(Seq2editsGecDeepPacked256, self).hparams(defaults, model_hparams)
177+
model_hparams.middle_prediction_layer_factor = 1.5
178+
179+
180+
@registry.register_problem
181+
class Seq2editsGecDeepNoTags(Seq2editsGec):
182+
"""Deep Seq2Edits model for grammatical error correction without tags."""
183+
184+
def hparams(self, defaults, model_hparams):
185+
super(Seq2editsGecDeepNoTags, self).hparams(defaults, model_hparams)
186+
model_hparams.middle_prediction_layer_factor = 1.5
187+
model_hparams.use_error_tags = False
188+
189+
190+
@registry.register_problem
191+
class Seq2editsGecDeepNoTagsPacked256(Seq2editsGecPacked256):
192+
"""Packed version for TPU."""
193+
194+
def hparams(self, defaults, model_hparams):
195+
super(Seq2editsGecDeepNoTagsPacked256, self).hparams(
196+
defaults, model_hparams)
197+
model_hparams.middle_prediction_layer_factor = 1.5
198+
model_hparams.use_error_tags = False
199+
200+
201+
@registry.register_problem
202+
class Seq2editsTextnorm(Seq2editsGec):
203+
"""Seq2Edits for text normalization."""
204+
205+
def dataset_filename(self):
206+
return "edit_ops_textnorm"
207+
208+
@property
209+
def source_vocab_file(self):
210+
return "vocab.source"
211+
212+
@property
213+
def target_vocab_file(self):
214+
return "vocab.target"
215+
216+
@property
217+
def error_tag_vocab_file(self):
218+
return "vocab.error_tags"
219+
220+
def feature_encoders(self, data_dir):
221+
source_encoder = text_encoder.TokenTextEncoder(
222+
os.path.join(data_dir, self.source_vocab_file))
223+
target_encoder = text_encoder.TokenTextEncoder(
224+
os.path.join(data_dir, self.target_vocab_file))
225+
error_tag_encoder = text_encoder.TokenTextEncoder(
226+
os.path.join(data_dir, self.error_tag_vocab_file))
227+
return {
228+
"inputs": source_encoder,
229+
"targets": target_encoder,
230+
"targets_error_tag": error_tag_encoder
231+
}
232+
233+
234+
@registry.register_problem
235+
class Seq2editsTextnormPacked256(Seq2editsTextnorm):
236+
"""Packed version for TPU."""
237+
238+
def dataset_filename(self):
239+
return "edit_ops_textnorm_packed256"
240+
241+
@property
242+
def packed_length(self):
243+
return 256
244+
245+
@property
246+
def max_segment_length(self):
247+
return 256
248+
249+
250+
@registry.register_problem
251+
class Seq2editsTextnormNoTags(Seq2editsTextnorm):
252+
"""Seq2Edits for text normalization without tags."""
253+
254+
def hparams(self, defaults, model_hparams):
255+
super(Seq2editsTextnormNoTags, self).hparams(defaults, model_hparams)
256+
model_hparams.use_error_tags = False
257+
258+
259+
@registry.register_problem
260+
class Seq2editsTextnormNoTagsPacked256(Seq2editsTextnormPacked256):
261+
"""Packed version for TPU."""
262+
263+
def hparams(self, defaults, model_hparams):
264+
super(Seq2editsTextnormNoTagsPacked256, self).hparams(
265+
defaults, model_hparams)
266+
model_hparams.use_error_tags = False

tensor2tensor/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
from tensor2tensor.models.research import transformer_nat
5757
from tensor2tensor.models.research import transformer_parallel
5858
from tensor2tensor.models.research import transformer_revnet
59+
from tensor2tensor.models.research import transformer_seq2edits
5960
from tensor2tensor.models.research import transformer_sketch
6061
from tensor2tensor.models.research import transformer_symshard
6162
from tensor2tensor.models.research import transformer_vae

0 commit comments

Comments
 (0)