Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolves #115 #116

Open
wants to merge 8 commits into
base: solve-collectables-sb3
Choose a base branch
from
24 changes: 24 additions & 0 deletions collectables_check_env.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this file need to exist or be here?

Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from gymnasium import spaces
from pysc2.env import sc2_env
from stable_baselines3.common.env_checker import check_env

from urnai.sc2.actions.collectables import CollectablesActionSpace
from urnai.sc2.environments.sc2environment import SC2Env
from urnai.sc2.environments.stablebaselines3.custom_env import CustomEnv
from urnai.sc2.rewards.collectables import CollectablesReward
from urnai.sc2.states.collectables import CollectablesState

players = [sc2_env.Agent(sc2_env.Race.terran)]
env = SC2Env(map_name='CollectMineralShards', visualize=False,
step_mul=16, players=players)
state = CollectablesState()
urnai_action_space = CollectablesActionSpace()
reward = CollectablesReward()

# Define action and observation space
action_space = spaces.Discrete(n=4, start=0)
observation_space = spaces.Box(low=0, high=255, shape=(4096,), dtype=float)

custom_env = CustomEnv(env, state, urnai_action_space, reward, observation_space,
action_space)
check_env(custom_env, warn=True)
34 changes: 34 additions & 0 deletions docs/pysc2_usage.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# PySC2 usage in URNAI

## Actions in URNAI vs Actions in PySC2

In PySC2 you would normally call an action in the following manner:
`actions.X.F`. Where `X` is the action function set, such as `RAW_FUNCTIONS`
or `FUNCTIONS`, and `F` is the action function you are calling,
such as `no_op` or `Move_pt`.

But in URNAI, due to encapsulation, the call is done in the following manner:
`X[F].run()`. Where `X` is a dict containing classes for each of the action functions
in PySC2, and `F` is a string with the name of the action function you are calling.
The dictionaries mentioned above are stored in the file `sc2_actions.py`, and can
be imported such as in the example below:

```python
from urnai.sc2.actions.sc2_actions import raw_functions_classes as sc2_actions
```

### Examples

```python
actions.RAW_FUNCTIONS.no_op() #PySC2
sc2_actions["no_op"].run() #URNAI

actions.RAW_FUNCTIONS.Move_pt('now', unit.tag, [new_army_x, new_army_y]) #PySC2
sc2_actions["Move_pt"].run('now', unit.tag,[new_army_x, new_army_y]) #URNAI
Comment on lines +23 to +27
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
actions.RAW_FUNCTIONS.no_op() #PySC2
sc2_actions["no_op"].run() #URNAI
actions.RAW_FUNCTIONS.Move_pt('now', unit.tag, [new_army_x, new_army_y]) #PySC2
sc2_actions["Move_pt"].run('now', unit.tag,[new_army_x, new_army_y]) #URNAI
# PySC2
actions.RAW_FUNCTIONS.no_op()
actions.RAW_FUNCTIONS.Move_pt('now', unit.tag, [new_army_x, new_army_y])
# URNAI
sc2_actions["no_op"].run()
sc2_actions["Move_pt"].run('now', unit.tag,[new_army_x, new_army_y])

```

### Why?

URNAI chooses to represent actions as classes so they can be better organized and tested.
Therefore, this encapsulation, which transforms each PySC2 action into a class, contributes
to all environments being done under a single system, something that is desired.
67 changes: 67 additions & 0 deletions solve_collectables_sb3.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this file need to exist or be here?

Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os

from gymnasium import spaces
from pysc2.env import sc2_env
from stable_baselines3 import PPO

from urnai.sc2.actions.collectables import CollectablesActionSpace
from urnai.sc2.environments.sc2environment import SC2Env
from urnai.sc2.environments.stablebaselines3.custom_env import CustomEnv
from urnai.sc2.rewards.collectables import CollectablesReward
from urnai.sc2.states.collectables import CollectablesState

players = [sc2_env.Agent(sc2_env.Race.terran)]
env = SC2Env(map_name='CollectMineralShards', visualize=False,
step_mul=16, players=players)
state = CollectablesState()
urnai_action_space = CollectablesActionSpace()
reward = CollectablesReward()

# Define action and observation space
action_space = spaces.Discrete(n=4, start=0)
observation_space = spaces.Box(low=0, high=255, shape=(4096, ), dtype=float)

# Create the custom environment
custom_env = CustomEnv(env, state, urnai_action_space, reward, observation_space,
action_space)


# models_dir = "saves/models/DQN"
models_dir = "saves/models/PPO"
logdir = "saves/logs"

if not os.path.exists(models_dir):
os.makedirs(models_dir)

if not os.path.exists(logdir):
os.makedirs(logdir)

# If training from scratch, uncomment 1
# If loading a model, uncomment 2

## 1 - Train and Save model

# model=DQN("MlpPolicy",custom_env,buffer_size=100000,verbose=1,tensorboard_log=logdir)
model=PPO("MlpPolicy", custom_env, verbose=1, tensorboard_log=logdir)

TIMESTEPS = 10000
for i in range(1,30):
model.learn(total_timesteps=TIMESTEPS, reset_num_timesteps=False, tb_log_name="DQN")
model.save(f"{models_dir}/{TIMESTEPS*i}")

## 1 - End

## 2 - Load model
# model = PPO.load(f"{models_dir}/40000.zip", env = custom_env)
## 2 - End

vec_env = model.get_env()
obs = vec_env.reset()

# Test model
for _ in range(10000):
action, _state = model.predict(obs, deterministic=True)
# print(action)
obs, rewards, done, info = vec_env.step(action)

env.close()
4 changes: 0 additions & 4 deletions tests/units/actions/test_action_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,8 @@ def test_abstract_methods(self):

# WHEN
run_return = fake_action.run()
check_return = fake_action.check("observation")
is_complete_return = fake_action.is_complete

# THEN
assert fake_action.__id__ is None
assert isinstance(ActionBase, ABCMeta)
assert run_return is None
assert check_return is None
assert is_complete_return is None
29 changes: 29 additions & 0 deletions tests/units/actions/test_action_base_strict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import unittest
from abc import ABCMeta

from urnai.actions.action_base_strict import ActionBaseStrict


class FakeActionStrict(ActionBaseStrict):
ActionBaseStrict.__abstractmethods__ = set()
__id__ = None
...

class TestActionBase(unittest.TestCase):

def test_abstract_methods(self):

# GIVEN
fake_action = FakeActionStrict()

# WHEN
run_return = fake_action.run()
check_return = fake_action.check("observation")
is_complete_return = fake_action.is_complete

# THEN
assert fake_action.__id__ is None
assert isinstance(ActionBaseStrict, ABCMeta)
assert run_return is None
assert check_return is None
assert is_complete_return is None
8 changes: 4 additions & 4 deletions tests/units/sc2/actions/test_sc2_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pysc2.lib import actions

from urnai.sc2.actions.sc2_action import SC2Action
from urnai.sc2.actions.sc2_actions import raw_functions_classes as sc2_actions
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not:

Suggested change
from urnai.sc2.actions.sc2_actions import raw_functions_classes as sc2_actions
from urnai.sc2.actions import raw_functions_classes

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think using it with a smaller name is better. Is this a valid enough reason?


_BUILD_REFINERY = actions.RAW_FUNCTIONS.Build_Refinery_pt
_NO_OP = actions.FUNCTIONS.no_op
Expand All @@ -11,9 +11,9 @@ class TestSC2Action(unittest.TestCase):

def test_run(self):

run_no_op = SC2Action.run(_NO_OP)
run_build_refinery = SC2Action.run(_BUILD_REFINERY, 'now', 0)
run_no_op = sc2_actions["no_op"].run()
run_build_refinery = sc2_actions["Build_Refinery_pt"].run('now', 0)

self.assertEqual(run_no_op.function, _NO_OP.id)
self.assertEqual(run_no_op.arguments, [])

Expand Down
14 changes: 2 additions & 12 deletions urnai/actions/action_base.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,11 @@
from abc import ABC, abstractmethod
from typing import Any


class ActionBase(ABC):
__id__ = None

@abstractmethod
def run(self) -> None:
def run(*args) -> Any:
"""Executing the action"""
...

@abstractmethod
def check(self, obs) -> bool:
"""Returns whether the action can be executed or not"""
...

@property
@abstractmethod
def is_complete(self) -> bool:
"""Returns whether the action has finished or not"""
...
23 changes: 23 additions & 0 deletions urnai/actions/action_base_strict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from abc import abstractmethod
from typing import Any
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to avoid using typing because of that PEP we discussed earlier.


from urnai.actions.action_base import ActionBase


class ActionBaseStrict(ActionBase):

@abstractmethod
def run(*args) -> Any:
"""Executing the action"""
...

@abstractmethod
def check(self, obs) -> bool:
"""Returns whether the action can be executed or not"""
...

@property
@abstractmethod
def is_complete(self) -> bool:
"""Returns whether the action has finished or not"""
...
142 changes: 142 additions & 0 deletions urnai/sc2/actions/collectables.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this file from an experiment?

Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from statistics import mean

from pysc2.env import sc2_env

from urnai.actions.action_space_base import ActionSpaceBase
from urnai.sc2.actions import sc2_actions_aux as scaux
from urnai.sc2.actions.sc2_actions import raw_functions_classes as sc2_actions


class CollectablesActionSpace(ActionSpaceBase):

def __init__(self):
self.noaction = [sc2_actions["no_op"].run()]
self.move_number = 0

self.hor_threshold = 2
self.ver_threshold = 2

self.moveleft = 0
self.moveright = 1
self.moveup = 2
self.movedown = 3

self.excluded_actions = []

self.actions = [self.moveleft, self.moveright, self.moveup, self.movedown]
self.named_actions = ['move_left', 'move_right', 'move_up', 'move_down']
self.action_indices = range(len(self.actions))

self.pending_actions = []
self.named_actions = None

def is_action_done(self):
# return len(self.pending_actions) == 0
return True
Comment on lines +33 to +35
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either remove this function or put some logic in it.


def reset(self):
self.move_number = 0
self.pending_actions = []

def get_actions(self):
return self.action_indices
Comment on lines +41 to +42
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't understand that. Shouldn't it return the actions? Why is it returning an iterator?


def get_excluded_actions(self, obs):
return []
Comment on lines +44 to +45
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either remove this function or put some logic in it.


def get_action(self, action_idx, obs):
action = None
if len(self.pending_actions) == 0:
action = self.noaction
else:
action = [self.pending_actions.pop()]
Comment on lines +49 to +52
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if len(self.pending_actions) == 0:
action = self.noaction
else:
action = [self.pending_actions.pop()]
action = self.noaction
if len(self.pending_actions) > 0:
action = [self.pending_actions.pop()]

self.solve_action(action_idx, obs)
return action

def solve_action(self, action_idx, obs):
if action_idx is not None:
if action_idx is not self.noaction:
action = self.actions[action_idx]
if action == self.moveleft:
self.move_left(obs)
elif action == self.moveright:
self.move_right(obs)
elif action == self.moveup:
self.move_up(obs)
elif action == self.movedown:
self.move_down(obs)
else:
# if action_idx was None, this means that the actionwrapper
# was not resetted properly, so I will reset it here
# this is not the best way to fix this
# but until we cannot find why the agent is
# not resetting the action wrapper properly
# i'm gonna leave this here
self.reset()
Comment on lines +56 to +75
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def solve_action(self, action_idx, obs):
if action_idx is not None:
if action_idx is not self.noaction:
action = self.actions[action_idx]
if action == self.moveleft:
self.move_left(obs)
elif action == self.moveright:
self.move_right(obs)
elif action == self.moveup:
self.move_up(obs)
elif action == self.movedown:
self.move_down(obs)
else:
# if action_idx was None, this means that the actionwrapper
# was not resetted properly, so I will reset it here
# this is not the best way to fix this
# but until we cannot find why the agent is
# not resetting the action wrapper properly
# i'm gonna leave this here
self.reset()
def solve_action(self, action_idx, obs):
if action_idx is None:
self.reset()
return
...

I didn't understand the rest of the function. Could you add a description to the function describing what you are trying to do?


def move_left(self, obs):
army = scaux.select_army(obs, sc2_env.Race.terran)
xs = [unit.x for unit in army]
ys = [unit.y for unit in army]

new_army_x = int(mean(xs)) - self.hor_threshold
new_army_y = int(mean(ys))

for unit in army:
self.pending_actions.append(
sc2_actions["Move_pt"].run(
'now', unit.tag,[new_army_x, new_army_y]))

def move_right(self, obs):
army = scaux.select_army(obs, sc2_env.Race.terran)
xs = [unit.x for unit in army]
ys = [unit.y for unit in army]

new_army_x = int(mean(xs)) + self.hor_threshold
new_army_y = int(mean(ys))

for unit in army:
self.pending_actions.append(
sc2_actions["Move_pt"].run(
'now', unit.tag,[new_army_x, new_army_y]))

def move_down(self, obs):
army = scaux.select_army(obs, sc2_env.Race.terran)
xs = [unit.x for unit in army]
ys = [unit.y for unit in army]

new_army_x = int(mean(xs))
new_army_y = int(mean(ys)) + self.ver_threshold

for unit in army:
self.pending_actions.append(
sc2_actions["Move_pt"].run(
'now', unit.tag,[new_army_x, new_army_y]))

def move_up(self, obs):
army = scaux.select_army(obs, sc2_env.Race.terran)
xs = [unit.x for unit in army]
ys = [unit.y for unit in army]

new_army_x = int(mean(xs))
new_army_y = int(mean(ys)) - self.ver_threshold

for unit in army:
self.pending_actions.append(
sc2_actions["Move_pt"].run(
'now', unit.tag,[new_army_x, new_army_y]))

def get_action_name_str_by_int(self, action_int):
action_str = ''
for attrstr in dir(self):
attr = getattr(self, attrstr)
if action_int == attr:
action_str = attrstr

return action_str

def get_no_action(self):
return self.noaction

def get_named_actions(self):
return self.named_actions
9 changes: 0 additions & 9 deletions urnai/sc2/actions/sc2_action.py

This file was deleted.

Loading