diff --git a/monkey/infection_monkey/heart.py b/monkey/infection_monkey/heart.py index 35a5e1f771c..fbbca811e37 100644 --- a/monkey/infection_monkey/heart.py +++ b/monkey/infection_monkey/heart.py @@ -4,7 +4,6 @@ from common.common_consts import HEARTBEAT_INTERVAL from common.utils.code_utils import PeriodicCaller from infection_monkey.island_api_client import IIslandAPIClient -from infection_monkey.utils.ids import get_agent_id logger = logging.getLogger(__name__) @@ -15,14 +14,13 @@ def __init__(self, island_api_client: IIslandAPIClient): self._periodic_caller = PeriodicCaller( self._send_heartbeats, HEARTBEAT_INTERVAL, "AgentHeart" ) - self._agent_id = get_agent_id() def start(self): logger.info("Starting the Agent's heart") self._periodic_caller.start() def _send_heartbeats(self): - self._island_api_client.send_heartbeat(self._agent_id, time.time()) + self._island_api_client.send_heartbeat(time.time()) def stop(self): logger.info("Stopping the Agent's heart") diff --git a/monkey/infection_monkey/island_api_client/configuration_validator_decorator.py b/monkey/infection_monkey/island_api_client/configuration_validator_decorator.py index d7d9691edfa..0a1124a8c53 100644 --- a/monkey/infection_monkey/island_api_client/configuration_validator_decorator.py +++ b/monkey/infection_monkey/island_api_client/configuration_validator_decorator.py @@ -7,7 +7,6 @@ from common.agent_events import AbstractAgentEvent from common.agent_plugins import AgentPlugin, AgentPluginManifest, AgentPluginType from common.credentials import Credentials -from common.types import AgentID from . import IIslandAPIClient, IslandAPIError @@ -46,8 +45,8 @@ def get_agent_plugin_manifest( ) -> AgentPluginManifest: return self._island_api_client.get_agent_plugin_manifest(plugin_type, plugin_name) - def get_agent_signals(self, agent_id: AgentID) -> AgentSignals: - return self._island_api_client.get_agent_signals(agent_id) + def get_agent_signals(self) -> AgentSignals: + return self._island_api_client.get_agent_signals() def get_agent_configuration_schema(self) -> Dict[str, Any]: return self._island_api_client.get_agent_configuration_schema() @@ -72,8 +71,8 @@ def register_agent(self, agent_registration_data: AgentRegistrationData): def send_events(self, events: Sequence[AbstractAgentEvent]): return self._island_api_client.send_events(events) - def send_heartbeat(self, agent_id: AgentID, timestamp: float): - return self._island_api_client.send_heartbeat(agent_id, timestamp) + def send_heartbeat(self, timestamp: float): + return self._island_api_client.send_heartbeat(timestamp) - def send_log(self, agent_id: AgentID, log_contents: str): - return self._island_api_client.send_log(agent_id, log_contents) + def send_log(self, log_contents: str): + return self._island_api_client.send_log(log_contents) diff --git a/monkey/infection_monkey/island_api_client/http_island_api_client.py b/monkey/infection_monkey/island_api_client/http_island_api_client.py index 1f7dbce27da..40fe3e8fec2 100644 --- a/monkey/infection_monkey/island_api_client/http_island_api_client.py +++ b/monkey/infection_monkey/island_api_client/http_island_api_client.py @@ -51,9 +51,11 @@ def __init__( self, agent_event_serializer_registry: AgentEventSerializerRegistry, http_client: HTTPClient, + agent_id: AgentID, ): self._agent_event_serializer_registry = agent_event_serializer_registry self._http_client = http_client + self._agent_id = agent_id @handle_response_parsing_errors def login(self, otp: OTP): @@ -101,9 +103,9 @@ def get_agent_plugin_manifest( return AgentPluginManifest(**response.json()) @handle_response_parsing_errors - def get_agent_signals(self, agent_id: str) -> AgentSignals: + def get_agent_signals(self) -> AgentSignals: response = self._http_client.get( - f"/agent-signals/{agent_id}", timeout=SHORT_REQUEST_TIMEOUT + f"/agent-signals/{self._agent_id}", timeout=SHORT_REQUEST_TIMEOUT ) return AgentSignals(**response.json()) @@ -154,12 +156,12 @@ def _serialize_events(self, events: Sequence[AbstractAgentEvent]) -> JSONSeriali return serialized_events - def send_heartbeat(self, agent_id: AgentID, timestamp: float): + def send_heartbeat(self, timestamp: float): data = AgentHeartbeat(timestamp=timestamp).dict(simplify=True) - self._http_client.post(f"/agent/{agent_id}/heartbeat", data) + self._http_client.post(f"/agent/{self._agent_id}/heartbeat", data) - def send_log(self, agent_id: AgentID, log_contents: str): + def send_log(self, log_contents: str): self._http_client.put( - f"/agent-logs/{agent_id}", + f"/agent-logs/{self._agent_id}", log_contents, ) diff --git a/monkey/infection_monkey/island_api_client/http_island_api_client_factory.py b/monkey/infection_monkey/island_api_client/http_island_api_client_factory.py index 5b1f37efe65..2f76db3a94c 100644 --- a/monkey/infection_monkey/island_api_client/http_island_api_client_factory.py +++ b/monkey/infection_monkey/island_api_client/http_island_api_client_factory.py @@ -1,5 +1,5 @@ from common.agent_event_serializers import AgentEventSerializerRegistry -from common.types import SocketAddress +from common.types import AgentID, SocketAddress from . import ( AbstractIslandAPIClientFactory, @@ -11,12 +11,17 @@ class HTTPIslandAPIClientFactory(AbstractIslandAPIClientFactory): - def __init__(self, agent_event_serializer_registry: AgentEventSerializerRegistry): + def __init__( + self, agent_event_serializer_registry: AgentEventSerializerRegistry, agent_id: AgentID + ): self._agent_event_serializer_registry = agent_event_serializer_registry + self._agent_id = agent_id def create_island_api_client(self, server: SocketAddress) -> IIslandAPIClient: return ConfigurationValidatorDecorator( HTTPIslandAPIClient( - self._agent_event_serializer_registry, HTTPClient(f"https://{server}/api") + self._agent_event_serializer_registry, + HTTPClient(f"https://{server}/api"), + self._agent_id, ) ) diff --git a/monkey/infection_monkey/island_api_client/i_island_api_client.py b/monkey/infection_monkey/island_api_client/i_island_api_client.py index 2b35348dda8..b6a66d4ed5b 100644 --- a/monkey/infection_monkey/island_api_client/i_island_api_client.py +++ b/monkey/infection_monkey/island_api_client/i_island_api_client.py @@ -6,7 +6,6 @@ from common.agent_events import AbstractAgentEvent from common.agent_plugins import AgentPlugin, AgentPluginManifest, AgentPluginType from common.credentials import Credentials -from common.types import AgentID class IIslandAPIClient(ABC): @@ -105,11 +104,10 @@ def get_agent_plugin_manifest( """ @abstractmethod - def get_agent_signals(self, agent_id: AgentID) -> AgentSignals: + def get_agent_signals(self) -> AgentSignals: """ - Gets an agent's signals from the island + Gets the agent's signals from the island - :param agent_id: ID of the agent whose signals should be retrieved :raises IslandAPIAuthenticationError: If the client is not authorized to access this endpoint :raises IslandAPIConnectionError: If the client could not connect to the island @@ -197,11 +195,10 @@ def send_events(self, events: Sequence[AbstractAgentEvent]): """ @abstractmethod - def send_heartbeat(self, agent_id: AgentID, timestamp: float): + def send_heartbeat(self, timestamp: float): """ Send a "heartbeat" to the Island to indicate that the agent is still alive - :param agent_id: The ID of the agent who is sending a heartbeat :param timestamp: The timestamp of the agent's heartbeat :raises IslandAPIAuthenticationError: If the client is not authorized to access this endpoint @@ -216,11 +213,10 @@ def send_heartbeat(self, agent_id: AgentID, timestamp: float): """ @abstractmethod - def send_log(self, agent_id: AgentID, log_contents: str): + def send_log(self, log_contents: str): """ Send the contents of the agent's log to the island - :param agent_id: The ID of the agent whose logs are being sent :param log_contents: The contents of the agent's log :raises IslandAPIAuthenticationError: If the client is not authorized to access this endpoint diff --git a/monkey/infection_monkey/master/control_channel.py b/monkey/infection_monkey/master/control_channel.py index 0c2f7383028..8bafb00306c 100644 --- a/monkey/infection_monkey/master/control_channel.py +++ b/monkey/infection_monkey/master/control_channel.py @@ -6,7 +6,6 @@ from common.agent_configuration import AgentConfiguration from common.credentials import Credentials -from common.types import AgentID from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError from infection_monkey.island_api_client import IIslandAPIClient, IslandAPIError @@ -27,8 +26,7 @@ def wrapper(*args, **kwargs): class ControlChannel(IControlChannel): - def __init__(self, server: str, agent_id: AgentID, api_client: IIslandAPIClient): - self._agent_id = agent_id + def __init__(self, server: str, api_client: IIslandAPIClient): self._control_channel_server = server self._island_api_client = api_client @@ -37,7 +35,7 @@ def should_agent_stop(self) -> bool: if not self._control_channel_server: logger.error("Agent should stop because it can't connect to the C&C server.") return True - agent_signals = self._island_api_client.get_agent_signals(self._agent_id) + agent_signals = self._island_api_client.get_agent_signals() return agent_signals.terminate is not None @handle_island_api_errors diff --git a/monkey/infection_monkey/monkey.py b/monkey/infection_monkey/monkey.py index 5fa0a861fa5..e80fe7c2b36 100644 --- a/monkey/infection_monkey/monkey.py +++ b/monkey/infection_monkey/monkey.py @@ -142,9 +142,7 @@ def __init__(self, args, ipc_logger_queue: multiprocessing.Queue, log_path: Path self._island_address, self._island_api_client = self._connect_to_island_api() self._register_agent() - self._control_channel = ControlChannel( - str(self._island_address), self._agent_id, self._island_api_client - ) + self._control_channel = ControlChannel(str(self._island_address), self._island_api_client) self._legacy_propagation_credentials_repository = ( AggregatingPropagationCredentialsRepository(self._control_channel) ) @@ -202,7 +200,7 @@ def _connect_to_island_api(self) -> Tuple[SocketAddress, IIslandAPIClient]: ) http_island_api_client_factory = HTTPIslandAPIClientFactory( - self._agent_event_serializer_registry + self._agent_event_serializer_registry, self._agent_id ) server, island_api_client = self._select_server( @@ -553,7 +551,7 @@ def _send_log(self): except FileNotFoundError: logger.exception(f"Log file {self._log_path} is not found.") - self._island_api_client.send_log(self._agent_id, log_contents) + self._island_api_client.send_log(log_contents) def _delete_plugin_dir(self): if not self._plugin_dir.exists(): diff --git a/monkey/tests/unit_tests/infection_monkey/base_island_api_client.py b/monkey/tests/unit_tests/infection_monkey/base_island_api_client.py index c79800b2d73..57c75fc4b63 100644 --- a/monkey/tests/unit_tests/infection_monkey/base_island_api_client.py +++ b/monkey/tests/unit_tests/infection_monkey/base_island_api_client.py @@ -5,48 +5,49 @@ from common.agent_events import AbstractAgentEvent from common.agent_plugins import AgentPlugin, AgentPluginManifest, AgentPluginType from common.credentials import Credentials -from common.types import AgentID from infection_monkey.island_api_client import IIslandAPIClient class BaseIslandAPIClient(IIslandAPIClient): def login(self, otp: str): - pass + return def get_agent_binary(self, operating_system: OperatingSystem) -> bytes: - pass + return b"" - def get_agent_plugin(self, plugin_type: AgentPluginType, plugin_name: str) -> AgentPlugin: - pass + def get_agent_plugin( + self, operating_system: OperatingSystem, plugin_type: AgentPluginType, plugin_name: str + ) -> AgentPlugin: + return AgentPlugin() def get_otp(self): - pass + return def get_agent_plugin_manifest( self, plugin_type: AgentPluginType, plugin_name: str ) -> AgentPluginManifest: - pass + return AgentPluginManifest() - def get_agent_signals(self, agent_id: str) -> AgentSignals: - pass + def get_agent_signals(self) -> AgentSignals: + return AgentSignals() def get_agent_configuration_schema(self) -> Dict[str, Any]: - pass + return {} def get_config(self) -> AgentConfiguration: - pass + return AgentConfiguration() def get_credentials_for_propagation(self) -> Sequence[Credentials]: - pass + return [] def register_agent(self, agent_registration_data: AgentRegistrationData): - pass + return def send_events(self, events: Sequence[AbstractAgentEvent]): - pass + return - def send_heartbeat(self, agent: AgentID, timestamp: float): - pass + def send_heartbeat(self, timestamp: float): + return - def send_log(self, agent_id: AgentID, log_contents: str): - pass + def send_log(self, log_contents: str): + return diff --git a/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py b/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py index 98b5e036423..6dec3890326 100644 --- a/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py +++ b/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py @@ -88,7 +88,7 @@ def agent_event_serializer_registry(): def build_api_client(http_client): - return HTTPIslandAPIClient(agent_event_serializer_registry(), http_client) + return HTTPIslandAPIClient(agent_event_serializer_registry(), http_client, AGENT_ID) def _build_client_with_json_response(response): @@ -205,7 +205,7 @@ def test_island_api_client__unhandled_exceptions(): api_client = build_api_client(http_client_stub) with pytest.raises(OSError): - api_client.get_agent_signals(agent_id=AGENT_ID) + api_client.get_agent_signals() def test_island_api_client_get_otp(): @@ -229,7 +229,7 @@ def test_island_api_client__handled_exceptions(): api_client = build_api_client(http_client_stub) with pytest.raises(IslandAPIResponseParsingError): - api_client.get_agent_signals(agent_id=AGENT_ID) + api_client.get_agent_signals() def test_island_api_client_get_agent_plugin_manifest(): @@ -259,7 +259,7 @@ def test_island_api_client_get_agent_signals(timestamp): expected_agent_signals = AgentSignals(terminate=timestamp) api_client = _build_client_with_json_response({"terminate": timestamp}) - actual_agent_signals = api_client.get_agent_signals(agent_id=AGENT_ID) + actual_agent_signals = api_client.get_agent_signals() assert actual_agent_signals == expected_agent_signals @@ -269,7 +269,7 @@ def test_island_api_client_get_agent_signals__bad_json(timestamp): api_client = _build_client_with_json_response({"terminate": timestamp, "discombobulate": 20}) with pytest.raises(IslandAPIResponseParsingError): - api_client.get_agent_signals(agent_id=AGENT_ID) + api_client.get_agent_signals() def test_island_api_client_get_agent_configuration_schema(): diff --git a/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py b/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py index efc52f79f3d..b29dbbc3443 100644 --- a/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py +++ b/monkey/tests/unit_tests/infection_monkey/master/test_control_channel.py @@ -32,7 +32,7 @@ def island_api_client() -> IIslandAPIClient: @pytest.fixture def control_channel(island_api_client) -> ControlChannel: - return ControlChannel(SERVER, AGENT_ID, island_api_client) + return ControlChannel(SERVER, island_api_client) @pytest.mark.parametrize("signal_time,expected_should_stop", [(1663950115, True), (None, False)])