Modular Deep Reinforcement Learning framework in PyTorch. Companion library of the book "Foundations of Deep Reinforcement Learning".
Full Changelog: https://github.com/kengz/SLM-Lab/compare/v4.2.3...v4.2.4
reinforce_pong.json
spec to prevent confusion in https://github.com/kengz/SLM-Lab/pull/499
Full Changelog: https://github.com/kengz/SLM-Lab/compare/v4.2.2...v4.2.3
:raised_hands: Thanks to @Nickfagiano help with debugging.
atari-py
version to 0.2.6 for safety:raised_hands: Thanks to @piosif97 for helping.
:raised_hands: Thanks to @vladimirnitu and @steindaian for providing the PDF.
Dependencies and systems around SLM Lab has changed and caused some breakages. This release fixes these installation issues.
homebrew/cask
(thanks @ben-e, @amjadmajid )train@
resume mode and refactors the enjoy
mode. See PR for detailed info.train@
usage exampleSpecify train mode as train@{predir}
, where {predir} is the data directory of the last training run, or simply use
latest` to use the latest. e.g.:
python run_lab.py slm_lab/spec/benchmark/reinforce/reinforce_cartpole.json reinforce_cartpole train
# terminate run before its completion
# optionally edit the spec file in a past-future-consistent manner
# run resume with either of the commands:
python run_lab.py slm_lab/spec/benchmark/reinforce/reinforce_cartpole.json reinforce_cartpole train@latest
# or to use a specific run folder
python run_lab.py slm_lab/spec/benchmark/reinforce/reinforce_cartpole.json reinforce_cartpole train@data/reinforce_cartpole_2020_04_13_232521
enjoy
mode refactorThe train@
resume mode API allows for the enjoy
mode to be refactored. Both share similar syntax. Continuing with the example above, to enjoy a train model, we now use:
python run_lab.py slm_lab/spec/benchmark/reinforce/reinforce_cartpole.json reinforce_cartpole enjoy@data/reinforce_cartpole_2020_04_13_232521/reinforce_cartpole_t0_s0_spec.json
OnPolicyCrossEntropy
memory class. See PR for details. Credits to @ingambe.Env. \ Alg. | DQN | DDQN+PER | A2C (GAE) | A2C (n-step) | PPO | SAC |
Breakout graph |
80.88 | 182 | 377 | 398 | 443 | 3.51* |
Pong graph |
18.48 | 20.5 | 19.31 | 19.56 | 20.58 | 19.87* |
Seaquest graph |
1185 | 4405 | 1070 | 1684 | 1715 | 171* |
Qbert graph |
5494 | 11426 | 12405 | 13590 | 13460 | 923* |
LunarLander graph |
192 | 233 | 25.21 | 68.23 | 214 | 276 |
UnityHallway graph |
-0.32 | 0.27 | 0.08 | -0.96 | 0.73 | 0.01 |
UnityPushBlock graph |
4.88 | 4.93 | 4.68 | 4.93 | 4.97 | -0.70 |
Episode score at the end of training attained by SLM Lab implementations on discrete-action control problems. Reported episode scores are the average over the last 100 checkpoints, and then averaged over 4 Sessions. A Random baseline with score averaged over 100 episodes is included. Results marked with
*
were trained using the hybrid synchronous/asynchronous version of SAC to parallelize and speed up training time. For SAC, Breakout, Pong and Seaquest were trained for 2M frames instead of 10M frames.
For the full Atari benchmark, see Atari Benchmark
This marks a stable release of SLM Lab with full benchmark results
tensorboard --logdir=data
after a session/trial is completed. Example screenshot:Env. \ Alg. | DQN | DDQN+PER | A2C (GAE) | A2C (n-step) | PPO | SAC |
Breakout graph |
80.88 | 182 | 377 | 398 | 443 | - |
Pong graph |
18.48 | 20.5 | 19.31 | 19.56 | 20.58 | 19.87* |
Seaquest graph |
1185 | 4405 | 1070 | 1684 | 1715 | - |
Qbert graph |
5494 | 11426 | 12405 | 13590 | 13460 | 214* |
LunarLander graph |
192 | 233 | 25.21 | 68.23 | 214 | 276 |
UnityHallway graph |
-0.32 | 0.27 | 0.08 | -0.96 | 0.73 | - |
UnityPushBlock graph |
4.88 | 4.93 | 4.68 | 4.93 | 4.97 | - |
Episode score at the end of training attained by SLM Lab implementations on discrete-action control problems. Reported episode scores are the average over the last 100 checkpoints, and then averaged over 4 Sessions. Results marked with
*
were trained using the hybrid synchronous/asynchronous version of SAC to parallelize and speed up training time.
For the full Atari benchmark, see Atari Benchmark
Env. \ Alg. | A2C (GAE) | A2C (n-step) | PPO | SAC |
RoboschoolAnt graph |
787 | 1396 | 1843 | 2915 |
RoboschoolAtlasForwardWalk graph |
59.87 | 88.04 | 172 | 800 |
RoboschoolHalfCheetah graph |
712 | 439 | 1960 | 2497 |
RoboschoolHopper graph |
710 | 285 | 2042 | 2045 |
RoboschoolInvertedDoublePendulum graph |
996 | 4410 | 8076 | 8085 |
RoboschoolInvertedPendulum graph |
995 | 978 | 986 | 941 |
RoboschoolReacher graph |
12.9 | 10.16 | 19.51 | 19.99 |
RoboschoolWalker2d graph |
280 | 220 | 1660 | 1894 |
RoboschoolHumanoid graph |
99.31 | 54.58 | 2388 | 2621* |
RoboschoolHumanoidFlagrun graph |
73.57 | 178 | 2014 | 2056* |
RoboschoolHumanoidFlagrunHarder graph |
-429 | 253 | 680 | 280* |
Unity3DBall graph |
33.48 | 53.46 | 78.24 | 98.44 |
Unity3DBallHard graph |
62.92 | 71.92 | 91.41 | 97.06 |
Episode score at the end of training attained by SLM Lab implementations on continuous control problems. Reported episode scores are the average over the last 100 checkpoints, and then averaged over 4 Sessions. Results marked with
*
require 50M-100M frames, so we use the hybrid synchronous/asynchronous version of SAC to parallelize and speed up training time.
Env. \ Alg. | DQN | DDQN+PER | A2C (GAE) | A2C (n-step) | PPO |
Adventure graph |
-0.94 | -0.92 | -0.77 | -0.85 | -0.3 |
AirRaid graph |
1876 | 3974 | 4202 | 3557 | 4028 |
Alien graph |
822 | 1574 | 1519 | 1627 | 1413 |
Amidar graph |
90.95 | 431 | 577 | 418 | 795 |
Assault graph |
1392 | 2567 | 3366 | 3312 | 3619 |
Asterix graph |
1253 | 6866 | 5559 | 5223 | 6132 |
Asteroids graph |
439 | 426 | 2951 | 2147 | 2186 |
Atlantis graph |
68679 | 644810 | 2747371 | 2259733 | 2148077 |
BankHeist graph |
131 | 623 | 855 | 1170 | 1183 |
BattleZone graph |
6564 | 6395 | 4336 | 4533 | 13649 |
BeamRider graph |
2799 | 5870 | 2659 | 4139 | 4299 |
Berzerk graph |
319 | 401 | 1073 | 763 | 860 |
Bowling graph |
30.29 | 39.5 | 24.51 | 23.75 | 31.64 |
Boxing graph |
72.11 | 90.98 | 1.57 | 1.26 | 96.53 |
Breakout graph |
80.88 | 182 | 377 | 398 | 443 |
Carnival graph |
4280 | 4773 | 2473 | 1827 | 4566 |
Centipede graph |
1899 | 2153 | 3909 | 4202 | 5003 |
ChopperCommand graph |
1083 | 4020 | 3043 | 1280 | 3357 |
CrazyClimber graph |
46984 | 88814 | 106256 | 109998 | 116820 |
Defender graph |
281999 | 313018 | 665609 | 657823 | 534639 |
DemonAttack graph |
1705 | 19856 | 23779 | 19615 | 121172 |
DoubleDunk graph |
-21.44 | -22.38 | -5.15 | -13.3 | -6.01 |
ElevatorAction graph |
32.62 | 17.91 | 9966 | 8818 | 6471 |
Enduro graph |
437 | 959 | 787 | 0.0 | 1926 |
FishingDerby graph |
-88.14 | -1.7 | 16.54 | 1.65 | 36.03 |
Freeway graph |
24.46 | 30.49 | 30.97 | 0.0 | 32.11 |
Frostbite graph |
98.8 | 2497 | 277 | 261 | 1062 |
Gopher graph |
1095 | 7562 | 929 | 1545 | 2933 |
Gravitar graph |
87.34 | 258 | 313 | 433 | 223 |
Hero graph |
1051 | 12579 | 16502 | 19322 | 17412 |
IceHockey graph |
-14.96 | -14.24 | -5.79 | -6.06 | -6.43 |
Jamesbond graph |
44.87 | 702 | 521 | 453 | 561 |
JourneyEscape graph |
-4818 | -2003 | -921 | -2032 | -1094 |
Kangaroo graph |
1965 | 8897 | 67.62 | 554 | 4989 |
Krull graph |
5522 | 6650 | 7785 | 6642 | 8477 |
KungFuMaster graph |
2288 | 16547 | 31199 | 25554 | 34523 |
MontezumaRevenge graph |
0.0 | 0.02 | 0.08 | 0.19 | 1.08 |
MsPacman graph |
1175 | 2215 | 1965 | 2158 | 2350 |
NameThisGame graph |
3915 | 4474 | 5178 | 5795 | 6386 |
Phoenix graph |
2909 | 8179 | 16345 | 13586 | 30504 |
Pitfall graph |
-68.83 | -73.65 | -101 | -31.13 | -35.93 |
Pong graph |
18.48 | 20.5 | 19.31 | 19.56 | 20.58 |
Pooyan graph |
1958 | 2741 | 2862 | 2531 | 6799 |
PrivateEye graph |
784 | 303 | 93.22 | 78.07 | 50.12 |
Qbert graph |
5494 | 11426 | 12405 | 13590 | 13460 |
Riverraid graph |
953 | 10492 | 8308 | 7565 | 9636 |
RoadRunner graph |
15237 | 29047 | 30152 | 31030 | 32956 |
Robotank graph |
3.43 | 9.05 | 2.98 | 2.27 | 2.27 |
Seaquest graph |
1185 | 4405 | 1070 | 1684 | 1715 |
Skiing graph |
-14094 | -12883 | -19481 | -14234 | -24713 |
Solaris graph |
612 | 1396 | 2115 | 2236 | 1892 |
SpaceInvaders graph |
451 | 670 | 733 | 750 | 797 |
StarGunner graph |
3565 | 38238 | 44816 | 48410 | 60579 |
Tennis graph |
-23.78 | -10.33 | -22.42 | -19.06 | -11.52 |
TimePilot graph |
2819 | 1884 | 3331 | 3440 | 4398 |
Tutankham graph |
35.03 | 159 | 161 | 175 | 211 |
UpNDown graph |
2043 | 11632 | 89769 | 18878 | 262208 |
Venture graph |
4.56 | 9.61 | 0.0 | 0.0 | 11.84 |
VideoPinball graph |
8056 | 79730 | 35371 | 40423 | 58096 |
WizardOfWor graph |
869 | 328 | 1516 | 1247 | 4283 |
YarsRevenge graph |
5816 | 15698 | 27097 | 11742 | 10114 |
Zaxxon graph |
442 | 54.28 | 64.72 | 24.7 | 641 |
The table above presents results for 62 Atari games. All agents were trained for 10M frames (40M including skipped frames). Reported results are the episode score at the end of training, averaged over the previous 100 evaluation checkpoints with each checkpoint averaged over 4 Sessions. Agents were checkpointed every 10k training frames.
This release adds a new algorithm: Soft Actor-Critic (SAC).
-implement the original paper: "Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor" https://arxiv.org/abs/1801.01290 #398
GumbelSoftmax
distribution (custom)Note that the Roboschool reward scales are different from MuJoCo's.
Env. \ Alg. | SAC |
---|---|
RoboschoolAnt | 2451.55 graph |
RoboschoolHalfCheetah | 2004.27 graph |
RoboschoolHopper | 2090.52 graph |
RoboschoolWalker2d | 1711.92 graph |
Trial graph | Moving average |
This release corrects and optimizes all the algorithms from benchmarking on Atari. New metrics are introduced. The lab's API is also redesigned for simplicity.
This benchmark table is pulled from PR396. See the full benchmark results here.
Env. \ Alg. | A2C (GAE) | A2C (n-step) | PPO | DQN | DDQN+PER |
---|---|---|---|---|---|
Breakout graph |
389.99 graph |
391.32 graph |
425.89 graph |
65.04 graph |
181.72 graph |
Pong graph |
20.04 graph |
19.66 graph |
20.09 graph |
18.34 graph |
20.44 graph |
Qbert graph |
13,328.32 graph |
13,259.19 graph |
13,691.89 graph |
4,787.79 graph |
11,673.52 graph |
Seaquest graph |
892.68 graph |
1,686.08 graph |
1,583.04 graph |
1,118.50 graph |
3,751.34 graph |
Now, the full list of algorithms are:
rigorous_eval
#390 , and using inference for fast eval #391