Skip to content

Commit f95d778

Browse files
committed
yapf
1 parent 9f36174 commit f95d778

29 files changed

+826
-767
lines changed

predicators/approaches/__init__.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from predicators import utils
99
from predicators.approaches.base_approach import ApproachFailure, \
1010
ApproachTimeout, BaseApproach, BaseApproachWrapper
11-
from predicators.structs import ParameterizedOption, Predicate, Task, Type,\
12-
ConceptPredicate
11+
from predicators.structs import ConceptPredicate, ParameterizedOption, \
12+
Predicate, Task, Type
1313

1414
__all__ = ["BaseApproach", "ApproachTimeout", "ApproachFailure"]
1515

@@ -47,14 +47,21 @@ def create_approach(name: str, initial_predicates: Set[Predicate],
4747
train_tasks)
4848
# Find wrapper.
4949
wrapper_cls = _get_wrapper_cls_from_name(wrapper_name)
50-
return wrapper_cls(base_approach, initial_predicates,
51-
initial_options,
52-
types, action_space, train_tasks,
53-
)
50+
return wrapper_cls(
51+
base_approach,
52+
initial_predicates,
53+
initial_options,
54+
types,
55+
action_space,
56+
train_tasks,
57+
)
5458

5559
# Handle main approaches.
5660
cls = _get_approach_cls_from_name(name)
57-
return cls(initial_predicates,
58-
initial_options, types, action_space,
59-
train_tasks,
60-
)
61+
return cls(
62+
initial_predicates,
63+
initial_options,
64+
types,
65+
action_space,
66+
train_tasks,
67+
)

predicators/approaches/base_approach.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,24 @@
88
from gym.spaces import Box
99

1010
from predicators.settings import CFG
11-
from predicators.structs import Action, Dataset, InteractionRequest, \
12-
InteractionResult, Metrics, ParameterizedOption, Predicate, State, Task, \
13-
Type, ConceptPredicate
11+
from predicators.structs import Action, ConceptPredicate, Dataset, \
12+
InteractionRequest, InteractionResult, Metrics, ParameterizedOption, \
13+
Predicate, State, Task, Type
1414
from predicators.utils import ExceptionWithInfo
1515

1616

1717
class BaseApproach(abc.ABC):
1818
"""Base approach."""
1919

20-
def __init__(self, initial_predicates: Set[Predicate],
21-
initial_options: Set[ParameterizedOption], types: Set[Type],
22-
action_space: Box, train_tasks: List[Task],
23-
initial_concept_predicates: Set[ConceptPredicate],
24-
) -> None:
20+
def __init__(
21+
self,
22+
initial_predicates: Set[Predicate],
23+
initial_options: Set[ParameterizedOption],
24+
types: Set[Type],
25+
action_space: Box,
26+
train_tasks: List[Task],
27+
initial_concept_predicates: Set[ConceptPredicate],
28+
) -> None:
2529
"""All approaches are initialized with only the necessary information
2630
about the environment."""
2731
self._initial_predicates = initial_predicates

predicators/approaches/bilevel_planning_approach.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,30 +17,33 @@
1717
from predicators.planning import PlanningFailure, PlanningTimeout, \
1818
run_task_plan_once, sesame_plan
1919
from predicators.settings import CFG
20-
from predicators.structs import NSRT, Action, GroundAtom, Metrics, \
21-
ParameterizedOption, Predicate, State, Task, Type, _GroundNSRT, _Option,\
22-
ConceptPredicate
20+
from predicators.structs import NSRT, Action, ConceptPredicate, GroundAtom, \
21+
Metrics, ParameterizedOption, Predicate, State, Task, Type, _GroundNSRT, \
22+
_Option
2323

2424

2525
class BilevelPlanningApproach(BaseApproach):
2626
"""Bilevel planning approach."""
2727

28-
def __init__(self,
29-
initial_predicates: Set[Predicate],
30-
initial_options: Set[ParameterizedOption],
31-
types: Set[Type],
32-
action_space: Box,
33-
train_tasks: List[Task],
34-
task_planning_heuristic: str = "default",
35-
max_skeletons_optimized: int = -1,
36-
bilevel_plan_without_sim: Optional[bool] = None,
37-
option_model: Optional[_OptionModelBase] = None,
38-
initial_concept_predicates: Set[ConceptPredicate] = set(),
39-
) -> None:
40-
super().__init__(initial_predicates,
41-
initial_options, types,
42-
action_space, train_tasks,
43-
initial_concept_predicates=initial_concept_predicates)
28+
def __init__(
29+
self,
30+
initial_predicates: Set[Predicate],
31+
initial_options: Set[ParameterizedOption],
32+
types: Set[Type],
33+
action_space: Box,
34+
train_tasks: List[Task],
35+
task_planning_heuristic: str = "default",
36+
max_skeletons_optimized: int = -1,
37+
bilevel_plan_without_sim: Optional[bool] = None,
38+
option_model: Optional[_OptionModelBase] = None,
39+
initial_concept_predicates: Set[ConceptPredicate] = set(),
40+
) -> None:
41+
super().__init__(initial_predicates,
42+
initial_options,
43+
types,
44+
action_space,
45+
train_tasks,
46+
initial_concept_predicates=initial_concept_predicates)
4447
if task_planning_heuristic == "default":
4548
task_planning_heuristic = CFG.sesame_task_planning_heuristic
4649
if max_skeletons_optimized == -1:
@@ -191,12 +194,13 @@ def _get_current_predicates(self) -> Set[Predicate]:
191194
Defaults to initial predicates.
192195
"""
193196
return self._initial_predicates | self._initial_concept_predicates
194-
197+
195198
def _get_current_concept_predicates(self) -> Set[ConceptPredicate]:
196-
"""Get the current set of concept predicates.
197-
"""
198-
cnpt_preds = set([pred for pred in self._get_current_predicates() if
199-
isinstance(pred, ConceptPredicate)])
199+
"""Get the current set of concept predicates."""
200+
cnpt_preds = set([
201+
pred for pred in self._get_current_predicates()
202+
if isinstance(pred, ConceptPredicate)
203+
])
200204
return cnpt_preds
201205

202206
def get_option_model(self) -> _OptionModelBase:

predicators/approaches/nsrt_learning_approach.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
import logging
88
import time
9-
from typing import Any, Dict, List, Optional, Set
109
from itertools import chain
10+
from typing import Any, Dict, List, Optional, Set
1111

1212
import dill as pkl
1313
from gym.spaces import Box
@@ -18,9 +18,9 @@
1818
from predicators.nsrt_learning.nsrt_learning_main import learn_nsrts_from_data
1919
from predicators.planning import task_plan, task_plan_grounding
2020
from predicators.settings import CFG
21-
from predicators.structs import NSRT, Dataset, GroundAtomTrajectory, \
22-
LowLevelTrajectory, ParameterizedOption, Predicate, Segment, Task, Type,\
23-
ConceptPredicate
21+
from predicators.structs import NSRT, ConceptPredicate, Dataset, \
22+
GroundAtomTrajectory, LowLevelTrajectory, ParameterizedOption, Predicate, \
23+
Segment, Task, Type
2424

2525

2626
class NSRTLearningApproach(BilevelPlanningApproach):
@@ -106,7 +106,7 @@ def _learn_nsrts(self, trajectories: List[LowLevelTrajectory],
106106
# TODO: make sure it expands until no more new predicates are added
107107
aux_preds = set(chain.from_iterable(p.auxiliary_concepts for p
108108
in (self._get_current_predicates() |\
109-
self._get_current_concept_predicates()) if
109+
self._get_current_concept_predicates()) if
110110
isinstance(p, ConceptPredicate) and\
111111
p.auxiliary_concepts))
112112
self._nsrts, self._segmented_trajs, self._seg_to_nsrt = \

predicators/approaches/oracle_approach.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,27 @@
1515
from predicators.ground_truth_models import get_gt_nsrts
1616
from predicators.option_model import _OptionModelBase
1717
from predicators.settings import CFG
18-
from predicators.structs import NSRT, ParameterizedOption, Predicate, Task, \
19-
Type, ConceptPredicate
18+
from predicators.structs import NSRT, ConceptPredicate, ParameterizedOption, \
19+
Predicate, Task, Type
2020

2121

2222
class OracleApproach(BilevelPlanningApproach):
2323
"""A bilevel planning approach that uses hand-specified NSRTs."""
2424

25-
def __init__(self,
26-
initial_predicates: Set[Predicate],
27-
initial_options: Set[ParameterizedOption],
28-
types: Set[Type],
29-
action_space: Box,
30-
train_tasks: List[Task],
31-
task_planning_heuristic: str = "default",
32-
max_skeletons_optimized: int = -1,
33-
bilevel_plan_without_sim: Optional[bool] = None,
34-
nsrts: Optional[Set[NSRT]] = None,
35-
option_model: Optional[_OptionModelBase] = None,
36-
initial_concept_predicates: Set[ConceptPredicate] = set(),
37-
) -> None:
25+
def __init__(
26+
self,
27+
initial_predicates: Set[Predicate],
28+
initial_options: Set[ParameterizedOption],
29+
types: Set[Type],
30+
action_space: Box,
31+
train_tasks: List[Task],
32+
task_planning_heuristic: str = "default",
33+
max_skeletons_optimized: int = -1,
34+
bilevel_plan_without_sim: Optional[bool] = None,
35+
nsrts: Optional[Set[NSRT]] = None,
36+
option_model: Optional[_OptionModelBase] = None,
37+
initial_concept_predicates: Set[ConceptPredicate] = set(),
38+
) -> None:
3839
super().__init__(initial_predicates,
3940
initial_options,
4041
types,

0 commit comments

Comments
 (0)