From 54f50ef490e4d4aada72d2bc2718afafc3bc73a8 Mon Sep 17 00:00:00 2001 From: Willie McClinton Date: Mon, 27 Feb 2023 17:59:10 -0500 Subject: [PATCH] lots of changes, almost there --- .../online_nsrt_learning_approach.py | 27 ++++++++---- predicators/datasets/demo_only.py | 4 +- predicators/explorers/base_explorer.py | 10 +++++ .../explorers/random_actions_explorer.py | 2 +- .../explorers/random_options_explorer.py | 10 +++++ predicators/main.py | 44 ++++++++++++------- predicators/planning.py | 19 +++++--- predicators/structs.py | 1 + predicators/utils.py | 22 ++++++++++ tests/test_planning.py | 14 +++--- 10 files changed, 113 insertions(+), 40 deletions(-) diff --git a/predicators/approaches/online_nsrt_learning_approach.py b/predicators/approaches/online_nsrt_learning_approach.py index b964b701d6..6e19fbfda3 100644 --- a/predicators/approaches/online_nsrt_learning_approach.py +++ b/predicators/approaches/online_nsrt_learning_approach.py @@ -69,14 +69,25 @@ def get_interaction_requests(self) -> List[InteractionRequest]: for _ in range(CFG.online_nsrt_learning_requests_per_cycle): # Select a random task (with replacement). task_idx = self._rng.choice(len(self._train_tasks)) - # Set up the explorer policy and termination function. - policy, termination_function = explorer.get_exploration_strategy( - task_idx, CFG.timeout) - # Create the interaction request. - req = InteractionRequest(train_task_idx=task_idx, - act_policy=policy, - query_policy=lambda s: None, - termination_function=termination_function) + if CFG.env == "behavior": + # Set up the explorer policy and termination function. + act_plan = explorer.get_exploration_plan_strategy( + task_idx, CFG.timeout) + # Create the interaction request. + req = InteractionRequest(train_task_idx=task_idx, + act_policy=None, + query_policy=lambda s: None, + termination_function=None, + act_plan=act_plan) + else: + # Set up the explorer policy and termination function. + policy, termination_function = explorer.get_exploration_strategy( + task_idx, CFG.timeout) + # Create the interaction request. + req = InteractionRequest(train_task_idx=task_idx, + act_policy=policy, + query_policy=lambda s: None, + termination_function=termination_function) requests.append(req) return requests diff --git a/predicators/datasets/demo_only.py b/predicators/datasets/demo_only.py index 57d1229fc8..2e5ed626de 100644 --- a/predicators/datasets/demo_only.py +++ b/predicators/datasets/demo_only.py @@ -245,8 +245,8 @@ def _generate_demonstrations( # states to create our low-level trajectories. last_traj = oracle_approach.get_last_traj() traj, success = _run_plan_with_option_model( - task, idx, oracle_approach.get_option_model(), - last_plan, last_traj) + idx, oracle_approach.get_option_model(), + last_plan, task=task, last_traj=last_traj) # Is successful if we found a low-level plan that achieves # our goal using option models. if not success: diff --git a/predicators/explorers/base_explorer.py b/predicators/explorers/base_explorer.py index 73d1499c24..80fb816c59 100644 --- a/predicators/explorers/base_explorer.py +++ b/predicators/explorers/base_explorer.py @@ -44,6 +44,16 @@ def get_exploration_strategy( tuple of a policy and a termination function.""" raise NotImplementedError("Override me!") + @abc.abstractmethod + def get_exploration_plan_strategy( + self, + train_task_idx: int, + timeout: int, + ) -> ExplorationStrategy: + """Given a train task idx, create an ExplorationStrategy, which is a + tuple of a policy and a termination function.""" + raise NotImplementedError("Override me!") + def _set_seed(self, seed: int) -> None: """Reset seed and rng.""" self._seed = seed diff --git a/predicators/explorers/random_actions_explorer.py b/predicators/explorers/random_actions_explorer.py index 0b63a3758b..037e9b93c5 100644 --- a/predicators/explorers/random_actions_explorer.py +++ b/predicators/explorers/random_actions_explorer.py @@ -17,4 +17,4 @@ def get_exploration_strategy(self, train_task_idx: int, policy = lambda _: Action(self._action_space.sample()) # Never terminate (until the interaction budget is exceeded). termination_function = lambda _: False - return policy, termination_function + return policy, termination_function \ No newline at end of file diff --git a/predicators/explorers/random_options_explorer.py b/predicators/explorers/random_options_explorer.py index 16f16c9711..777db61964 100644 --- a/predicators/explorers/random_options_explorer.py +++ b/predicators/explorers/random_options_explorer.py @@ -30,3 +30,13 @@ def fallback_policy(state: State) -> Action: # Never terminate (until the interaction budget is exceeded). termination_function = lambda _: False return policy, termination_function + + def get_exploration_plan_strategy(self, train_task_idx: int, + timeout: int) -> ExplorationStrategy: + + # Take a random action. + plan = utils.create_random_option_plan(self._options, self._rng) + # Never terminate (until the interaction budget is exceeded). + termination_function = lambda _: False + return plan, termination_function + diff --git a/predicators/main.py b/predicators/main.py index a828b9f3bf..6a62146d70 100644 --- a/predicators/main.py +++ b/predicators/main.py @@ -175,7 +175,7 @@ def _run_pipeline(env: BaseEnv, "terminating") break # agent doesn't want to learn anything more; terminate interaction_results, query_cost = _generate_interaction_results( - env, teacher, interaction_requests, i) + env, teacher, interaction_requests, i, approach=approach) num_online_transitions += sum( len(result.actions) for result in interaction_results) total_query_cost += query_cost @@ -209,7 +209,8 @@ def _generate_interaction_results( env: BaseEnv, teacher: Teacher, requests: Sequence[InteractionRequest], - cycle_num: Optional[int] = None + cycle_num: Optional[int] = None, + approach: BaseApproach = None ) -> Tuple[List[InteractionResult], float]: """Given a sequence of InteractionRequest objects, handle the requests and return a list of InteractionResult objects.""" @@ -225,18 +226,29 @@ def _generate_interaction_results( "if allow_interaction_in_demo_tasks is False.") monitor = TeacherInteractionMonitorWithVideo(env.render, request, teacher) - traj, _ = utils.run_policy( - request.act_policy, - env, - "train", - request.train_task_idx, - request.termination_function, - max_num_steps=CFG.max_num_steps_interaction_request, - exceptions_to_break_on={ - utils.EnvironmentFailure, utils.OptionExecutionFailure, - utils.RequestActPolicyFailure - }, - monitor=monitor) + if CFG.behavior_option_model_eval: + # TODO needs to run random option_model instead of policy. + assert approach + # TODO needs to be fixed + init_state = env.current_ig_state_to_state(use_test_scene=env.task_instance_id >= 10) + # + plan = request.act_policy + traj, _ = _run_plan_with_option_model( + request.train_task_idx, approach.get_option_model(), + plan, init_state=init_state) + else: + traj, _ = utils.run_policy( + request.act_policy, + env, + "train", + request.train_task_idx, + request.termination_function, + max_num_steps=CFG.max_num_steps_interaction_request, + exceptions_to_break_on={ + utils.EnvironmentFailure, utils.OptionExecutionFailure, + utils.RequestActPolicyFailure + }, + monitor=monitor) request_responses = monitor.get_responses() query_cost += monitor.get_query_cost() result = InteractionResult(traj.states, traj.actions, @@ -337,8 +349,8 @@ def _run_testing(env: BaseEnv, approach: BaseApproach) -> Metrics: last_traj = approach.get_last_traj() option_model_start_time = time.time() traj, solved = _run_plan_with_option_model( - task, test_task_idx, approach.get_option_model(), - last_plan, last_traj) + test_task_idx, approach.get_option_model(), + last_plan, task=task, last_traj=last_traj) execution_metrics = { "policy_call_time": option_model_start_time - time.time() } diff --git a/predicators/planning.py b/predicators/planning.py index 0fb6dfd47c..c6da3fcd30 100644 --- a/predicators/planning.py +++ b/predicators/planning.py @@ -636,9 +636,11 @@ def run_low_level_search( def _run_plan_with_option_model( - task: Task, task_idx: int, option_model: _OptionModelBase, + task_idx: int, option_model: _OptionModelBase, plan: List[_Option], - last_traj: List[State]) -> Tuple[LowLevelTrajectory, bool]: + task: Task = None, + last_traj: List[State] = None, + init_state: State = None) -> Tuple[LowLevelTrajectory, bool]: """Runs a plan on an option model to generate a low-level trajectory. Returns a LowLevelTrajectory and a boolean. If the boolean is True, @@ -647,19 +649,24 @@ def _run_plan_with_option_model( and False. Since option models return only states, we will add dummy actions to the states to create our low level trajectories. """ - traj: List[State] = [task.init] + [DefaultState for _ in plan] + if init_state is None: + init_state = task.init + if task is not None: + assert init_state == task.init + traj: List[State] = [init_state] + [DefaultState for _ in plan] actions: List[Action] = [Action(np.array([0.0])) for _ in plan] for idx in range(len(plan)): state = traj[idx] option = plan[idx] if not option.initiable(state): # The option is not initiable. - return LowLevelTrajectory(_states=[task.init], + return LowLevelTrajectory(_states=[init_state], _actions=[], _is_demo=False, _train_task_idx=task_idx), False if CFG.plan_only_eval: # pragma: no cover assert isinstance(option_model, _BehaviorOptionModel) + assert last_traj is not None # We need to load state into option model so predicate classifiers # work when we run task.goal_holds(traj[-1]), otherwise # classifiers will be ran on non-updated prior state. @@ -681,12 +688,12 @@ def _run_plan_with_option_model( actions[idx].set_option(action_option) # Since we're not checking the expected_atoms, we need to # explicitly check if the goal is achieved. - if task.goal_holds(traj[-1]): + if task is None or task.goal_holds(traj[-1]): return LowLevelTrajectory(_states=traj, _actions=actions, _is_demo=True, _train_task_idx=task_idx), True # success! - return LowLevelTrajectory(_states=[task.init], + return LowLevelTrajectory(_states=[init_state], _actions=[], _is_demo=False, _train_task_idx=task_idx), False diff --git a/predicators/structs.py b/predicators/structs.py index 88b1fa9422..dfc340d8a3 100644 --- a/predicators/structs.py +++ b/predicators/structs.py @@ -1281,6 +1281,7 @@ class InteractionRequest: act_policy: Callable[[State], Action] query_policy: Callable[[State], Optional[Query]] # query can be None termination_function: Callable[[State], bool] + act_plan: Callable[[State], list[_Option]] = None @dataclass(frozen=True, eq=False, repr=False) diff --git a/predicators/utils.py b/predicators/utils.py index e62855b853..7db4570649 100644 --- a/predicators/utils.py +++ b/predicators/utils.py @@ -1076,6 +1076,28 @@ def _policy(state: State) -> Action: return _policy +def create_random_option_plan( + options: Collection[ParameterizedOption], rng: np.random.Generator) -> Callable[[State], list[_Option]]: + """Create a policy that executes random initiable options. + + If no applicable option can be found, query the fallback policy. + """ + sorted_options = sorted(options, key=lambda o: o.name) + cur_option = DummyOption + + def _plan(state: State) -> Action: + nonlocal cur_option + param_opt = sorted_options[rng.choice(len(sorted_options))] + objs = get_random_object_combination(list(state), + param_opt.types, rng) + assert objs is not None + params = param_opt.params_space.sample() + opt = param_opt.ground(objs, params) + plan = [opt] + return plan + + return _plan + def action_arrs_to_policy( action_arrs: Sequence[Array]) -> Callable[[State], Action]: diff --git a/tests/test_planning.py b/tests/test_planning.py index 2697fd718b..9daa63e65d 100644 --- a/tests/test_planning.py +++ b/tests/test_planning.py @@ -58,24 +58,24 @@ def test_sesame_plan(sesame_check_expected_atoms, sesame_grounder, ) # Test our run_plan_with_option_model function # Case 1: plan is empty - traj, success = _run_plan_with_option_model(task, 0, option_model, [], - last_traj) + traj, success = _run_plan_with_option_model(0, option_model, [], + task=task, last_traj=last_traj) assert not success and len(traj.states) == 1 and len(traj.actions) == 0 # Case 2: plan does not achieve goal - traj, success = _run_plan_with_option_model(task, 0, option_model, + traj, success = _run_plan_with_option_model(0, option_model, [plan[0]], last_traj) assert not success and len(traj.states) == 1 and len(traj.actions) == 0 # Case 3: plan does achieve goal - traj, success = _run_plan_with_option_model(task, 0, option_model, - plan, last_traj) + traj, success = _run_plan_with_option_model(0, option_model, + plan, task=task, last_traj=last_traj) assert success and len(traj.states) > 1 and len( traj.states) == len(traj.actions) + 1 # Case 4: plan has option that is non initiable non_initiable_option = plan[0] non_initiable_option.initiable = lambda s: False - traj, success = _run_plan_with_option_model(task, 0, option_model, + traj, success = _run_plan_with_option_model(0, option_model, [non_initiable_option], - last_traj) + task=task, last_traj=last_traj) assert not success and len(traj.states) == 1 and len(traj.actions) == 0 if e is None: assert len(plan) == 3