Skip to content

Commit

Permalink
Divide into reinforcement learning and transformers to separate repos…
Browse files Browse the repository at this point in the history
…itories for efficiency and less bugs when size increases.
  • Loading branch information
hallvardnmbu committed Mar 12, 2024
1 parent 0ded3f3 commit 4cae69d
Show file tree
Hide file tree
Showing 36 changed files with 38 additions and 53 deletions.
39 changes: 11 additions & 28 deletions README.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
Modern applied deep learning with reinforcement and Transformer model methodology
Modern applied deep learning with reinforcement methodology

---

Special syllabus at NMBU (Norwegian University of Life Sciences)
Spring 2024

* Hallvard H. Lavik
* Leo Q. T. Bækholt
Special syllabus Spring 2024
Norwegian University of Life Sciences (NMBU)

---

Expand All @@ -15,33 +10,21 @@ used.

---

Syllabus:

Reinforcement Learning:
- "Human-level control through deep reinforcement learning" (doi:10.1038/nature14236)
- "Mastering Chess and Shogi by Self-Play with a General Reinforcement Learning Algorithm" (arXiv:1712.01815v1)

Transformer:
- "Geometry of deep learning" (ISBN 978-981-16-6046-7)
- Chapter 9.3 ("Attention")
- Chapter 9.4.5 ("Transformer")
- Chapter 9.4.7 ("Generative Pre-trained Transformer (GPT)")
- "Attention Is All You Need" (arXiv:1706.03762v7)
- "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding" (arXiv:1810.04805v2)
- "An image is worth 16x16 words: Transformers for image recognition at scale" (arXiv:2010.11929v2)
Reinforcement learning:
- "Human-level control through deep reinforcement learning"
doi:10.1038/nature14236
- "Mastering Chess and Shogi by Self-Play with a General Reinforcement Learning Algorithm"
arXiv:1712.01815v1

---

Learning goals:

- Understand and know how to build, use and deploy reinforcement learning algorithms
* Experiment with reinforcement agent(s) (for instance playing chess)
- Understand and know how to build, use and deploy Transformer architectures
* Experiment with architectures and applications (for instance, a language translator)

---
* Experiment with reinforcement agent(s) (for instance playing video-games)

Learning outcomes:

- Be competent in modern deep learning situations
* Understand (and to some extent be able to reproduce) cutting-edge “artificial intelligence” models
* Understand (and to some extent be able to reproduce) cutting-edge “artificial intelligence”
models
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
"cell_type": "code",
"outputs": [],
"source": [
"WEIGHTS = './weights-10000' # NB: without '.pth'\n",
"METRICS = './metrics.csv'"
"WEIGHTS = './_output/weights-15000.pth'\n",
"METRICS = './_output/metrics.csv'"
],
"metadata": {
"collapsed": false
Expand All @@ -37,8 +37,8 @@
"from DQN import VisionDeepQ\n",
"\n",
"sys.path.append(\"../\")\n",
"from utilities.visualisation.plot import visualise_csv_grouped_rewards # noqa\n",
"from utilities.visualisation.gif import gif_stacked # noqa"
"from utilities.visualisation.plot import graph # noqa\n",
"from utilities.visualisation.movie import movie # noqa"
],
"metadata": {
"collapsed": false
Expand Down Expand Up @@ -66,15 +66,17 @@
" \"kernels\": [8, 4, 3],\n",
" \"padding\": [\"valid\", \"valid\", \"valid\"],\n",
" \"strides\": [4, 2, 1],\n",
" \"nodes\": [128],\n",
" \"nodes\": [],\n",
"}\n",
"optimizer = {\n",
" \"optimizer\": torch.optim.Adam,\n",
" \"lr\": 0.0000625,\n",
" \"hyperparameters\": {\"eps\": 1.5e-4}\n",
" \"lr\": 1e-5,\n",
" \"hyperparameters\": {}\n",
"}\n",
"shape = {\n",
" \"original\": (1, 1, 210, 160),\n",
" \"width\": slice(7, -7),\n",
" \"height\": slice(31, -17),\n",
" \"max_pooling\": 2,\n",
"}\n",
"skip = 4"
Expand All @@ -101,10 +103,10 @@
"source": [
"value_agent = VisionDeepQ(\n",
" network=network, optimizer=optimizer, shape=shape,\n",
" exploration_rate=1.0,\n",
" exploration_rate=0.002,\n",
")\n",
"\n",
"weights = torch.load(f'{WEIGHTS}.pth', map_location=torch.device('cpu'))\n",
"weights = torch.load(WEIGHTS, map_location=torch.device('cpu'))\n",
"value_agent.load_state_dict(weights)\n",
"\n",
"environment = gym.make('ALE/Breakout-v5', render_mode=\"rgb_array\",\n",
Expand Down Expand Up @@ -141,7 +143,7 @@
"cell_type": "code",
"outputs": [],
"source": [
"visualise_csv_grouped_rewards(METRICS, title=\"Training history\", window=20) if METRICS else None\n",
"graph(METRICS, title=\"Training history\", window=20) if METRICS else None\n",
"plt.show() if METRICS else None"
],
"metadata": {
Expand All @@ -164,7 +166,7 @@
"cell_type": "code",
"outputs": [],
"source": [
"gif_stacked(environment, value_agent, f'./{WEIGHTS}.gif', skip)"
"movie(environment, value_agent, './_output/breakout.avi', fps=60)"
],
"metadata": {
"collapsed": false
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
"cell_type": "code",
"outputs": [],
"source": [
"WEIGHTS = './weights-10000' # NB: without '.pth'\n",
"METRICS = './metrics.csv'"
"WEIGHTS = './_output/weights-0.pth'\n",
"METRICS = None #'./_output/metrics.csv'"
],
"metadata": {
"collapsed": false
Expand All @@ -37,8 +37,8 @@
"from DQN import VisionDeepQ\n",
"\n",
"sys.path.append(\"../\")\n",
"from utilities.visualisation.plot import visualise_csv # noqa\n",
"from utilities.visualisation.gif import gif_stacked # noqa"
"from utilities.visualisation.plot import visualise_csv # noqa\n",
"from utilities.visualisation.gif import gif # noqa"
],
"metadata": {
"collapsed": false
Expand All @@ -61,16 +61,16 @@
"outputs": [],
"source": [
"network = {\n",
" \"input_channels\": 10, \"outputs\": 8,\n",
" \"input_channels\": 2, \"outputs\": 9,\n",
" \"channels\": [32, 64, 64],\n",
" \"kernels\": [8, 4, 3],\n",
" \"padding\": [\"valid\", \"valid\", \"valid\"],\n",
" \"strides\": [4, 2, 1],\n",
" \"nodes\": [512],\n",
"}\n",
"optimizer = {\n",
" \"optimizer\": torch.optim.Adam,\n",
" \"lr\": 0.00025,\n",
" \"optimizer\": torch.optim.RMSprop,\n",
" \"lr\": 0.0001,\n",
" \"hyperparameters\": {}\n",
"}\n",
"shape = {\n",
Expand Down Expand Up @@ -105,7 +105,7 @@
" exploration_rate=0.01,\n",
")\n",
"\n",
"weights = torch.load(f'{WEIGHTS}.pth', map_location=torch.device('cpu'))\n",
"weights = torch.load(WEIGHTS, map_location=torch.device('cpu'))\n",
"value_agent.load_state_dict(weights)\n",
"\n",
"environment = gym.make('ALE/Enduro-v5', render_mode=\"rgb_array\",\n",
Expand Down Expand Up @@ -165,7 +165,7 @@
"cell_type": "code",
"outputs": [],
"source": [
"gif_stacked(environment, value_agent, f'./{WEIGHTS}.gif', skip)"
"gif(environment, value_agent, './_output/enduro-0.gif', skip, 25)"
],
"metadata": {
"collapsed": false
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Binary file removed reinforcement-learning/tetris/_output/tetris.gif
Binary file not shown.
File renamed without changes.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Might need: brew install swig libjpeg libpng

gymnasium[all]
gymnasium[atari]
autorom[accept-rom-license]

torch
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
"from DQN import DeepQ\n",
"\n",
"sys.path.append(\"../\")\n",
"from utilities.visualisation.plot import visualise_csv_grouped_rewards # noqa\n",
"from utilities.visualisation.gif import gif_stacked # noqa"
"from utilities.visualisation.plot import graph # noqa\n",
"from utilities.visualisation.gif import gif # noqa"
],
"metadata": {
"collapsed": false,
Expand Down Expand Up @@ -160,7 +160,7 @@
}
],
"source": [
"visualise_csv_grouped_rewards(METRICS, title=\"Tetris (RAM)\", window=50) if METRICS else None\n",
"graph(METRICS, title=\"Tetris (RAM)\", window=50) if METRICS else None\n",
"plt.savefig('./_output/metrics.png') if METRICS else None\n",
"plt.show() if METRICS else None"
],
Expand Down Expand Up @@ -188,7 +188,7 @@
"cell_type": "code",
"outputs": [],
"source": [
"gif_stacked(environment, value_agent, './_output/tetris.gif', skip)"
"gif(environment, value_agent, './_output/tetris.gif', skip)"
],
"metadata": {
"collapsed": false,
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Empty file removed transformer/empty
Empty file.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 comments on commit 4cae69d

Please sign in to comment.