Skip to content

Commit

Permalink
[Tutorial] PettingZoo Parallel competitive tutorial (#2047)
Browse files Browse the repository at this point in the history
Co-authored-by: Vincent Moens <vmoens@meta.com>
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
  • Loading branch information
3 people authored Apr 24, 2024
1 parent d1b0e2b commit 934100f
Show file tree
Hide file tree
Showing 6 changed files with 960 additions and 17 deletions.
3 changes: 2 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ imageio[ffmpeg,pyav]
memory_profiler
pyrender
pytest
vmas==1.2.11
vmas
pettingzoo[mpe]==1.24.3
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ Advanced
.. toctree::
:maxdepth: 1

tutorials/multiagent_competitive_ddpg
tutorials/multi_task
tutorials/coding_ddpg
tutorials/coding_dqn
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/libs/vmas.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def _get_default_group_map(self, agent_names: List[str]):
group_map = MarlGroupMapType.ALL_IN_ONE_GROUP.get_group_map(agent_names)

# For BC-compatibility rename the "agent" group to "agents"
if "agent" in group_map:
if "agent" in group_map and len(group_map) == 1:
agent_group = group_map["agent"]
group_map["agents"] = agent_group
del group_map["agent"]
Expand Down
5 changes: 5 additions & 0 deletions torchrl/record/loggers/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pathlib import Path
from typing import Dict, Optional, Sequence, Union

import tensordict.utils
import torch

from tensordict import MemoryMappedTensor
Expand Down Expand Up @@ -196,3 +197,7 @@ def __repr__(self) -> str:

def log_histogram(self, name: str, data: Sequence, **kwargs):
raise NotImplementedError("Logging histograms in cvs is not permitted.")

def print_log_dir(self):
"""Prints the log directory content."""
tensordict.utils.print_directory_tree(self.log_dir)
Loading

0 comments on commit 934100f

Please sign in to comment.