PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
device
keyword argument of policies; use policy.to(device)
instead. (@qxcv)BaseClass.get_torch_variables
-> BaseClass._get_torch_save_params
and
BaseClass.excluded_save_params
-> BaseClass._excluded_save_params
tensors
to pytorch_variables
for claritymake_atari_env
, make_vec_env
and set_random_seed
must be imported with (and not directly from stable_baselines3.common
):from stable_baselines3.common.cmd_util import make_atari_env, make_vec_env
from stable_baselines3.common.utils import set_random_seed
unwrap_vec_wrapper()
to common.vec_env
to extract VecEnvWrapper
if neededStopTrainingOnMaxEpisodes
to callback collection (@xicocaio)device
keyword argument to BaseAlgorithm.load()
(@liorcohen5)get_parameters
and set_parameters
for accessing/setting parameters of the agentevaluate_policy
clip_fraction
in PPO (@diditforlulz273)device="cuda:0"
(@liorcohen5)VecEnv
make_vec_env
(@ManifoldFR)AlreadySteppingError
and NotSteppingError
that were not usedBaseClass
(save/load functions close to each other, private
functions at top)save_to_zip_file
function by removing duplicate codeStopTrainingOnMaxEpisodes
details and example (@xicocaio)sphinx_autodoc_typehints
AtariWrapper
and other Atari wrappers were updated to match SB2 onessave_replay_buffer
now receives as argument the file path instead of the folder path (@tirafesi)Critic
class for TD3
and SAC
, it is now called ContinuousCritic
and has an additional parameter n_critics
SAC
and TD3
now accept an arbitrary number of critics (e.g. policy_kwargs=dict(n_critics=3)
)
instead of only 2 previouslyDQN
Algorithm (@Artemis-Skade)ReplayBuffer
psutil
is availableDDPG
algorithm as a special case of TD3
.BaseModel
abstract parent for BasePolicy
, which critics inherit from.close()
method of SubprocVecEnv
, causing wrappers further down in the wrapper stack to not be closed. (@NeoExtended)cloudpickle.load
instead of pickle.load
in CloudpickleWrapper
. (@shwang)bias=False
in custom policy (@rk37)dones
in on-policy algorithm rollout collection. (@andyshih12).learn()
methodcollect_rollout()
method for off-policy algorithms_on_step()
for off-policy base classnext_observations
numpy arrayblack
codestyle and added make format
, make check-codestyle
and commit-checks
gSDE
common.sb2_compat.RMSpropTFLike
optimizer, which corresponds closer to the implementation of RMSprop from Tensorflow.render()
method of VecEnvs
now only accept one argument: mode
Created new file common/torch_layers.py, similar to SB refactoring
MlpExtractor
, create_mlp
, NatureCNN
Renamed BaseRLModel
to BaseAlgorithm
(along with offpolicy and onpolicy variants)
Moved on-policy and off-policy base algorithms to common/on_policy_algorithm.py
and common/off_policy_algorithm.py
, respectively.
Moved PPOPolicy
to ActorCriticPolicy
in common/policies.py
Moved PPO
(algorithm class) into OnPolicyAlgorithm
(common/on_policy_algorithm.py
), to be shared with A2C
Moved following functions from BaseAlgorithm
:
_load_from_file
to load_from_zip_file
(save_util.py)_save_to_file_zip
to save_to_zip_file
(save_util.py)safe_mean
to safe_mean
(utils.py)check_env
to check_for_correct_spaces
(utils.py. Renamed to avoid confusion with environment checker tools)Moved static function _is_vectorized_observation
from common/policies.py to common/utils.py under name is_vectorized_observation
.
Removed {save,load}_running_average
functions of VecNormalize
in favor of load/save
.
Removed use_gae
parameter from RolloutBuffer.compute_returns_and_advantage
.
render()
method for VecEnvs
seed()
method for SubprocVecEnv
deterministic=False
register_policy
to allow re-registering same policy for same sub-class (i.e. assign same value to same key).gSDE
with PPO
/A2C
, this does not affect SAC
fork
start method in the tests (was causing a deadlock with tensorflow)SubprocVecEnv
and renderingprogress
(value from 1 in start of training to 0 in end) to progress_remaining
.policies.py
files for A2C/PPO, which define MlpPolicy/CnnPolicy (renamed ActorCriticPolicies).VecNormalize
, VecCheckNan
and PPO
.TD3
logkv
-> record
, writekvs
-> write
, writeseq
-> write_sequence
,logkvs
-> record_dict
, dumpkvs
-> dump
,getkvs
-> get_log_dict
, logkv_mean
-> record_mean
,VecCheckNan
and VecVideoRecorder
(Sync with Stable Baselines)cmd_util
and atari_wrappers
MultiDiscrete
and MultiBinary
observation spaces (@rolandgvc)MultiCategorical
and Bernoulli
distributions for PPO/A2C (@rolandgvc)VectorizedActionNoise
for continuous vectorized environments (@PartiallyTyped)EvalCallback
using the loggersde_sample_freq
that was not taken into account for SACBaseCallback
otherwise they cannot write in the one used by the algorithmsVecEnvs
with Stable-Baselinesgym>=0.17
.readthedoc.yml
fileflake8
and make lint
commandtrain_freq
and n_episodes_rollout
to Off-Policy AlgorithmsTD3
example code block