-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #79 from strakam/registry-creation
Registry creation
- Loading branch information
Showing
21 changed files
with
277 additions
and
151 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,47 +1,50 @@ | ||
from generals import pz_generals | ||
from generals.agents import RandomAgent, ExpanderAgent | ||
from generals import GridFactory | ||
|
||
# Initialize agents - their names are then called for actions | ||
randomer = RandomAgent("Random1", color=(255, 125, 0)) | ||
expander = ExpanderAgent("BigBoy") | ||
|
||
agents = { | ||
randomer.name: randomer, | ||
expander.name: expander, | ||
} | ||
import gymnasium as gym | ||
from generals import AgentFactory, GridFactory | ||
|
||
# Initialize agents -- see generals/agents/agent_factory.py for more options | ||
agent = AgentFactory.make_agent("expander") | ||
npc = AgentFactory.make_agent("random") | ||
|
||
# Initialize grid factory | ||
grid_factory = GridFactory( | ||
grid_dims=(10, 10), # Grid height and width | ||
mountain_density=0.2, # Expected percentage of mountains | ||
city_density=0.05, # Expected percentage of cities | ||
general_positions=[(1, 2), (7, 8)], # Positions of the generals | ||
seed=38 # Seed to generate the same map every time | ||
) | ||
|
||
gf = GridFactory( | ||
grid_dims=(4, 8), # height x width | ||
mountain_density=0.2, | ||
city_density=0.05, | ||
general_positions=[(0, 0), (3, 3)], | ||
env = gym.make( | ||
"gym-generals-v0", # Environment name | ||
grid_factory=grid_factory, # Grid factory | ||
agent=agent, # Your agent (used to get metadata like name and color) | ||
npc=npc, # NPC that will play against the agent | ||
render_mode="human", # "human" mode is for rendering, None is for no rendering | ||
) | ||
|
||
# Custom map that will override GridFactory for this game | ||
map = """ | ||
A..# | ||
.#3# | ||
...# | ||
##B# | ||
# We can draw custom maps - see symbol explanations in README | ||
grid = """ | ||
..#...##.. | ||
..A.#..4.. | ||
.3...1.... | ||
...###.... | ||
####...9.B | ||
...###.... | ||
.2...5.... | ||
....#..6.. | ||
..#...##.. | ||
""" | ||
|
||
# Create environment | ||
env = pz_generals(gf, agents, render_mode=None) # Disable rendering | ||
|
||
# Options are used only for the next game | ||
options = { | ||
"grid": map, | ||
"replay_file": "replay", | ||
"replay_file": "my_replay", # Save replay as my_replay.pkl | ||
"grid": grid # Use the custom map | ||
} | ||
|
||
observations, info = env.reset(options=options) | ||
done = False | ||
observation, info = env.reset(options=options) | ||
|
||
while not done: | ||
actions = {} | ||
for agent in env.agents: | ||
# Ask agent for action | ||
actions[agent] = agents[agent].play(observations[agent]) | ||
# All agents perform their actions | ||
observations, rewards, terminated, truncated, info = env.step(actions) | ||
done = any(terminated.values()) | ||
terminated = truncated = False | ||
while not (terminated or truncated): | ||
action = agent.act(observation) | ||
observation, reward, terminated, truncated, info = env.step(action) | ||
env.render() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,21 @@ | ||
from generals import gym_generals | ||
from generals.agents import RandomAgent, ExpanderAgent | ||
import gymnasium as gym | ||
from generals import AgentFactory | ||
|
||
# Initialize agents | ||
agent = RandomAgent() | ||
npc = ExpanderAgent() | ||
agent = AgentFactory.make_agent("expander") | ||
npc = AgentFactory.make_agent("random") | ||
|
||
# Create environment -- render modes: {None, "human"} | ||
env = gym_generals(agent=agent, npc=npc, render_mode="human") | ||
observation, info = env.reset() | ||
env = gym.make( | ||
"gym-generals-v0", | ||
agent=agent, | ||
npc=npc, | ||
render_mode="human", | ||
) | ||
|
||
done = False | ||
observation, info = env.reset() | ||
|
||
while not done: | ||
action = agent.play(observation) | ||
terminated = truncated = False | ||
while not (terminated or truncated): | ||
action = agent.act(observation) | ||
observation, reward, terminated, truncated, info = env.step(action) | ||
done = terminated or truncated | ||
env.render(fps=6) | ||
env.render() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import gymnasium as gym | ||
from generals import AgentFactory, GridFactory | ||
|
||
# Initialize agents -- see generals/agents/agent_factory.py for more options | ||
agent = AgentFactory.make_agent("expander") | ||
npc = AgentFactory.make_agent("random") | ||
|
||
# Initialize grid factory | ||
grid_factory = GridFactory( | ||
grid_dims=(5, 5), # Grid height and width | ||
mountain_density=0.2, # Expected percentage of mountains | ||
city_density=0.05, # Expected percentage of cities | ||
general_positions=[(1, 2), (3, 4)], # Positions of the generals | ||
seed=38 # Seed to generate the same map every time | ||
) | ||
|
||
env = gym.make( | ||
"gym-generals-v0", # Environment name | ||
grid_factory=grid_factory, # Grid factory | ||
agent=agent, # Your agent (used to get metadata like name and color) | ||
npc=npc, # NPC that will play against the agent | ||
) | ||
|
||
# Options are used only for the next game | ||
options = { | ||
"replay_file": "my_replay", # Save replay as my_replay.pkl | ||
} | ||
|
||
observation, info = env.reset(options=options) | ||
|
||
terminated = truncated = False | ||
while not (terminated or truncated): | ||
action = agent.act(observation) | ||
observation, reward, terminated, truncated, info = env.step(action) | ||
env.render() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from generals import Replay | ||
|
||
replay = Replay.load("replay.pkl") | ||
replay = Replay.load("my_replay.pkl") | ||
replay.play() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,29 @@ | ||
from .core.grid import GridFactory, Grid | ||
from .envs.env import pz_generals, gym_generals | ||
from .core.replay import Replay | ||
from .agents.agent_factory import AgentFactory | ||
from gymnasium.envs.registration import register | ||
|
||
|
||
__all__ = ['GridFactory', 'Grid', 'Replay', pz_generals, gym_generals] | ||
__all__ = [ | ||
"AgentFactory", | ||
"GridFactory", | ||
"Grid", | ||
"Replay", | ||
] | ||
|
||
|
||
def _register_generals_envs(): | ||
register( | ||
id="gym-generals-v0", | ||
entry_point="generals.envs.env:gym_generals_v0", | ||
) | ||
|
||
register( | ||
id="pz-generals-v0", | ||
entry_point="generals.envs.env:pz_generals_v0", | ||
disable_env_checker=True, | ||
) | ||
|
||
|
||
|
||
_register_generals_envs() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,7 @@ | ||
# agents/__init__.py | ||
|
||
from .random_agent import RandomAgent | ||
from .expander_agent import ExpanderAgent | ||
from .agent import Agent | ||
from .agent_factory import AgentFactory | ||
|
||
# You can also define an __all__ list if you want to restrict what gets imported with * | ||
__all__ = ["Agent", "RandomAgent", "ExpanderAgent"] | ||
__all__ = ["Agent", "AgentFactory"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.