From b490174d923642f298dc47bc9cf7fe602e6d91e7 Mon Sep 17 00:00:00 2001 From: AlexPasqua Date: Sun, 6 Jul 2025 16:00:32 +0200 Subject: [PATCH 01/19] Created empty structure for custom policies --- sb3_contrib/common/hybrid/__init__.py | 0 sb3_contrib/common/hybrid/policies.py | 27 +++++++++++ sb3_contrib/ppo_hybrid/__init__.py | 4 ++ sb3_contrib/ppo_hybrid/policies.py | 9 ++++ sb3_contrib/ppo_hybrid/ppo_hybrid.py | 69 +++++++++++++++++++++++++++ 5 files changed, 109 insertions(+) create mode 100644 sb3_contrib/common/hybrid/__init__.py create mode 100644 sb3_contrib/common/hybrid/policies.py create mode 100644 sb3_contrib/ppo_hybrid/__init__.py create mode 100644 sb3_contrib/ppo_hybrid/policies.py create mode 100644 sb3_contrib/ppo_hybrid/ppo_hybrid.py diff --git a/sb3_contrib/common/hybrid/__init__.py b/sb3_contrib/common/hybrid/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sb3_contrib/common/hybrid/policies.py b/sb3_contrib/common/hybrid/policies.py new file mode 100644 index 00000000..91304c7c --- /dev/null +++ b/sb3_contrib/common/hybrid/policies.py @@ -0,0 +1,27 @@ +from typing import Any, Optional, Union +from stable_baselines3.common.policies import BasePolicy +from gymnasium import spaces +from stable_baselines3.common.type_aliases import PyTorchObs, Schedule +import torch as th +from torch import nn +from stable_baselines3.common.torch_layers import ( + BaseFeaturesExtractor, + CombinedExtractor, + FlattenExtractor, + MlpExtractor, + NatureCNN, +) + + +class HybridActorCriticPolicy(BasePolicy): + pass + + +# TODO: check superclass +class HybridActorCriticCnnPolicy(HybridActorCriticPolicy): + pass + + +# TODO: check superclass +class HybridMultiInputActorCriticPolicy(HybridActorCriticPolicy): + pass \ No newline at end of file diff --git a/sb3_contrib/ppo_hybrid/__init__.py b/sb3_contrib/ppo_hybrid/__init__.py new file mode 100644 index 00000000..448c6768 --- /dev/null +++ b/sb3_contrib/ppo_hybrid/__init__.py @@ -0,0 +1,4 @@ +from sb3_contrib.ppo_hybrid.policies import CnnPolicy, MlpPolicy, MultiInputPolicy +from sb3_contrib.ppo_hybrid.ppo_hybrid import HybridPPO + +__all__ = ["CnnPolicy", "HybridPPO", "MlpPolicy", "MultiInputPolicy"] diff --git a/sb3_contrib/ppo_hybrid/policies.py b/sb3_contrib/ppo_hybrid/policies.py new file mode 100644 index 00000000..15f60e58 --- /dev/null +++ b/sb3_contrib/ppo_hybrid/policies.py @@ -0,0 +1,9 @@ +from sb3_contrib.common.hybrid.policies import ( + HybridActorCriticPolicy, + HybridActorCriticCnnPolicy, + HybridMultiInputActorCriticPolicy, +) + +MlpPolicy = HybridActorCriticPolicy +CnnPolicy = HybridActorCriticCnnPolicy +MultiInputPolicy = HybridMultiInputActorCriticPolicy diff --git a/sb3_contrib/ppo_hybrid/ppo_hybrid.py b/sb3_contrib/ppo_hybrid/ppo_hybrid.py new file mode 100644 index 00000000..940c12f8 --- /dev/null +++ b/sb3_contrib/ppo_hybrid/ppo_hybrid.py @@ -0,0 +1,69 @@ +from typing import ClassVar +from stable_baselines3.ppo import PPO +from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy + + +class HybrudPPO(PPO): + policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = { + "MlpPolicy": ActorCriticPolicy, + "CnnPolicy": ActorCriticCnnPolicy, + "MultiInputPolicy": MultiInputActorCriticPolicy, + } + + def __init__( + self, + policy, + env, + learning_rate=0.0003, + n_steps=2048, + batch_size=64, + n_epochs=10, + gamma=0.99, + gae_lambda=0.95, + clip_range=0.2, + clip_range_vf=None, + normalize_advantage=True, + ent_coef=0, + vf_coef=0.5, + max_grad_norm=0.5, + use_sde=False, + sde_sample_freq=-1, + rollout_buffer_class=None, + rollout_buffer_kwargs=None, + target_kl=None, + stats_window_size=100, + tensorboard_log=None, + policy_kwargs=None, + verbose=0, + seed=None, + device="auto", + _init_setup_model=True, + ): + super().__init__( + policy, + env, + learning_rate, + n_steps, + batch_size, + n_epochs, + gamma, + gae_lambda, + clip_range, + clip_range_vf, + normalize_advantage, + ent_coef, + vf_coef, + max_grad_norm, + use_sde, + sde_sample_freq, + rollout_buffer_class, + rollout_buffer_kwargs, + target_kl, + stats_window_size, + tensorboard_log, + policy_kwargs, + verbose, + seed, + device, + _init_setup_model, + ) From 71291eba65f72feff8d9d3dce5a6b25034a606e5 Mon Sep 17 00:00:00 2001 From: AlexPasqua Date: Sun, 6 Jul 2025 17:11:00 +0200 Subject: [PATCH 02/19] Started populating HybridActorCriticPolicy --- sb3_contrib/common/hybrid/policies.py | 56 ++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/sb3_contrib/common/hybrid/policies.py b/sb3_contrib/common/hybrid/policies.py index 91304c7c..20d61b54 100644 --- a/sb3_contrib/common/hybrid/policies.py +++ b/sb3_contrib/common/hybrid/policies.py @@ -14,7 +14,61 @@ class HybridActorCriticPolicy(BasePolicy): - pass + """ + Policy class for actor-critic algorithms (has both policy and value prediction). + Used by A2C, PPO and the likes. + + :param observation_space: Observation space + :param action_space: Tuple Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param ortho_init: Whether to use or not orthogonal initialization + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param share_features_extractor: If True, the features extractor is shared between the policy and value networks. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: spaces.Space, + action_space: spaces.Tuple, + lr_schedule: Schedule, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + activation_fn: type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[dict[str, Any]] = None, + share_features_extractor: bool = True, + normalize_images: bool = True, + optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, + ): + if optimizer_kwargs is None: + optimizer_kwargs = {} + # Small values to avoid NaN in Adam optimizer + if optimizer_class == th.optim.Adam: + optimizer_kwargs["eps"] = 1e-5 + + super().__init__( + observation_space, + action_space, + features_extractor_class, + features_extractor_kwargs, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + normalize_images=normalize_images, + squash_output=False, + ) + + assert isinstance(action_space, spaces.Tuple), "Action space must be a Tuple space." # TODO: check superclass From 33605703b8074b7b0104f009f244b9baeb9dedac Mon Sep 17 00:00:00 2001 From: AlexPasqua Date: Sat, 26 Jul 2025 18:47:55 +0200 Subject: [PATCH 03/19] Create HybridDistribution and HybridDistributionNet (tbc) --- sb3_contrib/common/hybrid/distributions.py | 130 +++++++++++++++++++++ sb3_contrib/common/hybrid/policies.py | 54 ++++++++- 2 files changed, 181 insertions(+), 3 deletions(-) create mode 100644 sb3_contrib/common/hybrid/distributions.py diff --git a/sb3_contrib/common/hybrid/distributions.py b/sb3_contrib/common/hybrid/distributions.py new file mode 100644 index 00000000..2d850dd7 --- /dev/null +++ b/sb3_contrib/common/hybrid/distributions.py @@ -0,0 +1,130 @@ +import numpy as np +import torch as th +from torch import nn +from typing import Any, Optional, TypeVar, Union +from stable_baselines3.common.distributions import Distribution +from gymnasium import spaces + + +SelfHybridDistribution = TypeVar("SelfHybridDistribution", bound="HybridDistribution") + + +class HybridDistributionNet(nn.Module): + """ + Base class for hybrid distributions that handle both discrete and continuous actions. + This class should be extended to implement specific hybrid distributions. + """ + + def __init__(self, latent_dim: int, categorical_dimensions: np.ndarray, n_continuous: int): + super().__init__() + # For discrete action space + self.categorical_nets = nn.ModuleList([nn.Linear(latent_dim, out_dim) for out_dim in categorical_dimensions]) + # For continuous action space + self.gaussian_net = nn.Linear(latent_dim, n_continuous) + + def forward(self, latent: th.Tensor) -> tuple[list[th.Tensor], th.Tensor]: + """ + Forward pass through all categorical nets and the gaussian net. + + :param latent: Latent tensor input + :return: Tuple (list of categorical outputs, gaussian output) + """ + categorical_outputs = [net(latent) for net in self.categorical_nets] + gaussian_output = self.gaussian_net(latent) + return categorical_outputs, gaussian_output + + +class HybridDistribution(Distribution): + def __init__(self, categorical_dimensions: np.ndarray, n_continuous: int): + super().__init__() + self.categorical_dimensions = categorical_dimensions + self.n_continuous = n_continuous + self.categorical_dists = None + self.gaussian_dist = None + + def proba_distribution_net(self, latent_dim: int) -> Union[nn.Module, tuple[nn.Module, nn.Parameter]]: + """Create the layers and parameters that represent the distribution. + + Subclasses must define this, but the arguments and return type vary between + concrete classes.""" + action_net = HybridDistributionNet(latent_dim, self.categorical_dimensions) + return action_net + + def proba_distribution(self: SelfHybridDistribution, *args, **kwargs) -> SelfHybridDistribution: + """Set parameters of the distribution. + + :return: self + """ + + def log_prob(self, x: th.Tensor) -> th.Tensor: + """ + Returns the log likelihood + + :param x: the taken action + :return: The log likelihood of the distribution + """ + + def entropy(self) -> Optional[th.Tensor]: + """ + Returns Shannon's entropy of the probability + + :return: the entropy, or None if no analytical form is known + """ + + def sample(self) -> th.Tensor: + """ + Returns a sample from the probability distribution + + :return: the stochastic action + """ + + def mode(self) -> th.Tensor: + """ + Returns the most likely action (deterministic output) + from the probability distribution + + :return: the stochastic action + """ + + # TODO: this is not abstract in superclass, you can also not re-implement it --> check + def get_actions(self, deterministic: bool = False) -> th.Tensor: + """ + Return actions according to the probability distribution. + + :param deterministic: + :return: + """ + if deterministic: + return self.mode() + return self.sample() + + def actions_from_params(self, *args, **kwargs) -> th.Tensor: + """ + Returns samples from the probability distribution + given its parameters. + + :return: actions + """ + + def log_prob_from_params(self, *args, **kwargs) -> tuple[th.Tensor, th.Tensor]: + """ + Returns samples and the associated log probabilities + from the probability distribution given its parameters. + + :return: actions and log prob + """ + + +def make_hybrid_proba_distribution(action_space: spaces.Tuple[spaces.MultiDiscrete, spaces.Box]) -> HybridDistribution: + """ + Create a hybrid probability distribution for the given action space. + + :param action_space: Tuple Action space containing a MultiDiscrete action space and a Box action space. + :return: A HybridDistribution object that handles the hybrid action space. + """ + assert len(action_space[1].shape) == 1, "Continuous action space must have a monodimensional shape (e.g., (n,))" + return HybridDistribution( + categorical_dimensions=len(action_space[0].nvec), + n_continuous=action_space[1].shape[0] + ) + diff --git a/sb3_contrib/common/hybrid/policies.py b/sb3_contrib/common/hybrid/policies.py index 20d61b54..97de61d6 100644 --- a/sb3_contrib/common/hybrid/policies.py +++ b/sb3_contrib/common/hybrid/policies.py @@ -1,4 +1,5 @@ from typing import Any, Optional, Union +import warnings from stable_baselines3.common.policies import BasePolicy from gymnasium import spaces from stable_baselines3.common.type_aliases import PyTorchObs, Schedule @@ -12,6 +13,8 @@ NatureCNN, ) +from sb3_contrib.common.hybrid.distributions import make_hybrid_proba_distribution + class HybridActorCriticPolicy(BasePolicy): """ @@ -19,11 +22,12 @@ class HybridActorCriticPolicy(BasePolicy): Used by A2C, PPO and the likes. :param observation_space: Observation space - :param action_space: Tuple Action space + :param action_space: Tuple Action space containing a MultiDiscrete action space and a Box action space. :param lr_schedule: Learning rate schedule (could be constant) :param net_arch: The specification of the policy and value networks. :param activation_fn: Activation function :param ortho_init: Whether to use or not orthogonal initialization + :param log_std_init: Initial value for the log standard deviation :param features_extractor_class: Features extractor to use. :param features_extractor_kwargs: Keyword arguments to pass to the features extractor. @@ -39,11 +43,12 @@ class HybridActorCriticPolicy(BasePolicy): def __init__( self, observation_space: spaces.Space, - action_space: spaces.Tuple, + action_space: spaces.Tuple[spaces.MultiDiscrete, spaces.Box], lr_schedule: Schedule, net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, activation_fn: type[nn.Module] = nn.Tanh, ortho_init: bool = True, + log_std_init: float = 0.0, features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor, features_extractor_kwargs: Optional[dict[str, Any]] = None, share_features_extractor: bool = True, @@ -68,7 +73,50 @@ def __init__( squash_output=False, ) - assert isinstance(action_space, spaces.Tuple), "Action space must be a Tuple space." + # assert that the action space is compatible with its type hint + assert isinstance(action_space, spaces.Tuple), "Action space must be a gymnasium.spaces.Tuple" + assert len(action_space.spaces) == 2, "Action space Tuple must contain exactly two spaces" + assert isinstance(action_space.spaces[0], spaces.MultiDiscrete), "First element of action space Tuple must be MultiDiscrete" + assert isinstance(action_space.spaces[1], spaces.Box), "Second element of action space Tuple must be Box" + + if isinstance(net_arch, list) and len(net_arch) > 0 and isinstance(net_arch[0], dict): + warnings.warn( + ( + "As shared layers in the mlp_extractor are removed since SB3 v1.8.0, " + "you should now pass directly a dictionary and not a list " + "(net_arch=dict(pi=..., vf=...) instead of net_arch=[dict(pi=..., vf=...)])" + ), + ) + net_arch = net_arch[0] + + # Default network architecture, from stable-baselines + if net_arch is None: + if features_extractor_class == NatureCNN: + net_arch = [] + else: + net_arch = dict(pi=[64, 64], vf=[64, 64]) + + self.net_arch = net_arch + self.activation_fn = activation_fn + self.ortho_init = ortho_init + + # features extractor + self.share_features_extractor = share_features_extractor + self.features_extractor = self.make_features_extractor() + self.features_dim = self.features_extractor.features_dim + if self.share_features_extractor: + self.pi_features_extractor = self.features_extractor + self.vf_features_extractor = self.features_extractor + else: + self.pi_features_extractor = self.features_extractor + self.vf_features_extractor = self.make_features_extractor() + + self.log_std_init = log_std_init + + # Action distribution + self.action_dist = make_hybrid_proba_distribution(action_space) + + # TODO: self._build() # TODO: check superclass From caf1a449d8b95173cc306dce6859787b15058e58 Mon Sep 17 00:00:00 2001 From: AlexPasqua Date: Sat, 13 Sep 2025 14:54:51 +0200 Subject: [PATCH 04/19] Add _build method to HybridActorCriticPolicy --- sb3_contrib/common/hybrid/distributions.py | 6 +++ sb3_contrib/common/hybrid/policies.py | 57 ++++++++++++++++++++-- 2 files changed, 60 insertions(+), 3 deletions(-) diff --git a/sb3_contrib/common/hybrid/distributions.py b/sb3_contrib/common/hybrid/distributions.py index 2d850dd7..32c26773 100644 --- a/sb3_contrib/common/hybrid/distributions.py +++ b/sb3_contrib/common/hybrid/distributions.py @@ -36,6 +36,12 @@ def forward(self, latent: th.Tensor) -> tuple[list[th.Tensor], th.Tensor]: class HybridDistribution(Distribution): def __init__(self, categorical_dimensions: np.ndarray, n_continuous: int): + """ + Initialize the hybrid distribution with categorical and continuous components. + + :param categorical_dimensions: An array specifying the dimensions of the categorical actions. + :param n_continuous: The number of continuous actions. + """ super().__init__() self.categorical_dimensions = categorical_dimensions self.n_continuous = n_continuous diff --git a/sb3_contrib/common/hybrid/policies.py b/sb3_contrib/common/hybrid/policies.py index 97de61d6..e6eef51c 100644 --- a/sb3_contrib/common/hybrid/policies.py +++ b/sb3_contrib/common/hybrid/policies.py @@ -1,5 +1,7 @@ +from functools import partial from typing import Any, Optional, Union import warnings +import numpy as np from stable_baselines3.common.policies import BasePolicy from gymnasium import spaces from stable_baselines3.common.type_aliases import PyTorchObs, Schedule @@ -13,7 +15,7 @@ NatureCNN, ) -from sb3_contrib.common.hybrid.distributions import make_hybrid_proba_distribution +from sb3_contrib.common.hybrid.distributions import HybridDistribution, make_hybrid_proba_distribution class HybridActorCriticPolicy(BasePolicy): @@ -114,9 +116,58 @@ def __init__( self.log_std_init = log_std_init # Action distribution - self.action_dist = make_hybrid_proba_distribution(action_space) + self.action_dist: HybridDistribution = make_hybrid_proba_distribution(action_space) + + self._build(lr_schedule) + + def _build_mlp_extractor(self) -> None: + """ + Create the policy and value networks. + """ + self.mlp_extractor = MlpExtractor( + self.features_dim, + net_arch=self.net_arch, + activation_fn=self.activation_fn, + device=self.device, + ) + + def _build(self, lr_schedule: Schedule) -> None: + self._build_mlp_extractor() + + # Create action net and valule net + self.action_net, self.log_std = self.action_dist.proba_distribution_net(latent_dim=self.mlp_extractor.latent_dim_pi) + self.value_net = nn.Linear(self.mlp_extractor.latent_dim_vf, 1) + + # Init weights: use orthogonal initialization + # with small initial weight for the output + if self.ortho_init: + # TODO: check for features_extractor + # Values from stable-baselines. + # features_extractor/mlp values are + # originally from openai/baselines (default gains/init_scales). + module_gains = { + self.features_extractor: np.sqrt(2), + self.mlp_extractor: np.sqrt(2), + self.action_net: 0.01, + self.value_net: 1, + } + if not self.share_features_extractor: + # Note(antonin): this is to keep SB3 results + # consistent, see GH#1148 + del module_gains[self.features_extractor] + module_gains[self.pi_features_extractor] = np.sqrt(2) + module_gains[self.vf_features_extractor] = np.sqrt(2) + + for module, gain in module_gains.items(): + module.apply(partial(self.init_weights, gain=gain)) + + # Setup optimizer with initial learning rate + self.optimizer = self.optimizer_class( + self.parameters(), + lr=lr_schedule(1), # type: ignore[call-arg] + **self.optimizer_kwargs, + ) - # TODO: self._build() # TODO: check superclass From 9ac17e08039e1648597e9f0e1d92c046f4f8e160 Mon Sep 17 00:00:00 2001 From: AlexPasqua Date: Sat, 13 Sep 2025 17:27:39 +0200 Subject: [PATCH 05/19] Started forward method (tbc) Created Hybrid distr, populated HybridDirstribution and forward method (tbc) Co-authored-by: simrey --- sb3_contrib/common/hybrid/distributions.py | 54 ++++++++++++++++++++-- sb3_contrib/common/hybrid/policies.py | 32 +++++++++++++ 2 files changed, 81 insertions(+), 5 deletions(-) diff --git a/sb3_contrib/common/hybrid/distributions.py b/sb3_contrib/common/hybrid/distributions.py index 32c26773..d4db39f0 100644 --- a/sb3_contrib/common/hybrid/distributions.py +++ b/sb3_contrib/common/hybrid/distributions.py @@ -3,6 +3,7 @@ from torch import nn from typing import Any, Optional, TypeVar, Union from stable_baselines3.common.distributions import Distribution +from torch.distributions import Categorical, Normal from gymnasium import spaces @@ -34,6 +35,43 @@ def forward(self, latent: th.Tensor) -> tuple[list[th.Tensor], th.Tensor]: return categorical_outputs, gaussian_output +class Hybrid(th.distributions.Distribution): + """ + A hybrid distribution that combines multiple categorical distributions for discrete actions + and a Gaussian distribution for continuous actions. + """ + + def __init__(self, + probs: Optional[tuple[list[th.Tensor], th.Tensor]] = None, + logits: Optional[tuple[list[th.Tensor], th.Tensor]] = None, + validate_args: Optional[bool] = None, + ): + super().__init__() + categorical_logits: list[th.Tensor] = logits[0] + gaussian_means: th.Tensor = logits[1] + self.categorical_dists = [Categorical(logits=logit) for logit in categorical_logits] + self.gaussian_dist = Normal(loc=gaussian_means, scale=th.ones_like(gaussian_means)) + + def sample(self) -> tuple[th.Tensor, th.Tensor]: + categorical_samples = [dist.sample() for dist in self.categorical_dists] + gaussian_samples = self.gaussian_dist.sample() + return th.stack(categorical_samples, dim=-1), gaussian_samples + + def log_prob(self): + """ + Returns the log probability of the given actions, both discrete and continuous. + """ + gaussian_log_prob = self.gaussian_dist.log_prob(self.gaussian_dist.sample()) + categorical_log_probs = [dist.log_prob(dist.sample()) for dist in self.categorical_dists] + + # TODO: check dimensions + return th.sum(th.stack(categorical_log_probs, dim=-1), dim=-1), th.sum(gaussian_log_prob, dim=-1) + + # TODO: implement + def entropy(self): + raise NotImplementedError() + + class HybridDistribution(Distribution): def __init__(self, categorical_dimensions: np.ndarray, n_continuous: int): """ @@ -56,19 +94,24 @@ def proba_distribution_net(self, latent_dim: int) -> Union[nn.Module, tuple[nn.M action_net = HybridDistributionNet(latent_dim, self.categorical_dimensions) return action_net - def proba_distribution(self: SelfHybridDistribution, *args, **kwargs) -> SelfHybridDistribution: + def proba_distribution(self: SelfHybridDistribution, action_logits: tuple[list[th.Tensor], th.Tensor]) -> SelfHybridDistribution: """Set parameters of the distribution. :return: self """ + self.distribution = Hybrid(logits=action_logits) + return self - def log_prob(self, x: th.Tensor) -> th.Tensor: + # TODO: check return type hint + def log_prob(self, discrete_actions: th.Tensor, continuous_actions: th.Tensor) -> th.Tensor: """ Returns the log likelihood :param x: the taken action :return: The log likelihood of the distribution """ + assert self.distribution is not None, "Must set distribution parameters" + return self.distribution.log_prob(continuous_actions, discrete_actions) def entropy(self) -> Optional[th.Tensor]: """ @@ -77,12 +120,14 @@ def entropy(self) -> Optional[th.Tensor]: :return: the entropy, or None if no analytical form is known """ - def sample(self) -> th.Tensor: + def sample(self) -> tuple[th.Tensor, th.Tensor]: """ Returns a sample from the probability distribution :return: the stochastic action """ + assert self.distribution is not None, "Must set distribution parameters" + return self.distribution.sample() def mode(self) -> th.Tensor: """ @@ -92,8 +137,7 @@ def mode(self) -> th.Tensor: :return: the stochastic action """ - # TODO: this is not abstract in superclass, you can also not re-implement it --> check - def get_actions(self, deterministic: bool = False) -> th.Tensor: + def get_actions(self, deterministic: bool = False) -> tuple[th.Tensor, th.Tensor]: """ Return actions according to the probability distribution. diff --git a/sb3_contrib/common/hybrid/policies.py b/sb3_contrib/common/hybrid/policies.py index e6eef51c..276ea632 100644 --- a/sb3_contrib/common/hybrid/policies.py +++ b/sb3_contrib/common/hybrid/policies.py @@ -167,7 +167,39 @@ def _build(self, lr_schedule: Schedule) -> None: lr=lr_schedule(1), # type: ignore[call-arg] **self.optimizer_kwargs, ) + + def forward(self, obs: th.Tensor, deterministic: bool = False) -> tuple[th.Tensor, th.Tensor, th.Tensor]: + """ + Forward pass in all the networks (actor and critic) + :param obs: Observation + :param deterministic: Whether to sample or use deterministic actions + :return: action, value and log probability of the action + """ + # Preprocess the observation if needed + features = self.extract_features(obs) + if self.share_features_extractor: + latent_pi, latent_vf = self.mlp_extractor(features) + else: + pi_features, vf_features = features + latent_pi = self.mlp_extractor.forward_actor(pi_features) + latent_vf = self.mlp_extractor.forward_critic(vf_features) + # Evaluate the values for the given observations + values = self.value_net(latent_vf) + distribution = self._get_action_dist_from_latent(latent_pi) + distrete_actions, continuous_actions = distribution.get_actions(deterministic=deterministic) + log_prob = distribution.log_prob(actions) + return actions, values, log_prob + + def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> HybridDistribution: + """ + Retrieve action distribution given the latent codes. + + :param latent_pi: Latent code for the actor + :return: Action distribution + """ + action_logits: tuple[list[th.Tensor], th.Tensor] = self.action_net(latent_pi) + return self.action_dist.proba_distribution(action_logits=action_logits) # TODO: check superclass From 94c0c2a1b90bf39a644d7e4e3e1ec67a09bfdd75 Mon Sep 17 00:00:00 2001 From: AlexPasqua Date: Sun, 21 Sep 2025 16:53:42 +0200 Subject: [PATCH 06/19] Completed forward and implemented collect_rollouts collect_rollouts overrides the one of PPO (in turn inherited by OnPolicyAlgorithm). It requires to implement new environment and rollout buffer, as they need to work with multiple actions (discrete and continuous). Co-authored-by: simrey --- sb3_contrib/common/hybrid/policies.py | 7 +- sb3_contrib/ppo_hybrid/ppo_hybrid.py | 117 ++++++++++++++++++++++++++ 2 files changed, 121 insertions(+), 3 deletions(-) diff --git a/sb3_contrib/common/hybrid/policies.py b/sb3_contrib/common/hybrid/policies.py index 276ea632..4ab27564 100644 --- a/sb3_contrib/common/hybrid/policies.py +++ b/sb3_contrib/common/hybrid/policies.py @@ -187,9 +187,10 @@ def forward(self, obs: th.Tensor, deterministic: bool = False) -> tuple[th.Tenso # Evaluate the values for the given observations values = self.value_net(latent_vf) distribution = self._get_action_dist_from_latent(latent_pi) - distrete_actions, continuous_actions = distribution.get_actions(deterministic=deterministic) - log_prob = distribution.log_prob(actions) - return actions, values, log_prob + actions_d, actions_c = distribution.get_actions(deterministic=deterministic) + log_prob_d = distribution.log_prob(actions_d) + log_prob_c = distribution.log_prob(actions_c) + return actions_d, actions_c, values, log_prob_d, log_prob_c def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> HybridDistribution: """ diff --git a/sb3_contrib/ppo_hybrid/ppo_hybrid.py b/sb3_contrib/ppo_hybrid/ppo_hybrid.py index 940c12f8..50550e5d 100644 --- a/sb3_contrib/ppo_hybrid/ppo_hybrid.py +++ b/sb3_contrib/ppo_hybrid/ppo_hybrid.py @@ -1,6 +1,12 @@ from typing import ClassVar +import torch as th +import numpy as np from stable_baselines3.ppo import PPO from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy +from stable_baselines3.common.vec_env import VecEnv +from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3.common.buffers import RolloutBuffer +from stable_baselines3.common.utils import obs_as_tensor class HybrudPPO(PPO): @@ -67,3 +73,114 @@ def __init__( device, _init_setup_model, ) + + def collect_rollouts( + self, + env: VecEnv, + callback: BaseCallback, + rollout_buffer: RolloutBuffer, + n_rollout_steps: int, + ) -> bool: + """ + Collect experiences using the current policy and fill a ``RolloutBuffer``. + The term rollout here refers to the model-free notion and should not + be used with the concept of rollout used in model-based RL or planning. + + :param env: The training environment + :param callback: Callback that will be called at each step + (and at the beginning and end of the rollout) + :param rollout_buffer: Buffer to fill with rollouts + :param n_rollout_steps: Number of experiences to collect per environment + :return: True if function returned with at least `n_rollout_steps` + collected, False if callback terminated rollout prematurely. + """ + assert self._last_obs is not None, "No previous observation was provided" + # Switch to eval mode (this affects batch norm / dropout) + self.policy.set_training_mode(False) + + n_steps = 0 + rollout_buffer.reset() + # Sample new weights for the state dependent exploration + if self.use_sde: + self.policy.reset_noise(env.num_envs) + + callback.on_rollout_start() + + while n_steps < n_rollout_steps: + if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0: + # Sample a new noise matrix + self.policy.reset_noise(env.num_envs) + + with th.no_grad(): + # Convert to pytorch tensor or to TensorDict + obs_tensor = obs_as_tensor(self._last_obs, self.device) + actions_d, actions_c, values, log_probs_d, log_probs_c = self.policy(obs_tensor) + actions_d = actions_d.cpu().numpy() + actions_c = actions_c.cpu().numpy() + + # Rescale and perform action + clipped_actions_c = actions_c + + # action_space is spaces.Tuple[spaces.MultiDiscrete, spaces.Box] + if self.policy.squash_output: + # Unscale the actions to match env bounds + # if they were previously squashed (scaled in [-1, 1]) + clipped_actions_c = self.policy.unscale_action(clipped_actions_c) + else: + # Otherwise, clip the actions to avoid out of bound error + # as we are sampling from an unbounded Gaussian distribution + clipped_actions_c = np.clip(actions_c, self.action_space[1].low, self.action_space[1].high) + + new_obs, rewards, dones, infos = env.step(actions_d, clipped_actions_c) + + self.num_timesteps += env.num_envs + + # Give access to local variables + callback.update_locals(locals()) + if not callback.on_step(): + return False + + self._update_info_buffer(infos, dones) + n_steps += 1 + + # Reshape in case of discrete action + # actions_d = actions_d.reshape(-1, 1) # TODO: MultiDiscrete (not simple Discrete) --> check + actions_d = actions_d.reshape(-1, self.action_space.nvec.shape[0]) # TODO: check + + # Handle timeout by bootstrapping with value function + # see GitHub issue #633 + for idx, done in enumerate(dones): + if ( + done + and infos[idx].get("terminal_observation") is not None + and infos[idx].get("TimeLimit.truncated", False) + ): + terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0] + with th.no_grad(): + terminal_value = self.policy.predict_values(terminal_obs)[0] # type: ignore[arg-type] + rewards[idx] += self.gamma * terminal_value + + rollout_buffer.add( + self._last_obs, # type: ignore[arg-type] + actions_d, + clipped_actions_c, + rewards, + self._last_episode_starts, # type: ignore[arg-type] + values, + log_probs_d, + log_probs_c + ) + self._last_obs = new_obs # type: ignore[assignment] + self._last_episode_starts = dones + + with th.no_grad(): + # Compute value for the last timestep + values = self.policy.predict_values(obs_as_tensor(new_obs, self.device)) # type: ignore[arg-type] + + rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones) + + callback.update_locals(locals()) + + callback.on_rollout_end() + + return True From 219ab37f8bda476a4d0b4f8745e536abed8ed963 Mon Sep 17 00:00:00 2001 From: AlexPasqua Date: Sun, 21 Sep 2025 17:10:52 +0200 Subject: [PATCH 07/19] Added learn method Calls super-method as done in PPO Co-authored-by: simrey --- sb3_contrib/ppo_hybrid/ppo_hybrid.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/sb3_contrib/ppo_hybrid/ppo_hybrid.py b/sb3_contrib/ppo_hybrid/ppo_hybrid.py index 50550e5d..fbd48c92 100644 --- a/sb3_contrib/ppo_hybrid/ppo_hybrid.py +++ b/sb3_contrib/ppo_hybrid/ppo_hybrid.py @@ -1,4 +1,4 @@ -from typing import ClassVar +from typing import ClassVar, TypeVar import torch as th import numpy as np from stable_baselines3.ppo import PPO @@ -6,10 +6,13 @@ from stable_baselines3.common.vec_env import VecEnv from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.buffers import RolloutBuffer +from stable_baselines3.common.type_aliases import MaybeCallback from stable_baselines3.common.utils import obs_as_tensor +SelfHybridPPO = TypeVar("SelfHybridPPO", bound="HybridPPO") -class HybrudPPO(PPO): + +class HybridPPO(PPO): policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = { "MlpPolicy": ActorCriticPolicy, "CnnPolicy": ActorCriticCnnPolicy, @@ -184,3 +187,21 @@ def collect_rollouts( callback.on_rollout_end() return True + + def learn( + self: SelfHybridPPO, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 1, + tb_log_name: str = "Hybrid PPO", + reset_num_timesteps: bool = True, + progress_bar: bool = False, + ) -> SelfHybridPPO: + return super().learn( + total_timesteps=total_timesteps, + callback=callback, + log_interval=log_interval, + tb_log_name=tb_log_name, + reset_num_timesteps=reset_num_timesteps, + progress_bar=progress_bar, + ) From dee6ef0bdc15fe095aaf38edce5017ae71b03b99 Mon Sep 17 00:00:00 2001 From: AlexPasqua Date: Sun, 21 Sep 2025 17:19:42 +0200 Subject: [PATCH 08/19] Fixed HybridPPO __init__ arguments Co-authored-by: simrey --- sb3_contrib/ppo_hybrid/ppo_hybrid.py | 56 ++++++++++++++-------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/sb3_contrib/ppo_hybrid/ppo_hybrid.py b/sb3_contrib/ppo_hybrid/ppo_hybrid.py index fbd48c92..c0adf5a3 100644 --- a/sb3_contrib/ppo_hybrid/ppo_hybrid.py +++ b/sb3_contrib/ppo_hybrid/ppo_hybrid.py @@ -1,4 +1,4 @@ -from typing import ClassVar, TypeVar +from typing import Any, ClassVar, Optional, TypeVar, Union import torch as th import numpy as np from stable_baselines3.ppo import PPO @@ -6,7 +6,7 @@ from stable_baselines3.common.vec_env import VecEnv from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.buffers import RolloutBuffer -from stable_baselines3.common.type_aliases import MaybeCallback +from stable_baselines3.common.type_aliases import MaybeCallback, GymEnv, Schedule from stable_baselines3.common.utils import obs_as_tensor SelfHybridPPO = TypeVar("SelfHybridPPO", bound="HybridPPO") @@ -21,32 +21,32 @@ class HybridPPO(PPO): def __init__( self, - policy, - env, - learning_rate=0.0003, - n_steps=2048, - batch_size=64, - n_epochs=10, - gamma=0.99, - gae_lambda=0.95, - clip_range=0.2, - clip_range_vf=None, - normalize_advantage=True, - ent_coef=0, - vf_coef=0.5, - max_grad_norm=0.5, - use_sde=False, - sde_sample_freq=-1, - rollout_buffer_class=None, - rollout_buffer_kwargs=None, - target_kl=None, - stats_window_size=100, - tensorboard_log=None, - policy_kwargs=None, - verbose=0, - seed=None, - device="auto", - _init_setup_model=True, + policy: Union[str, type[ActorCriticPolicy]], + env: Union[GymEnv, str], # TODO: check if custom env needed to accept multiple actions + learning_rate: Union[float, Schedule] = 3e-4, + n_steps: int = 2048, + batch_size: int = 64, + n_epochs: int = 10, + gamma: float = 0.99, + gae_lambda: float = 0.95, + clip_range: Union[float, Schedule] = 0.2, + clip_range_vf: Union[None, float, Schedule] = None, + normalize_advantage: bool = True, + ent_coef: float = 0.0, + vf_coef: float = 0.5, + max_grad_norm: float = 0.5, + use_sde: bool = False, + sde_sample_freq: int = -1, + rollout_buffer_class: Optional[type[RolloutBuffer]] = None, # TODO: check if custom class needed to accept multiple actions + rollout_buffer_kwargs: Optional[dict[str, Any]] = None, + target_kl: Optional[float] = None, + stats_window_size: int = 100, + tensorboard_log: Optional[str] = None, + policy_kwargs: Optional[dict[str, Any]] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: Union[th.device, str] = "auto", + _init_setup_model: bool = True, ): super().__init__( policy, From a3686132cc5653c6cd08f9e578e561470c801362 Mon Sep 17 00:00:00 2001 From: AlexPasqua Date: Sun, 21 Sep 2025 17:52:41 +0200 Subject: [PATCH 09/19] Started train method (tbc) - Added evaluate_actions in HybridActorCriticPolicy - Completed log_prob and entropy of HybridDistribution Co-authored-by: simrey --- sb3_contrib/common/hybrid/distributions.py | 28 +++++++---- sb3_contrib/common/hybrid/policies.py | 26 ++++++++++ sb3_contrib/ppo_hybrid/ppo_hybrid.py | 56 +++++++++++++++++++--- 3 files changed, 94 insertions(+), 16 deletions(-) diff --git a/sb3_contrib/common/hybrid/distributions.py b/sb3_contrib/common/hybrid/distributions.py index d4db39f0..e5c3d779 100644 --- a/sb3_contrib/common/hybrid/distributions.py +++ b/sb3_contrib/common/hybrid/distributions.py @@ -57,7 +57,7 @@ def sample(self) -> tuple[th.Tensor, th.Tensor]: gaussian_samples = self.gaussian_dist.sample() return th.stack(categorical_samples, dim=-1), gaussian_samples - def log_prob(self): + def log_prob(self) -> tuple[th.Tensor, th.Tensor]: """ Returns the log probability of the given actions, both discrete and continuous. """ @@ -67,9 +67,18 @@ def log_prob(self): # TODO: check dimensions return th.sum(th.stack(categorical_log_probs, dim=-1), dim=-1), th.sum(gaussian_log_prob, dim=-1) - # TODO: implement - def entropy(self): - raise NotImplementedError() + def entropy(self) -> tuple[th.Tensor, th.Tensor]: + """ + Returns the entropy of the hybrid distribution, which is the sum of the entropies + of the categorical and gaussian components. + + :return: Tuple of (categorical entropy, gaussian entropy) + """ + categorical_entropies = [dist.entropy() for dist in self.categorical_dists] + # Sum entropies for all categorical distributions + categorical_entropy = th.sum(th.stack(categorical_entropies, dim=-1), dim=-1) + gaussian_entropy = self.gaussian_dist.entropy().sum(dim=-1) + return categorical_entropy, gaussian_entropy class HybridDistribution(Distribution): @@ -102,23 +111,24 @@ def proba_distribution(self: SelfHybridDistribution, action_logits: tuple[list[t self.distribution = Hybrid(logits=action_logits) return self - # TODO: check return type hint - def log_prob(self, discrete_actions: th.Tensor, continuous_actions: th.Tensor) -> th.Tensor: + def log_prob(self, discrete_actions: th.Tensor, continuous_actions: th.Tensor) -> tuple[th.Tensor, th.Tensor]: """ Returns the log likelihood :param x: the taken action - :return: The log likelihood of the distribution + :return: The log likelihood of the distribution for discrete and continuous distributions """ assert self.distribution is not None, "Must set distribution parameters" return self.distribution.log_prob(continuous_actions, discrete_actions) - def entropy(self) -> Optional[th.Tensor]: + def entropy(self) -> tuple[th.Tensor, th.Tensor]: """ Returns Shannon's entropy of the probability - :return: the entropy, or None if no analytical form is known + :return: the entropy of discrete and continuous distributions """ + assert self.distribution is not None, "Must set distribution parameters" + return self.distribution.entropy() def sample(self) -> tuple[th.Tensor, th.Tensor]: """ diff --git a/sb3_contrib/common/hybrid/policies.py b/sb3_contrib/common/hybrid/policies.py index 4ab27564..590e9a3d 100644 --- a/sb3_contrib/common/hybrid/policies.py +++ b/sb3_contrib/common/hybrid/policies.py @@ -201,6 +201,32 @@ def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> HybridDistributi """ action_logits: tuple[list[th.Tensor], th.Tensor] = self.action_net(latent_pi) return self.action_dist.proba_distribution(action_logits=action_logits) + + def evaluate_actions(self, obs: PyTorchObs, actions: th.Tensor) -> tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]: + """ + Evaluate actions according to the current policy, + given the observations. + + :param obs: Observation + :param actions: Actions + :return: estimated value, log likelihood of taking those actions + and entropy of the action distribution. + """ + # Preprocess the observation if needed + features = self.extract_features(obs) + if self.share_features_extractor: + latent_pi, latent_vf = self.mlp_extractor(features) + else: + pi_features, vf_features = features + latent_pi = self.mlp_extractor.forward_actor(pi_features) + latent_vf = self.mlp_extractor.forward_critic(vf_features) + distribution = self._get_action_dist_from_latent(latent_pi) + # log prob of discrete and continuous actions + log_prob: tuple[th.Tensor, th.Tensor] = distribution.log_prob(actions) + # entropy of discrete and continuous actions + entropy: tuple[th.Tensor, th.Tensor] = distribution.entropy() + values = self.value_net(latent_vf) + return values, log_prob, entropy # TODO: check superclass diff --git a/sb3_contrib/ppo_hybrid/ppo_hybrid.py b/sb3_contrib/ppo_hybrid/ppo_hybrid.py index c0adf5a3..42d82e22 100644 --- a/sb3_contrib/ppo_hybrid/ppo_hybrid.py +++ b/sb3_contrib/ppo_hybrid/ppo_hybrid.py @@ -1,8 +1,10 @@ from typing import Any, ClassVar, Optional, TypeVar, Union import torch as th import numpy as np +from gymnasium import spaces +from common.hybrid.policies import HybridActorCriticPolicy, HybridActorCriticCnnPolicy, HybridMultiInputActorCriticPolicy from stable_baselines3.ppo import PPO -from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy +from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.vec_env import VecEnv from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.buffers import RolloutBuffer @@ -14,14 +16,14 @@ class HybridPPO(PPO): policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = { - "MlpPolicy": ActorCriticPolicy, - "CnnPolicy": ActorCriticCnnPolicy, - "MultiInputPolicy": MultiInputActorCriticPolicy, + "MlpPolicy": HybridActorCriticPolicy, + "CnnPolicy": HybridActorCriticCnnPolicy, + "MultiInputPolicy": HybridMultiInputActorCriticPolicy, } def __init__( self, - policy: Union[str, type[ActorCriticPolicy]], + policy: Union[str, type[HybridActorCriticPolicy]], env: Union[GymEnv, str], # TODO: check if custom env needed to accept multiple actions learning_rate: Union[float, Schedule] = 3e-4, n_steps: int = 2048, @@ -147,8 +149,11 @@ def collect_rollouts( n_steps += 1 # Reshape in case of discrete action - # actions_d = actions_d.reshape(-1, 1) # TODO: MultiDiscrete (not simple Discrete) --> check - actions_d = actions_d.reshape(-1, self.action_space.nvec.shape[0]) # TODO: check + if isinstance(self.action_space[0], spaces.Discrete): + # Reshape in case of discrete action + actions_d = actions_d.reshape(-1, 1) + elif isinstance(self.action_space[0], spaces.MultiDiscrete): + actions_d = actions_d.reshape(-1, self.action_space[0].nvec.shape[0]) # Handle timeout by bootstrapping with value function # see GitHub issue #633 @@ -188,6 +193,43 @@ def collect_rollouts( return True + def train(self) -> None: + """ + Update policy using the currently gathered rollout buffer. + """ + # Switch to train mode (this affects batch norm / dropout) + self.policy.set_training_mode(True) + # Update optimizer learning rate + self._update_learning_rate(self.policy.optimizer) + # Compute current clip range + clip_range = self.clip_range(self._current_progress_remaining) # type: ignore[operator] + # Optional: clip range for the value function + if self.clip_range_vf is not None: + clip_range_vf = self.clip_range_vf(self._current_progress_remaining) # type: ignore[operator] + + entropy_losses = [] + pg_losses, value_losses = [], [] + clip_fractions = [] + + continue_training = True + # train for n_epochs epochs + for epoch in range(self.n_epochs): + approx_kl_divs = [] + # Do a complete pass on the rollout buffer + for rollout_data in self.rollout_buffer.get(self.batch_size): + actions_d = rollout_data.actions_d + actions_c = rollout_data.actions_c + if isinstance(self.action_space[0], spaces.Discrete): + # Reshape in case of discrete action + actions_d = actions_d.reshape(-1, 1) + elif isinstance(self.action_space[0], spaces.MultiDiscrete): + actions_d = actions_d.reshape(-1, self.action_space[0].nvec.shape[0]) + + values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions_d, actions_c) + log_prob_d, log_prob_c = log_prob + entropy_d, entropy_c = entropy + + def learn( self: SelfHybridPPO, total_timesteps: int, From 37e1dcd85ce5686de7e5dab4389cecea5ac5ab6c Mon Sep 17 00:00:00 2001 From: AlexPasqua Date: Sat, 4 Oct 2025 16:55:51 +0200 Subject: [PATCH 10/19] Created HybridActionsRolloutBuffer RolloutBuffer for hybrid actions. HybridPPO.train() not adapted yet Co-authored-by: simrey --- sb3_contrib/ppo_hybrid/__init__.py | 1 + sb3_contrib/ppo_hybrid/buffers.py | 187 +++++++++++++++++++++++++++ sb3_contrib/ppo_hybrid/ppo_hybrid.py | 10 +- 3 files changed, 194 insertions(+), 4 deletions(-) create mode 100644 sb3_contrib/ppo_hybrid/buffers.py diff --git a/sb3_contrib/ppo_hybrid/__init__.py b/sb3_contrib/ppo_hybrid/__init__.py index 448c6768..515da75a 100644 --- a/sb3_contrib/ppo_hybrid/__init__.py +++ b/sb3_contrib/ppo_hybrid/__init__.py @@ -1,4 +1,5 @@ from sb3_contrib.ppo_hybrid.policies import CnnPolicy, MlpPolicy, MultiInputPolicy from sb3_contrib.ppo_hybrid.ppo_hybrid import HybridPPO +from sb3_contrib.ppo_hybrid.buffers import HybridActionsRolloutBuffer __all__ = ["CnnPolicy", "HybridPPO", "MlpPolicy", "MultiInputPolicy"] diff --git a/sb3_contrib/ppo_hybrid/buffers.py b/sb3_contrib/ppo_hybrid/buffers.py new file mode 100644 index 00000000..d761f40d --- /dev/null +++ b/sb3_contrib/ppo_hybrid/buffers.py @@ -0,0 +1,187 @@ +from collections.abc import Generator +import numpy as np +from stable_baselines3.common.buffers import RolloutBuffer +from gymnasium import spaces +from typing import Optional, Union, NamedTuple +import torch as th +from stable_baselines3.common.preprocessing import get_obs_shape +from stable_baselines3.common.utils import get_device +from stable_baselines3.common.vec_env import VecNormalize + + +class HybridActionsRolloutBufferSamples(NamedTuple): + observations: th.Tensor + actions_d: th.Tensor + actions_c: th.Tensor + old_values: th.Tensor + old_log_prob_d: th.Tensor + old_log_prob_c: th.Tensor + advantages: th.Tensor + returns: th.Tensor + + +def get_action_dim(action_space: spaces.Tuple[spaces.MultiDiscrete, spaces.Box]) -> tuple[int, int]: + """ + Get the dimension of the action space, + assumed to be the one of HybridPPO (spaces.Tuple[spaces.MultiDiscrete, spaces.Box]). + + :param action_space: action_space + :return: (dim_d, dim_c) where dim_d is the discrete action dimension and dim_c the continuous action dimension. + """ + return ( + len(action_space.nvec), # discrete action dimension + int(np.prod(action_space.shape)) # continuous action dimension + ) + + +class HybridActionsRolloutBuffer(RolloutBuffer): + """ + Rollout buffer for hybrid action spaces (discrete + continuous). + Stores separate actions and log probabilities for discrete and continuous parts. + """ + + actions: dict[str, np.ndarray] + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Tuple[spaces.MultiDiscrete, spaces.Box], + device: Union[th.device, str] = "auto", + gae_lambda: float = 1, + gamma: float = 0.99, + n_envs: int = 1, + ): + # NOTE: it would be nice to use RolloutBuffer.__init__(), but BaseBuffer calls + # get_action_dim which is not compatible with Tuple action spaces. + + # BaseBuffer's constructor code (excluding call to ABC.__init__()) + self.buffer_size = buffer_size + self.observation_space = observation_space + self.action_space = action_space + self.obs_shape = get_obs_shape(observation_space) # type: ignore[assignment] + + self.action_dim = get_action_dim(action_space) + self.pos = 0 + self.full = False + self.device = get_device(device) + self.n_envs = n_envs + + # RolloutBuffer's constructor code + self.gae_lambda = gae_lambda + self.gamma = gamma + self.generator_ready = False + self.reset() + + def reset(self) -> None: + super().reset() + # override actions and log_probs to handle hybrid actions + self.actions = { + 'd': np.zeros((self.buffer_size, self.n_envs, self.action_dim[0]), dtype=np.float32), + 'c': np.zeros((self.buffer_size, self.n_envs, self.action_dim[1]), dtype=np.float32) + } + self.log_probs = { + 'd': np.zeros((self.buffer_size, self.n_envs), dtype=np.float32), + 'c': np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + } + + def add( + self, + obs: np.ndarray, + action_d: np.ndarray, + action_c: np.ndarray, + reward: np.ndarray, + episode_start: np.ndarray, + value: th.Tensor, + log_prob_d: th.Tensor, + log_prob_c: th.Tensor, + ) -> None: + """ + :param obs: Observation + :param action_d: Discrete action + :param action_c: Continuous action + :param reward: + :param episode_start: Start of episode signal. + :param value: estimated value of the current state + following the current policy. + :param log_prob_d: log probabilities of the discrete action following the current policy. + :param log_prob_c: log probabilities of the continuous action following the current policy. + """ + # Reshape 0-d tensor to avoid error + if len(log_prob_d.shape) == 0: + log_prob_d = log_prob_d.reshape(-1, 1) + if len(log_prob_c.shape) == 0: + log_prob_c = log_prob_c.reshape(-1, 1) + + # copied from RolloutBuffer: + # Reshape needed when using multiple envs with discrete observations + # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) + if isinstance(self.observation_space, spaces.Discrete): + obs = obs.reshape((self.n_envs, *self.obs_shape)) + + # Adapted from RolloutBuffer: + # Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392 + action_d = action_d.reshape((self.n_envs, self.action_dim)) + action_c = action_c.reshape((self.n_envs, self.action_dim)) + + self.observations[self.pos] = np.array(obs) + self.actions['d'][self.pos] = np.array(action_d) + self.actions['c'][self.pos] = np.array(action_c) + self.rewards[self.pos] = np.array(reward) + self.episode_starts[self.pos] = np.array(episode_start) + self.values[self.pos] = value.clone().cpu().numpy().flatten() + self.log_probs['d'][self.pos] = log_prob_d.clone().cpu().numpy() + self.log_probs['c'][self.pos] = log_prob_c.clone().cpu().numpy() + self.pos += 1 + if self.pos == self.buffer_size: + self.full = True + + def get(self, batch_size: Optional[int] = None) -> Generator[HybridActionsRolloutBufferSamples, None, None]: + assert self.full, "" + indices = np.random.permutation(self.buffer_size * self.n_envs) + # Prepare the data + if not self.generator_ready: + _tensor_names = [ + "observations", + "actions_d", + "actions_c", + "values", + "log_probs_d", + "log_probs_c", + "advantages", + "returns", + ] + + for tensor in _tensor_names: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + self.generator_ready = True + + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * self.n_envs + + start_idx = 0 + while start_idx < self.buffer_size * self.n_envs: + yield self._get_samples(indices[start_idx : start_idx + batch_size]) + start_idx += batch_size + + def _get_samples( + self, + batch_inds: np.ndarray, + env: Optional[VecNormalize] = None # TODO: check type hint + ) -> HybridActionsRolloutBufferSamples: + data = ( + self.observations[batch_inds], + self.actions['d'][batch_inds], + self.actions['c'][batch_inds], + self.values[batch_inds].flatten(), + self.log_probs['d'][batch_inds].flatten(), + self.log_probs['c'][batch_inds].flatten(), + self.advantages[batch_inds].flatten(), + self.returns[batch_inds].flatten(), + ) + return HybridActionsRolloutBufferSamples(*tuple(map(self.to_torch, data))) + + +# TODO: implement +# class HybridActionsRolloutBuffer(HybridActionsRolloutBuffer): diff --git a/sb3_contrib/ppo_hybrid/ppo_hybrid.py b/sb3_contrib/ppo_hybrid/ppo_hybrid.py index 42d82e22..b24ca684 100644 --- a/sb3_contrib/ppo_hybrid/ppo_hybrid.py +++ b/sb3_contrib/ppo_hybrid/ppo_hybrid.py @@ -7,9 +7,9 @@ from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.vec_env import VecEnv from stable_baselines3.common.callbacks import BaseCallback -from stable_baselines3.common.buffers import RolloutBuffer from stable_baselines3.common.type_aliases import MaybeCallback, GymEnv, Schedule from stable_baselines3.common.utils import obs_as_tensor +from buffers import HybridActionsRolloutBuffer SelfHybridPPO = TypeVar("SelfHybridPPO", bound="HybridPPO") @@ -21,6 +21,8 @@ class HybridPPO(PPO): "MultiInputPolicy": HybridMultiInputActorCriticPolicy, } + rollout_buffer: HybridActionsRolloutBuffer + def __init__( self, policy: Union[str, type[HybridActorCriticPolicy]], @@ -39,7 +41,7 @@ def __init__( max_grad_norm: float = 0.5, use_sde: bool = False, sde_sample_freq: int = -1, - rollout_buffer_class: Optional[type[RolloutBuffer]] = None, # TODO: check if custom class needed to accept multiple actions + rollout_buffer_class: Optional[type[HybridActionsRolloutBuffer]] = None, # TODO: check if custom class needed to accept multiple actions rollout_buffer_kwargs: Optional[dict[str, Any]] = None, target_kl: Optional[float] = None, stats_window_size: int = 100, @@ -83,11 +85,11 @@ def collect_rollouts( self, env: VecEnv, callback: BaseCallback, - rollout_buffer: RolloutBuffer, + rollout_buffer: HybridActionsRolloutBuffer, n_rollout_steps: int, ) -> bool: """ - Collect experiences using the current policy and fill a ``RolloutBuffer``. + Collect experiences using the current policy and fill a ``HybridActionsRolloutBuffer``. The term rollout here refers to the model-free notion and should not be used with the concept of rollout used in model-based RL or planning. From 6e343567c5478772ed2f3ad64d75ce756f4c1f5a Mon Sep 17 00:00:00 2001 From: AlexPasqua Date: Sat, 4 Oct 2025 17:56:09 +0200 Subject: [PATCH 11/19] Completed train method Co-authored-by: simrey --- sb3_contrib/common/hybrid/distributions.py | 5 +- sb3_contrib/common/hybrid/policies.py | 12 ++- sb3_contrib/ppo_hybrid/ppo_hybrid.py | 116 ++++++++++++++++++++- 3 files changed, 125 insertions(+), 8 deletions(-) diff --git a/sb3_contrib/common/hybrid/distributions.py b/sb3_contrib/common/hybrid/distributions.py index e5c3d779..a14a1501 100644 --- a/sb3_contrib/common/hybrid/distributions.py +++ b/sb3_contrib/common/hybrid/distributions.py @@ -65,7 +65,10 @@ def log_prob(self) -> tuple[th.Tensor, th.Tensor]: categorical_log_probs = [dist.log_prob(dist.sample()) for dist in self.categorical_dists] # TODO: check dimensions - return th.sum(th.stack(categorical_log_probs, dim=-1), dim=-1), th.sum(gaussian_log_prob, dim=-1) + return ( + th.sum(th.stack(categorical_log_probs, dim=-1), dim=-1), + th.sum(gaussian_log_prob, dim=-1) + ) def entropy(self) -> tuple[th.Tensor, th.Tensor]: """ diff --git a/sb3_contrib/common/hybrid/policies.py b/sb3_contrib/common/hybrid/policies.py index 590e9a3d..4c244212 100644 --- a/sb3_contrib/common/hybrid/policies.py +++ b/sb3_contrib/common/hybrid/policies.py @@ -202,13 +202,19 @@ def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> HybridDistributi action_logits: tuple[list[th.Tensor], th.Tensor] = self.action_net(latent_pi) return self.action_dist.proba_distribution(action_logits=action_logits) - def evaluate_actions(self, obs: PyTorchObs, actions: th.Tensor) -> tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]: + def evaluate_actions( + self, + obs: PyTorchObs, + actions_d: th.Tensor, + actions_c: th.Tensor + ) -> tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]: """ Evaluate actions according to the current policy, given the observations. :param obs: Observation - :param actions: Actions + :param actions_d: Discrete actions + :param actions_c: Continuous actions :return: estimated value, log likelihood of taking those actions and entropy of the action distribution. """ @@ -222,7 +228,7 @@ def evaluate_actions(self, obs: PyTorchObs, actions: th.Tensor) -> tuple[th.Tens latent_vf = self.mlp_extractor.forward_critic(vf_features) distribution = self._get_action_dist_from_latent(latent_pi) # log prob of discrete and continuous actions - log_prob: tuple[th.Tensor, th.Tensor] = distribution.log_prob(actions) + log_prob: tuple[th.Tensor, th.Tensor] = distribution.log_prob(actions_d, actions_c) # entropy of discrete and continuous actions entropy: tuple[th.Tensor, th.Tensor] = distribution.entropy() values = self.value_net(latent_vf) diff --git a/sb3_contrib/ppo_hybrid/ppo_hybrid.py b/sb3_contrib/ppo_hybrid/ppo_hybrid.py index b24ca684..3cc74b20 100644 --- a/sb3_contrib/ppo_hybrid/ppo_hybrid.py +++ b/sb3_contrib/ppo_hybrid/ppo_hybrid.py @@ -1,5 +1,6 @@ from typing import Any, ClassVar, Optional, TypeVar, Union import torch as th +from torch.nn import functional as F import numpy as np from gymnasium import spaces from common.hybrid.policies import HybridActorCriticPolicy, HybridActorCriticCnnPolicy, HybridMultiInputActorCriticPolicy @@ -10,6 +11,7 @@ from stable_baselines3.common.type_aliases import MaybeCallback, GymEnv, Schedule from stable_baselines3.common.utils import obs_as_tensor from buffers import HybridActionsRolloutBuffer +from stable_baselines3.common.utils import explained_variance SelfHybridPPO = TypeVar("SelfHybridPPO", bound="HybridPPO") @@ -22,6 +24,7 @@ class HybridPPO(PPO): } rollout_buffer: HybridActionsRolloutBuffer + policy: HybridActorCriticPolicy def __init__( self, @@ -210,8 +213,8 @@ def train(self) -> None: clip_range_vf = self.clip_range_vf(self._current_progress_remaining) # type: ignore[operator] entropy_losses = [] - pg_losses, value_losses = [], [] - clip_fractions = [] + pg_losses_d, pg_losses_c, value_losses = [], [], [] + clip_fractions_d, clip_fractions_c = [], [] continue_training = True # train for n_epochs epochs @@ -227,10 +230,115 @@ def train(self) -> None: elif isinstance(self.action_space[0], spaces.MultiDiscrete): actions_d = actions_d.reshape(-1, self.action_space[0].nvec.shape[0]) - values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions_d, actions_c) - log_prob_d, log_prob_c = log_prob + values, log_probs, entropy = self.policy.evaluate_actions(rollout_data.observations, actions_d, actions_c) + log_prob_d, log_prob_c = log_probs entropy_d, entropy_c = entropy + values = values.flatten() + + # Normalize advantage + advantages = rollout_data.advantages + # Normalization does not make sense if mini batchsize == 1, see GH issue #325 + if self.normalize_advantage and len(advantages) > 1: + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + # ratio between old and new policy, for discrete and continuous actions. + # Should be one at the first iteration + ratio_d = th.exp(log_prob_d - rollout_data.old_log_prob_d) + ratio_c = th.exp(log_prob_c - rollout_data.old_log_prob_c) + + # clipped surrogate loss for discrete actions + policy_loss_d_1 = advantages * ratio_d + policy_loss_d_2 = advantages * th.clamp(ratio_d, 1 - clip_range, 1 + clip_range) + policy_loss_d = -th.min(policy_loss_d_1, policy_loss_d_2).mean() + + # clipped surrogate loss for continuous actions + policy_loss_c_1 = advantages * ratio_c + policy_loss_c_2 = advantages * th.clamp(ratio_c, 1 - clip_range, 1 + clip_range) + policy_loss_c = -th.min(policy_loss_c_1, policy_loss_c_2).mean() + + # Logging + pg_losses_d.append(policy_loss_d.item()) + clip_fraction_d = th.mean((th.abs(ratio_d - 1) > clip_range).float()).item() + clip_fractions_d.append(clip_fraction_d) + pg_losses_c.append(policy_loss_c.item()) + clip_fraction_c = th.mean((th.abs(ratio_c - 1) > clip_range).float()).item() + clip_fractions_c.append(clip_fraction_c) + + # Value loss + if self.clip_range_vf is None: + # No clipping + values_pred = values + else: + # Clip the difference between old and new value + # NOTE: this depends on the reward scaling + values_pred = rollout_data.old_values + th.clamp( + values - rollout_data.old_values, -clip_range_vf, clip_range_vf + ) + # Value loss using the TD(gae_lambda) target + value_loss = F.mse_loss(rollout_data.returns, values_pred) + value_losses.append(value_loss.item()) + + # Entropy loss favor exploration + if entropy is None: + # Approximate entropy when no analytical form + entropy_loss_d = -th.mean(-log_prob_d) + entropy_loss_c = -th.mean(-log_prob_c) + entropy_loss = entropy_loss_d + entropy_loss_c + else: + entropy_loss = -th.mean(entropy_d) - th.mean(entropy_c) + entropy_losses.append(entropy_loss.item()) + + # total loss function + loss = 0.5 * (policy_loss_d + policy_loss_c) + self.ent_coef * entropy_loss + self.vf_coef * value_loss + + # Calculate approximate form of reverse KL Divergence for early stopping + # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417 + # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419 + # and Schulman blog: http://joschu.net/blog/kl-approx.html + # Note: using the max KL divergence between discrete and continuous actions to stay conservative + with th.no_grad(): + log_ratio_d = log_prob_d - rollout_data.old_log_prob_d + log_ratio_c = log_prob_c - rollout_data.old_log_prob_c + approx_kl_div_d = th.mean((th.exp(log_ratio_d) - 1) - log_ratio_d).cpu().numpy() + approx_kl_div_c = th.mean((th.exp(log_ratio_c) - 1) - log_ratio_c).cpu().numpy() + approx_kl_div = max(approx_kl_div_d, approx_kl_div_c) + approx_kl_divs.append(approx_kl_div) + + if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: + continue_training = False + if self.verbose >= 1: + print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}") + break + + # Optimization step + self.policy.optimizer.zero_grad() + loss.backward() + # Clip grad norm + th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) + self.policy.optimizer.step() + + self._n_updates += 1 + if not continue_training: + break + + explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) + + # Logs + self.logger.record("train/entropy_loss", np.mean(entropy_losses)) + self.logger.record("train/policy_gradient_discrete_loss", np.mean(pg_losses_d)) + self.logger.record("train/policy_gradient_continuous_loss", np.mean(pg_losses_c)) + self.logger.record("train/value_loss", np.mean(value_losses)) + self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) + self.logger.record("train/clip_fraction_discrete", np.mean(clip_fractions_d)) + self.logger.record("train/clip_fraction_continuous", np.mean(clip_fractions_c)) + self.logger.record("train/loss", loss.item()) + self.logger.record("train/explained_variance", explained_var) + if hasattr(self.policy, "log_std"): + self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) + self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + self.logger.record("train/clip_range", clip_range) + if self.clip_range_vf is not None: + self.logger.record("train/clip_range_vf", clip_range_vf) def learn( self: SelfHybridPPO, From 6994f2d57adf8e40b24bf7beca70a945f76e0ec5 Mon Sep 17 00:00:00 2001 From: AlexPasqua Date: Sun, 5 Oct 2025 11:23:12 +0200 Subject: [PATCH 12/19] Update sb3_contrib/__init__.py --- sb3_contrib/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sb3_contrib/__init__.py b/sb3_contrib/__init__.py index 2aa7a19b..569ba731 100644 --- a/sb3_contrib/__init__.py +++ b/sb3_contrib/__init__.py @@ -4,6 +4,7 @@ from sb3_contrib.crossq import CrossQ from sb3_contrib.ppo_mask import MaskablePPO from sb3_contrib.ppo_recurrent import RecurrentPPO +from sb3_contrib.ppo_hybrid import HybridPPO from sb3_contrib.qrdqn import QRDQN from sb3_contrib.tqc import TQC from sb3_contrib.trpo import TRPO @@ -21,4 +22,5 @@ "CrossQ", "MaskablePPO", "RecurrentPPO", + "HybridPPO", ] From 0cd46f41922dad9dc9b72b6f8e325883d40306dd Mon Sep 17 00:00:00 2001 From: AlexPasqua Date: Sun, 26 Oct 2025 16:58:38 +0100 Subject: [PATCH 13/19] Created env 'catching point' (to be tested) Co-authored-by: simrey --- sb3_contrib/common/envs/__init__.py | 5 +- sb3_contrib/common/envs/hybrid_actions_env.py | 104 ++++++++++++++++++ 2 files changed, 108 insertions(+), 1 deletion(-) create mode 100644 sb3_contrib/common/envs/hybrid_actions_env.py diff --git a/sb3_contrib/common/envs/__init__.py b/sb3_contrib/common/envs/__init__.py index e9f740b9..68bb6be8 100644 --- a/sb3_contrib/common/envs/__init__.py +++ b/sb3_contrib/common/envs/__init__.py @@ -3,5 +3,8 @@ InvalidActionEnvMultiBinary, InvalidActionEnvMultiDiscrete, ) +from sb3_contrib.common.envs.hybrid_actions_env import ( + CatchingPointEnv +) -__all__ = ["InvalidActionEnvDiscrete", "InvalidActionEnvMultiBinary", "InvalidActionEnvMultiDiscrete"] +__all__ = ["InvalidActionEnvDiscrete", "InvalidActionEnvMultiBinary", "InvalidActionEnvMultiDiscrete", "CatchingPointEnv"] diff --git a/sb3_contrib/common/envs/hybrid_actions_env.py b/sb3_contrib/common/envs/hybrid_actions_env.py new file mode 100644 index 00000000..18d223f0 --- /dev/null +++ b/sb3_contrib/common/envs/hybrid_actions_env.py @@ -0,0 +1,104 @@ +import gymnasium as gym +from gymnasium import spaces +import numpy as np + + +class CatchingPointEnv(gym.Env): + """ + Enviornment for Hybrid PPO for the 'Catching Point' task of the paper + 'Hybrid Actor-Critic Reinforcement Learning in Parameterized Action Space', Fan et al. + (https://arxiv.org/pdf/1903.01344) + """ + + def __init__( + self, + arena_size: float = 1.0, + move_dist=0.05, + catch_radius=0.05, + max_catches=10, + max_steps=200 + ): + super().__init__() + self.max_steps = max_steps + self.max_catches = max_catches + self.arena_size = arena_size + self.move_dist = move_dist + self.catch_radius = catch_radius + + # action space + self.action_space = spaces.Tuple( + spaces=( + spaces.MultiDiscrete([2]), # MOVE=0, CATCH=1 + spaces.Box(low=-1.0, high=1.0, shape=(2,), dtype=np.float32) # direction + ) + ) + + # observation: [agent_x, agent_y, target_x, target_y, catches_left, step_norm] + obs_low = np.array([-arena_size, -arena_size, -arena_size, -arena_size, 0.0, 0.0], dtype=np.float32) + obs_high= np.array([ arena_size, arena_size, arena_size, arena_size, float(max_catches), 1.0], dtype=np.float32) + self.observation_space = spaces.Box(obs_low, obs_high, dtype=np.float32) + + def reset(self) -> np.ndarray: + """ + Reset the environment to an initial state and return the initial observation. + """ + self.agent_pos = self.np_random.uniform(-self.arena_size, self.arena_size, size=2).astype(np.float32) + self.target_pos = self.np_random.uniform(-self.arena_size, self.arena_size, size=2).astype(np.float32) + self.catches_used = 0 + self.step_count = 0 + return self._get_obs() + + def step(self, action: tuple[np.ndarray, np.ndarray]) -> tuple[np.ndarray, float, bool, dict]: + """ + Take a step in the environment using the provided action. + + :param action: A tuple containing the discrete action and continuous parameters. + :return: observation, reward, done, info + """ + action_d = int(action[0][0]) + dir_vec = action[1] + reward = 0.0 + done = False + + # step penalty + reward = -0.01 + + # MOVE + if action_d == 0: + norm = np.linalg.norm(dir_vec) + dir_u = dir_vec / norm + self.agent_pos = (self.agent_pos + dir_u * self.move_dist).astype(np.float32) + # clamp to arena + self.agent_pos = np.clip(self.agent_pos, -self.arena_size, self.arena_size) + + # CATCH + else: + self.catches_used += 1 + dist = np.linalg.norm(self.agent_pos - self.target_pos) + if dist <= self.catch_radius: + reward = 1.0 # caught the target + done = True + else: + if self.catches_used >= self.max_catches: + done = True + + self.step_count += 1 + if self.step_count >= self.max_steps: + done = True + + obs = self._get_obs() + info = {"caught": (reward > 0)} + return obs, float(reward), bool(done), info + + def _get_obs(self) -> np.ndarray: + """ + Get the current observation. + """ + step_norm = self.step_count / self.max_steps + catches_left = self.max_catches - self.catches_used + obs = np.concatenate(( + self.agent_pos, + self.target_pos, + np.array([catches_left, step_norm], dtype=np.float32) + )) + return obs From fafae0891c8a274de0028e9df7c9c33a24e5c1ba Mon Sep 17 00:00:00 2001 From: AlexPasqua Date: Sun, 26 Oct 2025 17:38:06 +0100 Subject: [PATCH 14/19] spaces.Tuple is not a generic class Plus fix some imports Co-authored-by: simrey --- sb3_contrib/common/envs/hybrid_actions_env.py | 8 ++++---- sb3_contrib/common/hybrid/distributions.py | 6 +++++- sb3_contrib/common/hybrid/policies.py | 2 +- sb3_contrib/ppo_hybrid/buffers.py | 12 ++++++++---- sb3_contrib/ppo_hybrid/ppo_hybrid.py | 4 ++-- 5 files changed, 20 insertions(+), 12 deletions(-) diff --git a/sb3_contrib/common/envs/hybrid_actions_env.py b/sb3_contrib/common/envs/hybrid_actions_env.py index 18d223f0..6cb061dd 100644 --- a/sb3_contrib/common/envs/hybrid_actions_env.py +++ b/sb3_contrib/common/envs/hybrid_actions_env.py @@ -13,10 +13,10 @@ class CatchingPointEnv(gym.Env): def __init__( self, arena_size: float = 1.0, - move_dist=0.05, - catch_radius=0.05, - max_catches=10, - max_steps=200 + move_dist: float = 0.05, + catch_radius: float = 0.05, + max_catches: float = 10, + max_steps: float = 200 ): super().__init__() self.max_steps = max_steps diff --git a/sb3_contrib/common/hybrid/distributions.py b/sb3_contrib/common/hybrid/distributions.py index a14a1501..ea17a8da 100644 --- a/sb3_contrib/common/hybrid/distributions.py +++ b/sb3_contrib/common/hybrid/distributions.py @@ -178,13 +178,17 @@ def log_prob_from_params(self, *args, **kwargs) -> tuple[th.Tensor, th.Tensor]: """ -def make_hybrid_proba_distribution(action_space: spaces.Tuple[spaces.MultiDiscrete, spaces.Box]) -> HybridDistribution: +def make_hybrid_proba_distribution(action_space: spaces.Tuple) -> HybridDistribution: """ Create a hybrid probability distribution for the given action space. :param action_space: Tuple Action space containing a MultiDiscrete action space and a Box action space. :return: A HybridDistribution object that handles the hybrid action space. """ + assert isinstance(action_space, spaces.Tuple), "Action space must be a Tuple space" + assert len(action_space.spaces) == 2, "Action space must contain exactly 2 subspaces" + assert isinstance(action_space.spaces[0], spaces.MultiDiscrete), "First subspace must be MultiDiscrete" + assert isinstance(action_space.spaces[1], spaces.Box), "Second subspace must be Box" assert len(action_space[1].shape) == 1, "Continuous action space must have a monodimensional shape (e.g., (n,))" return HybridDistribution( categorical_dimensions=len(action_space[0].nvec), diff --git a/sb3_contrib/common/hybrid/policies.py b/sb3_contrib/common/hybrid/policies.py index 4c244212..cdd21eeb 100644 --- a/sb3_contrib/common/hybrid/policies.py +++ b/sb3_contrib/common/hybrid/policies.py @@ -45,7 +45,7 @@ class HybridActorCriticPolicy(BasePolicy): def __init__( self, observation_space: spaces.Space, - action_space: spaces.Tuple[spaces.MultiDiscrete, spaces.Box], + action_space: spaces.Tuple, # Type[spaces.MultiDiscrete, spaces.Box] lr_schedule: Schedule, net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, activation_fn: type[nn.Module] = nn.Tanh, diff --git a/sb3_contrib/ppo_hybrid/buffers.py b/sb3_contrib/ppo_hybrid/buffers.py index d761f40d..469b3f2c 100644 --- a/sb3_contrib/ppo_hybrid/buffers.py +++ b/sb3_contrib/ppo_hybrid/buffers.py @@ -20,14 +20,18 @@ class HybridActionsRolloutBufferSamples(NamedTuple): returns: th.Tensor -def get_action_dim(action_space: spaces.Tuple[spaces.MultiDiscrete, spaces.Box]) -> tuple[int, int]: +def get_action_dim(action_space: spaces.Tuple) -> tuple[int, int]: """ Get the dimension of the action space, - assumed to be the one of HybridPPO (spaces.Tuple[spaces.MultiDiscrete, spaces.Box]). + assumed to be the one of HybridPPO (Tuple[MultiDiscrete, Box]). - :param action_space: action_space + :param action_space: Tuple action space containing MultiDiscrete and Box spaces :return: (dim_d, dim_c) where dim_d is the discrete action dimension and dim_c the continuous action dimension. """ + assert isinstance(action_space, spaces.Tuple), "Action space must be a Tuple space" + assert len(action_space.spaces) == 2, "Action space must contain exactly 2 subspaces" + assert isinstance(action_space.spaces[0], spaces.MultiDiscrete), "First subspace must be MultiDiscrete" + assert isinstance(action_space.spaces[1], spaces.Box), "Second subspace must be Box" return ( len(action_space.nvec), # discrete action dimension int(np.prod(action_space.shape)) # continuous action dimension @@ -46,7 +50,7 @@ def __init__( self, buffer_size: int, observation_space: spaces.Space, - action_space: spaces.Tuple[spaces.MultiDiscrete, spaces.Box], + action_space: spaces.Tuple, # Type[spaces.MultiDiscrete, spaces.Box] device: Union[th.device, str] = "auto", gae_lambda: float = 1, gamma: float = 0.99, diff --git a/sb3_contrib/ppo_hybrid/ppo_hybrid.py b/sb3_contrib/ppo_hybrid/ppo_hybrid.py index 3cc74b20..4249405a 100644 --- a/sb3_contrib/ppo_hybrid/ppo_hybrid.py +++ b/sb3_contrib/ppo_hybrid/ppo_hybrid.py @@ -3,14 +3,14 @@ from torch.nn import functional as F import numpy as np from gymnasium import spaces -from common.hybrid.policies import HybridActorCriticPolicy, HybridActorCriticCnnPolicy, HybridMultiInputActorCriticPolicy +from sb3_contrib.common.hybrid.policies import HybridActorCriticPolicy, HybridActorCriticCnnPolicy, HybridMultiInputActorCriticPolicy from stable_baselines3.ppo import PPO from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.vec_env import VecEnv from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.type_aliases import MaybeCallback, GymEnv, Schedule from stable_baselines3.common.utils import obs_as_tensor -from buffers import HybridActionsRolloutBuffer +from sb3_contrib.ppo_hybrid.buffers import HybridActionsRolloutBuffer from stable_baselines3.common.utils import explained_variance SelfHybridPPO = TypeVar("SelfHybridPPO", bound="HybridPPO") From 365bdb90d63ab436da696beb5a5f0cb786453c4d Mon Sep 17 00:00:00 2001 From: AlexPasqua Date: Sun, 26 Oct 2025 18:13:46 +0100 Subject: [PATCH 15/19] Created HybridToBoxWrapper to handle hybrid actions but still integrate with the library PPO and the base algorithm do some validation on the action space. If we want our HybridPPO to be subclass of PPO, we need this wrapper for the integration with the library. Co-authored-by: simrey --- sb3_contrib/common/wrappers/__init__.py | 3 +- .../common/wrappers/hybrid_to_box_wrapper.py | 68 +++++++++++++++++++ 2 files changed, 70 insertions(+), 1 deletion(-) create mode 100644 sb3_contrib/common/wrappers/hybrid_to_box_wrapper.py diff --git a/sb3_contrib/common/wrappers/__init__.py b/sb3_contrib/common/wrappers/__init__.py index c7dc0b04..737da832 100644 --- a/sb3_contrib/common/wrappers/__init__.py +++ b/sb3_contrib/common/wrappers/__init__.py @@ -1,4 +1,5 @@ from sb3_contrib.common.wrappers.action_masker import ActionMasker from sb3_contrib.common.wrappers.time_feature import TimeFeatureWrapper +from sb3_contrib.common.wrappers.hybrid_to_box_wrapper import HybridToBoxWrapper -__all__ = ["ActionMasker", "TimeFeatureWrapper"] +__all__ = ["ActionMasker", "TimeFeatureWrapper", "HybridToBoxWrapper"] diff --git a/sb3_contrib/common/wrappers/hybrid_to_box_wrapper.py b/sb3_contrib/common/wrappers/hybrid_to_box_wrapper.py new file mode 100644 index 00000000..cca5fba3 --- /dev/null +++ b/sb3_contrib/common/wrappers/hybrid_to_box_wrapper.py @@ -0,0 +1,68 @@ +from typing import Any, Dict, Tuple +import numpy as np +from gymnasium import spaces, Wrapper + + +class HybridToBoxWrapper(Wrapper): + """ + Wrapper that converts a hybrid action space (Tuple of MultiDiscrete and Box) + into a single Box action space, enabling compatibility with standard algorithms. + The wrapper handles the conversion between the flattened Box action space and + the hybrid action space internally. + """ + + def __init__(self, env): + """ + Initialize the wrapper. + + :param env: The environment to wrap + """ + super().__init__(env) + assert isinstance(env.action_space, spaces.Tuple), "Environment must have a Tuple action space" + assert len(env.action_space.spaces) == 2, "Action space must contain exactly 2 subspaces" + assert isinstance(env.action_space.spaces[0], spaces.MultiDiscrete), "First subspace must be MultiDiscrete" + assert isinstance(env.action_space.spaces[1], spaces.Box), "Second subspace must be Box" + + # Store original action space + self.hybrid_action_space = env.action_space + + # Calculate total dimensions needed for Box space + self.discrete_dims = sum(env.action_space.spaces[0].nvec) # One-hot encoding for each discrete action + self.continuous_dims = env.action_space.spaces[1].shape[0] + + # Create new Box action space + # First part: one-hot encoding for discrete actions + # Second part: continuous actions + total_dims = self.discrete_dims + self.continuous_dims + self.action_space = spaces.Box( + low=np.concatenate([np.zeros(self.discrete_dims), env.action_space.spaces[1].low]), + high=np.concatenate([np.ones(self.discrete_dims), env.action_space.spaces[1].high]), + dtype=np.float32 + ) + + def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, Dict[str, Any]]: + """ + Convert the Box action back to hybrid format and step the environment. + + :param action: Action from the Box action space + :return: Next observation, reward, done flag, and info dictionary + """ + # Split action into discrete and continuous parts + discrete_part = action[:self.discrete_dims] + continuous_part = action[self.discrete_dims:] + + # Convert one-hot encodings back to MultiDiscrete format + discrete_actions = [] + start_idx = 0 + for n in self.hybrid_action_space.spaces[0].nvec: + one_hot = discrete_part[start_idx:start_idx + n] + discrete_actions.append(np.argmax(one_hot)) + start_idx += n + + # Create hybrid action tuple + hybrid_action = ( + np.array(discrete_actions, dtype=np.int64), + continuous_part + ) + + return self.env.step(hybrid_action) \ No newline at end of file From 392e60f7db36a9a3416ccab1cbcc7625acf586b0 Mon Sep 17 00:00:00 2001 From: AlexPasqua Date: Sun, 26 Oct 2025 18:14:08 +0100 Subject: [PATCH 16/19] done --> terminated & truncated Co-authored-by: simrey --- sb3_contrib/common/envs/hybrid_actions_env.py | 39 ++++++++++++++----- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/sb3_contrib/common/envs/hybrid_actions_env.py b/sb3_contrib/common/envs/hybrid_actions_env.py index 6cb061dd..e84c12e8 100644 --- a/sb3_contrib/common/envs/hybrid_actions_env.py +++ b/sb3_contrib/common/envs/hybrid_actions_env.py @@ -1,3 +1,5 @@ +from typing import Optional, Tuple, Dict + import gymnasium as gym from gymnasium import spaces import numpy as np @@ -16,7 +18,7 @@ def __init__( move_dist: float = 0.05, catch_radius: float = 0.05, max_catches: float = 10, - max_steps: float = 200 + max_steps: float = 500 ): super().__init__() self.max_steps = max_steps @@ -38,27 +40,44 @@ def __init__( obs_high= np.array([ arena_size, arena_size, arena_size, arena_size, float(max_catches), 1.0], dtype=np.float32) self.observation_space = spaces.Box(obs_low, obs_high, dtype=np.float32) - def reset(self) -> np.ndarray: + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None) -> tuple[np.ndarray, dict]: """ Reset the environment to an initial state and return the initial observation. + + :param seed: The seed for random number generation + :param options: Additional options for environment reset, e.g., {'difficulty': 'hard'} + :return: Tuple of (observation, info) """ + super().reset(seed=seed) + + # Handle options (none used in this environment currently, but following the gymnasium API) + if options is not None: + # Example of how to handle options if needed: + # if 'difficulty' in options: + # self.move_dist = 0.03 if options['difficulty'] == 'hard' else 0.05 + pass + self.agent_pos = self.np_random.uniform(-self.arena_size, self.arena_size, size=2).astype(np.float32) self.target_pos = self.np_random.uniform(-self.arena_size, self.arena_size, size=2).astype(np.float32) self.catches_used = 0 self.step_count = 0 - return self._get_obs() + + obs = self._get_obs() + info = {} # Additional info dict, empty for now but could include initialization info + return obs, info - def step(self, action: tuple[np.ndarray, np.ndarray]) -> tuple[np.ndarray, float, bool, dict]: + def step(self, action: Tuple[np.ndarray, np.ndarray]) -> Tuple[np.ndarray, float, bool, bool, Dict[str, bool]]: """ Take a step in the environment using the provided action. :param action: A tuple containing the discrete action and continuous parameters. - :return: observation, reward, done, info + :return: Tuple of (observation, reward, terminated, truncated, info) """ action_d = int(action[0][0]) dir_vec = action[1] reward = 0.0 - done = False + terminated = False + truncated = False # step penalty reward = -0.01 @@ -77,18 +96,18 @@ def step(self, action: tuple[np.ndarray, np.ndarray]) -> tuple[np.ndarray, float dist = np.linalg.norm(self.agent_pos - self.target_pos) if dist <= self.catch_radius: reward = 1.0 # caught the target - done = True + terminated = True # Natural termination else: if self.catches_used >= self.max_catches: - done = True + terminated = True # Natural termination self.step_count += 1 if self.step_count >= self.max_steps: - done = True + truncated = True # Episode truncated due to time limit obs = self._get_obs() info = {"caught": (reward > 0)} - return obs, float(reward), bool(done), info + return obs, float(reward), terminated, truncated, info def _get_obs(self) -> np.ndarray: """ From 45120b46154a60767ef2628e02781557a103f194 Mon Sep 17 00:00:00 2001 From: AlexPasqua Date: Wed, 5 Nov 2025 10:13:29 +0100 Subject: [PATCH 17/19] Revert "done --> terminated & truncated" This reverts commit 392e60f7db36a9a3416ccab1cbcc7625acf586b0. --- sb3_contrib/common/envs/hybrid_actions_env.py | 39 +++++-------------- 1 file changed, 10 insertions(+), 29 deletions(-) diff --git a/sb3_contrib/common/envs/hybrid_actions_env.py b/sb3_contrib/common/envs/hybrid_actions_env.py index e84c12e8..6cb061dd 100644 --- a/sb3_contrib/common/envs/hybrid_actions_env.py +++ b/sb3_contrib/common/envs/hybrid_actions_env.py @@ -1,5 +1,3 @@ -from typing import Optional, Tuple, Dict - import gymnasium as gym from gymnasium import spaces import numpy as np @@ -18,7 +16,7 @@ def __init__( move_dist: float = 0.05, catch_radius: float = 0.05, max_catches: float = 10, - max_steps: float = 500 + max_steps: float = 200 ): super().__init__() self.max_steps = max_steps @@ -40,44 +38,27 @@ def __init__( obs_high= np.array([ arena_size, arena_size, arena_size, arena_size, float(max_catches), 1.0], dtype=np.float32) self.observation_space = spaces.Box(obs_low, obs_high, dtype=np.float32) - def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None) -> tuple[np.ndarray, dict]: + def reset(self) -> np.ndarray: """ Reset the environment to an initial state and return the initial observation. - - :param seed: The seed for random number generation - :param options: Additional options for environment reset, e.g., {'difficulty': 'hard'} - :return: Tuple of (observation, info) """ - super().reset(seed=seed) - - # Handle options (none used in this environment currently, but following the gymnasium API) - if options is not None: - # Example of how to handle options if needed: - # if 'difficulty' in options: - # self.move_dist = 0.03 if options['difficulty'] == 'hard' else 0.05 - pass - self.agent_pos = self.np_random.uniform(-self.arena_size, self.arena_size, size=2).astype(np.float32) self.target_pos = self.np_random.uniform(-self.arena_size, self.arena_size, size=2).astype(np.float32) self.catches_used = 0 self.step_count = 0 - - obs = self._get_obs() - info = {} # Additional info dict, empty for now but could include initialization info - return obs, info + return self._get_obs() - def step(self, action: Tuple[np.ndarray, np.ndarray]) -> Tuple[np.ndarray, float, bool, bool, Dict[str, bool]]: + def step(self, action: tuple[np.ndarray, np.ndarray]) -> tuple[np.ndarray, float, bool, dict]: """ Take a step in the environment using the provided action. :param action: A tuple containing the discrete action and continuous parameters. - :return: Tuple of (observation, reward, terminated, truncated, info) + :return: observation, reward, done, info """ action_d = int(action[0][0]) dir_vec = action[1] reward = 0.0 - terminated = False - truncated = False + done = False # step penalty reward = -0.01 @@ -96,18 +77,18 @@ def step(self, action: Tuple[np.ndarray, np.ndarray]) -> Tuple[np.ndarray, float dist = np.linalg.norm(self.agent_pos - self.target_pos) if dist <= self.catch_radius: reward = 1.0 # caught the target - terminated = True # Natural termination + done = True else: if self.catches_used >= self.max_catches: - terminated = True # Natural termination + done = True self.step_count += 1 if self.step_count >= self.max_steps: - truncated = True # Episode truncated due to time limit + done = True obs = self._get_obs() info = {"caught": (reward > 0)} - return obs, float(reward), terminated, truncated, info + return obs, float(reward), bool(done), info def _get_obs(self) -> np.ndarray: """ From b4191428e0bc2e4f83ce88ecdf00a36bb2dec5c8 Mon Sep 17 00:00:00 2001 From: AlexPasqua Date: Wed, 5 Nov 2025 10:20:41 +0100 Subject: [PATCH 18/19] Revert "Created HybridToBoxWrapper to handle hybrid actions but still integrate with the library" This reverts commit 365bdb90d63ab436da696beb5a5f0cb786453c4d. --- sb3_contrib/common/wrappers/__init__.py | 3 +- .../common/wrappers/hybrid_to_box_wrapper.py | 68 ------------------- 2 files changed, 1 insertion(+), 70 deletions(-) delete mode 100644 sb3_contrib/common/wrappers/hybrid_to_box_wrapper.py diff --git a/sb3_contrib/common/wrappers/__init__.py b/sb3_contrib/common/wrappers/__init__.py index 737da832..c7dc0b04 100644 --- a/sb3_contrib/common/wrappers/__init__.py +++ b/sb3_contrib/common/wrappers/__init__.py @@ -1,5 +1,4 @@ from sb3_contrib.common.wrappers.action_masker import ActionMasker from sb3_contrib.common.wrappers.time_feature import TimeFeatureWrapper -from sb3_contrib.common.wrappers.hybrid_to_box_wrapper import HybridToBoxWrapper -__all__ = ["ActionMasker", "TimeFeatureWrapper", "HybridToBoxWrapper"] +__all__ = ["ActionMasker", "TimeFeatureWrapper"] diff --git a/sb3_contrib/common/wrappers/hybrid_to_box_wrapper.py b/sb3_contrib/common/wrappers/hybrid_to_box_wrapper.py deleted file mode 100644 index cca5fba3..00000000 --- a/sb3_contrib/common/wrappers/hybrid_to_box_wrapper.py +++ /dev/null @@ -1,68 +0,0 @@ -from typing import Any, Dict, Tuple -import numpy as np -from gymnasium import spaces, Wrapper - - -class HybridToBoxWrapper(Wrapper): - """ - Wrapper that converts a hybrid action space (Tuple of MultiDiscrete and Box) - into a single Box action space, enabling compatibility with standard algorithms. - The wrapper handles the conversion between the flattened Box action space and - the hybrid action space internally. - """ - - def __init__(self, env): - """ - Initialize the wrapper. - - :param env: The environment to wrap - """ - super().__init__(env) - assert isinstance(env.action_space, spaces.Tuple), "Environment must have a Tuple action space" - assert len(env.action_space.spaces) == 2, "Action space must contain exactly 2 subspaces" - assert isinstance(env.action_space.spaces[0], spaces.MultiDiscrete), "First subspace must be MultiDiscrete" - assert isinstance(env.action_space.spaces[1], spaces.Box), "Second subspace must be Box" - - # Store original action space - self.hybrid_action_space = env.action_space - - # Calculate total dimensions needed for Box space - self.discrete_dims = sum(env.action_space.spaces[0].nvec) # One-hot encoding for each discrete action - self.continuous_dims = env.action_space.spaces[1].shape[0] - - # Create new Box action space - # First part: one-hot encoding for discrete actions - # Second part: continuous actions - total_dims = self.discrete_dims + self.continuous_dims - self.action_space = spaces.Box( - low=np.concatenate([np.zeros(self.discrete_dims), env.action_space.spaces[1].low]), - high=np.concatenate([np.ones(self.discrete_dims), env.action_space.spaces[1].high]), - dtype=np.float32 - ) - - def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, Dict[str, Any]]: - """ - Convert the Box action back to hybrid format and step the environment. - - :param action: Action from the Box action space - :return: Next observation, reward, done flag, and info dictionary - """ - # Split action into discrete and continuous parts - discrete_part = action[:self.discrete_dims] - continuous_part = action[self.discrete_dims:] - - # Convert one-hot encodings back to MultiDiscrete format - discrete_actions = [] - start_idx = 0 - for n in self.hybrid_action_space.spaces[0].nvec: - one_hot = discrete_part[start_idx:start_idx + n] - discrete_actions.append(np.argmax(one_hot)) - start_idx += n - - # Create hybrid action tuple - hybrid_action = ( - np.array(discrete_actions, dtype=np.int64), - continuous_part - ) - - return self.env.step(hybrid_action) \ No newline at end of file From 364769709846b42e4329b18c80edcac1f58d7aeb Mon Sep 17 00:00:00 2001 From: AlexPasqua Date: Sun, 9 Nov 2025 12:50:40 +0100 Subject: [PATCH 19/19] HybridPPO inherits from OnPolicyAlgorithm Instead of PPO --- sb3_contrib/ppo_hybrid/ppo_hybrid.py | 127 +++++++++++++++++++++------ 1 file changed, 99 insertions(+), 28 deletions(-) diff --git a/sb3_contrib/ppo_hybrid/ppo_hybrid.py b/sb3_contrib/ppo_hybrid/ppo_hybrid.py index 4249405a..e428f804 100644 --- a/sb3_contrib/ppo_hybrid/ppo_hybrid.py +++ b/sb3_contrib/ppo_hybrid/ppo_hybrid.py @@ -1,10 +1,11 @@ from typing import Any, ClassVar, Optional, TypeVar, Union +import warnings import torch as th from torch.nn import functional as F import numpy as np from gymnasium import spaces from sb3_contrib.common.hybrid.policies import HybridActorCriticPolicy, HybridActorCriticCnnPolicy, HybridMultiInputActorCriticPolicy -from stable_baselines3.ppo import PPO +from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.vec_env import VecEnv from stable_baselines3.common.callbacks import BaseCallback @@ -12,11 +13,12 @@ from stable_baselines3.common.utils import obs_as_tensor from sb3_contrib.ppo_hybrid.buffers import HybridActionsRolloutBuffer from stable_baselines3.common.utils import explained_variance +from stable_baselines3.common.utils import FloatSchedule SelfHybridPPO = TypeVar("SelfHybridPPO", bound="HybridPPO") -class HybridPPO(PPO): +class HybridPPO(OnPolicyAlgorithm): policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = { "MlpPolicy": HybridActorCriticPolicy, "CnnPolicy": HybridActorCriticCnnPolicy, @@ -56,34 +58,103 @@ def __init__( _init_setup_model: bool = True, ): super().__init__( - policy, - env, - learning_rate, - n_steps, - batch_size, - n_epochs, - gamma, - gae_lambda, - clip_range, - clip_range_vf, - normalize_advantage, - ent_coef, - vf_coef, - max_grad_norm, - use_sde, - sde_sample_freq, - rollout_buffer_class, - rollout_buffer_kwargs, - target_kl, - stats_window_size, - tensorboard_log, - policy_kwargs, - verbose, - seed, - device, - _init_setup_model, + policy=policy, + env=env, + learning_rate=learning_rate, + n_steps=n_steps, + gamma=gamma, + gae_lambda=gae_lambda, + ent_coef=ent_coef, + vf_coef=vf_coef, + max_grad_norm=max_grad_norm, + use_sde=use_sde, + sde_sample_freq=sde_sample_freq, + rollout_buffer_class=rollout_buffer_class, + rollout_buffer_kwargs=rollout_buffer_kwargs, + stats_window_size=stats_window_size, + tensorboard_log=tensorboard_log, + policy_kwargs=policy_kwargs, + verbose=verbose, + seed=seed, + device=device, + _init_setup_model=_init_setup_model, + supported_action_spaces=(spaces.Tuple,), ) + # Sanity check, otherwise it will lead to noisy gradient and NaN + # because of the advantage normalization + if normalize_advantage: + assert ( + batch_size > 1 + ), "`batch_size` must be greater than 1. See https://github.com/DLR-RM/stable-baselines3/issues/440" + + if self.env is not None: + # Check that `n_steps * n_envs > 1` to avoid NaN + # when doing advantage normalization + buffer_size = self.env.num_envs * self.n_steps + assert buffer_size > 1 or ( + not normalize_advantage + ), f"`n_steps * n_envs` must be greater than 1. Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}" + # Check that the rollout buffer size is a multiple of the mini-batch size + untruncated_batches = buffer_size // batch_size + if buffer_size % batch_size > 0: + warnings.warn( + f"You have specified a mini-batch size of {batch_size}," + f" but because the `RolloutBuffer` is of size `n_steps * n_envs = {buffer_size}`," + f" after every {untruncated_batches} untruncated mini-batches," + f" there will be a truncated mini-batch of size {buffer_size % batch_size}\n" + f"We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.\n" + f"Info: (n_steps={self.n_steps} and n_envs={self.env.num_envs})" + ) + + self.batch_size = batch_size + self.n_epochs = n_epochs + self.clip_range = clip_range + self.clip_range_vf = clip_range_vf + self.normalize_advantage = normalize_advantage + self.target_kl = target_kl + + if _init_setup_model: + self._setup_model() + + def _setup_model(self) -> None: + self._setup_lr_schedule() + self.set_random_seed(self.seed) + + if self.rollout_buffer_class is None: + # TODO: mauybe extend if buffers for Dict obs is implemented + self.rollout_buffer_class = HybridActionsRolloutBuffer + + self.rollout_buffer = self.rollout_buffer_class( + self.n_steps, + self.observation_space, + self.action_space, + self.device, + gamma=self.gamma, + gae_lambda=self.gae_lambda, + n_envs=self.n_envs, + **self.rollout_buffer_kwargs, + ) + + self.policy = self.policy_class( + self.observation_space, + self.action_space, + self.lr_schedule, + **self.policy_kwargs, + ) + self.policy = self.policy.to(self.device) + + if not isinstance(self.policy, HybridActorCriticPolicy): + raise ValueError("Policy must subclass HybridActorCriticPolicy") + + # Initialize schedules for policy/value clipping + self.clip_range = FloatSchedule(self.clip_range) + if self.clip_range_vf is not None: + if isinstance(self.clip_range_vf, (float, int)): + assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping" + + self.clip_range_vf = FloatSchedule(self.clip_range_vf) + def collect_rollouts( self, env: VecEnv,