Skip to content

Commit

Permalink
POMCP [WIP]
Browse files Browse the repository at this point in the history
  • Loading branch information
Limmen committed Jan 19, 2024
1 parent 01e7f30 commit b6aa05c
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 30 deletions.
7 changes: 6 additions & 1 deletion examples/manual_play/cyborg_restore_defender.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,15 @@
csle_cyborg_env = CyborgScenarioTwoDefender(config=config)
o, info = csle_cyborg_env.reset()
initial_state_id = info[env_constants.ENV_METRICS.STATE]
csle_cyborg_env.step(1)
o, r, done, _, info = csle_cyborg_env.step(1)
# csle_cyborg_env.get_table()
obs = info[env_constants.ENV_METRICS.OBSERVATION]
# print("FIRST OBS:")
# print(csle_cyborg_env.get_observation_from_id(obs_id=obs))
csle_cyborg_env.set_state(state=initial_state_id)
# print(csle_cyborg_env.cyborg_challenge_env.env.env.env.env.env.environment_controller.observation["Red"].data["User0"])
csle_cyborg_env.step(1)
csle_cyborg_env.get_table()

# print("INITIAL2 STATE")
# print(csle_cyborg_env.get_true_table())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
descr="the planning time"),
agents_constants.POMCP.MAX_PARTICLES: HParam(value=1000, name=agents_constants.POMCP.MAX_PARTICLES,
descr="the maximum number of belief particles"),
agents_constants.POMCP.MAX_DEPTH: HParam(value=500, name=agents_constants.POMCP.MAX_DEPTH,
agents_constants.POMCP.MAX_DEPTH: HParam(value=100, name=agents_constants.POMCP.MAX_DEPTH,
descr="the maximum depth for planning"),
agents_constants.POMCP.C: HParam(value=0.35, name=agents_constants.POMCP.C,
descr="the weighting factor for UCB exploration"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ def pomcp(self, exp_result: ExperimentResult, seed: int,
done = False
eval_env = gym.make(self.simulation_env_config.gym_env_name, config=config)
train_env: BaseEnv = gym.make(self.simulation_env_config.gym_env_name, config=config)
eval_env.reset()
_, info = eval_env.reset()
s = info[agents_constants.COMMON.STATE]
train_env.reset()
belief = b1.copy()
pomcp = POMCP(A=A, gamma=gamma, env=train_env, c=c, initial_belief=belief,
Expand All @@ -218,7 +219,7 @@ def pomcp(self, exp_result: ExperimentResult, seed: int,
R = 0
t = 1
if t % log_steps_frequency == 0:
Logger.__call__().get_logger().info(f"[POMCP] t: {t}, b: {belief}")
Logger.__call__().get_logger().info(f"[POMCP] t: {t}, b: {belief}, s: {s}")
# Run episode
while not done and t <= max_env_steps:
pomcp.solve(max_depth=max_depth)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,17 +239,21 @@ def populate_info(self, info: Dict[str, Any], obs: npt.NDArray[Any], reset: bool
(deepcopy(self.cyborg_challenge_env.env.env.env.env.env.environment_controller.state),
deepcopy(self.cyborg_challenge_env.env.env.env.env.scanned_ips),
agent_interfaces_copy,
self.cyborg_challenge_env.env.env.env.env.env.environment_controller.done,
deepcopy(self.cyborg_challenge_env.env.env.env.env.env.environment_controller.done),
deepcopy(self.cyborg_challenge_env.env.env.env.env.env.environment_controller.reward),
deepcopy(self.cyborg_challenge_env.env.env.env.env.env.environment_controller.actions),
self.cyborg_challenge_env.env.env.env.env.env.environment_controller.step,
deepcopy(self.cyborg_challenge_env.env.env.env.env.env.environment_controller.step),
deepcopy(self.cyborg_challenge_env.env.env.env.env.env.environment_controller.hostname_ip_map),
deepcopy(self.cyborg_challenge_env.env.env.env.env.env.environment_controller.subnet_cidr_map),
deepcopy(self.cyborg_challenge_env.env.env.env.env.env.environment_controller.observation),
self.cyborg_challenge_env.env.env.env.env.step_counter
deepcopy(self.cyborg_challenge_env.env.env.env.env.step_counter),
deepcopy(self.cyborg_challenge_env.env.env.env.success),
deepcopy(self.cyborg_challenge_env.env.env.env.baseline),
deepcopy(self.cyborg_challenge_env.env.env.env.info),
deepcopy(self.cyborg_challenge_env.env.env.env.blue_info)
)
self.visited_scanned_states[state_id] = self.scan_state.copy()
self.visited_decoy_states[state_id] = self.decoy_state.copy()
self.visited_scanned_states[state_id] = deepcopy(self.scan_state)
self.visited_decoy_states[state_id] = deepcopy(self.decoy_state)
return info

def get_table(self) -> PrettyTable:
Expand Down Expand Up @@ -411,32 +415,35 @@ def set_state(self, state: Any) -> None:
s = int(state)
if s in self.visited_cyborg_states:
self.cyborg_challenge_env.env.env.env.env.env.environment_controller.state = \
self.visited_cyborg_states[s][0]
self.cyborg_challenge_env.env.env.env.env.scanned_ips = self.visited_cyborg_states[s][1]
deepcopy(self.visited_cyborg_states[s][0])
self.cyborg_challenge_env.env.env.env.env.scanned_ips = deepcopy(self.visited_cyborg_states[s][1])
self.cyborg_challenge_env.env.env.env.env.env.environment_controller.agent_interfaces \
= self.visited_cyborg_states[s][2]
= deepcopy(self.visited_cyborg_states[s][2])
for k, v in self.cyborg_challenge_env.env.env.env.env.env.environment_controller.agent_interfaces.items():
v.action_space.create_action_params()
self.cyborg_challenge_env.env.env.env.env.env.environment_controller.done = self.visited_cyborg_states[s][3]
self.cyborg_challenge_env.env.env.env.env.env.environment_controller.done = (
deepcopy(self.visited_cyborg_states[s][3]))
self.cyborg_challenge_env.env.env.env.env.env.environment_controller.reward = \
self.visited_cyborg_states[s][4]
deepcopy(self.visited_cyborg_states[s][4])
self.cyborg_challenge_env.env.env.env.env.env.environment_controller.actions = \
self.visited_cyborg_states[s][5]
deepcopy(self.visited_cyborg_states[s][5])
self.cyborg_challenge_env.env.env.env.env.env.environment_controller.step = \
self.visited_cyborg_states[s][6]
deepcopy(self.visited_cyborg_states[s][6])
self.cyborg_challenge_env.env.env.env.env.env.environment_controller.hostname_ip_map = \
self.visited_cyborg_states[s][7]
deepcopy(self.visited_cyborg_states[s][7])
self.cyborg_challenge_env.env.env.env.env.env.environment_controller.subnet_cidr_map = \
self.visited_cyborg_states[s][8]
self.cyborg_challenge_env.env.env.env.env.env.environment_controller.observation = \
self.visited_cyborg_states[s][9]
self.cyborg_challenge_env.env.env.env.env.step_counter = self.visited_cyborg_states[s][10]
# self.cyborg_challenge_env.env.env.env.env.observation_change(
# self.cyborg_challenge_env.env.env.env.env.env.environment_controller.observation)
# self.cyborg_challenge_env.env.env.env.observation_change(
# self.cyborg_challenge_env.env.env.env.env.env.environment_controller.observation)
self.decoy_state = self.visited_decoy_states[s]
self.scan_state = self.visited_scanned_states[s]
deepcopy(self.visited_cyborg_states[s][8])
obs = deepcopy(self.visited_cyborg_states[s][9])
obs["Blue"].data["success"] = self.visited_cyborg_states[s][11]
self.cyborg_challenge_env.env.env.env.env.env.environment_controller.observation = obs
self.cyborg_challenge_env.env.env.env.env.step_counter = deepcopy(self.visited_cyborg_states[s][10])
self.cyborg_challenge_env.env.env.env.baseline = deepcopy(self.visited_cyborg_states[s][12])
self.cyborg_challenge_env.env.env.env.info = deepcopy(self.visited_cyborg_states[s][13])
self.cyborg_challenge_env.env.env.env.blue_info = deepcopy(self.visited_cyborg_states[s][14])
self.decoy_state = deepcopy(self.visited_decoy_states[s])
self.scan_state = deepcopy(self.visited_scanned_states[s])
self.cyborg_challenge_env.env.env.env.env.observation_change(obs)
self.cyborg_challenge_env.env.env.env.observation_change(obs["Blue"])
else:
raise NotImplementedError(f"Unknown state: {s}")

Expand Down Expand Up @@ -469,3 +476,21 @@ def manual_play(self) -> None:
:return: None
"""
return None

def get_observation_from_id(self, obs_id: int) -> List[List[int]]:
"""
Converts an observation id to an observation vector
:param obs_id: the id to convert
:return: the observation vector
"""
return CyborgEnvUtil.state_id_to_state_vector(state_id=obs_id, observation=True)

def get_state_from_id(self, state_id: int) -> List[List[int]]:
"""
Converts a state id to a state vector
:param state_id: the id to convert
:return: the observation vector
"""
return CyborgEnvUtil.state_id_to_state_vector(state_id=state_id, observation=False)
Original file line number Diff line number Diff line change
Expand Up @@ -330,11 +330,11 @@ def state_to_vector(state: List[List[Any]], decoy_state: List[List[BlueAgentActi
host_scanned = scan_state[host_id]
activity = ActivityType.from_str(state[host_id][3]).value
host_access = state[host_id][4]
if host_access == "None":
if host_access == "No" or host_access == "None":
host_access = 0
elif host_access == "User":
if host_access == "User":
host_access = 1
else:
if host_access == "Privileged":
host_access = 2
host_decoy_state = len(decoy_state[host_id])
if not observation:
Expand Down

0 comments on commit b6aa05c

Please sign in to comment.