Skip to content

Commit

Permalink
Fixed remaining notebooks and spelling
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Panchenko committed Dec 28, 2023
1 parent 4e86934 commit c5a6f0f
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 37 deletions.
91 changes: 61 additions & 30 deletions docs/02_notebooks/L4_Policy.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,19 @@
},
"outputs": [],
"source": [
"from typing import Dict, List\n",
"from typing import cast\n",
"\n",
"import numpy as np\n",
"import torch\n",
"import gymnasium as gym\n",
"\n",
"from tianshou.data import Batch, ReplayBuffer, to_torch, to_torch_as\n",
"from tianshou.policy import BasePolicy\n",
"from dataclasses import dataclass\n",
"\n",
"from tianshou.data import Batch, ReplayBuffer, to_torch, to_torch_as, SequenceSummaryStats\n",
"from tianshou.policy import BasePolicy, TrainingStats\n",
"from tianshou.utils.net.common import Net\n",
"from tianshou.utils.net.discrete import Actor"
"from tianshou.utils.net.discrete import Actor\n",
"from tianshou.data.types import BatchWithReturnsProtocol, RolloutBatchProtocol"
]
},
{
Expand Down Expand Up @@ -102,12 +105,14 @@
"\n",
"\n",
"\n",
"1. Since Tianshou is a **Deep** RL libraries, there should be a policy network in our Policy Module, also a Torch optimizer.\n",
"2. In Tianshou's BasePolicy, `Policy.update()` first calls `Policy.process_fn()` to preprocess training data and computes quantities like episodic returns (gradient free), then it will call `Policy.learn()` to perform the back-propagation.\n",
"1. Since Tianshou is a **Deep** RL libraries, there should be a policy network in our Policy Module, \n",
"also a Torch optimizer.\n",
"2. In Tianshou's BasePolicy, `Policy.update()` first calls `Policy.process_fn()` to \n",
"preprocess training data and computes quantities like episodic returns (gradient free), \n",
"then it will call `Policy.learn()` to perform the back-propagation.\n",
"3. Each Policy is accompanied by a dedicated implementation of `TrainingStats` to store details of training.\n",
"\n",
"Then we get the implementation below.\n",
"\n",
"\n",
"\n"
]
},
Expand All @@ -119,7 +124,14 @@
},
"outputs": [],
"source": [
"class REINFORCEPolicy(BasePolicy):\n",
"@dataclass(kw_only=True)\n",
"class REINFORCETrainingStats(TrainingStats):\n",
" \"\"\"A dedicated class for REINFORCE training statistics.\"\"\"\n",
"\n",
" loss: SequenceSummaryStats\n",
"\n",
"\n",
"class REINFORCEPolicy(BasePolicy[REINFORCETrainingStats]):\n",
" \"\"\"Implementation of REINFORCE algorithm.\"\"\"\n",
"\n",
" def __init__(\n",
Expand All @@ -138,7 +150,7 @@
" \"\"\"Compute the discounted returns for each transition.\"\"\"\n",
" pass\n",
"\n",
" def learn(self, batch: Batch, batch_size: int, repeat: int) -> Dict[str, List[float]]:\n",
" def learn(self, batch: Batch, batch_size: int, repeat: int) -> REINFORCETrainingStats:\n",
" \"\"\"Perform the back-propagation.\"\"\"\n",
" return"
]
Expand Down Expand Up @@ -220,7 +232,9 @@
},
"source": [
"### Policy.learn()\n",
"Data batch returned by `Policy.process_fn()` will flow into `Policy.learn()`. Final we can construct our loss function and perform the back-propagation."
"Data batch returned by `Policy.process_fn()` will flow into `Policy.learn()`. Finally,\n",
"we can construct our loss function and perform the back-propagation. The method \n",
"should look something like this:"
]
},
{
Expand All @@ -231,22 +245,24 @@
},
"outputs": [],
"source": [
"def learn(self, batch: Batch, batch_size: int, repeat: int) -> Dict[str, List[float]]:\n",
"from tianshou.utils.optim import optim_step\n",
"\n",
"\n",
"def learn(self, batch: Batch, batch_size: int, repeat: int):\n",
" \"\"\"Perform the back-propagation.\"\"\"\n",
" logging_losses = []\n",
" train_losses = []\n",
" for _ in range(repeat):\n",
" for minibatch in batch.split(batch_size, merge_last=True):\n",
" self.optim.zero_grad()\n",
" result = self(minibatch)\n",
" dist = result.dist\n",
" act = to_torch_as(minibatch.act, result.act)\n",
" ret = to_torch(minibatch.returns, torch.float, result.act.device)\n",
" log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1)\n",
" loss = -(log_prob * ret).mean()\n",
" loss.backward()\n",
" self.optim.step()\n",
" logging_losses.append(loss.item())\n",
" return {\"loss\": logging_losses}"
" optim_step(loss, self.optim)\n",
" train_losses.append(loss.item())\n",
"\n",
" return REINFORCETrainingStats(loss=SequenceSummaryStats.from_sequence(train_losses))"
]
},
{
Expand All @@ -256,7 +272,12 @@
},
"source": [
"## Implementation\n",
"Finally we can assemble the implemented methods and form a REINFORCE Policy."
"Now we can assemble the methods and form a REINFORCE Policy. The outputs of\n",
"`learn` will be collected to a dedicated dataclass.\n",
"\n",
"We will also use protocols to specify what fields are expected and produced inside a `Batch` in\n",
"each processing step. By using protocols, we can get better type checking and IDE support \n",
"without having to implement a separate class for each combination of fields."
]
},
{
Expand Down Expand Up @@ -290,30 +311,33 @@
" act = dist.sample()\n",
" return Batch(act=act, dist=dist)\n",
"\n",
" def process_fn(self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray) -> Batch:\n",
" def process_fn(\n",
" self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray\n",
" ) -> BatchWithReturnsProtocol:\n",
" \"\"\"Compute the discounted returns for each transition.\"\"\"\n",
" returns, _ = self.compute_episodic_return(\n",
" batch, buffer, indices, gamma=0.99, gae_lambda=1.0\n",
" )\n",
" batch.returns = returns\n",
" return batch\n",
" return cast(BatchWithReturnsProtocol, batch)\n",
"\n",
" def learn(self, batch: Batch, batch_size: int, repeat: int) -> Dict[str, List[float]]:\n",
" def learn(\n",
" self, batch: BatchWithReturnsProtocol, batch_size: int, repeat: int\n",
" ) -> REINFORCETrainingStats:\n",
" \"\"\"Perform the back-propagation.\"\"\"\n",
" logging_losses = []\n",
" train_losses = []\n",
" for _ in range(repeat):\n",
" for minibatch in batch.split(batch_size, merge_last=True):\n",
" self.optim.zero_grad()\n",
" result = self(minibatch)\n",
" dist = result.dist\n",
" act = to_torch_as(minibatch.act, result.act)\n",
" ret = to_torch(minibatch.returns, torch.float, result.act.device)\n",
" log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1)\n",
" loss = -(log_prob * ret).mean()\n",
" loss.backward()\n",
" self.optim.step()\n",
" logging_losses.append(loss.item())\n",
" return {\"loss\": logging_losses}"
" optim_step(loss, self.optim)\n",
" train_losses.append(loss.item())\n",
"\n",
" return REINFORCETrainingStats(loss=SequenceSummaryStats.from_sequence(train_losses))"
]
},
{
Expand Down Expand Up @@ -370,8 +394,8 @@
"source": [
"print(policy)\n",
"print(\"========================================\")\n",
"for para in policy.parameters():\n",
" print(para.shape)"
"for param in policy.parameters():\n",
" print(param.shape)"
]
},
{
Expand Down Expand Up @@ -831,6 +855,13 @@
"<img src=../_static/images/policy_table.svg></img>\n",
"</center>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
2 changes: 1 addition & 1 deletion docs/02_notebooks/L6_Trainer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@
"replaybuffer.reset()\n",
"for i in range(10):\n",
" evaluation_result = test_collector.collect(n_episode=10)\n",
" print(\"Evaluation reward is {}\".format(evaluation_result[\"rew\"]))\n",
" print(f\"Evaluation mean episodic reward is: {evaluation_result.returns.mean()}\")\n",
" train_collector.collect(n_step=2000)\n",
" # 0 means taking all data stored in train_collector.buffer\n",
" policy.update(0, train_collector.buffer, batch_size=512, repeat=1)\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/02_notebooks/L7_Experiment.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@
"# Let's watch its performance!\n",
"policy.eval()\n",
"result = test_collector.collect(n_episode=1, render=False)\n",
"print(\"Final reward: {}, length: {}\".format(result[\"rews\"].mean(), result[\"lens\"].mean()))"
"print(f\"Final episode reward: {result.returns.mean()}, length: {result.lens.mean()}\")"
]
}
],
Expand Down
7 changes: 7 additions & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,10 @@ logp
autogenerated
subpackage
subpackages
recurse
rollout
rollouts
prepend
prepends
dict
dicts
3 changes: 2 additions & 1 deletion tianshou/policy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Policy package."""
# isort:skip_file

from tianshou.policy.base import BasePolicy
from tianshou.policy.base import BasePolicy, TrainingStats
from tianshou.policy.random import RandomPolicy
from tianshou.policy.modelfree.dqn import DQNPolicy
from tianshou.policy.modelfree.bdq import BranchingDQNPolicy
Expand Down Expand Up @@ -63,4 +63,5 @@
"PSRLPolicy",
"ICMPolicy",
"MultiAgentPolicyManager",
"TrainingStats",
]
2 changes: 1 addition & 1 deletion tianshou/policy/modelfree/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def forward(
# TODO: why does mypy complain?
def learn( # type: ignore
self,
batch: RolloutBatchProtocol,
batch: BatchWithReturnsProtocol,
batch_size: int | None,
repeat: int,
*args: Any,
Expand Down
11 changes: 8 additions & 3 deletions tianshou/utils/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,24 @@
def optim_step(
loss: torch.Tensor,
optim: torch.optim.Optimizer,
module: nn.Module,
module: nn.Module | None = None,
max_grad_norm: float | None = None,
) -> None:
"""Perform a single optimization step.
"""Perform a single optimization step: zero_grad -> backward (-> clip_grad_norm) -> step.
:param loss:
:param optim:
:param module:
:param module: the module to optimize, required if max_grad_norm is passed
:param max_grad_norm: if passed, will clip gradients using this
"""
optim.zero_grad()
loss.backward()
if max_grad_norm:
if not module:
raise ValueError(
"module must be passed if max_grad_norm is passed. "
"Note: often the module will be the policy, i.e.`self`",
)
nn.utils.clip_grad_norm_(module.parameters(), max_norm=max_grad_norm)
optim.step()

Expand Down

0 comments on commit c5a6f0f

Please sign in to comment.