Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 874389b

Browse files
authored
Add enwik8 with different lengths + binary read (#1895)
1 parent ae042f6 commit 874389b

File tree

1 file changed

+74
-1
lines changed

1 file changed

+74
-1
lines changed

tensor2tensor/data_generators/enwik8.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def _maybe_download_corpus(tmp_dir):
5555
class Enwik8L65k(text_problems.Text2SelfProblem):
5656
"""Enwiki8, with examples up to 65,536 characters long."""
5757

58+
READ_MODE = "r"
5859
DUPE_FACTOR = 4
5960

6061
@property
@@ -92,7 +93,7 @@ def sequence_length(self):
9293

9394
def generate_samples(self, data_dir, tmp_dir, dataset_split):
9495
filepath = _maybe_download_corpus(tmp_dir)
95-
with tf.io.gfile.GFile(filepath) as f:
96+
with tf.io.gfile.GFile(filepath, mode=self.READ_MODE) as f:
9697
data = f.read()
9798

9899
tf.logging.info("Length of enwik8 = %d", len(data))
@@ -126,3 +127,75 @@ def generate_encoded_samples(self, data_dir, tmp_dir, dataset_split):
126127
for sample in generator:
127128
sample["targets"] = vocab.encode(sample["targets"])
128129
yield sample
130+
131+
132+
@registry.register_problem
133+
class Enwik8L2k(Enwik8L65k):
134+
"""Enwiki8, with examples up to 2048 characters long. Reads the input
135+
byte-wise and chunks it into fragments of maximum length of 2048. Does not
136+
shift byte indices (we do not assume cls or pad are used),
137+
unlike the base class!"""
138+
139+
READ_MODE = "rb"
140+
141+
@property
142+
def sequence_length(self):
143+
"""Length of each example (number of characters)."""
144+
return 2048
145+
146+
def generate_encoded_samples(self, data_dir, tmp_dir, dataset_split):
147+
return self.generate_samples(data_dir, tmp_dir, dataset_split)
148+
149+
150+
@registry.register_problem
151+
class Enwik8L32k(Enwik8L2k):
152+
153+
@property
154+
def sequence_length(self):
155+
"""Length of each example (in tokens)."""
156+
return 32768
157+
158+
159+
@registry.register_problem
160+
class Enwik8L16k(Enwik8L2k):
161+
162+
@property
163+
def sequence_length(self):
164+
"""Length of each example (in tokens)."""
165+
return 16384
166+
167+
168+
@registry.register_problem
169+
class Enwik8L8k(Enwik8L2k):
170+
171+
@property
172+
def sequence_length(self):
173+
"""Length of each example (in tokens)."""
174+
return 8192
175+
176+
177+
@registry.register_problem
178+
class Enwik8L4k(Enwik8L2k):
179+
180+
@property
181+
def sequence_length(self):
182+
"""Length of each example (in tokens)."""
183+
return 4096
184+
185+
186+
@registry.register_problem
187+
class Enwik8L1k(Enwik8L2k):
188+
189+
@property
190+
def sequence_length(self):
191+
"""Length of each example (in tokens)."""
192+
return 1024
193+
194+
195+
@registry.register_problem
196+
class Enwik8L512(Enwik8L2k):
197+
198+
@property
199+
def sequence_length(self):
200+
"""Length of each example (in tokens)."""
201+
return 512

0 commit comments

Comments
 (0)