Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update gymnasium example with actual RL training #156

Merged
merged 2 commits into from
Sep 19, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
"\n",
"**This notebook shows you how to log your Gymnasium metrics with Comet.** For more information about Comet's integration with Gymnasium, visit our [Docs](https://www.comet.com/docs/v2/integrations/ml-frameworks/gymnasium/?utm_source=gymnasium&utm_medium=partner&utm_campaign=partner_gymnasium_2023&utm_content=comet_colab) page.\n",
"\n",
"If you prefer to preview what's to come, check out a completed experiment created from this notebook [here](https://www.comet.com/examples/comet-examples-gymnasium-notebook/58a1e400d18342fdabb4ddbbb07c9802?utm_source=gymnasium&utm_medium=partner&utm_campaign=partner_gymnasium_2023&utm_content=comet_colab)."
"If you prefer to preview what's to come, check out completed experiments created from this notebook [here](https://www.comet.com/examples/comet-examples-gymnasium-notebook/?utm_source=gymnasium&utm_medium=partner&utm_campaign=partner_gymnasium_2023&utm_content=comet_colab).\n",
"\n"
]
},
{
Expand All @@ -32,7 +33,7 @@
},
"outputs": [],
"source": [
"%pip install gymnasium[classic_control] comet_ml"
"%pip install 'gymnasium[classic-control]' comet_ml stable-baselines3"
]
},
{
Expand Down Expand Up @@ -65,7 +66,7 @@
"id": "031ezY2Dr2n4"
},
"source": [
"# Import Gymnasium and Initialize Your Enviornment"
"# Train an Agent using StableBaselines3 A2C Algorithm"
]
},
{
Expand All @@ -76,100 +77,72 @@
},
"outputs": [],
"source": [
"from comet_ml.integration.gymnasium import CometLogger\n",
"from stable_baselines3 import A2C\n",
"import gymnasium as gym\n",
"\n",
"env = gym.make(\"Acrobot-v1\", render_mode=\"rgb_array\")\n",
"\n",
"# Uncomment if you want to Upload Videos of your enviornment with Comet\n",
"# env = gym.wrappers.RecordVideo(env, 'test')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "g4c6nL7ysczO"
},
"source": [
"# Initialize your Comet Experiment and Wrap your Environment with the Comet Logger"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "bxUWMLHJtCxw"
},
"outputs": [],
"source": [
"from comet_ml.integration.gymnasium import CometLogger\n",
"# Uncomment if you want to Upload Videos of your environment to Comet\n",
"# env = gym.wrappers.RecordVideo(env, 'test')\n",
"\n",
"experiment = comet_ml.Experiment()\n",
"\n",
"env = CometLogger(env, experiment)"
"env = CometLogger(env, experiment)\n",
"\n",
"model = A2C(\"MlpPolicy\", env, verbose=0)\n",
"model.learn(total_timesteps=10000)\n",
"\n",
"env.close()\n",
"experiment.end()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RkHkaVn5t8O5"
},
"metadata": {},
"source": [
"# Step Through The Environment Randomly For 20 Episodes \n"
"# Train an Agent using StableBaselines3 PPO Algorithm"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Go-xDU-7uLl0"
},
"metadata": {},
"outputs": [],
"source": [
"for x in range(20):\n",
"\n",
" observation, info = env.reset()\n",
" truncated = False\n",
" terminated = False\n",
" while not (truncated or terminated):\n",
" observation, reward, terminated, truncated, info = env.step(\n",
" env.action_space.sample()\n",
" )\n",
" env.render()\n",
"\n",
"env.close() # Will Upload videos to Comet if RecordVideo was used"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EzpGq4xJuWcg"
},
"source": [
"# View Metrics like Cumulative Episode Reward and Episode Length in Comet\n",
"from stable_baselines3 import PPO\n",
"\n",
"\n",
"env = gym.make(\"Acrobot-v1\", render_mode=\"rgb_array\")\n",
"\n",
"# Uncomment if you want to Upload Videos of your environment to Comet\n",
"# env = gym.wrappers.RecordVideo(env, 'test')\n",
"\n",
"experiment = comet_ml.Experiment()\n",
"\n",
"env = CometLogger(env, experiment)\n",
"\n",
"After running an experiment, run this cell to view the Comet UI in this notebook. "
"model = PPO(\"MlpPolicy\", env, verbose=0)\n",
"model.learn(total_timesteps=10000)\n",
"\n",
"env.close()\n",
"experiment.end()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wEowdeOxuqnH"
},
"outputs": [],
"cell_type": "markdown",
"metadata": {},
"source": [
"experiment.display(tab=\"charts\")"
"# Use Comet's UI to Benchmark Different RL Algorithims "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yy6oDoYiKuSQ"
},
"metadata": {},
"outputs": [],
"source": [
"experiment.end()"
"experiment.display()"
]
}
],
Expand All @@ -192,9 +165,14 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
"version": "3.10.12"
},
"vscode": {
"interpreter": {
"hash": "8c9587381b2341d562742e36a89690be32a732b11830813473890249dd40a07d"
}
}
},
"nbformat": 4,
"nbformat_minor": 1
"nbformat_minor": 4
}
Loading