Skip to content

Commit f93af82

Browse files
authored
Merge pull request #730 from luotao1/demo
fix protobuf size limit of seq2seq demo
2 parents 44d4be6 + 11b7625 commit f93af82

File tree

3 files changed

+74
-19
lines changed

3 files changed

+74
-19
lines changed

demo/seqToseq/dataprovider.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,44 @@
1919
END = "<e>"
2020

2121

22-
def hook(settings, src_dict, trg_dict, file_list, **kwargs):
22+
def hook(settings, src_dict_path, trg_dict_path, is_generating, file_list,
23+
**kwargs):
2324
# job_mode = 1: training mode
2425
# job_mode = 0: generating mode
25-
settings.job_mode = trg_dict is not None
26-
settings.src_dict = src_dict
26+
settings.job_mode = not is_generating
27+
settings.src_dict = dict()
28+
with open(src_dict_path, "r") as fin:
29+
settings.src_dict = {
30+
line.strip(): line_count
31+
for line_count, line in enumerate(fin)
32+
}
33+
settings.trg_dict = dict()
34+
with open(trg_dict_path, "r") as fin:
35+
settings.trg_dict = {
36+
line.strip(): line_count
37+
for line_count, line in enumerate(fin)
38+
}
39+
2740
settings.logger.info("src dict len : %d" % (len(settings.src_dict)))
2841
settings.sample_count = 0
2942

3043
if settings.job_mode:
31-
settings.trg_dict = trg_dict
32-
settings.slots = [
44+
settings.slots = {
45+
'source_language_word':
3346
integer_value_sequence(len(settings.src_dict)),
47+
'target_language_word':
3448
integer_value_sequence(len(settings.trg_dict)),
49+
'target_language_next_word':
3550
integer_value_sequence(len(settings.trg_dict))
36-
]
51+
}
3752
settings.logger.info("trg dict len : %d" % (len(settings.trg_dict)))
3853
else:
39-
settings.slots = [
54+
settings.slots = {
55+
'source_language_word':
4056
integer_value_sequence(len(settings.src_dict)),
57+
'sent_id':
4158
integer_value_sequence(len(open(file_list[0], "r").readlines()))
42-
]
59+
}
4360

4461

4562
def _get_ids(s, dictionary):
@@ -69,6 +86,10 @@ def process(settings, file_name):
6986
continue
7087
trg_ids_next = trg_ids + [settings.trg_dict[END]]
7188
trg_ids = [settings.trg_dict[START]] + trg_ids
72-
yield src_ids, trg_ids, trg_ids_next
89+
yield {
90+
'source_language_word': src_ids,
91+
'target_language_word': trg_ids,
92+
'target_language_next_word': trg_ids_next
93+
}
7394
else:
74-
yield src_ids, [line_count]
95+
yield {'source_language_word': src_ids, 'sent_id': [line_count]}

demo/seqToseq/seqToseq_net.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,10 @@ def seq_to_seq_data(data_dir,
3737
"""
3838
src_lang_dict = os.path.join(data_dir, 'src.dict')
3939
trg_lang_dict = os.path.join(data_dir, 'trg.dict')
40-
src_dict = dict()
41-
for line_count, line in enumerate(open(src_lang_dict, "r")):
42-
src_dict[line.strip()] = line_count
43-
trg_dict = dict()
44-
for line_count, line in enumerate(open(trg_lang_dict, "r")):
45-
trg_dict[line.strip()] = line_count
4640

4741
if is_generating:
4842
train_list = None
4943
test_list = os.path.join(data_dir, gen_list)
50-
trg_dict = None
5144
else:
5245
train_list = os.path.join(data_dir, train_list)
5346
test_list = os.path.join(data_dir, test_list)
@@ -57,8 +50,11 @@ def seq_to_seq_data(data_dir,
5750
test_list,
5851
module="dataprovider",
5952
obj="process",
60-
args={"src_dict": src_dict,
61-
"trg_dict": trg_dict})
53+
args={
54+
"src_dict_path": src_lang_dict,
55+
"trg_dict_path": trg_lang_dict,
56+
"is_generating": is_generating
57+
})
6258

6359
return {
6460
"src_dict_path": src_lang_dict,

doc_cn/faq/index.rst

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,3 +214,41 @@ PaddlePaddle的参数使用名字 :code:`name` 作为参数的ID,相同名字
214214
cmake .. -DPYTHON_EXECUTABLE=<exc_path> -DPYTHON_LIBRARY=<lib_path> -DPYTHON_INCLUDE_DIR=<inc_path>
215215
216216
用户需要指定本机上Python的路径:``<exc_path>``, ``<lib_path>``, ``<inc_path>``
217+
218+
10. A protocol message was rejected because it was too big
219+
----------------------------------------------------------
220+
221+
如果在训练NLP相关模型时,出现以下错误:
222+
223+
.. code-block:: bash
224+
225+
[libprotobuf ERROR google/protobuf/io/coded_stream.cc:171] A protocol message was rejected because it was too big (more than 67108864 bytes). To increase the limit (or to disable these warnings), see CodedInputStream::SetTotalBytesLimit() in google/protobuf/io/coded_stream.h.
226+
F1205 14:59:50.295174 14703 TrainerConfigHelper.cpp:59] Check failed: m->conf.ParseFromString(configProtoStr)
227+
228+
可能的原因是:传给dataprovider的某一个args过大,一般是由于直接传递大字典导致的。错误的define_py_data_sources2类似:
229+
230+
.. code-block:: python
231+
232+
src_dict = dict()
233+
for line_count, line in enumerate(open(src_dict_path, "r")):
234+
src_dict[line.strip()] = line_count
235+
236+
define_py_data_sources2(
237+
train_list,
238+
test_list,
239+
module="dataprovider",
240+
obj="process",
241+
args={"src_dict": src_dict})
242+
243+
解决方案是:将字典的地址作为args传给dataprovider,然后在dataprovider里面根据该地址加载字典。即define_py_data_sources2应改为:
244+
245+
.. code-block:: python
246+
247+
define_py_data_sources2(
248+
train_list,
249+
test_list,
250+
module="dataprovider",
251+
obj="process",
252+
args={"src_dict_path": src_dict_path})
253+
254+
完整源码可参考 `seqToseq <https://github.com/PaddlePaddle/Paddle/tree/develop/demo/seqToseq>`_ 示例。

0 commit comments

Comments
 (0)