Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Docs/fix trainer fct notebooks #1009

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions docs/01_tutorials/00_dqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ First of all, you have to make an environment for your agent to interact with. Y
import gymnasium as gym
import tianshou as ts

env = gym.make('CartPole-v0')
env = gym.make('CartPole-v1')

CartPole-v0 includes a cart carrying a pole moving on a track. This is a simple environment with a discrete action space, for which DQN applies. You have to identify whether the action space is continuous or discrete and apply eligible algorithms. DDPG :cite:`DDPG`, for example, could only be applied to continuous action spaces, while almost all other policy gradient methods could be applied to both.
CartPole-v1 includes a cart carrying a pole moving on a track. This is a simple environment with a discrete action space, for which DQN applies. You have to identify whether the action space is continuous or discrete and apply eligible algorithms. DDPG :cite:`DDPG`, for example, could only be applied to continuous action spaces, while almost all other policy gradient methods could be applied to both.

Here is the detail of useful fields of CartPole-v0:
Here is the detail of useful fields of CartPole-v1:

- ``state``: the position of the cart, the velocity of the cart, the angle of the pole and the velocity of the tip of the pole;
- ``action``: can only be one of ``[0, 1, 2]``, for moving the cart left, no move, and right;
Expand All @@ -62,8 +62,8 @@ Setup Vectorized Environment
If you want to use the original ``gym.Env``:
::

train_envs = gym.make('CartPole-v0')
test_envs = gym.make('CartPole-v0')
train_envs = gym.make('CartPole-v1')
test_envs = gym.make('CartPole-v1')

Tianshou supports vectorized environment for all algorithms. It provides four types of vectorized environment wrapper:

Expand All @@ -74,8 +74,8 @@ Tianshou supports vectorized environment for all algorithms. It provides four ty

::

train_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(10)])
test_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(100)])
train_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v1') for _ in range(10)])
test_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v1') for _ in range(100)])

Here, we set up 10 environments in ``train_envs`` and 100 environments in ``test_envs``.

Expand All @@ -84,8 +84,8 @@ You can also try the super-fast vectorized environment `EnvPool <https://github.
::

import envpool
train_envs = envpool.make_gymnasium("CartPole-v0", num_envs=10)
test_envs = envpool.make_gymnasium("CartPole-v0", num_envs=100)
train_envs = envpool.make_gymnasium("CartPole-v1", num_envs=10)
test_envs = envpool.make_gymnasium("CartPole-v1", num_envs=100)

For the demonstration, here we use the second code-block.

Expand Down
4 changes: 2 additions & 2 deletions docs/01_tutorials/01_concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ The general explanation is listed in :ref:`pseudocode`. Other usages of collecto
::

policy = PGPolicy(...) # or other policies if you wish
env = gym.make("CartPole-v0")
env = gym.make("CartPole-v1")

replay_buffer = ReplayBuffer(size=10000)

Expand All @@ -363,7 +363,7 @@ The general explanation is listed in :ref:`pseudocode`. Other usages of collecto
# the collector supports vectorized environments as well
vec_buffer = VectorReplayBuffer(total_size=10000, buffer_num=3)
# buffer_num should be equal to (suggested) or larger than #envs
envs = DummyVectorEnv([lambda: gym.make("CartPole-v0") for _ in range(3)])
envs = DummyVectorEnv([lambda: gym.make("CartPole-v1") for _ in range(3)])
collector = Collector(policy, envs, buffer=vec_buffer)

# collect 3 episodes
Expand Down
2 changes: 1 addition & 1 deletion docs/01_tutorials/07_cheatsheet.rst
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ toy_text and classic_control environments. For more information, please refer to
# install envpool: pip3 install envpool

import envpool
envs = envpool.make_gymnasium("CartPole-v0", num_envs=10)
envs = envpool.make_gymnasium("CartPole-v1", num_envs=10)
collector = Collector(policy, envs, buffer)

Here are some other `examples <https://github.com/sail-sg/envpool/tree/master/examples/tianshou_examples>`_.
Expand Down
17 changes: 15 additions & 2 deletions docs/02_notebooks/L6_Trainer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,10 @@
"base_uri": "https://localhost:8080/"
},
"id": "vcvw9J8RNtFE",
"outputId": "b483fa8b-2a57-4051-a3d0-6d8162d948c5"
"outputId": "b483fa8b-2a57-4051-a3d0-6d8162d948c5",
"tags": [
"remove-output"
]
},
"outputs": [],
"source": [
Expand All @@ -200,7 +203,17 @@
" episode_per_test=10,\n",
" step_per_collect=2000,\n",
" batch_size=512,\n",
")\n",
").run()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"print(result)"
]
},
Expand Down
14 changes: 8 additions & 6 deletions docs/02_notebooks/L7_Experiment.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,6 @@
"metadata": {
"editable": true,
"id": "ao9gWJDiHgG-",
"slideshow": {
"slide_type": ""
},
"tags": [
"hide-cell",
"remove-output"
Expand Down Expand Up @@ -233,8 +230,12 @@
"colab": {
"base_uri": "https://localhost:8080/"
},
"editable": true,
"id": "i45EDnpxQ8gu",
"outputId": "b1666b88-0bfa-4340-868e-58611872d988"
"outputId": "b1666b88-0bfa-4340-868e-58611872d988",
"tags": [
"remove-output"
]
},
"outputs": [],
"source": [
Expand All @@ -249,7 +250,7 @@
" batch_size=256,\n",
" step_per_collect=2000,\n",
" stop_fn=lambda mean_reward: mean_reward >= 195,\n",
")"
").run()"
]
},
{
Expand All @@ -270,7 +271,8 @@
"base_uri": "https://localhost:8080/"
},
"id": "tJCPgmiyiaaX",
"outputId": "40123ae3-3365-4782-9563-46c43812f10f"
"outputId": "40123ae3-3365-4782-9563-46c43812f10f",
"tags": []
},
"outputs": [],
"source": [
Expand Down
Loading