PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
SB3 Contrib (more algorithms): https://github.com/Stable-Baselines-Team/stable-baselines3-contrib RL Zoo3 (training framework): https://github.com/DLR-RM/rl-baselines3-zoo Stable-Baselines Jax (SBX): https://github.com/araffin/sbx
To upgrade:
pip install stable_baselines3 sb3_contrib --upgrade
or simply (rl zoo depends on SB3 and SB3 contrib):
pip install rl_zoo3 --upgrade
TD3
and DDPG
have been changed to be more consistent with SAC
# SB3 < 2.3.0 default hyperparameters
# model = TD3("MlpPolicy", env, train_freq=(1, "episode"), gradient_steps=-1, batch_size=100)
# SB3 >= 2.3.0:
model = TD3("MlpPolicy", env, train_freq=1, gradient_steps=1, batch_size=256)
[!NOTE] Two inconsistencies remain: the default network architecture for
TD3/DDPG
is[400, 300]
instead of[256, 256]
for SAC (for backward compatibility reasons, see report on the influence of the network size ) and the default learning rate is 1e-3 instead of 3e-4 for SAC (for performance reasons, see W&B report on the influence of the lr )
learning_starts
parameter of DQN
have been changed to be consistent with the other offpolicy algorithms
# SB3 < 2.3.0 default hyperparameters, 50_000 corresponded to Atari defaults hyperparameters
# model = DQN("MlpPolicy", env, learning_starts=50_000)
# SB3 >= 2.3.0:
model = DQN("MlpPolicy", env, learning_starts=100)
torch.load()
is now called with weights_only=True
when loading torch tensors,
policy load()
still uses weights_only=False
as gymnasium imports are required for it to workhuggingface_sb3
, you will now need to set TRUST_REMOTE_CODE=True
when downloading models from the hub, as pickle.load
is not safe.rollout/success_rate
when available for on policy algorithms (@corentinlger)monitor_wrapper
argument that was not passed to the parent class, and dones argument that wasn't passed to _update_into_buffer
(@corentinlger)rollout_buffer_class
and rollout_buffer_kwargs
arguments to MaskablePPOtrain_freq
type annotation for tqc and qrdqn (@Armandpl)sb3_contrib/common/maskable/*.py
type annotationssb3_contrib/ppo_mask/ppo_mask.py
type annotationssb3_contrib/common/vec_env/async_eval.py
type annotationsMaskablePPO
(evaluation and multi-process) (@icheered)setup.py
(@power-edge)requirements.txt
(remove duplicates from setup.py
)MultiDiscrete
and MultiBinary
action spaces to PPOtrain()
signature and update type hintsCrossQ
render_mode="human"
in the README example (@marekm4)log_interval
in the base class (@rushitnshah).Full Changelog: https://github.com/DLR-RM/stable-baselines3/compare/v2.2.1...v2.3.0
SB3 Contrib (more algorithms): https://github.com/Stable-Baselines-Team/stable-baselines3-contrib RL Zoo3 (training framework): https://github.com/DLR-RM/rl-baselines3-zoo Stable-Baselines Jax (SBX): https://github.com/araffin/sbx
To upgrade:
pip install stable_baselines3 sb3_contrib --upgrade
or simply (rl zoo depends on SB3 and SB3 contrib):
pip install rl_zoo3 --upgrade
[!NOTE] Stable-Baselines3 (SB3) v2.2.0 was yanked after a breaking change was found in GH#1751. Please use SB3 v2.2.1 and not v2.2.0.
ruff
for sorting imports (isort is no longer needed), black and ruff version now require a minimum versionx is False
in favor of not x
, which means that callbacks that wrongly returned None (instead of a boolean) will cause the training to stop (@iwishiwasaneagle)env_checker
for env wrongly detected as GoalEnv (compute_reward()
is defined)options
at reset with VecEnv via the set_options()
method. Same as seeds logic, options are reset at the end of an episode (@ReHoss)rollout_buffer_class
and rollout_buffer_kwargs
arguments to on-policy algorithms (A2C and PPO)_setup_learn()
in OffPolicyAlgorithm (@PatrickHelm)callback.update_locals()
before callback.on_rollout_end()
in OnPolicyAlgorithm (@PatrickHelm)render_mode
which was not properly loaded when using VecNormalize.load()
SimpleMultiObsEnv
(@NixGD)set_options
for AsyncEval
rollout_buffer_class
and rollout_buffer_kwargs
arguments to TRPOgym
dependency, the package is still required for some pretrained agents.--eval-env-kwargs
to train.py
(@Quentin18)ppo_lstm
to hyperparams_opt.py (@technocrat13)pybullet_envs_gymnasium>=0.4.0
optuna.suggest_uniform(...)
by optuna.suggest_float(..., low=..., high=...)
DDPG
and TD3
algorithmsstable_baselines3/common/callbacks.py
type hintsstable_baselines3/common/utils.py
type hintsstable_baselines3/common/vec_envs/vec_transpose.py
type hintsstable_baselines3/common/vec_env/vec_video_recorder.py
type hintsstable_baselines3/common/save_util.py
type hintsstable_baselines3/common/buffers.py
type hintsstable_baselines3/her/her_replay_buffer.py
type hints.copy()
when storing new transitionsActorCriticPolicy.extract_features()
signature by adding an optional features_extractor
argumentsphinx_autodoc_typehints
)stable_baselines3/common/off_policy_algorithm.py
type hintsstable_baselines3/common/distributions.py
type hintsstable_baselines3/common/vec_env/vec_normalize.py
type hintsstable_baselines3/common/vec_env/__init__.py
type hintsstable_baselines3/common/policies.py
type hintsmypy
only for checking typesFull changelog: https://github.com/DLR-RM/stable-baselines3/compare/v2.1.0...v2.2.1
SB3 Contrib (more algorithms): https://github.com/Stable-Baselines-Team/stable-baselines3-contrib RL Zoo3 (training framework): https://github.com/DLR-RM/rl-baselines3-zoo Stable-Baselines Jax (SBX): https://github.com/araffin/sbx
To upgrade:
pip install stable_baselines3 sb3_contrib --upgrade
or simply (rl zoo depends on SB3 and SB3 contrib):
pip install rl_zoo3 --upgrade
stats_window_size
argumentenv_checker.py
warning messages for out of bounds in complex observation spaces (@Gabo-Tor)test_spaces.py
testsFull Changelog: https://github.com/DLR-RM/stable-baselines3/compare/v2.0.0...v2.1.0
[!WARNING] Stable-Baselines3 (SB3) v2.0 will be the last one supporting python 3.7 (end of life in June 2023). We highly recommended you to upgrade to Python >= 3.8.
SB3 Contrib (more algorithms): https://github.com/Stable-Baselines-Team/stable-baselines3-contrib RL Zoo3 (training framework): https://github.com/DLR-RM/rl-baselines3-zoo Stable-Baselines Jax (SBX): https://github.com/araffin/sbx
To upgrade:
pip install stable_baselines3 sb3_contrib rl_zoo3 --upgrade
or simply (rl zoo depends on SB3 and SB3 contrib):
pip install rl_zoo3 --upgrade
shimmy
package (@carlosluis, @arjun-kg, @tlpss)online_sampling
argument of HerReplayBuffer
was removedstack_observation_space
method of StackedObservations
evaluate_policy
to prevent shadowing the input observations during callbacks (@npit)HumanOutputFormat
file check: now it verifies if the object is an instance of io.TextIOBase
instead of only checking for the presence of a write
method.vec_env.seed(seed=seed)
will only be effective after then env.reset()
call.shimmy
package)CarRacing-v1
to CarRacing-v2
in hyperparameters--n-timesteps
argument to adjust the length of the videorecord_video
steps (before it was stepping in a closed env)VecExtractDictObs
does not handle terminal observation (@WeberSamuel)>=1.20
due to use of numpy.typing
(@troiganto)target_update_interval
(@tobirohrer)step()
when checking
for Inf
and NaN
(@lutogniew)truncate_last_trajectory()
(@lbergmann1)stable_baselines3/a2c/*.py
type hintsstable_baselines3/ppo/*.py
type hintsstable_baselines3/sac/*.py
type hintsstable_baselines3/td3/*.py
type hintsstable_baselines3/common/base_class.py
type hintsstable_baselines3/common/logger.py
type hintsstable_baselines3/common/envs/*.py
type hintsstable_baselines3/common/vec_env/vec_monitor|vec_extract_dict_obs|util.py
type hintsstable_baselines3/common/vec_env/base_vec_env.py
type hintsstable_baselines3/common/vec_env/vec_frame_stack.py
type hintsstable_baselines3/common/vec_env/dummy_vec_env.py
type hintsstable_baselines3/common/vec_env/subproc_vec_env.py
type hintsVecEnv
and VecEnvWrapper
seed()
method return type from List
to Sequence
VecEnv
API vs Gym APIVecEnv
vs Gym envEvalCallback
example (@sidney-tio)pink-noise-rl
to projects pageortho_init
was ignoredFull Changelog: https://github.com/DLR-RM/stable-baselines3/compare/v1.8.0...v2.0.0
[!WARNING] Stable-Baselines3 (SB3) v1.8.0 will be the last one to use Gym as a backend. Starting with v2.0.0, Gymnasium will be the default backend (though SB3 will have compatibility layers for Gym envs). You can find a migration guide here. If you want to try the SB3 v2.0 alpha version, you can take a look at PR #1327.
SB3 Contrib (more algorithms): https://github.com/Stable-Baselines-Team/stable-baselines3-contrib RL Zoo3 (training framework): https://github.com/DLR-RM/rl-baselines3-zoo
To upgrade:
pip install stable_baselines3 sb3_contrib rl_zoo3 --upgrade
or simply (rl zoo depends on SB3 and SB3 contrib):
pip install rl_zoo3 --upgrade
mlp_extractor
(@AlexPasqua)StackedObservations
(it now handles dict obs, StackedDictObservations
was removed)features_extractor
parameter when calling extract_features()
HerReplayBuffer
HerReplayBuffer
was refactored to support multiprocessing, previous replay buffer are incompatible with this new versionHerReplayBuffer
doesn't require a max_episode_length
anymorerepeat_action_probability
argument in AtariWrapper
.NoopResetEnv
and MaxAndSkipEnv
when needed in AtariWrapper
VecCheckNan
, the check is now active in the env_checker()
(@DavyMorgan)HerReplayBuffer
HerReplayBuffer
now supports all datatypes supported by ReplayBuffer
observation_space
of custom gym environments using check_env
(@FieteO)stats_window_size
argument to control smoothing in rollout logging (@jonasreiher)check_env
in the MaskablePPO
docs (@AlexPasqua)sb3_contrib/qrdqn/*.py
type hintsmlp_extractor
(@AlexPasqua)dtype
(default to float32
) to the noise for consistency with gym action (@sidney-tio)DictRolloutBuffer.add
with multidimensional action space (@younik)tests/test_tensorboard.py
type hinttests/test_vec_normalize.py
type hintstable_baselines3/common/monitor.py
type hintsetup.cg
to pyproject.toml
configuration fileflake8
to ruff
stable_baselines3/dqn/*.py
type hintsextra_no_roms
option for package installation without Atari Romsload_parameters
to set_parameters
(@DavyMorgan)A2C
docstring (@AlexPasqua)log_interval
description (@theSquaredError)SB3 Contrib (more algorithms): https://github.com/Stable-Baselines-Team/stable-baselines3-contrib RL Zoo3 (training framework): https://github.com/DLR-RM/rl-baselines3-zoo
To upgrade:
pip install stable_baselines3 sb3_contrib rl_zoo3 --upgrade
or simply (rl zoo depends on SB3 and SB3 contrib):
pip install rl_zoo3 --upgrade
Warning Shared layers in MLP policy (
mlp_extractor
) are now deprecated for PPO, A2C and TRPO. This feature will be removed in SB3 v1.8.0 and the behavior ofnet_arch=[64, 64]
will create separate networks with the same architecture, to be consistent with the off-policy algorithms.
Note A2C and PPO models saved with SB3 < 1.7.0 will show a warning about missing keys in the state dict when loaded with SB3 >= 1.7.0. To suppress the warning, simply save the model again. You can find more info in issue #1233
create_eval_env
, eval_env
, eval_log_path
, n_eval_episodes
and eval_freq
parameters,
please use an EvalCallback
insteadsde_net_arch
parameterret
attributes in VecNormalize
, please use returns
insteadVecNormalize
now updates the observation space when normalizing imageswith_bias
argument to create_mlp
spaces.MultiBinary
observationsnormalize_images=False
normalized_image
parameter to NatureCNN
and CombinedExtractor
RecurrentPPO
where the lstm states where incorrectly reshaped for n_lstm_layers > 1
(thanks @kolbytn)RuntimeError: rnn: hx is not contiguous
while predicting terminal values for RecurrentPPO
when n_lstm_layers > 1
monitor_kwargs
parameterProgressBarCallback
under-reporting (@dominicgkerr)evaluate_actions
in ActorCritcPolicy
to reflect that entropy is an optional tensor (@Rocamonde)policy
in BaseAlgorithm
and OffPolicyAlgorithm
custom_objects
workaroundmodel
in evaluate_policy
Self
return type using TypeVar
normalize_images
which was not passed to parent class in some casesload_from_vector
that was broken with newer PyTorch version when passing PyTorch tensorfeatures_extractor
parameter when calling extract_features()
MlpExtractor
(@AlexPasqua)compute_reward
method, rather than by their inheritance to gym.GoalEnv
CartPole-v0
by CartPole-v1
is teststests/test_distributions.py
type hintsstable_baselines3/common/type_aliases.py
type hintsstable_baselines3/common/torch_layers.py
type hintsstable_baselines3/common/env_util.py
type hintsstable_baselines3/common/preprocessing.py
type hintsstable_baselines3/common/atari_wrappers.py
type hintsstable_baselines3/common/vec_env/vec_check_nan.py
type hints__init__.py
with the __all__
attribute (@ZikangXiong)np.bool = bool
so gym 0.21 is compatible with NumPy 1.24+from gym import spaces
get_system_info
to avoid issue linked to copy-pasting on GitHub issueenv
to vec_env
when environment is vectorizedmlp_extractor
's dimensions (@AlexPasqua)SB3 Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib RL Zoo3: https://github.com/DLR-RM/rl-baselines3-zoo
progress_bar
argument in the learn()
method, displayed using TQDM and rich packagespip install rl_zoo3
)self.num_timesteps
was initialized properly only after the first call to on_step()
for callbacks~=4.13
to be compatible with gym=0.21
eval_env
, eval_freq
or create_eval_env
are used (see #925) (@tobirohrer)env_id
parameter in make_vec_env
and make_atari_env
(@AlexPasqua)wrapper_class
parameter in make_vec_env
(@AlexPasqua)SB3 Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
VecNormalize
statistics (@anand-bala)Monitor
to append to existing file instead of overriding (@sidney-tio)CnnLstmPolicy
or MultiInputLstmPolicy
with RecurrentPPO
(@mlodel)PPO
gives NaN if rollout buffer provides a batch of size 1 (@hughperkins)predict
does not always return action as np.ndarray
(@qgallouedec)EvalCallback
constructor (@burakdmb)running_mean
and running_var
properties of batch norm layers are not updated (@honglu2875)common.OffPolicyAlgorithm
initializer, where an instance instead of a class was required (@Rocamonde)forward()
abstract method declaration from common.policies.BaseModel
(already defined in torch.nn.Module
) to fix type errors in subclasses (@Rocamonde).load()
and .learn()
methods in BaseAlgorithm
so that they now use TypeVar
(@Rocamonde)common.logger.HumanOutputFormat
(@Rocamonde and @AdamGleave)DictReplayBuffer.next_observations
typing (@qgallouedec)device="auto"
in buffers and made it default (@qgallouedec)ResultsWriter` (used internally by
Monitorwrapper) to automatically create missing directories when
filename`` is a path (@dominicgkerr)SB3 Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
register_policy
helper, policy_base
parameter and using policy_aliases
static attributes instead (@Gregwar)CnnPolicy
or MultiInputPolicy
with SAC or DDPG/TD3,
share_features_extractor
is now set to False by default and the net_arch=[256, 256]
(instead of net_arch=[]
that was before)DummyVecEnv
's and SubprocVecEnv
's seeding function. None value was unchecked (@ScheiklP)EvalCallback
would crash when trying to synchronize VecNormalize
stats when observation normalization was disabledkl_divergence
check that would fail when using numpy arrays with MultiCategorical distributionpyupgrade
BaseAlgorithm._wrap_env
(@TibiGG)SB3 Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
StopTrainingOnNoModelImprovement
to callback collection (@caburu)HumanOutputFormat
configurable,
depending on desired maximum width of output.VecMonitor
. The monitor did not consider the info_keywords
during stepping (@ScheiklP)HumanOutputFormat
. Distinct keys truncated to the same prefix would overwrite each others value,
resulting in only one being output. This now raises an error (this should only affect a small fraction of use cases
with very long keys.)nn.Module
calls through implicit rather than explict forward as per pytorch guidelines (@manuel-delverme)VecNormalize
where error occurs when norm_obs
is set to False for environment with dictionary observation (@buoyancy99)env
argument to None
in HerReplayBuffer.sample
(@qgallouedec)batch_size
typing in DQN
(@qgallouedec)DictReplayBuffer
(@qgallouedec)remove_time_limit_termination
in off policy algorithms since it was dead code (@Gregwar)Directly Accessing The Summary Writer
in tensorboard integration (@xy9485)Full Changelog: https://github.com/DLR-RM/stable-baselines3/compare/v1.4.0...v1.5.0