From 83359b673e8c1f04472b036787ef474d1fa322f6 Mon Sep 17 00:00:00 2001 From: RickFqt Date: Tue, 15 Oct 2024 17:57:54 -0300 Subject: [PATCH 1/8] feat: Added collectables and sb3 solve files --- collectables_check_env.py | 24 + solve_collectables_sb3.py | 67 ++ urnai/sc2/actions/collectables.py | 143 ++++ urnai/sc2/actions/sc2_actions_aux.py | 946 +++++++++++++++++++++++++++ urnai/sc2/rewards/collectables.py | 56 ++ urnai/sc2/states/collectables.py | 224 +++++++ 6 files changed, 1460 insertions(+) create mode 100644 collectables_check_env.py create mode 100644 solve_collectables_sb3.py create mode 100644 urnai/sc2/actions/collectables.py create mode 100644 urnai/sc2/actions/sc2_actions_aux.py create mode 100644 urnai/sc2/rewards/collectables.py create mode 100644 urnai/sc2/states/collectables.py diff --git a/collectables_check_env.py b/collectables_check_env.py new file mode 100644 index 0000000..c357082 --- /dev/null +++ b/collectables_check_env.py @@ -0,0 +1,24 @@ +from gymnasium import spaces +from pysc2.env import sc2_env +from stable_baselines3.common.env_checker import check_env + +from urnai.sc2.actions.collectables import CollectablesActionSpace +from urnai.sc2.environments.sc2environment import SC2Env +from urnai.sc2.environments.stablebaselines3.custom_env import CustomEnv +from urnai.sc2.rewards.collectables import CollectablesReward +from urnai.sc2.states.collectables import CollectablesState + +players = [sc2_env.Agent(sc2_env.Race.terran)] +env = SC2Env(map_name='CollectMineralShards', visualize=False, + step_mul=16, players=players) +state = CollectablesState() +urnai_action_space = CollectablesActionSpace() +reward = CollectablesReward() + +# Define action and observation space +action_space = spaces.Discrete(n=4, start=0) +observation_space = spaces.Box(low=0, high=255, shape=(4096,), dtype=float) + +custom_env = CustomEnv(env, state, urnai_action_space, reward, observation_space, + action_space) +check_env(custom_env, warn=True) \ No newline at end of file diff --git a/solve_collectables_sb3.py b/solve_collectables_sb3.py new file mode 100644 index 0000000..8452669 --- /dev/null +++ b/solve_collectables_sb3.py @@ -0,0 +1,67 @@ +import os + +from gymnasium import spaces +from pysc2.env import sc2_env +from stable_baselines3 import PPO + +from urnai.sc2.actions.collectables import CollectablesActionSpace +from urnai.sc2.environments.sc2environment import SC2Env +from urnai.sc2.environments.stablebaselines3.custom_env import CustomEnv +from urnai.sc2.rewards.collectables import CollectablesReward +from urnai.sc2.states.collectables import CollectablesState + +players = [sc2_env.Agent(sc2_env.Race.terran)] +env = SC2Env(map_name='CollectMineralShards', visualize=False, + step_mul=16, players=players) +state = CollectablesState() +urnai_action_space = CollectablesActionSpace() +reward = CollectablesReward() + +# Define action and observation space +action_space = spaces.Discrete(n=4, start=0) +observation_space = spaces.Box(low=0, high=255, shape=(4096, ), dtype=float) + +# Create the custom environment +custom_env = CustomEnv(env, state, urnai_action_space, reward, observation_space, + action_space) + + +# models_dir = "saves/models/DQN" +models_dir = "saves/models/PPO" +logdir = "saves/logs" + +if not os.path.exists(models_dir): + os.makedirs(models_dir) + +if not os.path.exists(logdir): + os.makedirs(logdir) + +# If training from scratch, uncomment 1 +# If loading a model, uncomment 2 + +## 1 - Train and Save model + +# model=DQN("MlpPolicy",custom_env,buffer_size=100000,verbose=1,tensorboard_log=logdir) +model=PPO("MlpPolicy", custom_env, verbose=1, tensorboard_log=logdir) + +TIMESTEPS = 10000 +for i in range(1,30): + model.learn(total_timesteps=TIMESTEPS, reset_num_timesteps=False, tb_log_name="DQN") + model.save(f"{models_dir}/{TIMESTEPS*i}") + +## 1 - End + +## 2 - Load model +# model = PPO.load(f"{models_dir}/40000.zip", env = custom_env) +## 2 - End + +vec_env = model.get_env() +obs = vec_env.reset() + +# Test model +for _ in range(10000): + action, _state = model.predict(obs, deterministic=True) + # print(action) + obs, rewards, done, info = vec_env.step(action) + +env.close() \ No newline at end of file diff --git a/urnai/sc2/actions/collectables.py b/urnai/sc2/actions/collectables.py new file mode 100644 index 0000000..8211bc2 --- /dev/null +++ b/urnai/sc2/actions/collectables.py @@ -0,0 +1,143 @@ +from statistics import mean + +from pysc2.env import sc2_env +from pysc2.lib import actions + +from urnai.actions.action_space_base import ActionSpaceBase +from urnai.sc2.actions import sc2_actions_aux as scaux +from urnai.sc2.actions.sc2_action import SC2Action + + +class CollectablesActionSpace(ActionSpaceBase): + + def __init__(self): + self.noaction = [actions.RAW_FUNCTIONS.no_op()] + self.move_number = 0 + + self.hor_threshold = 2 + self.ver_threshold = 2 + + self.moveleft = 0 + self.moveright = 1 + self.moveup = 2 + self.movedown = 3 + + self.excluded_actions = [] + + self.actions = [self.moveleft, self.moveright, self.moveup, self.movedown] + self.named_actions = ['move_left', 'move_right', 'move_up', 'move_down'] + self.action_indices = range(len(self.actions)) + + self.pending_actions = [] + self.named_actions = None + + def is_action_done(self): + # return len(self.pending_actions) == 0 + return True + + def reset(self): + self.move_number = 0 + self.pending_actions = [] + + def get_actions(self): + return self.action_indices + + def get_excluded_actions(self, obs): + return [] + + def get_action(self, action_idx, obs): + action = None + if len(self.pending_actions) == 0: + action = [actions.RAW_FUNCTIONS.no_op()] + else: + action = [self.pending_actions.pop()] + self.solve_action(action_idx, obs) + return action + + def solve_action(self, action_idx, obs): + if action_idx is not None: + if action_idx is not self.noaction: + action = self.actions[action_idx] + if action == self.moveleft: + self.move_left(obs) + elif action == self.moveright: + self.move_right(obs) + elif action == self.moveup: + self.move_up(obs) + elif action == self.movedown: + self.move_down(obs) + else: + # if action_idx was None, this means that the actionwrapper + # was not resetted properly, so I will reset it here + # this is not the best way to fix this + # but until we cannot find why the agent is + # not resetting the action wrapper properly + # i'm gonna leave this here + self.reset() + + def move_left(self, obs): + army = scaux.select_army(obs, sc2_env.Race.terran) + xs = [unit.x for unit in army] + ys = [unit.y for unit in army] + + new_army_x = int(mean(xs)) - self.hor_threshold + new_army_y = int(mean(ys)) + + for unit in army: + self.pending_actions.append( + SC2Action.run(actions.RAW_FUNCTIONS.Move_pt, + 'now', unit.tag, [new_army_x, new_army_y])) + + def move_right(self, obs): + army = scaux.select_army(obs, sc2_env.Race.terran) + xs = [unit.x for unit in army] + ys = [unit.y for unit in army] + + new_army_x = int(mean(xs)) + self.hor_threshold + new_army_y = int(mean(ys)) + + for unit in army: + self.pending_actions.append( + SC2Action.run(actions.RAW_FUNCTIONS.Move_pt, + 'now', unit.tag, [new_army_x, new_army_y])) + + def move_down(self, obs): + army = scaux.select_army(obs, sc2_env.Race.terran) + xs = [unit.x for unit in army] + ys = [unit.y for unit in army] + + new_army_x = int(mean(xs)) + new_army_y = int(mean(ys)) + self.ver_threshold + + for unit in army: + self.pending_actions.append( + SC2Action.run(actions.RAW_FUNCTIONS.Move_pt, + 'now', unit.tag, [new_army_x, new_army_y])) + + def move_up(self, obs): + army = scaux.select_army(obs, sc2_env.Race.terran) + xs = [unit.x for unit in army] + ys = [unit.y for unit in army] + + new_army_x = int(mean(xs)) + new_army_y = int(mean(ys)) - self.ver_threshold + + for unit in army: + self.pending_actions.append( + SC2Action.run(actions.RAW_FUNCTIONS.Move_pt, + 'now', unit.tag, [new_army_x, new_army_y])) + + def get_action_name_str_by_int(self, action_int): + action_str = '' + for attrstr in dir(self): + attr = getattr(self, attrstr) + if action_int == attr: + action_str = attrstr + + return action_str + + def get_no_action(self): + return self.noaction + + def get_named_actions(self): + return self.named_actions diff --git a/urnai/sc2/actions/sc2_actions_aux.py b/urnai/sc2/actions/sc2_actions_aux.py new file mode 100644 index 0000000..3ab15f2 --- /dev/null +++ b/urnai/sc2/actions/sc2_actions_aux.py @@ -0,0 +1,946 @@ +import random + +import numpy as np +from pysc2.env import sc2_env +from pysc2.lib import actions, features, units + +""" +An action set defines all actions an agent can use. In the case of StarCraft 2 +using PySC2, some actions require extra processing to work, so it's up to the +developper to come up with a way to make them work. + +Even though this is not called an action_wrapper, it effectively acts as a wrapper + +e.g: actions.RAW_FUNCTIONS.Build_Barracks_pt is a function implemented in PySC2 that +requires some extra arguments to work, like whether to build it now or to queue the +action, which worker is going to perform this action, and the target (given by a +[x, y] position) + +In this file we sort all of this issues, like deciding when to do an action, which +units to use for it, and where to build. +The methods in here effectively serve as a bridge between our high level actions +defined in sc2_wrapper.py and the PySC2 library. +""" + +# Defining constants for action ids, so our agent can check if an action is valid, and +# pass these actions as arguments to functions easily +_NO_OP = actions.FUNCTIONS.no_op + +_BUILD_COMMAND_CENTER = actions.RAW_FUNCTIONS.Build_CommandCenter_pt +_BUILD_SUPPLY_DEPOT = actions.RAW_FUNCTIONS.Build_SupplyDepot_pt +_BUILD_REFINERY = actions.RAW_FUNCTIONS.Build_Refinery_pt +_BUILD_ENGINEERINGBAY = actions.RAW_FUNCTIONS.Build_EngineeringBay_pt +_BUILD_ARMORY = actions.RAW_FUNCTIONS.Build_Armory_pt +_BUILD_MISSILETURRET = actions.RAW_FUNCTIONS.Build_MissileTurret_pt +_BUILD_SENSORTOWER = actions.RAW_FUNCTIONS.Build_SensorTower_pt +_BUILD_BUNKER = actions.RAW_FUNCTIONS.Build_Bunker_pt +_BUILD_FUSIONCORE = actions.RAW_FUNCTIONS.Build_FusionCore_pt +_BUILD_GHOSTACADEMY = actions.RAW_FUNCTIONS.Build_GhostAcademy_pt +_BUILD_BARRACKS = actions.RAW_FUNCTIONS.Build_Barracks_pt +_BUILD_FACTORY = actions.RAW_FUNCTIONS.Build_Factory_pt +_BUILD_STARPORT = actions.RAW_FUNCTIONS.Build_Starport_pt +_BUILD_TECHLAB_BARRACKS = actions.RAW_FUNCTIONS.Build_TechLab_Barracks_quick +_BUILD_TECHLAB_FACTORY = actions.RAW_FUNCTIONS.Build_TechLab_Factory_quick +_BUILD_TECHLAB_STARPORT = actions.RAW_FUNCTIONS.Build_TechLab_Starport_quick +_BUILD_REACTOR_BARRACKS = actions.RAW_FUNCTIONS.Build_Reactor_Barracks_quick +_BUILD_REACTOR_FACTORY = actions.RAW_FUNCTIONS.Build_Reactor_Factory_quick +_BUILD_REACTOR_STARPORT = actions.RAW_FUNCTIONS.Build_Reactor_Starport_quick + +"""ENGINEERING BAY RESEARCH""" +_RESEARCH_TERRAN_INF_WEAPONS = \ + actions.RAW_FUNCTIONS.Research_TerranInfantryWeapons_quick +_RESEARCH_TERRAN_INF_ARMOR = actions.RAW_FUNCTIONS.Research_TerranInfantryArmor_quick +_RESEARCH_TERRAN_HISEC_AUTOTRACKING = \ + actions.RAW_FUNCTIONS.Research_HiSecAutoTracking_quick +_RESEARCH_TERRAN_NEOSTEEL_FRAME = actions.RAW_FUNCTIONS.Research_NeosteelFrame_quick +_RESEARCH_TERRAN_STRUCTURE_ARMOR = \ + actions.RAW_FUNCTIONS.Research_TerranStructureArmorUpgrade_quick + +"""ARMORY RESEARCH""" +_RESEARCH_TERRAN_SHIPS_WEAPONS = actions.RAW_FUNCTIONS.Research_TerranShipWeapons_quick +_RESEARCH_TERRAN_VEHIC_WEAPONS = \ + actions.RAW_FUNCTIONS.Research_TerranVehicleWeapons_quick +_RESEARCH_TERRAN_SHIPVEHIC_PLATES = \ + actions.RAW_FUNCTIONS.Research_TerranVehicleAndShipPlating_quick + +"""GHOST ACADEMY RESEARCH""" +_RESEARCH_TERRAN_GHOST_CLOAK = actions.RAW_FUNCTIONS.Research_PersonalCloaking_quick + +"""BARRACK RESEARCH""" +_RESEARCH_TERRAN_STIMPACK = actions.RAW_FUNCTIONS.Research_Stimpack_quick +_RESEARCH_TERRAN_COMBATSHIELD = actions.RAW_FUNCTIONS.Research_CombatShield_quick +_RESEARCH_TERRAN_CONCUSSIVESHELL = actions.RAW_FUNCTIONS.Research_ConcussiveShells_quick + +"""FACTORY RESEARCH""" +_RESEARCH_TERRAN_INFERNAL_PREIGNITER = \ + actions.RAW_FUNCTIONS.Research_InfernalPreigniter_quick +_RESEARCH_TERRAN_DRILLING_CLAWS = actions.RAW_FUNCTIONS.Research_DrillingClaws_quick +# check if these two following research options are actually from the factory building +_RESEARCH_TERRAN_CYCLONE_LOCKONDMG = \ + actions.RAW_FUNCTIONS.Research_CycloneLockOnDamage_quick +_RESEARCH_TERRAN_CYCLONE_RAPIDFIRE = \ + actions.RAW_FUNCTIONS.Research_CycloneRapidFireLaunchers_quick + +"""STARPORT RESEARCH""" +_RESEARCH_TERRAN_HIGHCAPACITYFUEL = \ + actions.RAW_FUNCTIONS.Research_HighCapacityFuelTanks_quick +_RESEARCH_TERRAN_CORVIDREACTOR = actions.RAW_FUNCTIONS.Research_RavenCorvidReactor_quick +_RESEARCH_TERRAN_BANSHEECLOAK = \ + actions.RAW_FUNCTIONS.Research_BansheeCloakingField_quick +_RESEARCH_TERRAN_BANSHEEHYPERFLIGHT = \ + actions.RAW_FUNCTIONS.Research_BansheeHyperflightRotors_quick +_RESEARCH_TERRAN_ADVANCEDBALLISTICS = \ + actions.RAW_FUNCTIONS.Research_AdvancedBallistics_quick + +"""FUSION CORE RESEARCH""" +_RESEARCH_TERRAN_BATTLECRUISER_WEAPONREFIT =\ + actions.RAW_FUNCTIONS.Research_BattlecruiserWeaponRefit_quick + +"""TRAINING ACTIONS""" +_TRAIN_SCV = actions.RAW_FUNCTIONS.Train_SCV_quick +_TRAIN_MARINE = actions.RAW_FUNCTIONS.Train_Marine_quick +_TRAIN_MARAUDER = actions.RAW_FUNCTIONS.Train_Marauder_quick +_TRAIN_REAPER = actions.RAW_FUNCTIONS.Train_Reaper_quick +_TRAIN_GHOST = actions.RAW_FUNCTIONS.Train_Ghost_quick +_TRAIN_HELLION = actions.RAW_FUNCTIONS.Train_Hellion_quick +_TRAIN_HELLBAT = actions.RAW_FUNCTIONS.Train_Hellbat_quick +_TRAIN_SIEGETANK = actions.RAW_FUNCTIONS.Train_SiegeTank_quick +_TRAIN_CYCLONE = actions.RAW_FUNCTIONS.Train_Cyclone_quick +_TRAIN_WIDOWMINE = actions.RAW_FUNCTIONS.Train_WidowMine_quick +_TRAIN_THOR = actions.RAW_FUNCTIONS.Train_Thor_quick +_TRAIN_VIKING = actions.RAW_FUNCTIONS.Train_VikingFighter_quick +_TRAIN_MEDIVAC = actions.RAW_FUNCTIONS.Train_Medivac_quick +_TRAIN_LIBERATOR = actions.RAW_FUNCTIONS.Train_Liberator_quick +_TRAIN_RAVEN = actions.RAW_FUNCTIONS.Train_Raven_quick +_TRAIN_BANSHEE = actions.RAW_FUNCTIONS.Train_Banshee_quick +_TRAIN_BATTLECRUISER = actions.RAW_FUNCTIONS.Train_Battlecruiser_quick + +"""CALL DOWN ACTIONS""" +_CALL_DOWN_MULE = actions.RAW_FUNCTIONS.Effect_CalldownMULE_unit + +"""MORPH ACTIONS""" +_MORPH_ORBITAL_COMMAND = actions.RAW_FUNCTIONS.Morph_OrbitalCommand_quick +_MORPH_SIEGEMODE_TANK = actions.RAW_FUNCTIONS.Morph_SiegeMode_quick +_MORPH_UNSIEGE_TANK = actions.RAW_FUNCTIONS.Morph_Unsiege_quick + +"""UNIT EFFECTS""" +_EFFECT_STIMPACK = actions.RAW_FUNCTIONS.Effect_Stim_quick + +# PROTOSS ACTIONS + +_BUILD_PYLON = actions.RAW_FUNCTIONS.Build_Pylon_pt + +"""CONSTANTS USED TO DO GENERAL CHECKS""" +_NO_UNITS = 'no_units' +_TERRAN = sc2_env.Race.terran +_PROTOSS = sc2_env.Race.protoss +_ZERG = sc2_env.Race.zerg + + +def no_op(): + return actions.RAW_FUNCTIONS.no_op() + + +def build_structure_by_type(obs, action_id, player_race, target=None): + if player_race == _TERRAN: + worker = select_random_unit_by_type(obs, units.Terran.SCV) + elif player_race == _PROTOSS: + worker = select_random_unit_by_type(obs, units.Protoss.Probe) + else: + worker = select_random_unit_by_type(obs, units.Zerg.Drone) + + if worker != _NO_UNITS and target != _NO_UNITS: + if ' raw_cmd ' in str( + action_id.function_type): + #Checking if the build action is of type RAW_CMD + return action_id('now', + target.tag), _NO_UNITS #RAW_CMD actions only need [0]queue, + # [1]unit_tags and doesn't use a worker + + elif ' raw_cmd_pt ' in str( + action_id.function_type): + # Checking if the build action is of type RAW_CMD_PT + return action_id('now', worker.tag, + target), worker # RAW_CMD_PT actions need [0]queue, + # [1]unit_tags and [2]world_point + + elif ' raw_cmd_unit ' in str( + action_id.function_type): + # Checking if the build action is of type RAW_CMD_UNIT + return action_id('now', worker.tag, + target.tag), worker # RAW_CMD_UNIT actions need [0]queue, + # [1]unit_tags and [2]unit_tags + return _NO_OP(), _NO_UNITS + + +def research_upgrade(obs, action_id, building_type): + if unit_exists(obs, building_type): + buildings = get_units_by_type(obs, building_type) + for building in buildings: + if building.build_progress == 100 and building.order_progress_0 == 0: + return action_id('now', building.tag) + return _NO_OP() + + +def effect_units(action_id, units): + if len(units) > 0: + unit_tags = [unit.tag for unit in units] + return action_id('now', unit_tags) + return _NO_OP() + + +def train_unit(obs, action_id, building_type): + buildings = get_units_by_type(obs, building_type) + if len(buildings) > 0: + building_tags = [building.tag for building in buildings] + return action_id('now', building_tags) + return _NO_OP() + + +def calldown_mule(obs): + # the upgraded version of command center is required for this unit + orbital_command = get_units_by_type(obs, units.Terran.OrbitalCommand) + orbital_command.extend(get_units_by_type(obs, units.Terran.OrbitalCommandFlying)) + + mineral_fields = get_neutral_units_by_type(obs, units.Neutral.MineralField) + if len(mineral_fields) > 0 and len(orbital_command) > 0: + # part necessary to not fall into dimensional tensor errors + orbital_indexes = [x for x in range(len(orbital_command))] + choosen_index = np.random.choice(orbital_indexes) + choosen_orbital_command = orbital_command[choosen_index] + + if (choosen_orbital_command.build_progress == 100 and + choosen_orbital_command.energy >= 50): + # the orbital command spends 50 energy to make a mule + target = [choosen_orbital_command.x, choosen_orbital_command.y] + closest_mineral = get_closest_unit(obs, target, units_list=mineral_fields) + + if closest_mineral != _NO_UNITS: + return _CALL_DOWN_MULE('queued', choosen_orbital_command.tag, + closest_mineral.tag) + + return _NO_OP() + + +def attack_target_point(obs, player_race, target, base_top_left): + if not base_top_left: + target = (63 - target[0] - 5, 63 - target[1] + 5) + army = select_army(obs, player_race) + if len(army) > 0: + actions_queue = [] + army_tags = [unit.tag for unit in army] + actions_queue.append(actions.RAW_FUNCTIONS.Attack_pt('now', army_tags, target)) + return actions_queue + return [_NO_OP()] + + +def attack_target_point_spatial(units, target): + if len(units) > 0: + unit_tags = [unit.tag for unit in units] + actions_queue = [] + actions_queue.append(actions.RAW_FUNCTIONS.Attack_pt('now', unit_tags, target)) + return actions_queue + return [_NO_OP()] + + +def move_target_point_spatial(units, target): + if len(units) > 0: + unit_tags = [unit.tag for unit in units] + actions_queue = [] + actions_queue.append(actions.RAW_FUNCTIONS.Move_pt('now', unit_tags, target)) + return actions_queue + return [_NO_OP()] + + +def attack_distribute_army(obs, player_race): + army = select_army(obs, player_race) + if len(army) > 0: + actions_queue = [] + while len(army) != 0: + x_offset = random.randint(-8, 8) + y_offset = random.randint(-8, 8) + target = [army[0].x + x_offset, army[0].y + y_offset] + actions_queue.append( + actions.RAW_FUNCTIONS.Attack_pt('now', army[0].tag, target)) + army.pop(0) + return actions_queue + return [_NO_OP()] + + +def harvest_gather_minerals_quick(obs, worker, player_race): + if player_race == _TERRAN: + townhalls = get_units_by_type(obs, units.Terran.CommandCenter) + townhalls.extend(get_units_by_type(obs, units.Terran.PlanetaryFortress)) + townhalls.extend(get_units_by_type(obs, units.Terran.OrbitalCommand)) + if player_race == _PROTOSS: + townhalls = get_units_by_type(obs, units.Protoss.Nexus) + if player_race == _ZERG: + townhalls = get_units_by_type(obs, units.Zerg.Hatchery) + + if worker != _NO_UNITS: + mineral_fields = get_neutral_units_by_type(obs, units.Neutral.MineralField) + if len(mineral_fields) > 0 and len(townhalls) > 0: + # Checks every townhall if it is able to receive workers. If it is, + # searches for the closest mineral field + # If we find one, send the worker to gather minerals there. + for townhall in townhalls: + if townhall.build_progress == 100: + target = [townhall.x, townhall.y] + closest_mineral = get_closest_unit( + obs, target, units_list=mineral_fields) + if closest_mineral != _NO_UNITS: + return actions.RAW_FUNCTIONS.Harvest_Gather_unit( + 'queued', worker.tag, closest_mineral.tag) + + return _NO_OP() + + +def harvest_gather_minerals(obs, player_race): + if player_race == _TERRAN: + townhalls = get_units_by_type(obs, units.Terran.CommandCenter) + townhalls.extend(get_units_by_type(obs, units.Terran.PlanetaryFortress)) + townhalls.extend(get_units_by_type(obs, units.Terran.OrbitalCommand)) + workers = get_units_by_type(obs, units.Terran.SCV) + if player_race == _PROTOSS: + townhalls = get_units_by_type(obs, units.Protoss.Nexus) + workers = get_units_by_type(obs, units.Protoss.Probe) + if player_race == _ZERG: + townhalls = get_units_by_type(obs, units.Zerg.Hatchery) + workers = get_units_by_type(obs, units.Zerg.Drone) + + mineral_fields = get_neutral_units_by_type(obs, units.Neutral.MineralField) + if len(mineral_fields) > 0 and len(townhalls) > 0: + # Checks every townhall if it is able to receive workers. If it is, searches for + # closest mineral field + # If we find one, send the worker to gather minerals there. + for townhall in townhalls: + if townhall.build_progress == 100: + target = [townhall.x, townhall.y] + if len(workers) > 0: + distances = list(get_distances(obs, workers, target)) + while len(workers) != 0: + index = np.argmin(distances) + if (workers[index].order_id_0 == 362 + or workers[index].order_length == 0) \ + and distances[index] >= 2: + closest_mineral = get_closest_unit(obs, target, + units_list=mineral_fields) + if closest_mineral != _NO_UNITS: + return actions.RAW_FUNCTIONS.Harvest_Gather_unit( + 'queued', + workers[index].tag, + closest_mineral.tag) + else: + workers.pop(index) + distances.pop(index) + return _NO_OP() + + +def harvest_gather_minerals_idle(obs, player_race, idle_workers): + if player_race == _TERRAN: + townhalls = get_units_by_type(obs, units.Terran.CommandCenter) + townhalls.extend(get_units_by_type(obs, units.Terran.PlanetaryFortress)) + townhalls.extend(get_units_by_type(obs, units.Terran.OrbitalCommand)) + if player_race == _PROTOSS: + townhalls = get_units_by_type(obs, units.Protoss.Nexus) + if player_race == _ZERG: + townhalls = get_units_by_type(obs, units.Zerg.Hatchery) + + mineral_fields = get_neutral_units_by_type(obs, units.Neutral.MineralField) + if len(mineral_fields) > 0 and len(townhalls) > 0: + for townhall in townhalls: + if townhall.build_progress == 100: + target = [townhall.x, townhall.y] + worker = get_closest_unit(obs, target, units_list=idle_workers) + if worker != _NO_UNITS: + distances = get_distances(obs, mineral_fields, target) + closest_mineral_to_townhall = mineral_fields[np.argmin(distances)] + return actions.RAW_FUNCTIONS.Harvest_Gather_unit( + 'now', worker.tag, + closest_mineral_to_townhall.tag) + return _NO_OP() + + +def harvest_gather_gas(obs, player_race): + if player_race == _TERRAN: + gas_colectors = get_units_by_type(obs, units.Terran.Refinery) + workers = get_units_by_type(obs, units.Terran.SCV) + if player_race == _PROTOSS: + gas_colectors = get_units_by_type(obs, units.Protoss.Assimilator) + workers = get_units_by_type(obs, units.Protoss.Probe) + if player_race == _ZERG: + gas_colectors = get_units_by_type(obs, units.Zerg.Extractor) + workers = get_units_by_type(obs, units.Zerg.Drone) + + if len(gas_colectors) > 0 and len(workers) > 0: + for gas_colector in gas_colectors: + if (0 <= gas_colector.assigned_harvesters < 4 + and gas_colector.build_progress == 100): + target = [gas_colector.x, gas_colector.y] + + if len(workers) > 0: + distances = list(get_distances(obs, workers, target)) + + while len(workers) != 0: + index = np.argmin(distances) + if (workers[index].order_id_0 == 362 + or workers[index].order_length == 0) \ + and distances[index] >= 3: + return actions.RAW_FUNCTIONS.Harvest_Gather_unit( + 'queued', + workers[index].tag, + gas_colector.tag) + else: + workers.pop(index) + distances.pop(index) + return _NO_OP() + + +def harvest_gather_gas_idle(obs, player_race, idle_workers): + if player_race == _TERRAN: + # the terran townhall and its upgradable versions + townhalls = get_units_by_type(obs, units.Terran.CommandCenter) + townhalls.extend(get_units_by_type(obs, units.Terran.PlanetaryFortress)) + townhalls.extend(get_units_by_type(obs, units.Terran.OrbitalCommand)) + if player_race == _PROTOSS: + townhalls = get_units_by_type(obs, units.Protoss.Nexus) + if player_race == _ZERG: + townhalls = get_units_by_type(obs, units.Zerg.Hatchery) + + # sources of minerals (which are to harvest) + vespene_geysers = get_neutral_units_by_type(obs, units.Neutral.VespeneGeyser) + if len(vespene_geysers) > 0 and len(townhalls) > 0: + for townhall in townhalls: + if townhall.build_progress == 100: + target = [townhall.x, townhall.y] + worker = get_closest_unit(obs, target, units_list=idle_workers) + + if worker != _NO_UNITS: + distances = get_distances(obs, vespene_geysers, target) + closest_vespene_to_townhall = vespene_geysers[np.argmin(distances)] + return actions.RAW_FUNCTIONS.Harvest_Gather_unit( + 'now', worker.tag, + closest_vespene_to_townhall.tag) + return _NO_OP() + + +def harvest_return(obs, worker): + if worker != _NO_UNITS: + return actions.RAW_FUNCTIONS.Harvest_Return_quick('queued', worker.tag) + return _NO_OP() + + +def build_structure_raw(obs, building_type, building_action, max_amount=999): + player_race = get_unit_race(building_type) + + if get_my_units_amount(obs, building_type) < max_amount: + buildings = get_units_by_type(obs, building_type) + if len(buildings) > 0: + target = random.choice(buildings) + action_one, last_worker = build_structure_by_type(obs, building_action, + player_race, target) + action_two = harvest_gather_minerals_quick(obs, last_worker, player_race) + actions_queue = [action_one, action_two] + return actions_queue + + return [_NO_OP()] + + +def build_structure_raw_pt(obs, building_type, building_action, base_top_left, + max_amount=999, targets=None): + if targets is None: + targets = [] + ybrange = 0 if base_top_left else 32 + ytrange = 32 if base_top_left else 63 + + player_race = get_unit_race(building_type) + + building_amount = get_my_units_amount(obs, building_type) + if len(targets) == 0 or building_amount >= len(targets): + target = [random.randint(0, 63), random.randint(ybrange, ytrange)] + else: + target = targets[building_amount] + if not base_top_left: + target = (63 - target[0] - 5, 63 - target[1] + 5) + + if building_amount < max_amount: + action_one, last_worker = build_structure_by_type(obs, building_action, + player_race, target) + action_two = harvest_gather_minerals_quick(obs, last_worker, player_race) + actions_queue = [action_one, action_two] + return actions_queue + + return [_NO_OP()] + + +def build_structure_raw_pt_spatial(obs, building_type, building_action, target): + player_race = get_unit_race(building_type) + + try: + action_one, last_worker = build_structure_by_type(obs, building_action, + player_race, target) + action_two = harvest_gather_minerals_quick(obs, last_worker, player_race) + actions_queue = [action_one, action_two] + return actions_queue + except Exception: + return [_NO_OP()] + + +def build_gas_structure_raw_unit(obs, building_type, building_action, + player_race, max_amount=999): + player_race = get_unit_race(building_type) + if get_my_units_amount(obs, building_type) < max_amount: + chosen_geyser = get_exploitable_geyser(obs, player_race) + action_one, last_worker = build_structure_by_type(obs, building_action, + player_race, chosen_geyser) + action_two = harvest_gather_minerals_quick(obs, last_worker, player_race) + actions_queue = [action_one, action_two] + return actions_queue + return [_NO_OP()] + + +""" +The following methods are used to aid in various mechanical operations the agent has +to perform, such as: getting all units from a certain type, counting the amount of +free supply, etc +""" + + +def select_random_unit_by_type(obs, unit_type): + units = get_units_by_type(obs, unit_type) + + if len(units) > 0: + random_unit = random.choice(units) + return random_unit + return _NO_UNITS + + +def get_random_idle_worker(obs, player_race): + if player_race == _PROTOSS: + workers = get_units_by_type(obs, units.Protoss.Probe) + elif player_race == _TERRAN: + workers = get_units_by_type(obs, units.Terran.SCV) + elif player_race == _ZERG: + workers = get_units_by_type(obs, units.Zerg.Drone) + + if len(workers) > 0: + for worker in workers: + if worker.order_length == 0: # checking if worker is idle + return worker + return _NO_UNITS + + +def get_all_idle_workers(obs, player_race): + if player_race == _PROTOSS: + workers = get_units_by_type(obs, units.Protoss.Probe) + elif player_race == _TERRAN: + workers = get_units_by_type(obs, units.Terran.SCV) + elif player_race == _ZERG: + workers = get_units_by_type(obs, units.Zerg.Drone) + + idle_workers = [] + + if len(workers) > 0: + for worker in workers: + if worker.order_length == 0: # checking if worker is idle + idle_workers.append(worker) + return idle_workers + return _NO_UNITS + + +def get_closest_unit(obs, target_xy, unit_type=_NO_UNITS, units_list=_NO_UNITS): + if unit_type != _NO_UNITS: + units = get_units_by_type(obs, unit_type) + if len(units) > 0: + distances = get_distances(obs, units, target_xy) + min_dist_index = np.argmin(distances) + unit = units[min_dist_index] + return unit + + elif units_list != _NO_UNITS: + if len(units_list) != 0: + distances = get_distances(obs, units_list, target_xy) + min_dist_index = np.argmin(distances) + unit = units_list[min_dist_index] + return unit + return _NO_UNITS + + +# def get_my_units_by_type(obs, unit_type): +# return [unit for unit in obs.raw_units +# if unit.unit_type == unit_type +# and unit.alliance == features.PlayerRelative.SELF] + +def get_units_by_type(obs, unit_type, alliance=features.PlayerRelative.SELF): + return [unit for unit in obs.raw_units + if unit.unit_type == unit_type + and unit.alliance == alliance + and unit.build_progress == 100] + + +def can_queue_unit_terran(obs, unit_type): + structures = get_units_by_type(obs, unit_type) + for structure in structures: + # if we have less than 5 units on queue, we can queue another unit + if (structure.order_length < 5 or structure.addon_unit_type == 38 + and structure.order_length < 10): + return True + return False + + +def get_neutral_units_by_type(obs, unit_type): + return [unit for unit in obs.raw_units + if unit.unit_type == unit_type + and unit.alliance == features.PlayerRelative.NEUTRAL] + + +def get_all_neutral_units(obs): + return [unit for unit in obs.raw_units + if unit.alliance == features.PlayerRelative.NEUTRAL] + + +def get_free_supply(obs): + return obs.player.food_cap - obs.player.food_used + + +def get_unit_amount(obs, unit_type, player): + return len(get_units_by_type(obs, unit_type, player)) + + +def get_my_units_amount(obs, unit_type): + return len(get_units_by_type(obs, unit_type, features.PlayerRelative.SELF)) + + +def get_enemy_units_amount(obs, unit_type): + return len(get_units_by_type(obs, unit_type, features.PlayerRelative.ENEMY)) + + +def unit_exists(obs, unit_type): + if get_my_units_amount(obs, unit_type) > 0: + return True + return False + + +def get_exploitable_geyser(obs, player_race): + if player_race == _PROTOSS: + townhalls = get_units_by_type(obs, units.Protoss.Nexus) + elif player_race == _TERRAN: + townhalls = get_units_by_type(obs, units.Terran.CommandCenter) + townhalls.extend(get_units_by_type(obs, units.Terran.OrbitalCommand)) + townhalls.extend(get_units_by_type(obs, units.Terran.PlanetaryFortress)) + elif player_race == _ZERG: + townhalls = get_units_by_type(obs, units.Zerg.Hatchery) + townhalls.extend(get_units_by_type(obs, units.Zerg.Lair)) + townhalls.extend(get_units_by_type(obs, units.Zerg.Hive)) + geysers = get_neutral_units_by_type(obs, units.Neutral.VespeneGeyser) + if len(geysers) > 0 and len(townhalls) > 0: + for geyser in geysers: + for townhall in townhalls: + if get_euclidean_distance( + [geyser.x, geyser.y], [townhall.x, townhall.y]) < 10: + return geyser + return _NO_UNITS + + +def get_distances(obs, units, xy): + if len(units) > 0: + units_xy = [(unit.x, unit.y) for unit in units] + return np.linalg.norm(np.array(units_xy) - np.array(xy), axis=1) + pass + + +def get_euclidean_distance(unit_xy, xy): + return np.linalg.norm(np.array(unit_xy) - np.array(xy)) + + +def organize_queue(actions, actions_queue): + action = actions.pop(0) + while len(actions) > 0: + actions_queue.append(actions.pop(0)) + return action, actions_queue + + +# TO DO: Implement the following methods to facilitate checks and overall code reuse: + +# Create a 'get my units by types' where we pass instead of a single type an array of +# unit types and the return is an array of those units from the chosen types: +# possible function prototype: get_units_by_types(obs, unit_types) (maybe we can just +# reuse the get_units_by_type function and create a verification if unit_type is a +# single type or array of types) + +# check_unit_validity (should check if the object being received is a proper unit +# from pysc2) + +def select_all_race_units(obs, player_race): + army = [] + if player_race == _PROTOSS: + army.extend(get_units_by_type(obs, units.Protoss.Adept)) + army.extend(get_units_by_type(obs, units.Protoss.AdeptPhaseShift)) + army.extend(get_units_by_type(obs, units.Protoss.Archon)) + army.extend(get_units_by_type(obs, units.Protoss.Assimilator)) + army.extend(get_units_by_type(obs, units.Protoss.AssimilatorRich)) + army.extend(get_units_by_type(obs, units.Protoss.Carrier)) + army.extend(get_units_by_type(obs, units.Protoss.Colossus)) + army.extend(get_units_by_type(obs, units.Protoss.CyberneticsCore)) + army.extend(get_units_by_type(obs, units.Protoss.DarkShrine)) + army.extend(get_units_by_type(obs, units.Protoss.DarkTemplar)) + army.extend(get_units_by_type(obs, units.Protoss.Disruptor)) + army.extend(get_units_by_type(obs, units.Protoss.DisruptorPhased)) + army.extend(get_units_by_type(obs, units.Protoss.FleetBeacon)) + army.extend(get_units_by_type(obs, units.Protoss.ForceField)) + army.extend(get_units_by_type(obs, units.Protoss.Forge)) + army.extend(get_units_by_type(obs, units.Protoss.Gateway)) + army.extend(get_units_by_type(obs, units.Protoss.HighTemplar)) + army.extend(get_units_by_type(obs, units.Protoss.Immortal)) + army.extend(get_units_by_type(obs, units.Protoss.Interceptor)) + army.extend(get_units_by_type(obs, units.Protoss.Mothership)) + army.extend(get_units_by_type(obs, units.Protoss.MothershipCore)) + army.extend(get_units_by_type(obs, units.Protoss.Nexus)) + army.extend(get_units_by_type(obs, units.Protoss.Observer)) + army.extend(get_units_by_type(obs, units.Protoss.ObserverSurveillanceMode)) + army.extend(get_units_by_type(obs, units.Protoss.Oracle)) + army.extend(get_units_by_type(obs, units.Protoss.Phoenix)) + army.extend(get_units_by_type(obs, units.Protoss.PhotonCannon)) + army.extend(get_units_by_type(obs, units.Protoss.Probe)) + army.extend(get_units_by_type(obs, units.Protoss.Pylon)) + army.extend(get_units_by_type(obs, units.Protoss.PylonOvercharged)) + army.extend(get_units_by_type(obs, units.Protoss.RoboticsBay)) + army.extend(get_units_by_type(obs, units.Protoss.RoboticsFacility)) + army.extend(get_units_by_type(obs, units.Protoss.Sentry)) + army.extend(get_units_by_type(obs, units.Protoss.ShieldBattery)) + army.extend(get_units_by_type(obs, units.Protoss.Stalker)) + army.extend(get_units_by_type(obs, units.Protoss.Stargate)) + army.extend(get_units_by_type(obs, units.Protoss.StasisTrap)) + army.extend(get_units_by_type(obs, units.Protoss.Tempest)) + army.extend(get_units_by_type(obs, units.Protoss.TemplarArchive)) + army.extend(get_units_by_type(obs, units.Protoss.TwilightCouncil)) + army.extend(get_units_by_type(obs, units.Protoss.VoidRay)) + army.extend(get_units_by_type(obs, units.Protoss.WarpGate)) + army.extend(get_units_by_type(obs, units.Protoss.WarpPrism)) + army.extend(get_units_by_type(obs, units.Protoss.WarpPrismPhasing)) + army.extend(get_units_by_type(obs, units.Protoss.Zealot)) + elif player_race == _TERRAN: + army.extend(get_units_by_type(obs, units.Terran.Armory)) + army.extend(get_units_by_type(obs, units.Terran.AutoTurret)) + army.extend(get_units_by_type(obs, units.Terran.Banshee)) + army.extend(get_units_by_type(obs, units.Terran.Barracks)) + army.extend(get_units_by_type(obs, units.Terran.BarracksFlying)) + army.extend(get_units_by_type(obs, units.Terran.BarracksReactor)) + army.extend(get_units_by_type(obs, units.Terran.BarracksTechLab)) + army.extend(get_units_by_type(obs, units.Terran.Battlecruiser)) + army.extend(get_units_by_type(obs, units.Terran.Bunker)) + army.extend(get_units_by_type(obs, units.Terran.CommandCenter)) + army.extend(get_units_by_type(obs, units.Terran.CommandCenterFlying)) + army.extend(get_units_by_type(obs, units.Terran.Cyclone)) + army.extend(get_units_by_type(obs, units.Terran.EngineeringBay)) + army.extend(get_units_by_type(obs, units.Terran.Factory)) + army.extend(get_units_by_type(obs, units.Terran.FactoryFlying)) + army.extend(get_units_by_type(obs, units.Terran.FactoryReactor)) + army.extend(get_units_by_type(obs, units.Terran.FactoryTechLab)) + army.extend(get_units_by_type(obs, units.Terran.FusionCore)) + army.extend(get_units_by_type(obs, units.Terran.Ghost)) + army.extend(get_units_by_type(obs, units.Terran.GhostAcademy)) + army.extend(get_units_by_type(obs, units.Terran.GhostAlternate)) + army.extend(get_units_by_type(obs, units.Terran.GhostNova)) + army.extend(get_units_by_type(obs, units.Terran.Hellion)) + army.extend(get_units_by_type(obs, units.Terran.Hellbat)) + army.extend(get_units_by_type(obs, units.Terran.KD8Charge)) + army.extend(get_units_by_type(obs, units.Terran.Liberator)) + army.extend(get_units_by_type(obs, units.Terran.LiberatorAG)) + army.extend(get_units_by_type(obs, units.Terran.MULE)) + army.extend(get_units_by_type(obs, units.Terran.Marauder)) + army.extend(get_units_by_type(obs, units.Terran.Marine)) + army.extend(get_units_by_type(obs, units.Terran.Medivac)) + army.extend(get_units_by_type(obs, units.Terran.MissileTurret)) + army.extend(get_units_by_type(obs, units.Terran.Nuke)) + army.extend(get_units_by_type(obs, units.Terran.OrbitalCommand)) + army.extend(get_units_by_type(obs, units.Terran.OrbitalCommandFlying)) + army.extend(get_units_by_type(obs, units.Terran.PlanetaryFortress)) + army.extend(get_units_by_type(obs, units.Terran.PointDefenseDrone)) + army.extend(get_units_by_type(obs, units.Terran.Raven)) + army.extend(get_units_by_type(obs, units.Terran.Reactor)) + army.extend(get_units_by_type(obs, units.Terran.Reaper)) + army.extend(get_units_by_type(obs, units.Terran.Refinery)) + army.extend(get_units_by_type(obs, units.Terran.RefineryRich)) + army.extend(get_units_by_type(obs, units.Terran.RepairDrone)) + army.extend(get_units_by_type(obs, units.Terran.SCV)) + army.extend(get_units_by_type(obs, units.Terran.SensorTower)) + army.extend(get_units_by_type(obs, units.Terran.SiegeTank)) + army.extend(get_units_by_type(obs, units.Terran.SiegeTankSieged)) + army.extend(get_units_by_type(obs, units.Terran.Starport)) + army.extend(get_units_by_type(obs, units.Terran.StarportFlying)) + army.extend(get_units_by_type(obs, units.Terran.StarportReactor)) + army.extend(get_units_by_type(obs, units.Terran.StarportTechLab)) + army.extend(get_units_by_type(obs, units.Terran.SupplyDepot)) + army.extend(get_units_by_type(obs, units.Terran.SupplyDepotLowered)) + army.extend(get_units_by_type(obs, units.Terran.TechLab)) + army.extend(get_units_by_type(obs, units.Terran.Thor)) + army.extend(get_units_by_type(obs, units.Terran.ThorHighImpactMode)) + army.extend(get_units_by_type(obs, units.Terran.VikingAssault)) + army.extend(get_units_by_type(obs, units.Terran.VikingFighter)) + army.extend(get_units_by_type(obs, units.Terran.WidowMine)) + army.extend(get_units_by_type(obs, units.Terran.WidowMineBurrowed)) + elif player_race == _ZERG: + army.extend(get_units_by_type(obs, units.Zerg.Roach)) + army.extend(get_units_by_type(obs, units.Zerg.Baneling)) + army.extend(get_units_by_type(obs, units.Zerg.BanelingBurrowed)) + army.extend(get_units_by_type(obs, units.Zerg.BanelingCocoon)) + army.extend(get_units_by_type(obs, units.Zerg.BanelingNest)) + army.extend(get_units_by_type(obs, units.Zerg.BroodLord)) + army.extend(get_units_by_type(obs, units.Zerg.BroodLordCocoon)) + army.extend(get_units_by_type(obs, units.Zerg.Broodling)) + army.extend(get_units_by_type(obs, units.Zerg.BroodlingEscort)) + army.extend(get_units_by_type(obs, units.Zerg.Changeling)) + army.extend(get_units_by_type(obs, units.Zerg.ChangelingMarine)) + army.extend(get_units_by_type(obs, units.Zerg.ChangelingMarineShield)) + army.extend(get_units_by_type(obs, units.Zerg.ChangelingZealot)) + army.extend(get_units_by_type(obs, units.Zerg.ChangelingZergling)) + army.extend(get_units_by_type(obs, units.Zerg.ChangelingZerglingWings)) + army.extend(get_units_by_type(obs, units.Zerg.Cocoon)) + army.extend(get_units_by_type(obs, units.Zerg.Corruptor)) + army.extend(get_units_by_type(obs, units.Zerg.CreepTumor)) + army.extend(get_units_by_type(obs, units.Zerg.CreepTumorBurrowed)) + army.extend(get_units_by_type(obs, units.Zerg.CreepTumorQueen)) + army.extend(get_units_by_type(obs, units.Zerg.Drone)) + army.extend(get_units_by_type(obs, units.Zerg.DroneBurrowed)) + army.extend(get_units_by_type(obs, units.Zerg.EvolutionChamber)) + army.extend(get_units_by_type(obs, units.Zerg.Extractor)) + army.extend(get_units_by_type(obs, units.Zerg.ExtractorRich)) + army.extend(get_units_by_type(obs, units.Zerg.GreaterSpire)) + army.extend(get_units_by_type(obs, units.Zerg.Hatchery)) + army.extend(get_units_by_type(obs, units.Zerg.Hive)) + army.extend(get_units_by_type(obs, units.Zerg.Hydralisk)) + army.extend(get_units_by_type(obs, units.Zerg.HydraliskBurrowed)) + army.extend(get_units_by_type(obs, units.Zerg.HydraliskDen)) + army.extend(get_units_by_type(obs, units.Zerg.InfestationPit)) + army.extend(get_units_by_type(obs, units.Zerg.InfestedTerran)) + army.extend(get_units_by_type(obs, units.Zerg.InfestedTerranBurrowed)) + army.extend(get_units_by_type(obs, units.Zerg.InfestedTerranCocoon)) + army.extend(get_units_by_type(obs, units.Zerg.Infestor)) + army.extend(get_units_by_type(obs, units.Zerg.InfestorBurrowed)) + army.extend(get_units_by_type(obs, units.Zerg.Lair)) + army.extend(get_units_by_type(obs, units.Zerg.Larva)) + army.extend(get_units_by_type(obs, units.Zerg.Locust)) + army.extend(get_units_by_type(obs, units.Zerg.LocustFlying)) + army.extend(get_units_by_type(obs, units.Zerg.Lurker)) + army.extend(get_units_by_type(obs, units.Zerg.LurkerBurrowed)) + army.extend(get_units_by_type(obs, units.Zerg.LurkerDen)) + army.extend(get_units_by_type(obs, units.Zerg.LurkerCocoon)) + army.extend(get_units_by_type(obs, units.Zerg.Mutalisk)) + army.extend(get_units_by_type(obs, units.Zerg.NydusCanal)) + army.extend(get_units_by_type(obs, units.Zerg.NydusNetwork)) + army.extend(get_units_by_type(obs, units.Zerg.Overlord)) + army.extend(get_units_by_type(obs, units.Zerg.OverlordTransport)) + army.extend(get_units_by_type(obs, units.Zerg.OverlordTransportCocoon)) + army.extend(get_units_by_type(obs, units.Zerg.Overseer)) + army.extend(get_units_by_type(obs, units.Zerg.OverseerCocoon)) + army.extend(get_units_by_type(obs, units.Zerg.OverseerOversightMode)) + army.extend(get_units_by_type(obs, units.Zerg.ParasiticBombDummy)) + army.extend(get_units_by_type(obs, units.Zerg.Queen)) + army.extend(get_units_by_type(obs, units.Zerg.QueenBurrowed)) + army.extend(get_units_by_type(obs, units.Zerg.Ravager)) + army.extend(get_units_by_type(obs, units.Zerg.RavagerBurrowed)) + army.extend(get_units_by_type(obs, units.Zerg.RavagerCocoon)) + army.extend(get_units_by_type(obs, units.Zerg.Roach)) + army.extend(get_units_by_type(obs, units.Zerg.RoachBurrowed)) + army.extend(get_units_by_type(obs, units.Zerg.RoachWarren)) + army.extend(get_units_by_type(obs, units.Zerg.SpawningPool)) + army.extend(get_units_by_type(obs, units.Zerg.SpineCrawler)) + army.extend(get_units_by_type(obs, units.Zerg.SpineCrawlerUprooted)) + army.extend(get_units_by_type(obs, units.Zerg.Spire)) + army.extend(get_units_by_type(obs, units.Zerg.SporeCrawler)) + army.extend(get_units_by_type(obs, units.Zerg.SporeCrawlerUprooted)) + army.extend(get_units_by_type(obs, units.Zerg.SwarmHost)) + army.extend(get_units_by_type(obs, units.Zerg.SwarmHostBurrowed)) + army.extend(get_units_by_type(obs, units.Zerg.Ultralisk)) + army.extend(get_units_by_type(obs, units.Zerg.UltraliskBurrowed)) + army.extend(get_units_by_type(obs, units.Zerg.UltraliskCavern)) + army.extend(get_units_by_type(obs, units.Zerg.Viper)) + army.extend(get_units_by_type(obs, units.Zerg.Zergling)) + army.extend(get_units_by_type(obs, units.Zerg.ZerglingBurrowed)) + return army + + +def select_army(obs, player_race): + army = [] + if player_race == _PROTOSS: + army_unit_types = [ + units.Protoss.Adept, units.Protoss.AdeptPhaseShift, units.Protoss.Archon, + units.Protoss.Carrier, units.Protoss.Colossus, units.Protoss.DarkTemplar, + units.Protoss.Disruptor, units.Protoss.DisruptorPhased, + units.Protoss.HighTemplar, units.Protoss.Immortal, units.Protoss.Mothership, + units.Protoss.Observer, units.Protoss.ObserverSurveillanceMode, + units.Protoss.Oracle, units.Protoss.Phoenix, units.Protoss.Sentry, + units.Protoss.Stalker, units.Protoss.Tempest, + units.Protoss.VoidRay, units.Protoss.Zealot, + ] + + army = [unit for unit in obs.raw_units if + (unit.alliance == features.PlayerRelative.SELF + and unit.unit_type in army_unit_types)] + + elif player_race == _TERRAN: + army_unit_types = [units.Terran.Marine, units.Terran.Marauder, + units.Terran.Reaper, units.Terran.Ghost, + units.Terran.Hellion, units.Terran.Hellbat, + units.Terran.SiegeTank, units.Terran.Cyclone, + units.Terran.WidowMine, units.Terran.Thor, + units.Terran.ThorHighImpactMode, units.Terran.VikingAssault, + units.Terran.VikingFighter, units.Terran.Medivac, + units.Terran.Liberator, units.Terran.LiberatorAG, + units.Terran.Raven, units.Terran.Banshee, + units.Terran.Battlecruiser] + + army = [unit for unit in obs.raw_units if + unit.alliance == features.PlayerRelative.SELF \ + and unit.unit_type in army_unit_types] + + elif player_race == _ZERG: + army_unit_types = [ + units.Zerg.Baneling, units.Zerg.BanelingBurrowed, units.Zerg.BanelingCocoon, + units.Zerg.BroodLord, units.Zerg.BroodLordCocoon, units.Zerg.Broodling, + units.Zerg.BroodlingEscort, units.Zerg.Changeling, + units.Zerg.ChangelingMarine, units.Zerg.ChangelingMarineShield, + units.Zerg.ChangelingZealot, units.Zerg.ChangelingZergling, + units.Zerg.ChangelingZerglingWings, units.Zerg.Corruptor, + units.Zerg.Hydralisk, units.Zerg.HydraliskBurrowed, units.Zerg.Infestor, + units.Zerg.InfestorBurrowed, units.Zerg.Locust, units.Zerg.LocustFlying, + units.Zerg.Lurker, units.Zerg.LurkerBurrowed, units.Zerg.LurkerCocoon, + units.Zerg.Mutalisk, units.Zerg.Overseer, units.Zerg.OverseerCocoon, + units.Zerg.OverseerOversightMode, units.Zerg.Queen, + units.Zerg.QueenBurrowed, units.Zerg.Ravager, units.Zerg.RavagerBurrowed, + units.Zerg.RavagerCocoon, units.Zerg.Roach, units.Zerg.RoachBurrowed, + units.Zerg.SwarmHost, units.Zerg.SwarmHostBurrowed, + units.Zerg.Ultralisk, units.Zerg.UltraliskBurrowed, units.Zerg.Viper, + units.Zerg.Zergling, units.Zerg.ZerglingBurrowed, + ] + + army = [unit for unit in obs.raw_units if + unit.alliance == features.PlayerRelative.SELF \ + and unit.unit_type in army_unit_types] + + return army + + +def get_unit_race(unit_type): + if unit_type in units.Terran: + return _TERRAN + if unit_type in units.Protoss: + return _PROTOSS + if unit_type in units.Zerg: + return _ZERG + + +""" +move a unit to a new position based on +https://gist.github.com/fyr91/ +168996a23f5675536dbf6f1cf75b30d6#file-defeat_zerglings_banelings_env_5-py-L41 +""" + + +def move_to(obs, unit, dest_x, dest_y): + target = [dest_x, dest_y] + try: + return actions.RAW_FUNCTIONS.Move_pt('now', unit.tag, target) + except Exception: + return no_op() diff --git a/urnai/sc2/rewards/collectables.py b/urnai/sc2/rewards/collectables.py new file mode 100644 index 0000000..af2a67c --- /dev/null +++ b/urnai/sc2/rewards/collectables.py @@ -0,0 +1,56 @@ +import numpy as np + +import urnai.sc2.actions.sc2_actions_aux as sc2aux +from urnai.rewards.reward_base import RewardBase + +STATE_MAXIMUM_NUMBER_OF_MINERAL_SHARDS = 20 + + +class CollectablesReward(RewardBase): + + def __init__(self): + self.previous_state = None + self.old_collectable_counter = STATE_MAXIMUM_NUMBER_OF_MINERAL_SHARDS + + def get(self, obs, default_reward, terminated, truncated) -> int: + + reward = 0 + if(self.previous_state is not None): + # layer 4 is units (1 friendly, 2 enemy, 16 mineral shards, 3 neutral + current = self.filter_non_mineral_shard_units(obs) + curr = np.count_nonzero(current == 1) + if curr != self.old_collectable_counter: + self.old_collectable_counter = curr + reward = 10 + else: + reward = -1 + + if(truncated or terminated): + + if(self.old_collectable_counter == 0): + reward = 1000 + elif(self.old_collectable_counter >= 15): + reward = 500 + elif(self.old_collectable_counter >= 10): + reward = 100 + elif(self.old_collectable_counter >= 5): + reward = -100 + elif(self.old_collectable_counter >= 1): + reward = -500 + else: + reward = -1000 + + self.previous_state = obs + return reward + + def reset(self) -> None: + self.previous_state = None + self.old_collectable_counter = STATE_MAXIMUM_NUMBER_OF_MINERAL_SHARDS + + def filter_non_mineral_shard_units(self, obs): + filtered_map = np.zeros((len(obs.feature_minimap[0]), + len(obs.feature_minimap[0][0]))) + for unit in sc2aux.get_all_neutral_units(obs): + filtered_map[unit.y][unit.x] = 1 + + return filtered_map \ No newline at end of file diff --git a/urnai/sc2/states/collectables.py b/urnai/sc2/states/collectables.py new file mode 100644 index 0000000..478b0f6 --- /dev/null +++ b/urnai/sc2/states/collectables.py @@ -0,0 +1,224 @@ +import math +from statistics import mean + +import numpy as np +from pysc2.lib import units as sc2units + +import urnai.sc2.actions.sc2_actions_aux as sc2aux +from urnai.states.state_base import StateBase + +# from urnai.utils.constants import Games, RTSGeneralization + +STATE_MAP = 'map' +STATE_NON_SPATIAL = 'non_spatial_only' +STATE_BOTH = 'map_and_non_spatial' +STATE_MAP_DEFAULT_REDUCTIONFACTOR = 1 +STATE_MAX_COLL_DIST = 15 + + +class CollectablesState(StateBase): + + def __init__(self, trim_map=False): + self.previous_state = None + self.method = STATE_MAP + # number of quadrants is the amount of parts + # the map should be reduced + # this helps the agent to + # deal with the big size + # of state space + # if -1 (default value), the map + # wont be reduced + self.map_reduction_factor = STATE_MAP_DEFAULT_REDUCTIONFACTOR + self.non_spatial_maximums = [ + STATE_MAX_COLL_DIST, + STATE_MAX_COLL_DIST, + # RTSGeneralization.STATE_MAXIMUM_NUMBER_OF_MINERAL_SHARDS, + ] + self.non_spatial_minimums = [ + 0, + 0, + # 0, + ] + # non-spatial is composed of + # X distance to next mineral shard + # Y distance to next mineral shard + # number of mineral shards left + self.non_spatial_state = [ + 0, + 0, + # 0, + ] + self.trim_map = trim_map + self.reset() + + def update(self, obs): + state = [] + if self.method == STATE_MAP: + state = self.build_map(obs) + elif self.method == STATE_NON_SPATIAL: + state = self.build_non_spatial_state(obs) + elif self.method == STATE_BOTH: + state = self.build_map(obs) + state += self.build_non_spatial_state(obs) + + state = np.asarray(state).flatten() + self._dimension = len(state) + # state = state.reshape((1, len(state))) + self._state = state + + return state + + def build_map(self, obs): + map_ = self.build_basic_map(obs) + map_ = self.reduce_map(map_) + map_ = self.normalize_map(map_) + + return map_ + + def build_basic_map(self, obs): + + map_ = np.zeros(obs.feature_minimap[0].shape) + marines = sc2aux.get_units_by_type(obs, sc2units.Terran.Marine) + shards = sc2aux.get_all_neutral_units(obs) + + for marine in marines: + map_[marine.y][marine.x] = 7 + + for shard in shards: + map_[shard.y][shard.x] = 100 + + return map_ + + def normalize_map(self, map_): + # map = (map_ - map_.min()) / (map_.max() - map_.min()) + return map_ + + def normalize_non_spatial_list(self): + for i in range(len(self.non_spatial_state)): + value = self.non_spatial_state[i] + max_ = self.non_spatial_maximums[i] + min_ = self.non_spatial_minimums[i] + value = self.normalize_value(value, max_, min_) + self.non_spatial_state[i] = value + + def normalize_value(self, value, max_, min_=0): + return (value - min_) / (max_ - min_) + + # TODO: Remove magic numbers + @property + def dimension(self): + if self.method == STATE_MAP: + if self.trim_map: + a = int(22 / self.map_reduction_factor) + b = int(16 / self.map_reduction_factor) + else: + a = int(64 / self.map_reduction_factor) + b = int(64 / self.map_reduction_factor) + return int(a * b) + elif self.method == STATE_NON_SPATIAL: + return len(self.non_spatial_state) + + @property + def state(self): + return self._state + + def get_marine_mean(self, obs): + xs = [] + ys = [] + + for unit in sc2aux.get_units_by_type(obs, sc2units.Terran.Marine): + xs.append(unit.x) + ys.append(unit.y) + + x_mean = mean(xs) + y_mean = mean(ys) + + return x_mean, y_mean + + def get_closest_mineral_shard_x_y(self, obs): + closest_distance = STATE_MAX_COLL_DIST + x, y = self.get_marine_mean(obs) + x_closest_distance, y_closest_distance = -1, -1 + for mineral_shard in sc2aux.get_all_neutral_units(obs): + mineral_shard_x = mineral_shard.x + mineral_shard_y = mineral_shard.y + dist = self.calculate_distance(x, y, mineral_shard_x, mineral_shard_y) + if dist < closest_distance: + closest_distance = dist + x_closest_distance = x - mineral_shard_x + y_closest_distance = y - mineral_shard_y + + return abs(x_closest_distance), abs(y_closest_distance) + + def build_non_spatial_state(self, obs): + x, y = self.get_closest_mineral_shard_x_y(obs) + # position 0: distance x to closest shard + self.non_spatial_state[0] = int(x) + # position 1: distance y to closest shard + self.non_spatial_state[1] = int(y) + # position 2: number of remaining shards + # self.non_spatial_state[2]=np.count_nonzero(obs.feature_minimap[4]==16) + self.normalize_non_spatial_list() + return self.non_spatial_state + + def calculate_distance(self, x1, y1, x2, y2): + dist = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) + return dist + + def reduce_map(self, map_): + if self.trim_map: + x1, y1 = 22, 28 + x2, y2 = 43, 43 + map_ = self.trim_matrix(map_, x1, y1, x2, y2) + return self.lower_featuremap_resolution(map_, self.map_reduction_factor) + + def reset(self): + self._state = None + self._dimension = None + + def trim_matrix(matrix, x1, y1, x2, y2): + """ + If you have a 2D numpy array + and you want a submatrix of that array, + you can use this function to extract it. + You just need to tell this function + what are the top-left and bottom-right + corners of this submatrix, by setting + x1, y1 and x2, y2. + For example: some maps of StarCraft II + have parts that are not walkable, this + happens specially in PySC2 mini-games + where only a small portion of the map + is walkable. So, you may want to trim + this big map (generally a 64x64 matrix) + and leave only the useful parts. + """ + matrix = np.delete(matrix, np.s_[0:x1:1], 1) + matrix = np.delete(matrix, np.s_[0:y1:1], 0) + matrix = np.delete(matrix, np.s_[x2 - x1 + 1::1], 1) + matrix = np.delete(matrix, np.s_[y2 - y1 + 1::1], 0) + return matrix + + def lower_featuremap_resolution(self, map, rf): # rf = reduction_factor + """ + Reduces a matrix "resolution" by a reduction factor. If we have a 64x64 matrix + and rf=4 the map will be reduced to 16x16 in which every new element of the + matrix is an average from 4x4=16 elements from the original matrix + """ + if rf == 1: + return map + + N, M = map.shape + N = N // rf + M = M // rf + + reduced_map = np.empty((N, M)) + for i in range(N): + for j in range(M): + # reduction_array = map[rf*i:rf*i+rf, rf*j:rf*j+rf].flatten() + # reduced_map[i,j] = Counter(reduction_array).most_common(1)[0][0] + + reduced_map[i, j] = ((map[rf * i:rf * i + rf, rf * j:rf * j + rf].sum()) + / (rf * rf)) + + return reduced_map \ No newline at end of file From 51b5095dc6d0b21fdc9ff2685bb8ed172447213a Mon Sep 17 00:00:00 2001 From: CinquilCinquil <106356391+CinquilCinquil@users.noreply.github.com> Date: Tue, 29 Oct 2024 14:29:31 -0300 Subject: [PATCH 2/8] feat: Split part of ActionBase into ActionBaseStrict --- urnai/actions/action_base.py | 14 ++------------ urnai/actions/action_base_strict.py | 21 +++++++++++++++++++++ 2 files changed, 23 insertions(+), 12 deletions(-) create mode 100644 urnai/actions/action_base_strict.py diff --git a/urnai/actions/action_base.py b/urnai/actions/action_base.py index bcfb43a..63a722c 100644 --- a/urnai/actions/action_base.py +++ b/urnai/actions/action_base.py @@ -1,21 +1,11 @@ from abc import ABC, abstractmethod +from typing import Any class ActionBase(ABC): __id__ = None @abstractmethod - def run(self) -> None: + def run(*args) -> Any: """Executing the action""" ... - - @abstractmethod - def check(self, obs) -> bool: - """Returns whether the action can be executed or not""" - ... - - @property - @abstractmethod - def is_complete(self) -> bool: - """Returns whether the action has finished or not""" - ... diff --git a/urnai/actions/action_base_strict.py b/urnai/actions/action_base_strict.py new file mode 100644 index 0000000..3b98995 --- /dev/null +++ b/urnai/actions/action_base_strict.py @@ -0,0 +1,21 @@ +from abc import abstractmethod +from typing import Any +from urnai.actions.action_base import ActionBase + +class ActionBaseStrict(ActionBase): + + @abstractmethod + def run(*args) -> Any: + """Executing the action""" + ... + + @abstractmethod + def check(self, obs) -> bool: + """Returns whether the action can be executed or not""" + ... + + @property + @abstractmethod + def is_complete(self) -> bool: + """Returns whether the action has finished or not""" + ... From 5ad5104541994293d0fac259ba08fa8ede9157b5 Mon Sep 17 00:00:00 2001 From: CinquilCinquil <106356391+CinquilCinquil@users.noreply.github.com> Date: Tue, 29 Oct 2024 14:32:10 -0300 Subject: [PATCH 3/8] feat: Made automatic class generation for sc2 actions All of these classes inherit from ActionBase --- urnai/sc2/actions/collectables.py | 24 +++++++++++++----------- urnai/sc2/actions/sc2_action.py | 9 --------- urnai/sc2/actions/sc2_actions.py | 22 ++++++++++++++++++++++ 3 files changed, 35 insertions(+), 20 deletions(-) delete mode 100644 urnai/sc2/actions/sc2_action.py create mode 100644 urnai/sc2/actions/sc2_actions.py diff --git a/urnai/sc2/actions/collectables.py b/urnai/sc2/actions/collectables.py index 8211bc2..4c1f183 100644 --- a/urnai/sc2/actions/collectables.py +++ b/urnai/sc2/actions/collectables.py @@ -5,13 +5,13 @@ from urnai.actions.action_space_base import ActionSpaceBase from urnai.sc2.actions import sc2_actions_aux as scaux -from urnai.sc2.actions.sc2_action import SC2Action +from urnai.sc2.actions.sc2_actions import sc2_raw_action_classes as sc2_actions class CollectablesActionSpace(ActionSpaceBase): def __init__(self): - self.noaction = [actions.RAW_FUNCTIONS.no_op()] + self.noaction = [sc2_actions["no_op"]().run()] #actions.RAW_FUNCTIONS.no_op() self.move_number = 0 self.hor_threshold = 2 @@ -48,7 +48,7 @@ def get_excluded_actions(self, obs): def get_action(self, action_idx, obs): action = None if len(self.pending_actions) == 0: - action = [actions.RAW_FUNCTIONS.no_op()] + action = self.noaction else: action = [self.pending_actions.pop()] self.solve_action(action_idx, obs) @@ -85,8 +85,10 @@ def move_left(self, obs): for unit in army: self.pending_actions.append( - SC2Action.run(actions.RAW_FUNCTIONS.Move_pt, - 'now', unit.tag, [new_army_x, new_army_y])) + sc2_actions["Move_pt"]().run( + 'now', unit.tag,[new_army_x, new_army_y])) + #SC2Action.run(actions.RAW_FUNCTIONS.Move_pt, + #'now', unit.tag, [new_army_x, new_army_y])) def move_right(self, obs): army = scaux.select_army(obs, sc2_env.Race.terran) @@ -98,8 +100,8 @@ def move_right(self, obs): for unit in army: self.pending_actions.append( - SC2Action.run(actions.RAW_FUNCTIONS.Move_pt, - 'now', unit.tag, [new_army_x, new_army_y])) + sc2_actions["Move_pt"]().run( + 'now', unit.tag,[new_army_x, new_army_y])) def move_down(self, obs): army = scaux.select_army(obs, sc2_env.Race.terran) @@ -111,8 +113,8 @@ def move_down(self, obs): for unit in army: self.pending_actions.append( - SC2Action.run(actions.RAW_FUNCTIONS.Move_pt, - 'now', unit.tag, [new_army_x, new_army_y])) + sc2_actions["Move_pt"]().run( + 'now', unit.tag,[new_army_x, new_army_y])) def move_up(self, obs): army = scaux.select_army(obs, sc2_env.Race.terran) @@ -124,8 +126,8 @@ def move_up(self, obs): for unit in army: self.pending_actions.append( - SC2Action.run(actions.RAW_FUNCTIONS.Move_pt, - 'now', unit.tag, [new_army_x, new_army_y])) + sc2_actions["Move_pt"]().run( + 'now', unit.tag,[new_army_x, new_army_y])) def get_action_name_str_by_int(self, action_int): action_str = '' diff --git a/urnai/sc2/actions/sc2_action.py b/urnai/sc2/actions/sc2_action.py deleted file mode 100644 index d2d1047..0000000 --- a/urnai/sc2/actions/sc2_action.py +++ /dev/null @@ -1,9 +0,0 @@ -from pysc2.lib.actions import FunctionCall - - -class SC2Action: - - """This class encapsulates the usage of actions from pysc2""" - - def run(action_function, *args) -> FunctionCall: - return action_function(*args) diff --git a/urnai/sc2/actions/sc2_actions.py b/urnai/sc2/actions/sc2_actions.py new file mode 100644 index 0000000..2e8600b --- /dev/null +++ b/urnai/sc2/actions/sc2_actions.py @@ -0,0 +1,22 @@ +from pysc2.lib import actions +from urnai.actions.action_base import ActionBase + +sc2_raw_action_classes = {} + +def constructor(self): + ... + +def run_method(self, *args) -> actions.FunctionCall: + return self.my_action_function(*args) + +for sc2_action in actions.RAW_FUNCTIONS: + + sc2_raw_action_class = type(sc2_action.name, (ActionBase,), { + + "__init__" : constructor, + "my_action_function": sc2_action, + "run": run_method, + + }) + + sc2_raw_action_classes[sc2_action.name] = sc2_raw_action_class \ No newline at end of file From 22bac1986b1e4e4478cce351cd8b7986c348297d Mon Sep 17 00:00:00 2001 From: CinquilCinquil <106356391+CinquilCinquil@users.noreply.github.com> Date: Fri, 1 Nov 2024 16:27:06 -0300 Subject: [PATCH 4/8] feat: Change inherited classe's methods to static --- urnai/sc2/actions/collectables.py | 10 +++++----- urnai/sc2/actions/sc2_actions.py | 7 ++++--- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/urnai/sc2/actions/collectables.py b/urnai/sc2/actions/collectables.py index 4c1f183..2735a64 100644 --- a/urnai/sc2/actions/collectables.py +++ b/urnai/sc2/actions/collectables.py @@ -11,7 +11,7 @@ class CollectablesActionSpace(ActionSpaceBase): def __init__(self): - self.noaction = [sc2_actions["no_op"]().run()] #actions.RAW_FUNCTIONS.no_op() + self.noaction = [sc2_actions["no_op"].run()] #actions.RAW_FUNCTIONS.no_op() self.move_number = 0 self.hor_threshold = 2 @@ -85,7 +85,7 @@ def move_left(self, obs): for unit in army: self.pending_actions.append( - sc2_actions["Move_pt"]().run( + sc2_actions["Move_pt"].run( 'now', unit.tag,[new_army_x, new_army_y])) #SC2Action.run(actions.RAW_FUNCTIONS.Move_pt, #'now', unit.tag, [new_army_x, new_army_y])) @@ -100,7 +100,7 @@ def move_right(self, obs): for unit in army: self.pending_actions.append( - sc2_actions["Move_pt"]().run( + sc2_actions["Move_pt"].run( 'now', unit.tag,[new_army_x, new_army_y])) def move_down(self, obs): @@ -113,7 +113,7 @@ def move_down(self, obs): for unit in army: self.pending_actions.append( - sc2_actions["Move_pt"]().run( + sc2_actions["Move_pt"].run( 'now', unit.tag,[new_army_x, new_army_y])) def move_up(self, obs): @@ -126,7 +126,7 @@ def move_up(self, obs): for unit in army: self.pending_actions.append( - sc2_actions["Move_pt"]().run( + sc2_actions["Move_pt"].run( 'now', unit.tag,[new_army_x, new_army_y])) def get_action_name_str_by_int(self, action_int): diff --git a/urnai/sc2/actions/sc2_actions.py b/urnai/sc2/actions/sc2_actions.py index 2e8600b..e3b70c5 100644 --- a/urnai/sc2/actions/sc2_actions.py +++ b/urnai/sc2/actions/sc2_actions.py @@ -3,11 +3,12 @@ sc2_raw_action_classes = {} -def constructor(self): +def constructor(): ... -def run_method(self, *args) -> actions.FunctionCall: - return self.my_action_function(*args) +@classmethod +def run_method(cls, *args) -> actions.FunctionCall: + return cls.my_action_function(*args) for sc2_action in actions.RAW_FUNCTIONS: From a2ff779cf99472acff2ecdff72aef1ed0fb5fa66 Mon Sep 17 00:00:00 2001 From: CinquilCinquil <106356391+CinquilCinquil@users.noreply.github.com> Date: Thu, 7 Nov 2024 14:35:02 -0300 Subject: [PATCH 5/8] feat: Add documentation about action usage --- docs/pysc2_usage.md | 23 +++++++++++++++++++++++ urnai/sc2/actions/sc2_actions.py | 6 +++--- 2 files changed, 26 insertions(+), 3 deletions(-) create mode 100644 docs/pysc2_usage.md diff --git a/docs/pysc2_usage.md b/docs/pysc2_usage.md new file mode 100644 index 0000000..16afa1c --- /dev/null +++ b/docs/pysc2_usage.md @@ -0,0 +1,23 @@ +# PySC2 usage in URNAI + +## Actions in URNAI vs Actions in PySC2 + +In PySC2 you would normally call an action in the following manner: `actions.X.F`. Where `X` is the action function set, such as `RAW_FUNCTIONS` or `FUNCTIONS`, and `F` is the action function you are calling, such as `no_op` or `Move_pt`. + +But in URNAI, due to encapsulation, the call is done in the following manner: `X[F].run()`. Where `X` is a dict containing classes for each of the action functions in PySC2, and `F` is a string with the name of the action function you are calling. The dictionaries mentioned above are stored in the file `sc2_actions.py`, and can be imported such as in the example below: + +```python +from urnai.sc2.actions.sc2_actions import sc2_raw_action_classes as sc2_actions +``` + +### Examples: +```python +actions.RAW_FUNCTIONS.no_op() #PySC2 +sc2_actions["no_op"].run() #URNAI + +actions.RAW_FUNCTIONS.Move_pt('now', unit.tag, [new_army_x, new_army_y]) #PySC2 +sc2_actions["Move_pt"].run('now', unit.tag,[new_army_x, new_army_y]) #URNAI +``` + +### Why? +URNAI chooses to represent actions as classes so they can be better organized and tested. Therefore, this encapsulation, which transforms each PySC2 action into a class, contributes to all environments being done under a single system, something that is desired. \ No newline at end of file diff --git a/urnai/sc2/actions/sc2_actions.py b/urnai/sc2/actions/sc2_actions.py index e3b70c5..eeeb731 100644 --- a/urnai/sc2/actions/sc2_actions.py +++ b/urnai/sc2/actions/sc2_actions.py @@ -14,9 +14,9 @@ def run_method(cls, *args) -> actions.FunctionCall: sc2_raw_action_class = type(sc2_action.name, (ActionBase,), { - "__init__" : constructor, - "my_action_function": sc2_action, - "run": run_method, + "__init__" : constructor, + "my_action_function": sc2_action, + "run": run_method, }) From 5957af822ac5cfdd2d6cb8b5bc64d4f68f79f1f6 Mon Sep 17 00:00:00 2001 From: CinquilCinquil <106356391+CinquilCinquil@users.noreply.github.com> Date: Thu, 7 Nov 2024 15:01:10 -0300 Subject: [PATCH 6/8] refactor: Adjusting to lint's suggestions --- docs/pysc2_usage.md | 21 ++++++++++++++++----- urnai/actions/action_base_strict.py | 2 ++ urnai/sc2/actions/collectables.py | 7 ++----- urnai/sc2/actions/sc2_actions.py | 24 +++++++++++++++++++++--- 4 files changed, 41 insertions(+), 13 deletions(-) diff --git a/docs/pysc2_usage.md b/docs/pysc2_usage.md index 16afa1c..072b13c 100644 --- a/docs/pysc2_usage.md +++ b/docs/pysc2_usage.md @@ -2,15 +2,23 @@ ## Actions in URNAI vs Actions in PySC2 -In PySC2 you would normally call an action in the following manner: `actions.X.F`. Where `X` is the action function set, such as `RAW_FUNCTIONS` or `FUNCTIONS`, and `F` is the action function you are calling, such as `no_op` or `Move_pt`. +In PySC2 you would normally call an action in the following manner: +`actions.X.F`. Where `X` is the action function set, such as `RAW_FUNCTIONS` +or `FUNCTIONS`, and `F` is the action function you are calling, +such as `no_op` or `Move_pt`. -But in URNAI, due to encapsulation, the call is done in the following manner: `X[F].run()`. Where `X` is a dict containing classes for each of the action functions in PySC2, and `F` is a string with the name of the action function you are calling. The dictionaries mentioned above are stored in the file `sc2_actions.py`, and can be imported such as in the example below: +But in URNAI, due to encapsulation, the call is done in the following manner: +`X[F].run()`. Where `X` is a dict containing classes for each of the action functions +in PySC2, and `F` is a string with the name of the action function you are calling. +The dictionaries mentioned above are stored in the file `sc2_actions.py`, and can +be imported such as in the example below: ```python -from urnai.sc2.actions.sc2_actions import sc2_raw_action_classes as sc2_actions +from urnai.sc2.actions.sc2_actions import raw_functions_classes as sc2_actions ``` -### Examples: +### Examples + ```python actions.RAW_FUNCTIONS.no_op() #PySC2 sc2_actions["no_op"].run() #URNAI @@ -20,4 +28,7 @@ sc2_actions["Move_pt"].run('now', unit.tag,[new_army_x, new_army_y]) #URNAI ``` ### Why? -URNAI chooses to represent actions as classes so they can be better organized and tested. Therefore, this encapsulation, which transforms each PySC2 action into a class, contributes to all environments being done under a single system, something that is desired. \ No newline at end of file + +URNAI chooses to represent actions as classes so they can be better organized and tested. +Therefore, this encapsulation, which transforms each PySC2 action into a class, contributes +to all environments being done under a single system, something that is desired. diff --git a/urnai/actions/action_base_strict.py b/urnai/actions/action_base_strict.py index 3b98995..f5d5181 100644 --- a/urnai/actions/action_base_strict.py +++ b/urnai/actions/action_base_strict.py @@ -1,7 +1,9 @@ from abc import abstractmethod from typing import Any + from urnai.actions.action_base import ActionBase + class ActionBaseStrict(ActionBase): @abstractmethod diff --git a/urnai/sc2/actions/collectables.py b/urnai/sc2/actions/collectables.py index 2735a64..33e6117 100644 --- a/urnai/sc2/actions/collectables.py +++ b/urnai/sc2/actions/collectables.py @@ -1,17 +1,16 @@ from statistics import mean from pysc2.env import sc2_env -from pysc2.lib import actions from urnai.actions.action_space_base import ActionSpaceBase from urnai.sc2.actions import sc2_actions_aux as scaux -from urnai.sc2.actions.sc2_actions import sc2_raw_action_classes as sc2_actions +from urnai.sc2.actions.sc2_actions import raw_functions_classes as sc2_actions class CollectablesActionSpace(ActionSpaceBase): def __init__(self): - self.noaction = [sc2_actions["no_op"].run()] #actions.RAW_FUNCTIONS.no_op() + self.noaction = [sc2_actions["no_op"].run()] self.move_number = 0 self.hor_threshold = 2 @@ -87,8 +86,6 @@ def move_left(self, obs): self.pending_actions.append( sc2_actions["Move_pt"].run( 'now', unit.tag,[new_army_x, new_army_y])) - #SC2Action.run(actions.RAW_FUNCTIONS.Move_pt, - #'now', unit.tag, [new_army_x, new_army_y])) def move_right(self, obs): army = scaux.select_army(obs, sc2_env.Race.terran) diff --git a/urnai/sc2/actions/sc2_actions.py b/urnai/sc2/actions/sc2_actions.py index eeeb731..bdab4f6 100644 --- a/urnai/sc2/actions/sc2_actions.py +++ b/urnai/sc2/actions/sc2_actions.py @@ -1,7 +1,13 @@ from pysc2.lib import actions + from urnai.actions.action_base import ActionBase -sc2_raw_action_classes = {} +""" +This file creates a dict which stores a class for each of the actions in PySC2. +""" + +raw_functions_classes = {} +functions_classes = {} def constructor(): ... @@ -12,7 +18,19 @@ def run_method(cls, *args) -> actions.FunctionCall: for sc2_action in actions.RAW_FUNCTIONS: - sc2_raw_action_class = type(sc2_action.name, (ActionBase,), { + raw_function_class = type(sc2_action.name, (ActionBase,), { + + "__init__" : constructor, + "my_action_function": sc2_action, + "run": run_method, + + }) + + raw_functions_classes[sc2_action.name] = raw_function_class + +for sc2_action in actions.FUNCTIONS: + + functions_class = type(sc2_action.name, (ActionBase,), { "__init__" : constructor, "my_action_function": sc2_action, @@ -20,4 +38,4 @@ def run_method(cls, *args) -> actions.FunctionCall: }) - sc2_raw_action_classes[sc2_action.name] = sc2_raw_action_class \ No newline at end of file + functions_classes[sc2_action.name] = functions_class From e0fadbab28125c843beaadd17ea13c769c5bc340 Mon Sep 17 00:00:00 2001 From: CinquilCinquil <106356391+CinquilCinquil@users.noreply.github.com> Date: Thu, 14 Nov 2024 16:44:20 -0300 Subject: [PATCH 7/8] test(issue-115): Adapted tests for new ActionBase #115 --- tests/units/actions/test_action_base.py | 4 --- .../units/actions/test_action_base_strict.py | 29 +++++++++++++++++++ tests/units/sc2/actions/test_sc2_action.py | 8 ++--- 3 files changed, 33 insertions(+), 8 deletions(-) create mode 100644 tests/units/actions/test_action_base_strict.py diff --git a/tests/units/actions/test_action_base.py b/tests/units/actions/test_action_base.py index b8fd7d0..26d7ed8 100644 --- a/tests/units/actions/test_action_base.py +++ b/tests/units/actions/test_action_base.py @@ -18,12 +18,8 @@ def test_abstract_methods(self): # WHEN run_return = fake_action.run() - check_return = fake_action.check("observation") - is_complete_return = fake_action.is_complete # THEN assert fake_action.__id__ is None assert isinstance(ActionBase, ABCMeta) assert run_return is None - assert check_return is None - assert is_complete_return is None diff --git a/tests/units/actions/test_action_base_strict.py b/tests/units/actions/test_action_base_strict.py new file mode 100644 index 0000000..7989449 --- /dev/null +++ b/tests/units/actions/test_action_base_strict.py @@ -0,0 +1,29 @@ +import unittest +from abc import ABCMeta + +from urnai.actions.action_base_strict import ActionBaseStrict + + +class FakeActionStrict(ActionBaseStrict): + ActionBaseStrict.__abstractmethods__ = set() + __id__ = None + ... + +class TestActionBase(unittest.TestCase): + + def test_abstract_methods(self): + + # GIVEN + fake_action = FakeActionStrict() + + # WHEN + run_return = fake_action.run() + check_return = fake_action.check("observation") + is_complete_return = fake_action.is_complete + + # THEN + assert fake_action.__id__ is None + assert isinstance(ActionBaseStrict, ABCMeta) + assert run_return is None + assert check_return is None + assert is_complete_return is None diff --git a/tests/units/sc2/actions/test_sc2_action.py b/tests/units/sc2/actions/test_sc2_action.py index 865b3a9..af41e1e 100644 --- a/tests/units/sc2/actions/test_sc2_action.py +++ b/tests/units/sc2/actions/test_sc2_action.py @@ -2,7 +2,7 @@ from pysc2.lib import actions -from urnai.sc2.actions.sc2_action import SC2Action +from urnai.sc2.actions.sc2_actions import raw_functions_classes as sc2_actions _BUILD_REFINERY = actions.RAW_FUNCTIONS.Build_Refinery_pt _NO_OP = actions.FUNCTIONS.no_op @@ -11,9 +11,9 @@ class TestSC2Action(unittest.TestCase): def test_run(self): - run_no_op = SC2Action.run(_NO_OP) - run_build_refinery = SC2Action.run(_BUILD_REFINERY, 'now', 0) - + run_no_op = sc2_actions["no_op"].run() + run_build_refinery = sc2_actions["Build_Refinery_pt"].run('now', 0) + self.assertEqual(run_no_op.function, _NO_OP.id) self.assertEqual(run_no_op.arguments, []) From 4eb86abe40fc7693052563b0a5d60e37cbec0073 Mon Sep 17 00:00:00 2001 From: CinquilCinquil <106356391+CinquilCinquil@users.noreply.github.com> Date: Thu, 14 Nov 2024 16:45:32 -0300 Subject: [PATCH 8/8] feat: Remove constructors from action classes #115 --- urnai/sc2/actions/sc2_actions.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/urnai/sc2/actions/sc2_actions.py b/urnai/sc2/actions/sc2_actions.py index bdab4f6..2df9880 100644 --- a/urnai/sc2/actions/sc2_actions.py +++ b/urnai/sc2/actions/sc2_actions.py @@ -9,9 +9,6 @@ raw_functions_classes = {} functions_classes = {} -def constructor(): - ... - @classmethod def run_method(cls, *args) -> actions.FunctionCall: return cls.my_action_function(*args) @@ -20,7 +17,6 @@ def run_method(cls, *args) -> actions.FunctionCall: raw_function_class = type(sc2_action.name, (ActionBase,), { - "__init__" : constructor, "my_action_function": sc2_action, "run": run_method, @@ -32,7 +28,6 @@ def run_method(cls, *args) -> actions.FunctionCall: functions_class = type(sc2_action.name, (ActionBase,), { - "__init__" : constructor, "my_action_function": sc2_action, "run": run_method,