Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
b490174
Created empty structure for custom policies
AlexPasqua Jul 6, 2025
71291eb
Started populating HybridActorCriticPolicy
AlexPasqua Jul 6, 2025
3360570
Create HybridDistribution and HybridDistributionNet (tbc)
AlexPasqua Jul 26, 2025
642c383
Merge branch 'master' into hybrid_PPO
araffin Jul 30, 2025
caf1a44
Add _build method to HybridActorCriticPolicy
AlexPasqua Sep 13, 2025
9ac17e0
Started forward method (tbc)
AlexPasqua Sep 13, 2025
94c0c2a
Completed forward and implemented collect_rollouts
AlexPasqua Sep 21, 2025
219ab37
Added learn method
AlexPasqua Sep 21, 2025
dee6ef0
Fixed HybridPPO __init__ arguments
AlexPasqua Sep 21, 2025
a368613
Started train method (tbc)
AlexPasqua Sep 21, 2025
37e1dcd
Created HybridActionsRolloutBuffer
AlexPasqua Oct 4, 2025
6e34356
Completed train method
AlexPasqua Oct 4, 2025
6994f2d
Update sb3_contrib/__init__.py
AlexPasqua Oct 5, 2025
0cd46f4
Created env 'catching point' (to be tested)
AlexPasqua Oct 26, 2025
fafae08
spaces.Tuple is not a generic class
AlexPasqua Oct 26, 2025
365bdb9
Created HybridToBoxWrapper to handle hybrid actions but still integra…
AlexPasqua Oct 26, 2025
392e60f
done --> terminated & truncated
AlexPasqua Oct 26, 2025
45120b4
Revert "done --> terminated & truncated"
AlexPasqua Nov 5, 2025
b419142
Revert "Created HybridToBoxWrapper to handle hybrid actions but still…
AlexPasqua Nov 5, 2025
d8a1f77
Merge branch 'master' into hybrid_PPO
araffin Nov 6, 2025
3647697
HybridPPO inherits from OnPolicyAlgorithm
AlexPasqua Nov 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sb3_contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from sb3_contrib.crossq import CrossQ
from sb3_contrib.ppo_mask import MaskablePPO
from sb3_contrib.ppo_recurrent import RecurrentPPO
from sb3_contrib.ppo_hybrid import HybridPPO
from sb3_contrib.qrdqn import QRDQN
from sb3_contrib.tqc import TQC
from sb3_contrib.trpo import TRPO
Expand All @@ -21,4 +22,5 @@
"CrossQ",
"MaskablePPO",
"RecurrentPPO",
"HybridPPO",
]
5 changes: 4 additions & 1 deletion sb3_contrib/common/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,8 @@
InvalidActionEnvMultiBinary,
InvalidActionEnvMultiDiscrete,
)
from sb3_contrib.common.envs.hybrid_actions_env import (
CatchingPointEnv
)

__all__ = ["InvalidActionEnvDiscrete", "InvalidActionEnvMultiBinary", "InvalidActionEnvMultiDiscrete"]
__all__ = ["InvalidActionEnvDiscrete", "InvalidActionEnvMultiBinary", "InvalidActionEnvMultiDiscrete", "CatchingPointEnv"]
104 changes: 104 additions & 0 deletions sb3_contrib/common/envs/hybrid_actions_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import gymnasium as gym
from gymnasium import spaces
import numpy as np


class CatchingPointEnv(gym.Env):
"""
Enviornment for Hybrid PPO for the 'Catching Point' task of the paper
'Hybrid Actor-Critic Reinforcement Learning in Parameterized Action Space', Fan et al.
(https://arxiv.org/pdf/1903.01344)
"""

def __init__(
self,
arena_size: float = 1.0,
move_dist: float = 0.05,
catch_radius: float = 0.05,
max_catches: float = 10,
max_steps: float = 200
):
super().__init__()
self.max_steps = max_steps
self.max_catches = max_catches
self.arena_size = arena_size
self.move_dist = move_dist
self.catch_radius = catch_radius

# action space
self.action_space = spaces.Tuple(
spaces=(
spaces.MultiDiscrete([2]), # MOVE=0, CATCH=1
spaces.Box(low=-1.0, high=1.0, shape=(2,), dtype=np.float32) # direction
)
)

# observation: [agent_x, agent_y, target_x, target_y, catches_left, step_norm]
obs_low = np.array([-arena_size, -arena_size, -arena_size, -arena_size, 0.0, 0.0], dtype=np.float32)
obs_high= np.array([ arena_size, arena_size, arena_size, arena_size, float(max_catches), 1.0], dtype=np.float32)
self.observation_space = spaces.Box(obs_low, obs_high, dtype=np.float32)

def reset(self) -> np.ndarray:
"""
Reset the environment to an initial state and return the initial observation.
"""
self.agent_pos = self.np_random.uniform(-self.arena_size, self.arena_size, size=2).astype(np.float32)
self.target_pos = self.np_random.uniform(-self.arena_size, self.arena_size, size=2).astype(np.float32)
self.catches_used = 0
self.step_count = 0
return self._get_obs()

def step(self, action: tuple[np.ndarray, np.ndarray]) -> tuple[np.ndarray, float, bool, dict]:
"""
Take a step in the environment using the provided action.

:param action: A tuple containing the discrete action and continuous parameters.
:return: observation, reward, done, info
"""
action_d = int(action[0][0])
dir_vec = action[1]
reward = 0.0
done = False

# step penalty
reward = -0.01

# MOVE
if action_d == 0:
norm = np.linalg.norm(dir_vec)
dir_u = dir_vec / norm
self.agent_pos = (self.agent_pos + dir_u * self.move_dist).astype(np.float32)
# clamp to arena
self.agent_pos = np.clip(self.agent_pos, -self.arena_size, self.arena_size)

# CATCH
else:
self.catches_used += 1
dist = np.linalg.norm(self.agent_pos - self.target_pos)
if dist <= self.catch_radius:
reward = 1.0 # caught the target
done = True
else:
if self.catches_used >= self.max_catches:
done = True

self.step_count += 1
if self.step_count >= self.max_steps:
done = True

obs = self._get_obs()
info = {"caught": (reward > 0)}
return obs, float(reward), bool(done), info

def _get_obs(self) -> np.ndarray:
"""
Get the current observation.
"""
step_norm = self.step_count / self.max_steps
catches_left = self.max_catches - self.catches_used
obs = np.concatenate((
self.agent_pos,
self.target_pos,
np.array([catches_left, step_norm], dtype=np.float32)
))
return obs
Empty file.
197 changes: 197 additions & 0 deletions sb3_contrib/common/hybrid/distributions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
import numpy as np
import torch as th
from torch import nn
from typing import Any, Optional, TypeVar, Union
from stable_baselines3.common.distributions import Distribution
from torch.distributions import Categorical, Normal
from gymnasium import spaces


SelfHybridDistribution = TypeVar("SelfHybridDistribution", bound="HybridDistribution")


class HybridDistributionNet(nn.Module):
"""
Base class for hybrid distributions that handle both discrete and continuous actions.
This class should be extended to implement specific hybrid distributions.
"""

def __init__(self, latent_dim: int, categorical_dimensions: np.ndarray, n_continuous: int):
super().__init__()
# For discrete action space
self.categorical_nets = nn.ModuleList([nn.Linear(latent_dim, out_dim) for out_dim in categorical_dimensions])
# For continuous action space
self.gaussian_net = nn.Linear(latent_dim, n_continuous)

def forward(self, latent: th.Tensor) -> tuple[list[th.Tensor], th.Tensor]:
"""
Forward pass through all categorical nets and the gaussian net.

:param latent: Latent tensor input
:return: Tuple (list of categorical outputs, gaussian output)
"""
categorical_outputs = [net(latent) for net in self.categorical_nets]
gaussian_output = self.gaussian_net(latent)
return categorical_outputs, gaussian_output


class Hybrid(th.distributions.Distribution):
"""
A hybrid distribution that combines multiple categorical distributions for discrete actions
and a Gaussian distribution for continuous actions.
"""

def __init__(self,
probs: Optional[tuple[list[th.Tensor], th.Tensor]] = None,
logits: Optional[tuple[list[th.Tensor], th.Tensor]] = None,
validate_args: Optional[bool] = None,
):
super().__init__()
categorical_logits: list[th.Tensor] = logits[0]
gaussian_means: th.Tensor = logits[1]
self.categorical_dists = [Categorical(logits=logit) for logit in categorical_logits]
self.gaussian_dist = Normal(loc=gaussian_means, scale=th.ones_like(gaussian_means))

def sample(self) -> tuple[th.Tensor, th.Tensor]:
categorical_samples = [dist.sample() for dist in self.categorical_dists]
gaussian_samples = self.gaussian_dist.sample()
return th.stack(categorical_samples, dim=-1), gaussian_samples

def log_prob(self) -> tuple[th.Tensor, th.Tensor]:
"""
Returns the log probability of the given actions, both discrete and continuous.
"""
gaussian_log_prob = self.gaussian_dist.log_prob(self.gaussian_dist.sample())
categorical_log_probs = [dist.log_prob(dist.sample()) for dist in self.categorical_dists]

# TODO: check dimensions
return (
th.sum(th.stack(categorical_log_probs, dim=-1), dim=-1),
th.sum(gaussian_log_prob, dim=-1)
)

def entropy(self) -> tuple[th.Tensor, th.Tensor]:
"""
Returns the entropy of the hybrid distribution, which is the sum of the entropies
of the categorical and gaussian components.

:return: Tuple of (categorical entropy, gaussian entropy)
"""
categorical_entropies = [dist.entropy() for dist in self.categorical_dists]
# Sum entropies for all categorical distributions
categorical_entropy = th.sum(th.stack(categorical_entropies, dim=-1), dim=-1)
gaussian_entropy = self.gaussian_dist.entropy().sum(dim=-1)
return categorical_entropy, gaussian_entropy


class HybridDistribution(Distribution):
def __init__(self, categorical_dimensions: np.ndarray, n_continuous: int):
"""
Initialize the hybrid distribution with categorical and continuous components.

:param categorical_dimensions: An array specifying the dimensions of the categorical actions.
:param n_continuous: The number of continuous actions.
"""
super().__init__()
self.categorical_dimensions = categorical_dimensions
self.n_continuous = n_continuous
self.categorical_dists = None
self.gaussian_dist = None

def proba_distribution_net(self, latent_dim: int) -> Union[nn.Module, tuple[nn.Module, nn.Parameter]]:
"""Create the layers and parameters that represent the distribution.

Subclasses must define this, but the arguments and return type vary between
concrete classes."""
action_net = HybridDistributionNet(latent_dim, self.categorical_dimensions)
return action_net

def proba_distribution(self: SelfHybridDistribution, action_logits: tuple[list[th.Tensor], th.Tensor]) -> SelfHybridDistribution:
"""Set parameters of the distribution.

:return: self
"""
self.distribution = Hybrid(logits=action_logits)
return self

def log_prob(self, discrete_actions: th.Tensor, continuous_actions: th.Tensor) -> tuple[th.Tensor, th.Tensor]:
"""
Returns the log likelihood

:param x: the taken action
:return: The log likelihood of the distribution for discrete and continuous distributions
"""
assert self.distribution is not None, "Must set distribution parameters"
return self.distribution.log_prob(continuous_actions, discrete_actions)

def entropy(self) -> tuple[th.Tensor, th.Tensor]:
"""
Returns Shannon's entropy of the probability

:return: the entropy of discrete and continuous distributions
"""
assert self.distribution is not None, "Must set distribution parameters"
return self.distribution.entropy()

def sample(self) -> tuple[th.Tensor, th.Tensor]:
"""
Returns a sample from the probability distribution

:return: the stochastic action
"""
assert self.distribution is not None, "Must set distribution parameters"
return self.distribution.sample()

def mode(self) -> th.Tensor:
"""
Returns the most likely action (deterministic output)
from the probability distribution

:return: the stochastic action
"""

def get_actions(self, deterministic: bool = False) -> tuple[th.Tensor, th.Tensor]:
"""
Return actions according to the probability distribution.

:param deterministic:
:return:
"""
if deterministic:
return self.mode()
return self.sample()

def actions_from_params(self, *args, **kwargs) -> th.Tensor:
"""
Returns samples from the probability distribution
given its parameters.

:return: actions
"""

def log_prob_from_params(self, *args, **kwargs) -> tuple[th.Tensor, th.Tensor]:
"""
Returns samples and the associated log probabilities
from the probability distribution given its parameters.

:return: actions and log prob
"""


def make_hybrid_proba_distribution(action_space: spaces.Tuple) -> HybridDistribution:
"""
Create a hybrid probability distribution for the given action space.

:param action_space: Tuple Action space containing a MultiDiscrete action space and a Box action space.
:return: A HybridDistribution object that handles the hybrid action space.
"""
assert isinstance(action_space, spaces.Tuple), "Action space must be a Tuple space"
assert len(action_space.spaces) == 2, "Action space must contain exactly 2 subspaces"
assert isinstance(action_space.spaces[0], spaces.MultiDiscrete), "First subspace must be MultiDiscrete"
assert isinstance(action_space.spaces[1], spaces.Box), "Second subspace must be Box"
assert len(action_space[1].shape) == 1, "Continuous action space must have a monodimensional shape (e.g., (n,))"
return HybridDistribution(
categorical_dimensions=len(action_space[0].nvec),
n_continuous=action_space[1].shape[0]
)

Loading
Loading