-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
726 additions
and
73 deletions.
There are no files selected for viewing
480 changes: 480 additions & 0 deletions
480
examples/eval/cyborg_scenario_two/cage2_aggregate_mdp.py
Large diffs are not rendered by default.
Oops, something went wrong.
149 changes: 149 additions & 0 deletions
149
examples/eval/cyborg_scenario_two/eval_aggregate_mdp.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
import numpy as np | ||
import copy | ||
from gym_csle_cyborg.envs.cyborg_scenario_two_wrapper import CyborgScenarioTwoWrapper | ||
from gym_csle_cyborg.dao.red_agent_type import RedAgentType | ||
from gym_csle_cyborg.dao.csle_cyborg_wrapper_config import CSLECyborgWrapperConfig | ||
from gym_csle_cyborg.util.cyborg_env_util import CyborgEnvUtil | ||
from gym_csle_cyborg.dao.cyborg_wrapper_state import CyborgWrapperState | ||
from cyborg_agg_mdp import Cage2AggregateMDP | ||
|
||
|
||
# def monte_carlo_most_frequent(elements, num_samples): | ||
# if not elements: | ||
# raise ValueError("The input list is empty.") | ||
# | ||
# # Perform random sampling | ||
# samples = [random.choice(elements) for _ in range(num_samples)] | ||
# | ||
# # Count occurrences of sampled elements | ||
# counter = Counter(samples) | ||
# | ||
# # Find the most common element | ||
# most_frequent_element = counter.most_common(1)[0][0] | ||
# return most_frequent_element | ||
|
||
# def particle_filter(particles, max_num_particles, train_env, action, obs): | ||
# new_particles = [] | ||
# while len(particles) < max_num_particles: | ||
# x = random.choice(particles) | ||
# train_env.set_state(state=x) | ||
# _, r, _, _, info = train_env.step(action) | ||
# s_prime = info["s"] | ||
# o = info["o"] | ||
# if o == obs: | ||
# new_particles.append(s_prime) | ||
# return new_particles | ||
|
||
def restore_policy(s: CyborgWrapperState): | ||
a = -1 | ||
if s.s[1][2] == 2: | ||
a = 0 # Ent0 | ||
if s.s[2][2] == 2: | ||
a = 1 # Ent 1 | ||
if s.s[3][2] == 2: | ||
a = 2 # Ent 2 | ||
if s.s[7][2] == 2: | ||
a = 3 # Opserver | ||
|
||
if s.s[1][2] == 1: | ||
a = 8 # Ent0 | ||
if s.s[2][2] == 1: | ||
a = 9 # Ent1 | ||
if s.s[3][2] == 1: | ||
a = 10 # Ent2 | ||
if s.s[3][2] == 1: | ||
a = 11 # Opserver | ||
if s.s[9][2] == 1: | ||
a = 22 # User1 | ||
if s.s[10][2] == 1: | ||
a = 23 # User2 | ||
if s.s[11][2] == 1: | ||
a = 24 # User3 | ||
if s.s[12][2] == 1: | ||
a = 25 # User4 | ||
return a | ||
|
||
def rollout(s: CyborgWrapperState, train_env: CyborgScenarioTwoWrapper, J, state_to_id, mu, l, gamma=0.99): | ||
# U = [0, 1, 2, 3, 8, 9, 10, 11, 22, 23, 24, 25, 27, 28, 29, 30, 31, 32, 35] | ||
U = [27, 28, 29, 30, 31, 32, 35] | ||
U = [27, 28, 29, 30, 31, 32] | ||
Q_n = [] | ||
for u in U: | ||
u_r = restore_policy(s=s) | ||
if u_r != -1: | ||
o, c, done, _, info = train_env.step(action=u_r) | ||
s_prime = info["s"] | ||
aggregate_state = Cage2AggregateMDP.get_aggregate_state(s=s_prime, state_to_id=state_to_id) | ||
if l == 1: | ||
return u_r, J[aggregate_state] | ||
else: | ||
returns = [] | ||
for i in range(2): | ||
returns.append(rollout(copy.deepcopy(s_prime), train_env=train_env, J=J, state_to_id=state_to_id, mu=mu, l=l-1)[1]) | ||
cost_to_go = np.mean(returns) | ||
else: | ||
train_env.set_state(s) | ||
o, c, done, _, info = train_env.step(action=u) | ||
s_prime = info["s"] | ||
aggregate_state = Cage2AggregateMDP.get_aggregate_state(s=s_prime, state_to_id=state_to_id) | ||
if l == 1: | ||
cost_to_go = J[aggregate_state] | ||
else: | ||
returns = [] | ||
for i in range(2): | ||
returns.append(rollout(copy.deepcopy(s_prime), train_env=train_env, J=J, state_to_id=state_to_id, mu=mu, l=l-1)[1]) | ||
cost_to_go = np.mean(returns) | ||
Q_n.append(-c + gamma*cost_to_go) | ||
# print(Q_n) | ||
# print(U[int(np.argmin(Q_n))]) | ||
u_star = int(np.argmin(Q_n)) | ||
return U[u_star], Q_n[u_star] | ||
|
||
|
||
if __name__ == '__main__': | ||
config = CSLECyborgWrapperConfig(maximum_steps=100, gym_env_name="", | ||
save_trace=False, reward_shaping=False, scenario=2, | ||
red_agent_type=RedAgentType.B_LINE_AGENT) | ||
env = CyborgScenarioTwoWrapper(config=config) | ||
train_env = CyborgScenarioTwoWrapper(config=config) | ||
action_id_to_type_and_host, type_and_host_to_action_id \ | ||
= CyborgEnvUtil.get_action_dicts(scenario=2, reduced_action_space=True, decoy_state=True, decoy_optimization=False) | ||
N = 10000 | ||
max_env_steps = 100 | ||
mu = np.loadtxt("./mu1.txt") | ||
J = np.loadtxt("./J1.txt") | ||
X, state_to_id, id_to_state = Cage2AggregateMDP.X() | ||
gamma = 0.99 | ||
l = 3 | ||
returns = [] | ||
for i in range(N): | ||
print(f"{i}/{N}") | ||
done = False | ||
_, info = env.reset() | ||
s = info["s"] | ||
t = 1 | ||
R = 0 | ||
particles = env.initial_particles | ||
while not done and t < max_env_steps: | ||
# monte_carlo_state = monte_carlo_most_frequent(elements=particles, num_samples=100) | ||
aggregate_state = Cage2AggregateMDP.get_aggregate_state(s=s, state_to_id=state_to_id) | ||
a = -1 | ||
a = restore_policy(s=s) | ||
|
||
if t <= 1: | ||
a = 31 | ||
if a == -1: | ||
a = Cage2AggregateMDP.get_aggregate_control(mu=mu, aggregate_state=aggregate_state, | ||
id_to_state=id_to_state) | ||
# print(f"base: {a}") | ||
a = rollout(s=s, state_to_id=state_to_id, train_env=train_env, J=J, mu=mu, gamma=gamma, l=l)[0] | ||
# print(f"rollout: {a}") | ||
o, r, done, _, info = env.step(a) | ||
# particles = particle_filter(particles=particles, max_num_particles=1000, | ||
# train_env=train_env, action=a, obs=o) | ||
s = info["s"] | ||
t+= 1 | ||
R+= r | ||
# print(f"t:{t}, r: {r}, a: {action_id_to_type_and_host[a]}, R: {R}, aggstate: {id_to_state[aggregate_state]}") | ||
returns.append(R) | ||
print(np.mean(returns)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,94 +1,112 @@ | ||
from typing import List | ||
import numpy as np | ||
import torch | ||
import random | ||
import json | ||
import io | ||
from gym_csle_cyborg.dao.csle_cyborg_config import CSLECyborgConfig | ||
from gym_csle_cyborg.dao.red_agent_type import RedAgentType | ||
from gym_csle_cyborg.envs.cyborg_scenario_two_defender import CyborgScenarioTwoDefender | ||
import time | ||
from gym_csle_cyborg.envs.cyborg_scenario_two_wrapper import CyborgScenarioTwoWrapper | ||
from gym_csle_cyborg.dao.csle_cyborg_wrapper_config import CSLECyborgWrapperConfig | ||
from csle_agents.agents.pomcp.pomcp import POMCP | ||
from csle_agents.agents.pomcp.pomcp_acquisition_function_type import POMCPAcquisitionFunctionType | ||
import csle_agents.constants.constants as agents_constants | ||
from csle_common.logging.log import Logger | ||
from gym_csle_cyborg.util.cyborg_env_util import CyborgEnvUtil | ||
from gym_csle_cyborg.dao.red_agent_type import RedAgentType | ||
|
||
|
||
def heuristic_value(o: List[List[int]]) -> float: | ||
""" | ||
A heuristic value function | ||
:param o: the observation vector | ||
:return: the value | ||
""" | ||
host_costs = CyborgEnvUtil.get_host_compromised_costs() | ||
val = 0 | ||
for i in range(len(o)): | ||
if o[i][2] > 0: | ||
val += host_costs[i] | ||
return val | ||
|
||
|
||
if __name__ == '__main__': | ||
# ppo_policy = PPOPolicy(model=None, simulation_name="", save_path="") | ||
config = CSLECyborgConfig( | ||
gym_env_name="csle-cyborg-scenario-two-v1", scenario=2, baseline_red_agents=[RedAgentType.B_LINE_AGENT], | ||
maximum_steps=100, red_agent_distribution=[1.0], reduced_action_space=True, decoy_state=True, | ||
scanned_state=True, decoy_optimization=False, cache_visited_states=False) | ||
eval_env = CyborgScenarioTwoDefender(config=config) | ||
config = CSLECyborgWrapperConfig(maximum_steps=100, gym_env_name="", | ||
save_trace=False, reward_shaping=False, scenario=2) | ||
config = CSLECyborgWrapperConfig( | ||
gym_env_name="csle-cyborg-scenario-two-wrapper-v1", maximum_steps=100, save_trace=False, scenario=2, | ||
reward_shaping=True, red_agent_type=RedAgentType.B_LINE_AGENT) | ||
eval_env = CyborgScenarioTwoWrapper(config=config) | ||
train_env = CyborgScenarioTwoWrapper(config=config) | ||
action_id_to_type_and_host, type_and_host_to_action_id \ | ||
= CyborgEnvUtil.get_action_dicts(scenario=2, reduced_action_space=True, decoy_state=True, | ||
decoy_optimization=False) | ||
|
||
num_evaluations = 10 | ||
max_horizon = 100 | ||
returns = [] | ||
seed = 215125 | ||
random.seed(seed) | ||
np.random.seed(seed) | ||
torch.manual_seed(seed) | ||
N = 5000 | ||
rollout_policy = lambda x, deterministic: 35 | ||
value_function = heuristic_value | ||
A = train_env.get_action_space() | ||
gamma = 0.75 | ||
c = 1 | ||
print("Starting policy evaluation") | ||
for i in range(num_evaluations): | ||
gamma = 0.99 | ||
reinvigoration = False | ||
reinvigorated_particles_ratio = 0.0 | ||
initial_particles = train_env.initial_particles | ||
planning_time = 3.75 | ||
prune_action_space = False | ||
max_particles = 1000 | ||
max_planning_depth = 50 | ||
max_rollout_depth = 4 | ||
c = 0.5 | ||
c2 = 15000 | ||
use_rollout_policy = False | ||
prior_weight = 5 | ||
prior_confidence = 0 | ||
acquisition_function_type = POMCPAcquisitionFunctionType.UCB | ||
log_steps_frequency = 1 | ||
max_negative_samples = 20 | ||
default_node_value = 0 | ||
verbose = False | ||
eval_batch_size = 100 | ||
max_env_steps = 100 | ||
prune_size = 3 | ||
start = time.time() | ||
|
||
# Run N episodes | ||
returns = [] | ||
for i in range(N): | ||
done = False | ||
action_sequence = [] | ||
_, info = eval_env.reset() | ||
s = info[agents_constants.COMMON.STATE] | ||
train_env.reset() | ||
initial_particles = train_env.initial_particles | ||
max_particles = 1000 | ||
planning_time = 60 | ||
value_function = lambda x: 0 | ||
reinvigoration = False | ||
rollout_policy = False | ||
verbose = False | ||
default_node_value = 0 | ||
prior_weight = 1 | ||
acquisition_function_type = POMCPAcquisitionFunctionType.UCB | ||
use_rollout_policy = False | ||
reinvigorated_particles_ratio = False | ||
prune_action_space = False | ||
prune_size = 3 | ||
prior_confidence = 0 | ||
pomcp = POMCP(A=A, gamma=gamma, env=train_env, c=c, initial_particles=initial_particles, | ||
planning_time=planning_time, max_particles=max_particles, rollout_policy=rollout_policy, | ||
value_function=value_function, reinvigoration=reinvigoration, verbose=verbose, | ||
default_node_value=default_node_value, prior_weight=prior_weight, | ||
acquisition_function_type=acquisition_function_type, c2=1500, | ||
acquisition_function_type=acquisition_function_type, c2=c2, | ||
use_rollout_policy=use_rollout_policy, prior_confidence=prior_confidence, | ||
reinvigorated_particles_ratio=reinvigorated_particles_ratio, | ||
prune_action_space=prune_action_space, prune_size=prune_size) | ||
rollout_depth = 4 | ||
planning_depth = 100 | ||
R = 0 | ||
t = 0 | ||
action_sequence = [] | ||
while t < max_horizon: | ||
pomcp.solve(max_rollout_depth=rollout_depth, max_planning_depth=planning_depth) | ||
t = 1 | ||
|
||
# Run episode | ||
while not done and t <= max_env_steps: | ||
rollout_depth = max_rollout_depth | ||
planning_depth = max_planning_depth | ||
pomcp.solve(max_rollout_depth=rollout_depth, max_planning_depth=planning_depth, t=t) | ||
action = pomcp.get_action() | ||
o, r, done, _, info = eval_env.step(action) | ||
o, _, done, _, info = eval_env.step(action) | ||
r = info[agents_constants.COMMON.REWARD] | ||
action_sequence.append(action) | ||
s_prime = info[agents_constants.COMMON.STATE] | ||
obs_id = info[agents_constants.COMMON.OBSERVATION] | ||
pomcp.update_tree_with_new_samples(action_sequence=action_sequence, observation=obs_id) | ||
print(eval_env.get_true_table()) | ||
print(eval_env.get_table()) | ||
pomcp.update_tree_with_new_samples(action_sequence=action_sequence, observation=obs_id, t=t) | ||
R += r | ||
t += 1 | ||
Logger.__call__().get_logger().info(f"[POMCP] t: {t}, a: {action}, r: {r}, o: {obs_id}, " | ||
f"s_prime: {s_prime}," | ||
f", action sequence: {action_sequence}, R: {R}") | ||
if t % log_steps_frequency == 0: | ||
Logger.__call__().get_logger().info(f"[POMCP] t: {t}, a: {action_id_to_type_and_host[action]}, r: {r}, " | ||
f"action sequence: {action_sequence}, R: {round(R, 2)}") | ||
|
||
# Logging | ||
returns.append(R) | ||
print(f"{i}/{num_evaluations}, avg R: {np.mean(returns)}, R: {R}") | ||
results = {} | ||
results["seed"] = seed | ||
results["training_time"] = 0 | ||
results["returns"] = returns | ||
results["planning_time"] = planning_time | ||
json_str = json.dumps(results, indent=4, sort_keys=True) | ||
with io.open(f"/Users/kim/pomcp_{0}_60s.json", 'w', encoding='utf-8') as f: | ||
f.write(json_str) | ||
progress = round((i + 1) / N, 2) | ||
time_elapsed_minutes = round((time.time() - start) / 60, 3) | ||
Logger.__call__().get_logger().info( | ||
f"[POMCP] episode: {i}, J:{R}, " | ||
f"J_avg: {np.mean(returns)}, " | ||
f"progress: {round(progress * 100, 2)}%, " | ||
f"runtime: {time_elapsed_minutes} min") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.