Contrib package for Stable-Baselines3 - Experimental reinforcement learning (RL) code
learning_starts
parameter of QRDQN
have been changed to be consistent with the other offpolicy algorithms# SB3 < 2.3.0 default hyperparameters, 50_000 corresponded to Atari defaults hyperparameters
# model = QRDQN("MlpPolicy", env, learning_starts=50_000)
# SB3 >= 2.3.0:
model = QRDQN("MlpPolicy", env, learning_starts=100)
rollout_buffer_class
and rollout_buffer_kwargs
arguments to MaskablePPOrollout/success_rate
when available for on policy algorithmstrain_freq
type annotation for tqc and qrdqn (@Armandpl)sb3_contrib/common/maskable/*.py
type annotationssb3_contrib/ppo_mask/ppo_mask.py
type annotationssb3_contrib/common/vec_env/async_eval.py
type annotationsMaskablePPO
(evaluation and multi-process) (@icheered)Full Changelog: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/compare/v2.2.1...v2.3.0
SB3 Contrib (more algorithms): https://github.com/Stable-Baselines-Team/stable-baselines3-contrib RL Zoo3 (training framework): https://github.com/DLR-RM/rl-baselines3-zoo Stable-Baselines Jax (SBX): https://github.com/araffin/sbx
ruff
for sorting imports (isort is no longer needed), black and ruff version now require a minimum versionx is False
in favor of not x
, which means that callbacks that wrongly returned None (instead of a boolean) will cause the training to stop (@iwishiwasaneagle)set_options
for AsyncEval
rollout_buffer_class
and rollout_buffer_kwargs
arguments to TRPOActorCriticPolicy.extract_features()
signature by adding an optional features_extractor
argumentsphinx_autodoc_typehints
)SB3 Contrib (more algorithms): https://github.com/Stable-Baselines-Team/stable-baselines3-contrib RL Zoo3 (training framework): https://github.com/DLR-RM/rl-baselines3-zoo Stable-Baselines Jax (SBX): https://github.com/araffin/sbx
stats_window_size
argumentFull Changelog: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/compare/v2.0.0...v2.1.0
Warning Stable-Baselines3 (SB3) v2.0 will be the last one supporting python 3.7 (end of life in June 2023). We highly recommended you to upgrade to Python >= 3.8.
SB3 Contrib (more algorithms): https://github.com/Stable-Baselines-Team/stable-baselines3-contrib RL Zoo3 (training framework): https://github.com/DLR-RM/rl-baselines3-zoo Stable-Baselines Jax (SBX): https://github.com/araffin/sbx
To upgrade:
pip install stable_baselines3 sb3_contrib rl_zoo3 --upgrade
or simply (rl zoo depends on SB3 and SB3 contrib):
pip install rl_zoo3 --upgrade
shimmy
package (@carlosluis, @arjun-kg, @tlpss)sb3_contrib/tqc/*.py
type hintssb3_contrib/trpo/*.py
type hintssb3_contrib/common/envs/invalid_actions_env.py
type hintsFull Changelog: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/compare/v1.8.0...v2.0.0
Warning Stable-Baselines3 (SB3) v1.8.0 will be the last one to use Gym as a backend. Starting with v2.0.0, Gymnasium will be the default backend (though SB3 will have compatibility layers for Gym envs). You can find a migration guide here. If you want to try the SB3 v2.0 alpha version, you can take a look at PR #1327.
RL Zoo3 (training framework): https://github.com/DLR-RM/rl-baselines3-zoo
To upgrade:
pip install stable_baselines3 sb3_contrib rl_zoo3 --upgrade
or simply (rl zoo depends on SB3 and SB3 contrib):
pip install rl_zoo3 --upgrade
mlp_extractor
(@AlexPasqua)stats_window_size
argument to control smoothing in rollout logging (@jonasreiher)sb3_contrib/qrdqn/*.py
type hintsflake8
to ruff
check_env
in the MaskablePPO
docs (@AlexPasqua)Warning Shared layers in MLP policy (
mlp_extractor
) are now deprecated for PPO, A2C and TRPO. This feature will be removed in SB3 v1.8.0 and the behavior ofnet_arch=[64, 64]
will create separate networks with the same architecture, to be consistent with the off-policy algorithms.
Note TRPO models saved with SB3 < 1.7.0 will show a warning about missing keys in the state dict when loaded with SB3 >= 1.7.0. To suppress the warning, simply save the model again. You can find more info in issue # 1233
create_eval_env
, eval_env
, eval_log_path
, n_eval_episodes
and eval_freq
parameters,
please use an EvalCallback
insteadsde_net_arch
parameterwith_bias
parameter to ARSPolicy
normalize_images=False
RecurrentPPO
where the lstm states where incorrectly reshaped for n_lstm_layers > 1
(thanks @kolbytn)RuntimeError: rnn: hx is not contiguous
while predicting terminal values for RecurrentPPO
when n_lstm_layers > 1
features_extractor
parameter when calling extract_features()
MlpExtractor
(@AlexPasqua)sb3_contrib/common/utils.py
type hintsb3_contrib/common/recurrent/type_aliases.py
type hintsb3_contrib/ars/policies.py
type hint__init__.py
with __all__
attribute (@ZikangXiong)from gym import spaces
progress_bar
argument in the learn()
method, displayed using TQDM and rich packageseval_env
, eval_freq
and create_eval_env
.load()
methods so that they now use TypeVar
predict
does not always return action as np.ndarray
(@qgallouedec)RecurrentPPO
(@mlodel)MaskableEvalCallback
constructor (@burakdmb)running_mean
and running_var
properties of batch norm layers are not updated (@honglu2875)"cpu"
to "auto"
register_policy
helper, policy_base
parameter and using policy_aliases
static attributes instead (@Gregwar)rollout/exploration rate
key to rollout/exploration_rate
for QRDQN (to be consistent with SB3 DQN)pyupgrade
CnnPolicy
or MultiInputPolicy
with TQC,
share_features_extractor
is now set to False by default and the net_arch=[256, 256]
(instead of net_arch=[]
that was before)RecurrentPPO
(aka PPO LSTM)RecurrentPPO
when calculating the masked loss functions (@rnederstigt)TRPO
where kl divergence was not implemented for MultiDiscrete
spaceforward()
method as per pytorch guidelines