Skip to content

Commit 5bd3f5b

Browse files
Fix pred invent vlm double check (#1740)
* fix vlm double checking issue * fix documentation for burger env * woops * try update mypy-extensions * fix mypy
1 parent 026844a commit 5bd3f5b

File tree

4 files changed

+11
-10
lines changed

4 files changed

+11
-10
lines changed

predicators/datasets/generate_atom_trajs_with_vlm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,10 @@ def _label_single_trajectory_with_vlm_atom_values(indexed_traj: Tuple[
183183
"/predicators/datasets/vlm_input_data_prompts/atom_labelling/" + \
184184
"double_check_prompt_prev_labels.txt"
185185
# pylint: enable=line-too-long
186-
double_check_prompt += previous_timestep_check_prompt
186+
with open(previous_timestep_check_prompt, "r",
187+
encoding="utf-8") as f:
188+
previous_timestep_check_prompt_str = f.read()
189+
double_check_prompt += previous_timestep_check_prompt_str
187190
double_check_prompt += "\n\nTruth values of predicates at " + \
188191
"the previous timestep:\n\n"
189192

predicators/envs/burger.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ class BurgerEnv(BaseEnv):
4949
--approach grammar_search_invention --seed 0 --num_train_tasks 10
5050
--option_model_terminate_on_repeat False
5151
--sesame_max_skeletons_optimized 1000 --timeout 80
52-
--sesame_max_samples_per_step 1 --make_demo_videos
52+
--make_demo_videos
53+
--bilevel_plan_without_sim True
5354
--sesame_task_planner fdopt
5455
5556
Note that the default task planner is too slow -- fast downward is required.

predicators/pretrained_model_interface.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def _sample_completions(
276276
temperature=temperature)
277277
response = self._model.generate_content(
278278
[prompt], generation_config=generation_config) # type: ignore
279-
response.resolve()
279+
response.resolve() # type: ignore
280280
return [response.text]
281281

282282
def get_id(self) -> str:
@@ -306,9 +306,9 @@ def _sample_completions(
306306
candidate_count=num_completions,
307307
temperature=temperature)
308308
response = self._model.generate_content(
309-
[prompt] + imgs,
309+
[prompt] + imgs, # type: ignore
310310
generation_config=generation_config) # type: ignore
311-
response.resolve()
311+
response.resolve() # type: ignore
312312
return [response.text]
313313

314314
def get_id(self) -> str:

setup.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,7 @@
4343
include_package_data=True,
4444
extras_require={
4545
"develop": [
46-
"pytest-cov==2.12.1",
47-
"pytest-pylint==0.18.0",
48-
"yapf==0.32.0",
49-
"docformatter==1.4",
50-
"isort==5.10.1",
46+
"pytest-cov==2.12.1", "pytest-pylint==0.18.0", "yapf==0.32.0",
47+
"docformatter==1.4", "isort==5.10.1", "mypy-extensions==1.0.0"
5148
]
5249
})

0 commit comments

Comments
 (0)