PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
SB3 Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
StopTrainingOnNoModelImprovement
to callback collection (@caburu)HumanOutputFormat
configurable,
depending on desired maximum width of output.VecMonitor
. The monitor did not consider the info_keywords
during stepping (@ScheiklP)HumanOutputFormat
. Distinct keys truncated to the same prefix would overwrite each others value,
resulting in only one being output. This now raises an error (this should only affect a small fraction of use cases
with very long keys.)nn.Module
calls through implicit rather than explict forward as per pytorch guidelines (@manuel-delverme)VecNormalize
where error occurs when norm_obs
is set to False for environment with dictionary observation (@buoyancy99)env
argument to None
in HerReplayBuffer.sample
(@qgallouedec)batch_size
typing in DQN
(@qgallouedec)DictReplayBuffer
(@qgallouedec)remove_time_limit_termination
in off policy algorithms since it was dead code (@Gregwar)Directly Accessing The Summary Writer
in tensorboard integration (@xy9485)Full Changelog: https://github.com/DLR-RM/stable-baselines3/compare/v1.4.0...v1.5.0
SB3 Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
mask
argument of the predict()
method to episode_start
(used with RNN policies only)action
, done
and reward
were renamed to their plural form for offpolicy algorithms (actions
, dones
, rewards
),
this may affect custom callbacks.episode_reward
field from RolloutReturn()
typeWarning:
An update to the HER
algorithm is planned to support multi-env training and remove the max episode length constrain.
(see PR #704)
This will be a backward incompatible change (model trained with previous version of HER
won't work with the new version).
norm_obs_keys
param for VecNormalize
wrapper to configure which observation keys to normalize (@kachayev)HerReplayBuffer
currently not supported)TimeLimit
)skip
option to VecTransposeImage
to skip transforming the channel order when the heuristic is wrongcopy()
and combine()
methods to RunningMeanStd
set_env()
with VecNormalize
would result in an error with off-policy algorithms (thanks @cleversonahum)learn
call, even when reset_num_timesteps
is set to False
(@kachayev)VecFrameStack
with channel first image envs, where the terminal observation would be wrongly created.np.float32
for continuous actionsnewline="\n"
when opening CSV monitor files so that each line ends with \r\n
instead of \r\r\n
on Windows while Linux environments are not affected (@hsuehch)device
argument inconsistency (@qgallouedec)BaseAlgorithm.load
docstring (@Demetrio92)load
behavior in the examples (@Demetrio92)WARNING: This version will be the last one supporting Python 3.6 (end of life in Dec 2021). We highly recommend you to upgrade to Python >= 3.7.
SB3-Contrib changelog: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/releases/tag/v1.3.0
sde_net_arch
argument in policies is deprecated and will be removed in a future version.
_get_latent
(ActorCriticPolicy
) was removed
All logging keys now use underscores instead of spaces (@timokau). Concretely this changes:
time/total timesteps
to time/total_timesteps
for off-policy algorithms (PPO and A2C) and the eval callback (on-policy algorithms already used the underscored version),rollout/exploration rate
to rollout/exploration_rate
androllout/success rate
to rollout/success_rate
.get_distribution
and predict_values
for ActorCriticPolicy
for A2C/PPO/TRPO (@cyprienc)forward_actor
and forward_critic
for MlpExtractor
sb3.get_system_info()
helper function to gather version information relevant to SB3 (e.g., Python and PyTorch version)print_system_info
parameter to help debugging load issues.dtype
of observations for SimpleMultiObsEnv
VecNormalize
to wrap discrete-observation environments to normalize reward
when observation normalization is disabled.DQN
would throw an error when using Discrete
observation and stochastic actionsforce_reset
argument to load()
and set_env()
in order to be able to call learn(reset_num_timesteps=False)
with a new environmentEvalCallback
with two envs not wrapped the same way.setup.py
docutils
issue)VecNormalize
ret
attribute was renamed to returns
VecNormalize
where the observation filter was not updated at reset (thanks @vwxyzjn)train()
and eval()
(@davidblom603)gradient_steps=0
to an off-policy algorithm will result in no gradient steps being taken (vs as many gradient steps as steps done in the environment
during the rollout in previous versions)predict()
by moving the preprocessing to obs_to_tensor()
methodVecEnvWrapper
BitFlippingEnv
or IdentityEnv
) were moved to stable_baselines3.common.envs
folderHER
which is now the HerReplayBuffer
class that can be passed to any off-policy algorithmTimeLimit
)_last_dones
and dones
to _last_episode_starts
and episode_starts
in RolloutBuffer
.ObsDictWrapper
as Dict
observation spaces are now supported her_kwargs = dict(n_sampled_goal=2, goal_selection_strategy="future", online_sampling=True)
# SB3 < 1.1.0
# model = HER("MlpPolicy", env, model_class=SAC, **her_kwargs)
# SB3 >= 1.1.0:
model = SAC("MultiInputPolicy", env, replay_buffer_class=HerReplayBuffer, replay_buffer_kwargs=her_kwargs)
channels_last
from is_image_space
as it can be inferred.model.logger
that be set by the user using model.set_logger()
logger.configure
and utils.configure_logger
, they now return a Logger
objectLogger.CURRENT
and Logger.DEFAULT
warn(), debug(), log(), info(), dump()
methods to the Logger
class.learn()
now throws an import error when the user tries to log to tensorboard but the package is not installedDict
observation space (@JadenTravnik)DictRolloutBuffer
DictReplayBuffer
to support dictionary observations (@JadenTravnik)StackedObservations
and StackedDictObservations
that are used within VecFrameStack
HerReplayBuffer
now supports VecNormalize
when online_sampling=False
HER
replay_buffer_class
and replay_buffer_kwargs
arguments to off-policy algorithmskl_divergence
helper for Distribution
classes (@09tangriro)num_envs > 1
(@benblack769)wrapper_kwargs
argument to make_vec_env
(@amy12xx)ent_coef
for SAC
and TQC
, it was not optimized anymore (thanks @Atlis)A2C
and PPO
policy when using gSDE (thanks @liusida)verbose>=1
after passing verbose=0
onceflake8-bugbear
to tests dependencies to find likely bugsenv_checker
to reflect support of dict observation spacesbatch_size > 1
in PPO to avoid NaN in advantage normalizationProcgenEnv
docutils==0.16
to avoid issue with rtd themesave_freq
definitionA2C
docs (@bstee615)First Major Version
Blog post: https://araffin.github.io/post/sb3/
100+ pre-trained models in the zoo: https://github.com/DLR-RM/rl-baselines3-zoo
stable_baselines3.common.cmd_util
(already deprecated), please use env_util
insteadWarning
A refactoring of the HER
algorithm is planned together with support for dictionary observations (see PR #243 and
#351)
This will be a backward incompatible change (model trained with previous version of HER
won't work with the new version).
custom_objects
when loading modelsDQN
predict method when using deterministic=False
with image spaceSecond release candidate
evaluate_policy
now returns rewards/episode lengths from a Monitor
wrapper if one is present,
this allows to return the unnormalized reward in the case of Atari games for instance.common.vec_env.is_wrapped
to common.vec_env.is_vecenv_wrapped
to avoid confusion
with the new is_wrapped()
helper_get_data()
to _get_constructor_parameters()
for policies (this affects independent saving/loading of policies)n_episodes_rollout
and merged it with train_freq
, which now accepts a tuple (frequency, unit)
:replay_buffer
in collect_rollout
is no more optional
# SB3 < 0.11.0
# model = SAC("MlpPolicy", env, n_episodes_rollout=1, train_freq=-1)
# SB3 >= 0.11.0:
model = SAC("MlpPolicy", env, train_freq=(1, "episode"))
VecFrameStack
to stack on first or last observation dimension, along with
automatic check for image spaces.VecFrameStack
now has a channels_order
argument to tell if observations should be stacked
on the first or last observation dimension (originally always stacked on last).common.env_util.is_wrapped
and common.env_util.unwrap_wrapper
functions for checking/unwrapping
an environment for specific wrapper.env_is_wrapped()
method for VecEnv
to check if its environments are wrapped
with given Gym wrappers.monitor_kwargs
parameter to make_vec_env
and make_atari_env
Monitor
wrapper when possible.EvalCallback
now logs the success rate when available (is_success
must be present in the info dict)Logger
. (@lorenz-h)DQN
predict method when using single gym.Env
with deterministic=False
explained_variance()
in ppo.py
and a2c.py
is not correct (@thisray)HerReplayBuffer
leads to an index error. (@megan-klaiber)PPO
construction error in edge-case scenario where n_steps * n_envs = 1
(size of rollout buffer),
which otherwise causes downstream breaking errors in training (@decodyng)train_freq=1
)np.bool
with bool
)VecNormalize
was not normalizing the terminal observationVecTranspose
was not transposing the terminal observationaction_noise
was not used when using HER
(thanks @ShangqunYu)train_freq
was not properly converted when loading a saved modelNatureCNN
train()
method of SAC
, TD3
and DQN
to match SB3-Contrib.PPO
when n_steps * n_envs
is not a multiple of batch_size
(last mini-batch truncated) (@decodyng)A2C
(epsilon parameter)clip_range
docstringEvalCallback
docstring (thanks @tfederico)common.cmd_util
to common.env_util
for clarity (affects make_vec_env
and make_atari_env
functions)net_arch=dict(qf=[400, 300], pi=[64, 64])
for off-policy algorithms (SAC, TD3, DDPG)HER
. (@megan-klaiber)VecNormalize
now supports gym.spaces.Dict
observation spacesshare_features_extractor
argument to SAC
and TD3
policiesmake_vec_env
support the env_kwargs
argument when using an env ID str (@ManifoldFR)device="cpu"
is providedcheck_env
not checking if the env has a Dict actionspace before calling _check_nan
(@wmmc88)SAC
, DDPG
and TD3
when using CnnPolicy
(or custom feature extractor)CnnPolicy
, the passed env was not wrapped properly
(the bug was introduced when implementing HER
so it should not be present in previous versions).vscode
to the gitignoreCnnPolicies