From c5a6f0f9d649739f1cf9e6e2ed1cfe1203a7828e Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Thu, 28 Dec 2023 23:02:52 +0100 Subject: [PATCH] Fixed remaining notebooks and spelling --- docs/02_notebooks/L4_Policy.ipynb | 91 ++++++++++++++++++--------- docs/02_notebooks/L6_Trainer.ipynb | 2 +- docs/02_notebooks/L7_Experiment.ipynb | 2 +- docs/spelling_wordlist.txt | 7 +++ tianshou/policy/__init__.py | 3 +- tianshou/policy/modelfree/pg.py | 2 +- tianshou/utils/optim.py | 11 +++- 7 files changed, 81 insertions(+), 37 deletions(-) diff --git a/docs/02_notebooks/L4_Policy.ipynb b/docs/02_notebooks/L4_Policy.ipynb index 4a815d07a..e31214fb4 100644 --- a/docs/02_notebooks/L4_Policy.ipynb +++ b/docs/02_notebooks/L4_Policy.ipynb @@ -54,16 +54,19 @@ }, "outputs": [], "source": [ - "from typing import Dict, List\n", + "from typing import cast\n", "\n", "import numpy as np\n", "import torch\n", "import gymnasium as gym\n", "\n", - "from tianshou.data import Batch, ReplayBuffer, to_torch, to_torch_as\n", - "from tianshou.policy import BasePolicy\n", + "from dataclasses import dataclass\n", + "\n", + "from tianshou.data import Batch, ReplayBuffer, to_torch, to_torch_as, SequenceSummaryStats\n", + "from tianshou.policy import BasePolicy, TrainingStats\n", "from tianshou.utils.net.common import Net\n", - "from tianshou.utils.net.discrete import Actor" + "from tianshou.utils.net.discrete import Actor\n", + "from tianshou.data.types import BatchWithReturnsProtocol, RolloutBatchProtocol" ] }, { @@ -102,12 +105,14 @@ "\n", "\n", "\n", - "1. Since Tianshou is a **Deep** RL libraries, there should be a policy network in our Policy Module, also a Torch optimizer.\n", - "2. In Tianshou's BasePolicy, `Policy.update()` first calls `Policy.process_fn()` to preprocess training data and computes quantities like episodic returns (gradient free), then it will call `Policy.learn()` to perform the back-propagation.\n", + "1. Since Tianshou is a **Deep** RL libraries, there should be a policy network in our Policy Module, \n", + "also a Torch optimizer.\n", + "2. In Tianshou's BasePolicy, `Policy.update()` first calls `Policy.process_fn()` to \n", + "preprocess training data and computes quantities like episodic returns (gradient free), \n", + "then it will call `Policy.learn()` to perform the back-propagation.\n", + "3. Each Policy is accompanied by a dedicated implementation of `TrainingStats` to store details of training.\n", "\n", "Then we get the implementation below.\n", - "\n", - "\n", "\n" ] }, @@ -119,7 +124,14 @@ }, "outputs": [], "source": [ - "class REINFORCEPolicy(BasePolicy):\n", + "@dataclass(kw_only=True)\n", + "class REINFORCETrainingStats(TrainingStats):\n", + " \"\"\"A dedicated class for REINFORCE training statistics.\"\"\"\n", + "\n", + " loss: SequenceSummaryStats\n", + "\n", + "\n", + "class REINFORCEPolicy(BasePolicy[REINFORCETrainingStats]):\n", " \"\"\"Implementation of REINFORCE algorithm.\"\"\"\n", "\n", " def __init__(\n", @@ -138,7 +150,7 @@ " \"\"\"Compute the discounted returns for each transition.\"\"\"\n", " pass\n", "\n", - " def learn(self, batch: Batch, batch_size: int, repeat: int) -> Dict[str, List[float]]:\n", + " def learn(self, batch: Batch, batch_size: int, repeat: int) -> REINFORCETrainingStats:\n", " \"\"\"Perform the back-propagation.\"\"\"\n", " return" ] @@ -220,7 +232,9 @@ }, "source": [ "### Policy.learn()\n", - "Data batch returned by `Policy.process_fn()` will flow into `Policy.learn()`. Final we can construct our loss function and perform the back-propagation." + "Data batch returned by `Policy.process_fn()` will flow into `Policy.learn()`. Finally,\n", + "we can construct our loss function and perform the back-propagation. The method \n", + "should look something like this:" ] }, { @@ -231,22 +245,24 @@ }, "outputs": [], "source": [ - "def learn(self, batch: Batch, batch_size: int, repeat: int) -> Dict[str, List[float]]:\n", + "from tianshou.utils.optim import optim_step\n", + "\n", + "\n", + "def learn(self, batch: Batch, batch_size: int, repeat: int):\n", " \"\"\"Perform the back-propagation.\"\"\"\n", - " logging_losses = []\n", + " train_losses = []\n", " for _ in range(repeat):\n", " for minibatch in batch.split(batch_size, merge_last=True):\n", - " self.optim.zero_grad()\n", " result = self(minibatch)\n", " dist = result.dist\n", " act = to_torch_as(minibatch.act, result.act)\n", " ret = to_torch(minibatch.returns, torch.float, result.act.device)\n", " log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1)\n", " loss = -(log_prob * ret).mean()\n", - " loss.backward()\n", - " self.optim.step()\n", - " logging_losses.append(loss.item())\n", - " return {\"loss\": logging_losses}" + " optim_step(loss, self.optim)\n", + " train_losses.append(loss.item())\n", + "\n", + " return REINFORCETrainingStats(loss=SequenceSummaryStats.from_sequence(train_losses))" ] }, { @@ -256,7 +272,12 @@ }, "source": [ "## Implementation\n", - "Finally we can assemble the implemented methods and form a REINFORCE Policy." + "Now we can assemble the methods and form a REINFORCE Policy. The outputs of\n", + "`learn` will be collected to a dedicated dataclass.\n", + "\n", + "We will also use protocols to specify what fields are expected and produced inside a `Batch` in\n", + "each processing step. By using protocols, we can get better type checking and IDE support \n", + "without having to implement a separate class for each combination of fields." ] }, { @@ -290,30 +311,33 @@ " act = dist.sample()\n", " return Batch(act=act, dist=dist)\n", "\n", - " def process_fn(self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray) -> Batch:\n", + " def process_fn(\n", + " self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray\n", + " ) -> BatchWithReturnsProtocol:\n", " \"\"\"Compute the discounted returns for each transition.\"\"\"\n", " returns, _ = self.compute_episodic_return(\n", " batch, buffer, indices, gamma=0.99, gae_lambda=1.0\n", " )\n", " batch.returns = returns\n", - " return batch\n", + " return cast(BatchWithReturnsProtocol, batch)\n", "\n", - " def learn(self, batch: Batch, batch_size: int, repeat: int) -> Dict[str, List[float]]:\n", + " def learn(\n", + " self, batch: BatchWithReturnsProtocol, batch_size: int, repeat: int\n", + " ) -> REINFORCETrainingStats:\n", " \"\"\"Perform the back-propagation.\"\"\"\n", - " logging_losses = []\n", + " train_losses = []\n", " for _ in range(repeat):\n", " for minibatch in batch.split(batch_size, merge_last=True):\n", - " self.optim.zero_grad()\n", " result = self(minibatch)\n", " dist = result.dist\n", " act = to_torch_as(minibatch.act, result.act)\n", " ret = to_torch(minibatch.returns, torch.float, result.act.device)\n", " log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1)\n", " loss = -(log_prob * ret).mean()\n", - " loss.backward()\n", - " self.optim.step()\n", - " logging_losses.append(loss.item())\n", - " return {\"loss\": logging_losses}" + " optim_step(loss, self.optim)\n", + " train_losses.append(loss.item())\n", + "\n", + " return REINFORCETrainingStats(loss=SequenceSummaryStats.from_sequence(train_losses))" ] }, { @@ -370,8 +394,8 @@ "source": [ "print(policy)\n", "print(\"========================================\")\n", - "for para in policy.parameters():\n", - " print(para.shape)" + "for param in policy.parameters():\n", + " print(param.shape)" ] }, { @@ -831,6 +855,13 @@ "\n", "" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/docs/02_notebooks/L6_Trainer.ipynb b/docs/02_notebooks/L6_Trainer.ipynb index db6b0fb86..da4010b70 100644 --- a/docs/02_notebooks/L6_Trainer.ipynb +++ b/docs/02_notebooks/L6_Trainer.ipynb @@ -146,7 +146,7 @@ "replaybuffer.reset()\n", "for i in range(10):\n", " evaluation_result = test_collector.collect(n_episode=10)\n", - " print(\"Evaluation reward is {}\".format(evaluation_result[\"rew\"]))\n", + " print(f\"Evaluation mean episodic reward is: {evaluation_result.returns.mean()}\")\n", " train_collector.collect(n_step=2000)\n", " # 0 means taking all data stored in train_collector.buffer\n", " policy.update(0, train_collector.buffer, batch_size=512, repeat=1)\n", diff --git a/docs/02_notebooks/L7_Experiment.ipynb b/docs/02_notebooks/L7_Experiment.ipynb index 55a3be144..ad450374e 100644 --- a/docs/02_notebooks/L7_Experiment.ipynb +++ b/docs/02_notebooks/L7_Experiment.ipynb @@ -303,7 +303,7 @@ "# Let's watch its performance!\n", "policy.eval()\n", "result = test_collector.collect(n_episode=1, render=False)\n", - "print(\"Final reward: {}, length: {}\".format(result[\"rews\"].mean(), result[\"lens\"].mean()))" + "print(f\"Final episode reward: {result.returns.mean()}, length: {result.lens.mean()}\")" ] } ], diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 9066e8694..1849f2efc 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -239,3 +239,10 @@ logp autogenerated subpackage subpackages +recurse +rollout +rollouts +prepend +prepends +dict +dicts diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index c8fa45e8e..5e6967ad7 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -1,7 +1,7 @@ """Policy package.""" # isort:skip_file -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import BasePolicy, TrainingStats from tianshou.policy.random import RandomPolicy from tianshou.policy.modelfree.dqn import DQNPolicy from tianshou.policy.modelfree.bdq import BranchingDQNPolicy @@ -63,4 +63,5 @@ "PSRLPolicy", "ICMPolicy", "MultiAgentPolicyManager", + "TrainingStats", ] diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index e16895484..7db588be6 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -207,7 +207,7 @@ def forward( # TODO: why does mypy complain? def learn( # type: ignore self, - batch: RolloutBatchProtocol, + batch: BatchWithReturnsProtocol, batch_size: int | None, repeat: int, *args: Any, diff --git a/tianshou/utils/optim.py b/tianshou/utils/optim.py index 0c1093cc9..c69ef71db 100644 --- a/tianshou/utils/optim.py +++ b/tianshou/utils/optim.py @@ -8,19 +8,24 @@ def optim_step( loss: torch.Tensor, optim: torch.optim.Optimizer, - module: nn.Module, + module: nn.Module | None = None, max_grad_norm: float | None = None, ) -> None: - """Perform a single optimization step. + """Perform a single optimization step: zero_grad -> backward (-> clip_grad_norm) -> step. :param loss: :param optim: - :param module: + :param module: the module to optimize, required if max_grad_norm is passed :param max_grad_norm: if passed, will clip gradients using this """ optim.zero_grad() loss.backward() if max_grad_norm: + if not module: + raise ValueError( + "module must be passed if max_grad_norm is passed. " + "Note: often the module will be the policy, i.e.`self`", + ) nn.utils.clip_grad_norm_(module.parameters(), max_norm=max_grad_norm) optim.step()