Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions predicators/approaches/online_nsrt_learning_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,25 @@ def get_interaction_requests(self) -> List[InteractionRequest]:
for _ in range(CFG.online_nsrt_learning_requests_per_cycle):
# Select a random task (with replacement).
task_idx = self._rng.choice(len(self._train_tasks))
# Set up the explorer policy and termination function.
policy, termination_function = explorer.get_exploration_strategy(
task_idx, CFG.timeout)
# Create the interaction request.
req = InteractionRequest(train_task_idx=task_idx,
act_policy=policy,
query_policy=lambda s: None,
termination_function=termination_function)
if CFG.env == "behavior":
# Set up the explorer policy and termination function.
act_plan = explorer.get_exploration_plan_strategy(
task_idx, CFG.timeout)
# Create the interaction request.
req = InteractionRequest(train_task_idx=task_idx,
act_policy=None,
query_policy=lambda s: None,
termination_function=None,
act_plan=act_plan)
else:
# Set up the explorer policy and termination function.
policy, termination_function = explorer.get_exploration_strategy(
task_idx, CFG.timeout)
# Create the interaction request.
req = InteractionRequest(train_task_idx=task_idx,
act_policy=policy,
query_policy=lambda s: None,
termination_function=termination_function)
requests.append(req)
return requests

Expand Down
4 changes: 2 additions & 2 deletions predicators/datasets/demo_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,8 @@ def _generate_demonstrations(
# states to create our low-level trajectories.
last_traj = oracle_approach.get_last_traj()
traj, success = _run_plan_with_option_model(
task, idx, oracle_approach.get_option_model(),
last_plan, last_traj)
idx, oracle_approach.get_option_model(),
last_plan, task=task, last_traj=last_traj)
# Is successful if we found a low-level plan that achieves
# our goal using option models.
if not success:
Expand Down
10 changes: 10 additions & 0 deletions predicators/explorers/base_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ def get_exploration_strategy(
tuple of a policy and a termination function."""
raise NotImplementedError("Override me!")

@abc.abstractmethod
def get_exploration_plan_strategy(
self,
train_task_idx: int,
timeout: int,
) -> ExplorationStrategy:
"""Given a train task idx, create an ExplorationStrategy, which is a
tuple of a policy and a termination function."""
raise NotImplementedError("Override me!")

def _set_seed(self, seed: int) -> None:
"""Reset seed and rng."""
self._seed = seed
Expand Down
2 changes: 1 addition & 1 deletion predicators/explorers/random_actions_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ def get_exploration_strategy(self, train_task_idx: int,
policy = lambda _: Action(self._action_space.sample())
# Never terminate (until the interaction budget is exceeded).
termination_function = lambda _: False
return policy, termination_function
return policy, termination_function
10 changes: 10 additions & 0 deletions predicators/explorers/random_options_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,13 @@ def fallback_policy(state: State) -> Action:
# Never terminate (until the interaction budget is exceeded).
termination_function = lambda _: False
return policy, termination_function

def get_exploration_plan_strategy(self, train_task_idx: int,
timeout: int) -> ExplorationStrategy:

# Take a random action.
plan = utils.create_random_option_plan(self._options, self._rng)
# Never terminate (until the interaction budget is exceeded).
termination_function = lambda _: False
return plan, termination_function

44 changes: 28 additions & 16 deletions predicators/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def _run_pipeline(env: BaseEnv,
"terminating")
break # agent doesn't want to learn anything more; terminate
interaction_results, query_cost = _generate_interaction_results(
env, teacher, interaction_requests, i)
env, teacher, interaction_requests, i, approach=approach)
num_online_transitions += sum(
len(result.actions) for result in interaction_results)
total_query_cost += query_cost
Expand Down Expand Up @@ -209,7 +209,8 @@ def _generate_interaction_results(
env: BaseEnv,
teacher: Teacher,
requests: Sequence[InteractionRequest],
cycle_num: Optional[int] = None
cycle_num: Optional[int] = None,
approach: BaseApproach = None
) -> Tuple[List[InteractionResult], float]:
"""Given a sequence of InteractionRequest objects, handle the requests and
return a list of InteractionResult objects."""
Expand All @@ -225,18 +226,29 @@ def _generate_interaction_results(
"if allow_interaction_in_demo_tasks is False.")
monitor = TeacherInteractionMonitorWithVideo(env.render, request,
teacher)
traj, _ = utils.run_policy(
request.act_policy,
env,
"train",
request.train_task_idx,
request.termination_function,
max_num_steps=CFG.max_num_steps_interaction_request,
exceptions_to_break_on={
utils.EnvironmentFailure, utils.OptionExecutionFailure,
utils.RequestActPolicyFailure
},
monitor=monitor)
if CFG.behavior_option_model_eval:
# TODO needs to run random option_model instead of policy.
assert approach
# TODO needs to be fixed
init_state = env.current_ig_state_to_state(use_test_scene=env.task_instance_id >= 10)
#
plan = request.act_policy
traj, _ = _run_plan_with_option_model(
request.train_task_idx, approach.get_option_model(),
plan, init_state=init_state)
else:
traj, _ = utils.run_policy(
request.act_policy,
env,
"train",
request.train_task_idx,
request.termination_function,
max_num_steps=CFG.max_num_steps_interaction_request,
exceptions_to_break_on={
utils.EnvironmentFailure, utils.OptionExecutionFailure,
utils.RequestActPolicyFailure
},
monitor=monitor)
request_responses = monitor.get_responses()
query_cost += monitor.get_query_cost()
result = InteractionResult(traj.states, traj.actions,
Expand Down Expand Up @@ -337,8 +349,8 @@ def _run_testing(env: BaseEnv, approach: BaseApproach) -> Metrics:
last_traj = approach.get_last_traj()
option_model_start_time = time.time()
traj, solved = _run_plan_with_option_model(
task, test_task_idx, approach.get_option_model(),
last_plan, last_traj)
test_task_idx, approach.get_option_model(),
last_plan, task=task, last_traj=last_traj)
execution_metrics = {
"policy_call_time": option_model_start_time - time.time()
}
Expand Down
19 changes: 13 additions & 6 deletions predicators/planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,9 +636,11 @@ def run_low_level_search(


def _run_plan_with_option_model(
task: Task, task_idx: int, option_model: _OptionModelBase,
task_idx: int, option_model: _OptionModelBase,
plan: List[_Option],
last_traj: List[State]) -> Tuple[LowLevelTrajectory, bool]:
task: Task = None,
last_traj: List[State] = None,
init_state: State = None) -> Tuple[LowLevelTrajectory, bool]:
"""Runs a plan on an option model to generate a low-level trajectory.

Returns a LowLevelTrajectory and a boolean. If the boolean is True,
Expand All @@ -647,19 +649,24 @@ def _run_plan_with_option_model(
and False. Since option models return only states, we will add dummy
actions to the states to create our low level trajectories.
"""
traj: List[State] = [task.init] + [DefaultState for _ in plan]
if init_state is None:
init_state = task.init
if task is not None:
assert init_state == task.init
traj: List[State] = [init_state] + [DefaultState for _ in plan]
actions: List[Action] = [Action(np.array([0.0])) for _ in plan]
for idx in range(len(plan)):
state = traj[idx]
option = plan[idx]
if not option.initiable(state):
# The option is not initiable.
return LowLevelTrajectory(_states=[task.init],
return LowLevelTrajectory(_states=[init_state],
_actions=[],
_is_demo=False,
_train_task_idx=task_idx), False
if CFG.plan_only_eval: # pragma: no cover
assert isinstance(option_model, _BehaviorOptionModel)
assert last_traj is not None
# We need to load state into option model so predicate classifiers
# work when we run task.goal_holds(traj[-1]), otherwise
# classifiers will be ran on non-updated prior state.
Expand All @@ -681,12 +688,12 @@ def _run_plan_with_option_model(
actions[idx].set_option(action_option)
# Since we're not checking the expected_atoms, we need to
# explicitly check if the goal is achieved.
if task.goal_holds(traj[-1]):
if task is None or task.goal_holds(traj[-1]):
return LowLevelTrajectory(_states=traj,
_actions=actions,
_is_demo=True,
_train_task_idx=task_idx), True # success!
return LowLevelTrajectory(_states=[task.init],
return LowLevelTrajectory(_states=[init_state],
_actions=[],
_is_demo=False,
_train_task_idx=task_idx), False
Expand Down
1 change: 1 addition & 0 deletions predicators/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1281,6 +1281,7 @@ class InteractionRequest:
act_policy: Callable[[State], Action]
query_policy: Callable[[State], Optional[Query]] # query can be None
termination_function: Callable[[State], bool]
act_plan: Callable[[State], list[_Option]] = None


@dataclass(frozen=True, eq=False, repr=False)
Expand Down
22 changes: 22 additions & 0 deletions predicators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,28 @@ def _policy(state: State) -> Action:

return _policy

def create_random_option_plan(
options: Collection[ParameterizedOption], rng: np.random.Generator) -> Callable[[State], list[_Option]]:
"""Create a policy that executes random initiable options.

If no applicable option can be found, query the fallback policy.
"""
sorted_options = sorted(options, key=lambda o: o.name)
cur_option = DummyOption

def _plan(state: State) -> Action:
nonlocal cur_option
param_opt = sorted_options[rng.choice(len(sorted_options))]
objs = get_random_object_combination(list(state),
param_opt.types, rng)
assert objs is not None
params = param_opt.params_space.sample()
opt = param_opt.ground(objs, params)
plan = [opt]
return plan

return _plan


def action_arrs_to_policy(
action_arrs: Sequence[Array]) -> Callable[[State], Action]:
Expand Down
14 changes: 7 additions & 7 deletions tests/test_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,24 +58,24 @@ def test_sesame_plan(sesame_check_expected_atoms, sesame_grounder,
)
# Test our run_plan_with_option_model function
# Case 1: plan is empty
traj, success = _run_plan_with_option_model(task, 0, option_model, [],
last_traj)
traj, success = _run_plan_with_option_model(0, option_model, [],
task=task, last_traj=last_traj)
assert not success and len(traj.states) == 1 and len(traj.actions) == 0
# Case 2: plan does not achieve goal
traj, success = _run_plan_with_option_model(task, 0, option_model,
traj, success = _run_plan_with_option_model(0, option_model,
[plan[0]], last_traj)
assert not success and len(traj.states) == 1 and len(traj.actions) == 0
# Case 3: plan does achieve goal
traj, success = _run_plan_with_option_model(task, 0, option_model,
plan, last_traj)
traj, success = _run_plan_with_option_model(0, option_model,
plan, task=task, last_traj=last_traj)
assert success and len(traj.states) > 1 and len(
traj.states) == len(traj.actions) + 1
# Case 4: plan has option that is non initiable
non_initiable_option = plan[0]
non_initiable_option.initiable = lambda s: False
traj, success = _run_plan_with_option_model(task, 0, option_model,
traj, success = _run_plan_with_option_model(0, option_model,
[non_initiable_option],
last_traj)
task=task, last_traj=last_traj)
assert not success and len(traj.states) == 1 and len(traj.actions) == 0
if e is None:
assert len(plan) == 3
Expand Down