From 90e023261925c1d3577d05faacd109eba71e5186 Mon Sep 17 00:00:00 2001 From: Kim Hammar Date: Tue, 28 May 2024 18:25:28 +0200 Subject: [PATCH] fix linter --- .../test_intrusion_recovery_pomdp_config.py | 1 - .../test_intrusion_recovery_pomdp_util.py | 184 ++++-------------- .../envs/stopping_game_env.py | 5 +- .../tests/test_stopping_game_env.py | 140 +++++-------- .../test_stopping_game_mdp_attacker_env.py | 4 +- .../test_stopping_game_pomdp_defender_env.py | 62 +++--- 6 files changed, 126 insertions(+), 270 deletions(-) diff --git a/simulation-system/libs/csle-tolerance/tests/test_intrusion_recovery_pomdp_config.py b/simulation-system/libs/csle-tolerance/tests/test_intrusion_recovery_pomdp_config.py index dd18e3dad..bcfeb3f49 100644 --- a/simulation-system/libs/csle-tolerance/tests/test_intrusion_recovery_pomdp_config.py +++ b/simulation-system/libs/csle-tolerance/tests/test_intrusion_recovery_pomdp_config.py @@ -3,7 +3,6 @@ ) from csle_tolerance.util.intrusion_recovery_pomdp_util import IntrusionRecoveryPomdpUtil import pytest_mock -import numpy as np class TestIntrusionRecoveryPomdpConfigSuite: diff --git a/simulation-system/libs/csle-tolerance/tests/test_intrusion_recovery_pomdp_util.py b/simulation-system/libs/csle-tolerance/tests/test_intrusion_recovery_pomdp_util.py index 08ce57ba2..71a649214 100644 --- a/simulation-system/libs/csle-tolerance/tests/test_intrusion_recovery_pomdp_util.py +++ b/simulation-system/libs/csle-tolerance/tests/test_intrusion_recovery_pomdp_util.py @@ -20,9 +20,7 @@ def test__state_space(self) -> None: :return: None """ - assert ( - isinstance(item, int) for item in IntrusionRecoveryPomdpUtil.state_space() - ) + assert (isinstance(item, int) for item in IntrusionRecoveryPomdpUtil.state_space()) assert IntrusionRecoveryPomdpUtil.state_space() is not None assert IntrusionRecoveryPomdpUtil.state_space() == [0, 1, 2] @@ -40,9 +38,7 @@ def test_action_space(self) -> None: :return: None """ - assert ( - isinstance(item, int) for item in IntrusionRecoveryPomdpUtil.action_space() - ) + assert (isinstance(item, int) for item in IntrusionRecoveryPomdpUtil.action_space()) assert IntrusionRecoveryPomdpUtil.action_space() is not None assert IntrusionRecoveryPomdpUtil.action_space() == [0, 1] @@ -79,10 +75,7 @@ def test_cost_tensor(self) -> None: actions = [0] negate = False expected = [[0, 0.5]] - assert ( - IntrusionRecoveryPomdpUtil.cost_tensor(eta, states, actions, negate) - == expected - ) + assert IntrusionRecoveryPomdpUtil.cost_tensor(eta, states, actions, negate) == expected def test_observation_function(self) -> None: """ @@ -93,9 +86,7 @@ def test_observation_function(self) -> None: s = 1 o = 1 num_observations = 2 - assert round( - IntrusionRecoveryPomdpUtil.observation_function(s, o, num_observations), 1 - ) + assert round(IntrusionRecoveryPomdpUtil.observation_function(s, o, num_observations), 1) def test_observation_tensor(self) -> None: """ @@ -126,15 +117,7 @@ def test_transition_function(self) -> None: p_c_1 = 0.1 p_c_2 = 0.2 p_u = 0.5 - assert ( - round( - IntrusionRecoveryPomdpUtil.transition_function( - s, s_prime, a, p_a, p_c_1, p_c_2, p_u - ), - 1, - ) - == 0.2 - ) + assert (round(IntrusionRecoveryPomdpUtil.transition_function(s, s_prime, a, p_a, p_c_1, p_c_2, p_u), 1) == 0.2) def test_transition_function_game(self) -> None: """ @@ -148,15 +131,7 @@ def test_transition_function_game(self) -> None: a2 = 1 p_a = 0.2 p_c_1 = 0.1 - assert ( - round( - IntrusionRecoveryPomdpUtil.transition_function_game( - s, s_prime, a1, a2, p_a, p_c_1 - ), - 2, - ) - == 0.18 - ) + assert (round(IntrusionRecoveryPomdpUtil.transition_function_game(s, s_prime, a1, a2, p_a, p_c_1), 2) == 0.18) def test_transition_tensor(self) -> None: """ @@ -171,9 +146,7 @@ def test_transition_tensor(self) -> None: p_c_2 = 0.2 p_u = 0.5 expected = [[[0.7, 0.2, 0.1], [0.4, 0.4, 0.2], [0, 0, 1.0]]] - transition_tensor = IntrusionRecoveryPomdpUtil.transition_tensor( - states, actions, p_a, p_c_1, p_c_2, p_u - ) + transition_tensor = IntrusionRecoveryPomdpUtil.transition_tensor(states, actions, p_a, p_c_1, p_c_2, p_u) for i in range(len(transition_tensor)): for j in range(len(transition_tensor[i])): for k in range(len(transition_tensor[i][j])): @@ -181,9 +154,7 @@ def test_transition_tensor(self) -> None: assert transition_tensor == expected states = [0, 1] with pytest.raises(AssertionError): - transition_tensor = IntrusionRecoveryPomdpUtil.transition_tensor( - states, actions, p_a, p_c_1, p_c_2, p_u - ) + IntrusionRecoveryPomdpUtil.transition_tensor(states, actions, p_a, p_c_1, p_c_2, p_u) def test_transition_tensor_game(self) -> None: """ @@ -196,14 +167,12 @@ def test_transition_tensor_game(self) -> None: attacker_actions = [0, 1] p_a = 0.5 p_c_1 = 0.3 - result = IntrusionRecoveryPomdpUtil.transition_tensor_game( - states, defender_actions, attacker_actions, p_a, p_c_1 - ) + result = IntrusionRecoveryPomdpUtil.transition_tensor_game(states, defender_actions, attacker_actions, p_a, + p_c_1) assert len(result) == len(defender_actions) assert all(len(a1) == len(attacker_actions) for a1 in result) assert all(len(a2) == len(states) for a1 in result for a2 in a1) assert all(len(s) == len(states) for a1 in result for a2 in a1 for s in a2) - assert result[0][1][0][0] == (1 - p_a) * (1 - p_c_1) assert result[1][0][1][1] == 0 assert result[1][1][2][2] == 1.0 @@ -234,12 +203,8 @@ def test_sampe_next_observation(self) -> None: observation_tensor = [[0.8, 0.2], [0.4, 0.6]] s_prime = 1 observations = [0, 1] - assert isinstance( - IntrusionRecoveryPomdpUtil.sample_next_observation( - observation_tensor, s_prime, observations - ), - int, - ) + assert isinstance(IntrusionRecoveryPomdpUtil.sample_next_observation(observation_tensor, s_prime, observations), + int) def test_bayes_filter(self) -> None: """ @@ -256,22 +221,9 @@ def test_bayes_filter(self) -> None: observation_tensor = [[0.8, 0.2], [0.4, 0.6]] transition_tensor = [[[0.6, 0.4], [0.1, 0.9]]] b_prime_s_prime = 0.7 - assert ( - round( - IntrusionRecoveryPomdpUtil.bayes_filter( - s_prime, - o, - a, - b, - states, - observations, - observation_tensor, - transition_tensor, - ), - 1, - ) - == b_prime_s_prime - ) + assert (round(IntrusionRecoveryPomdpUtil.bayes_filter(s_prime, o, a, b, states, observations, + observation_tensor, transition_tensor), 1) + == b_prime_s_prime) def test_p_o_given_b_a1_a2(self) -> None: """ @@ -286,15 +238,8 @@ def test_p_o_given_b_a1_a2(self) -> None: observation_tensor = [[0.8, 0.2], [0.4, 0.6]] transition_tensor = [[[0.6, 0.4], [0.1, 0.9]]] expected = 0.5 - assert ( - round( - IntrusionRecoveryPomdpUtil.p_o_given_b_a1_a2( - o, b, a, states, transition_tensor, observation_tensor - ), - 1, - ) - == expected - ) + assert (round(IntrusionRecoveryPomdpUtil.p_o_given_b_a1_a2(o, b, a, states, transition_tensor, + observation_tensor), 1) == expected) def test_next_belief(self) -> None: """ @@ -309,23 +254,8 @@ def test_next_belief(self) -> None: observations = [0, 1] observation_tensor = [[0.8, 0.2], [0.4, 0.6]] transition_tensor = [[[0.3, 0.7], [0.6, 0.4]]] - assert ( - round( - sum( - IntrusionRecoveryPomdpUtil.next_belief( - o, - a, - b, - states, - observations, - observation_tensor, - transition_tensor, - ) - ), - 1, - ) - == 1 - ) + assert (round(sum(IntrusionRecoveryPomdpUtil.next_belief(o, a, b, states, observations, observation_tensor, + transition_tensor)), 1) == 1) def test_pomdp_solver_file(self) -> None: """ @@ -334,33 +264,14 @@ def test_pomdp_solver_file(self) -> None: :return: None """ - assert ( - IntrusionRecoveryPomdpUtil.pomdp_solver_file( - IntrusionRecoveryPomdpConfig( - eta=0.1, - p_a=0.2, - p_c_1=0.2, - p_c_2=0.3, - p_u=0.3, - BTR=1, - negate_costs=True, - seed=1, - discount_factor=0.5, - states=[0, 1], - actions=[0], - observations=[0, 1], - cost_tensor=[[0.1, 0.5], [0.5, 0.6]], - observation_tensor=[[0.8, 0.2], [0.4, 0.6]], - transition_tensor=[[[0.8, 0.2], [0.6, 0.4]]], - b1=[0.3, 0.7], - T=3, - simulation_env_name="env", - gym_env_name="gym", - max_horizon=np.inf, - ) - ) - is not None - ) + assert (IntrusionRecoveryPomdpUtil.pomdp_solver_file( + IntrusionRecoveryPomdpConfig(eta=0.1, p_a=0.2, p_c_1=0.2, p_c_2=0.3, p_u=0.3, BTR=1, negate_costs=True, + seed=1, discount_factor=0.5, states=[0, 1], actions=[0], observations=[0, 1], + cost_tensor=[[0.1, 0.5], [0.5, 0.6]], + observation_tensor=[[0.8, 0.2], [0.4, 0.6]], + transition_tensor=[[[0.8, 0.2], [0.6, 0.4]]], b1=[0.3, 0.7], T=3, + simulation_env_name="env", gym_env_name="gym", max_horizon=np.inf)) + is not None) def test_sample_next_state_game(self) -> None: """ @@ -444,9 +355,7 @@ def test_generate_transitions(self) -> None: gym_env_name="gym_env", max_horizon=1000, ) - assert ( - IntrusionRecoveryPomdpUtil.generate_transitions(dto)[0] == "0 0 0 0 0 0.06" - ) + assert IntrusionRecoveryPomdpUtil.generate_transitions(dto)[0] == "0 0 0 0 0 0.06" def test_generate_rewards(self) -> None: """ @@ -502,7 +411,11 @@ def test_generate_rewards(self) -> None: assert IntrusionRecoveryPomdpUtil.generate_rewards(dto)[0] == "0 0 0 -1" def test_generate_os_posg_game_file(self) -> None: - """ """ + """ + Tests the generate_os_posg_game function + + :return: None + """ states = [0, 1, 2] actions = [0, 1] @@ -580,24 +493,13 @@ def test_generate_os_posg_game_file(self) -> None: output_lines = game_file_str.split("\n") - assert ( - output_lines[0] == expected_game_description - ), f"Game description mismatch: {output_lines[0]}" - assert ( - output_lines[1:4] == expected_state_descriptions - ), f"State descriptions mismatch: {output_lines[1:4]}" - assert ( - output_lines[4:6] == expected_player_1_actions - ), f"Player 1 actions mismatch: {output_lines[4:6]}" - assert ( - output_lines[6:8] == expected_player_2_actions - ), f"Player 2 actions mismatch: {output_lines[6:8]}" - assert ( - output_lines[8:10] == expected_obs_descriptions - ), f"Observation descriptions mismatch: {output_lines[8:10]}" - assert ( - output_lines[10:13] == expected_player_2_legal_actions - ), f"Player 2 legal actions mismatch: {output_lines[10:13]}" - assert ( - output_lines[13:14] == expected_player_1_legal_actions - ), f"Player 1 legal actions mismatch: {output_lines[13:14]}" + assert (output_lines[0] == expected_game_description), f"Game description mismatch: {output_lines[0]}" + assert (output_lines[1:4] == expected_state_descriptions), f"State descriptions mismatch: {output_lines[1:4]}" + assert (output_lines[4:6] == expected_player_1_actions), f"Player 1 actions mismatch: {output_lines[4:6]}" + assert (output_lines[6:8] == expected_player_2_actions), f"Player 2 actions mismatch: {output_lines[6:8]}" + assert (output_lines[8:10] == expected_obs_descriptions), \ + f"Observation descriptions mismatch: {output_lines[8:10]}" + assert (output_lines[10:13] == expected_player_2_legal_actions), \ + f"Player 2 legal actions mismatch: {output_lines[10:13]}" + assert (output_lines[13:14] == expected_player_1_legal_actions), \ + f"Player 1 legal actions mismatch: {output_lines[13:14]}" diff --git a/simulation-system/libs/gym-csle-stopping-game/src/gym_csle_stopping_game/envs/stopping_game_env.py b/simulation-system/libs/gym-csle-stopping-game/src/gym_csle_stopping_game/envs/stopping_game_env.py index 53a757972..f7bc6007b 100644 --- a/simulation-system/libs/gym-csle-stopping-game/src/gym_csle_stopping_game/envs/stopping_game_env.py +++ b/simulation-system/libs/gym-csle-stopping-game/src/gym_csle_stopping_game/envs/stopping_game_env.py @@ -72,7 +72,7 @@ def step(self, action_profile: Tuple[int, Tuple[npt.NDArray[Any], int]]) \ a1, a2_profile = action_profile pi2, a2 = a2_profile assert pi2.shape[0] == len(self.config.S) - assert pi2.shape[1] == len(self.config.A1) + assert pi2.shape[1] == len(self.config.A2) done = False info: Dict[str, Any] = {} @@ -83,8 +83,7 @@ def step(self, action_profile: Tuple[int, Tuple[npt.NDArray[Any], int]]) \ else: # Compute r, s', b',o' r = self.config.R[self.state.l - 1][a1][a2][self.state.s] - self.state.s = StoppingGameUtil.sample_next_state(l=self.state.l, a1=a1, a2=a2, - T=self.config.T, + self.state.s = StoppingGameUtil.sample_next_state(l=self.state.l, a1=a1, a2=a2, T=self.config.T, S=self.config.S, s=self.state.s) o = StoppingGameUtil.sample_next_observation(Z=self.config.Z, O=self.config.O, s_prime=self.state.s) diff --git a/simulation-system/libs/gym-csle-stopping-game/tests/test_stopping_game_env.py b/simulation-system/libs/gym-csle-stopping-game/tests/test_stopping_game_env.py index 12b457457..eef079c70 100644 --- a/simulation-system/libs/gym-csle-stopping-game/tests/test_stopping_game_env.py +++ b/simulation-system/libs/gym-csle-stopping-game/tests/test_stopping_game_env.py @@ -1,11 +1,13 @@ +from typing import Dict, Any +import pytest +from unittest.mock import patch, MagicMock +from gym.spaces import Box, Discrete +import numpy as np from gym_csle_stopping_game.envs.stopping_game_env import StoppingGameEnv from gym_csle_stopping_game.dao.stopping_game_config import StoppingGameConfig from gym_csle_stopping_game.dao.stopping_game_state import StoppingGameState +import gym_csle_stopping_game.constants.constants as env_constants from csle_common.constants import constants -from unittest.mock import patch, MagicMock -from gym.spaces import Box, Discrete -import pytest -import numpy as np class TestStoppingGameEnvSuite: @@ -93,39 +95,24 @@ def test_stopping_game_init_(self) -> None: env = StoppingGameEnv(self.config) assert env.config == self.config - assert ( - env.attacker_observation_space.low.any() - == attacker_observation_space.low.any() - ) - assert ( - env.defender_observation_space.low.any() - == defender_observation_space.low.any() - ) + assert env.attacker_observation_space.low.any() == attacker_observation_space.low.any() + assert env.defender_observation_space.low.any() == defender_observation_space.low.any() assert env.attacker_action_space.n == attacker_action_space.n assert env.defender_action_space.n == defender_action_space.n assert env.traces == [] - with patch( - "gym_csle_stopping_game.dao.stopping_game_state.StoppingGameState" - ) as MockStoppingGameState: + with patch("gym_csle_stopping_game.dao.stopping_game_state.StoppingGameState") as MockStoppingGameState: MockStoppingGameState(b1=self.config.b1, L=self.config.L) - with patch( - "gym_csle_stopping_game.util.stopping_game_util.StoppingGameUtil.sample_initial_state" - ) as MockSampleInitialState: + with patch("gym_csle_stopping_game.util.stopping_game_util.StoppingGameUtil.sample_initial_state" + ) as MockSampleInitialState: MockSampleInitialState.return_value = 0 - env = StoppingGameEnv(self.config) + StoppingGameEnv(self.config) MockSampleInitialState.assert_called() - MockStoppingGameState.assert_called_once_with( - b1=self.config.b1, L=self.config.L - ) + MockStoppingGameState.assert_called_once_with(b1=self.config.b1, L=self.config.L) - with patch( - "csle_common.dao.simulation_config.simulation_trace.SimulationTrace" - ) as MockSimulationTrace: - mock_trace = MockSimulationTrace(self.config.env_name).return_value - print(mock_trace) - env = StoppingGameEnv(self.config) - print(env.trace) + with patch("csle_common.dao.simulation_config.simulation_trace.SimulationTrace") as MockSimulationTrace: + MockSimulationTrace(self.config.env_name).return_value + StoppingGameEnv(self.config) MockSimulationTrace.assert_called_once_with(self.config.env_name) def test_mean(self) -> None: @@ -148,21 +135,15 @@ def test_weighted_intrusion_prediction_distance(self) -> None: Tests the function of computing the weighed intrusion start time prediction distance """ # Test case when first_stop is before intrusion_start - result1 = StoppingGameEnv(self.config).weighted_intrusion_prediction_distance( - 5, 3 - ) + result1 = StoppingGameEnv(self.config).weighted_intrusion_prediction_distance(5, 3) assert result1 == 0 # Test case when first_stop is after intrusion_start - result2 = StoppingGameEnv(self.config).weighted_intrusion_prediction_distance( - 3, 5 - ) + result2 = StoppingGameEnv(self.config).weighted_intrusion_prediction_distance(3, 5) assert result2 == 0.95 # Test case when first_stop is equal to intrusion_start - result3 = StoppingGameEnv(self.config).weighted_intrusion_prediction_distance( - 3, 3 - ) + result3 = StoppingGameEnv(self.config).weighted_intrusion_prediction_distance(3, 3) assert result3 == 0 def test_reset(self) -> None: @@ -186,25 +167,13 @@ def test_reset(self) -> None: observation, info = env.reset() # Assertions assert env.state.reset.called, "State's reset method was not called." - assert ( - env.trace.simulation_env == self.config.env_name - ), "Trace was not initialized correctly." - assert ( - observation[0].all() == np.array([4, 5, 6]).all() - ), "Observation does not match expected values." - - assert ( - info[env_constants.ENV_METRICS.STOPS_REMAINING] == env.state.l - ), "Stops remaining does not match expected value." - assert ( - info[env_constants.ENV_METRICS.STATE] == env.state.s - ), "State info does not match expected value." - assert ( - info[env_constants.ENV_METRICS.OBSERVATION] == 0 - ), "Observation info does not match expected value." - assert ( - info[env_constants.ENV_METRICS.TIME_STEP] == env.state.t - ), "Time step info does not match expected value." + assert env.trace.simulation_env == self.config.env_name, "Trace was not initialized correctly." + assert observation[0].all() == np.array([4, 5, 6]).all(), "Observation does not match expected values." + assert info[env_constants.ENV_METRICS.STOPS_REMAINING] == env.state.l, \ + "Stops remaining does not match expected value." + assert info[env_constants.ENV_METRICS.STATE] == env.state.s, "State info does not match expected value." + assert info[env_constants.ENV_METRICS.OBSERVATION] == 0, "Observation info does not match expected value." + assert info[env_constants.ENV_METRICS.TIME_STEP] == env.state.t, "Time step info does not match expected value." # Check if trace was appended correctly if len(env.trace.attacker_rewards) > 0: @@ -241,10 +210,7 @@ def test_get_traces(self) -> None: :return: None """ - assert ( - StoppingGameEnv(self.config).get_traces() - == StoppingGameEnv(self.config).traces - ) + assert StoppingGameEnv(self.config).get_traces() == StoppingGameEnv(self.config).traces def test_reset_traces(self) -> None: """ @@ -267,7 +233,7 @@ def test_checkpoint_traces(self) -> None: fixed_timestamp = 123 with patch("time.time", return_value=fixed_timestamp): with patch( - "csle_common.dao.simulation_config.simulation_trace.SimulationTrace.save_traces" + "csle_common.dao.simulation_config.simulation_trace.SimulationTrace.save_traces" ) as mock_save_traces: env.traces = ["trace1", "trace2"] env._StoppingGameEnv__checkpoint_traces() @@ -312,7 +278,7 @@ def test_set_state(self) -> None: assert env.state.l == state_tuple[1] with pytest.raises(ValueError): - env.set_state([1, 2, 3]) + env.set_state([1, 2, 3]) # type: ignore def test_is_state_terminal(self) -> None: """ @@ -338,7 +304,7 @@ def test_is_state_terminal(self) -> None: assert not env.is_state_terminal(state_tuple) with pytest.raises(ValueError): - env.is_state_terminal([1, 2, 3]) + env.is_state_terminal([1, 2, 3]) # type: ignore def test_get_observation_from_history(self) -> None: """ @@ -400,18 +366,12 @@ def test_step(self) -> None: env.trace.attacker_observations = [] env.trace.defender_observations = [] - with patch( - "gym_csle_stopping_game.util.stopping_game_util.StoppingGameUtil.sample_next_state", - return_value=2, - ): - with patch( - "gym_csle_stopping_game.util.stopping_game_util.StoppingGameUtil.sample_next_observation", - return_value=1, - ): - with patch( - "gym_csle_stopping_game.util.stopping_game_util.StoppingGameUtil.next_belief", - return_value=np.array([0.3, 0.7, 0.0]), - ): + with patch("gym_csle_stopping_game.util.stopping_game_util.StoppingGameUtil.sample_next_state", + return_value=2): + with patch("gym_csle_stopping_game.util.stopping_game_util.StoppingGameUtil.sample_next_observation", + return_value=1): + with patch("gym_csle_stopping_game.util.stopping_game_util.StoppingGameUtil.next_belief", + return_value=np.array([0.3, 0.7, 0.0])): action_profile = ( 1, ( @@ -425,12 +385,8 @@ def test_step(self) -> None: action_profile ) - assert ( - observations[0] == np.array([4, 5, 6]) - ).all(), "Incorrect defender observations" - assert ( - observations[1] == np.array([1, 2, 3]) - ).all(), "Incorrect attacker observations" + assert (observations[0] == np.array([4, 5, 6])).all(), "Incorrect defender observations" + assert (observations[1] == np.array([1, 2, 3])).all(), "Incorrect attacker observations" assert rewards == (0, 0) assert not terminated assert not truncated @@ -443,13 +399,9 @@ def test_step(self) -> None: print(env.trace.beliefs) assert env.trace.beliefs[-1] == 0.7 assert env.trace.infrastructure_metrics[-1] == 1 - assert ( - env.trace.attacker_observations[-1] == np.array([1, 2, 3]) - ).all() - assert ( - env.trace.defender_observations[-1] == np.array([4, 5, 6]) - ).all() - + assert (env.trace.attacker_observations[-1] == np.array([1, 2, 3])).all() + assert (env.trace.defender_observations[-1] == np.array([4, 5, 6])).all() + def test_info(self) -> None: """ Tests the function of adding the cumulative reward and episode length to the info dict @@ -463,18 +415,14 @@ def test_info(self) -> None: env.trace.defender_actions = [0, 1] env.trace.states = [0, 1] env.trace.infrastructure_metrics = [0, 1] - - info = {} + info: Dict[str, Any] = {} updated_info = env._info(info) - print(updated_info) assert updated_info[env_constants.ENV_METRICS.RETURN] == sum(env.trace.defender_rewards) - + def test_emulation_evaluation(self) -> None: """ Tests the function for evaluating a strategy profile in the emulation environment :return: None """ - env = StoppingGameEnv(self.config) - env.state.b1 = [0.5, 0.5] - pass \ No newline at end of file + StoppingGameEnv(self.config) diff --git a/simulation-system/libs/gym-csle-stopping-game/tests/test_stopping_game_mdp_attacker_env.py b/simulation-system/libs/gym-csle-stopping-game/tests/test_stopping_game_mdp_attacker_env.py index 1eb438281..df461c511 100644 --- a/simulation-system/libs/gym-csle-stopping-game/tests/test_stopping_game_mdp_attacker_env.py +++ b/simulation-system/libs/gym-csle-stopping-game/tests/test_stopping_game_mdp_attacker_env.py @@ -108,7 +108,7 @@ def test_reset(self) -> None: env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config) attacker_obs, info = env.reset() - assert env.latest_defender_obs.all() == np.array([2, 0.4]).all() + assert env.latest_defender_obs.all() == np.array([2, 0.4]).all() # type: ignore assert info == {} def test_set_model(self) -> None: @@ -144,7 +144,7 @@ def test_set_state(self) -> None: ) env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config) - assert not env.set_state(1) + assert not env.set_state(1) # type: ignore def test_calculate_stage_policy(self) -> None: """ diff --git a/simulation-system/libs/gym-csle-stopping-game/tests/test_stopping_game_pomdp_defender_env.py b/simulation-system/libs/gym-csle-stopping-game/tests/test_stopping_game_pomdp_defender_env.py index fbc04c500..c3c8da4c1 100644 --- a/simulation-system/libs/gym-csle-stopping-game/tests/test_stopping_game_pomdp_defender_env.py +++ b/simulation-system/libs/gym-csle-stopping-game/tests/test_stopping_game_pomdp_defender_env.py @@ -1,12 +1,11 @@ -from gym_csle_stopping_game.envs.stopping_game_pomdp_defender_env import ( - StoppingGamePomdpDefenderEnv, -) +from gym_csle_stopping_game.envs.stopping_game_pomdp_defender_env import StoppingGamePomdpDefenderEnv from gym_csle_stopping_game.dao.stopping_game_config import StoppingGameConfig -from gym_csle_stopping_game.dao.stopping_game_defender_pomdp_config import ( - StoppingGameDefenderPomdpConfig, -) +from gym_csle_stopping_game.dao.stopping_game_defender_pomdp_config import StoppingGameDefenderPomdpConfig from gym_csle_stopping_game.envs.stopping_game_env import StoppingGameEnv +from gym_csle_stopping_game.util.stopping_game_util import StoppingGameUtil from csle_common.dao.training.policy import Policy +from csle_common.dao.training.random_policy import RandomPolicy +from csle_common.dao.training.player_type import PlayerType import pytest from unittest.mock import MagicMock import numpy as np @@ -25,19 +24,19 @@ def setup_env(self) -> None: :return: None """ env_name = "test_env" - T = np.array([[[0.1, 0.9], [0.4, 0.6]], [[0.7, 0.3], [0.2, 0.8]]]) - O = np.array([0, 1]) - Z = np.array([[[0.8, 0.2], [0.5, 0.5]], [[0.4, 0.6], [0.9, 0.1]]]) + T = StoppingGameUtil.transition_tensor(L=3, p=0) + O = StoppingGameUtil.observation_space(n=100) + Z = StoppingGameUtil.observation_tensor(n=100) R = np.zeros((2, 3, 3, 3)) - S = np.array([0, 1, 2]) - A1 = np.array([0, 1, 2]) - A2 = np.array([0, 1, 2]) + S = StoppingGameUtil.state_space() + A1 = StoppingGameUtil.defender_actions() + A2 = StoppingGameUtil.attacker_actions() L = 2 R_INT = 1 R_COST = 2 R_SLA = 3 R_ST = 4 - b1 = np.array([0.6, 0.4]) + b1 = StoppingGameUtil.b1() save_dir = "save_directory" checkpoint_traces_freq = 100 gamma = 0.9 @@ -220,11 +219,13 @@ def test_set_state(self) -> None: stopping_game_name="csle-stopping-game-v1", ) env = StoppingGamePomdpDefenderEnv(config=defender_pomdp_config) - assert not env.set_state(1) + assert env.set_state(1) is None # type: ignore def test_get_observation_from_history(self) -> None: """ - Tests the function for getting a defender observatin (belief) from a history + Tests the function for getting a defender observation (belief) from a history + + :return: None """ attacker_strategy = MagicMock(spec=Policy) defender_pomdp_config = StoppingGameDefenderPomdpConfig( @@ -299,11 +300,8 @@ def test_get_actions_from_particles(self) -> None: particles = [1, 2, 3] t = 0 observation = 0 - expected_actions = [0, 1, 2] - assert ( - env.get_actions_from_particles(particles, t, observation) - == expected_actions - ) + expected_actions = [0, 1] + assert env.get_actions_from_particles(particles, t, observation) == expected_actions def test_step(self) -> None: """ @@ -311,7 +309,14 @@ def test_step(self) -> None: :return: None """ - attacker_strategy = MagicMock(spec=Policy) + attacker_stage_strategy = np.zeros((3, 2)) + attacker_stage_strategy[0][0] = 0.9 + attacker_stage_strategy[0][1] = 0.1 + attacker_stage_strategy[1][0] = 0.9 + attacker_stage_strategy[1][1] = 0.1 + attacker_stage_strategy[2] = attacker_stage_strategy[1] + attacker_strategy = RandomPolicy(actions=list(self.config.A2), player_type=PlayerType.ATTACKER, + stage_policy_tensor=list(attacker_stage_strategy)) defender_pomdp_config = StoppingGameDefenderPomdpConfig( env_name="test_env", stopping_game_config=self.config, @@ -319,10 +324,13 @@ def test_step(self) -> None: stopping_game_name="csle-stopping-game-v1", ) env = StoppingGamePomdpDefenderEnv(config=defender_pomdp_config) - a1 = 2 + a1 = 1 + env.reset() defender_obs, reward, terminated, truncated, info = env.step(a1) - assert isinstance(defender_obs, int) - assert isinstance(reward, int) - assert isinstance(terminated, bool) - assert isinstance(truncated, bool) - assert isinstance(info, dict) + assert len(defender_obs) == 2 + assert isinstance(defender_obs[0], float) # type: ignore + assert isinstance(defender_obs[1], float) # type: ignore + assert isinstance(reward, float) # type: ignore + assert isinstance(terminated, bool) # type: ignore + assert isinstance(truncated, bool) # type: ignore + assert isinstance(info, dict) # type: ignore