Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3119 heart #3160

Merged
merged 5 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions monkey/infection_monkey/heart.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from common.common_consts import HEARTBEAT_INTERVAL
from common.utils.code_utils import PeriodicCaller
from infection_monkey.island_api_client import IIslandAPIClient
from infection_monkey.utils.ids import get_agent_id

logger = logging.getLogger(__name__)

Expand All @@ -15,14 +14,13 @@ def __init__(self, island_api_client: IIslandAPIClient):
self._periodic_caller = PeriodicCaller(
self._send_heartbeats, HEARTBEAT_INTERVAL, "AgentHeart"
)
self._agent_id = get_agent_id()

def start(self):
logger.info("Starting the Agent's heart")
self._periodic_caller.start()

def _send_heartbeats(self):
self._island_api_client.send_heartbeat(self._agent_id, time.time())
self._island_api_client.send_heartbeat(time.time())

def stop(self):
logger.info("Stopping the Agent's heart")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from common.agent_events import AbstractAgentEvent
from common.agent_plugins import AgentPlugin, AgentPluginManifest, AgentPluginType
from common.credentials import Credentials
from common.types import AgentID

from . import IIslandAPIClient, IslandAPIError

Expand Down Expand Up @@ -46,8 +45,8 @@ def get_agent_plugin_manifest(
) -> AgentPluginManifest:
return self._island_api_client.get_agent_plugin_manifest(plugin_type, plugin_name)

def get_agent_signals(self, agent_id: AgentID) -> AgentSignals:
return self._island_api_client.get_agent_signals(agent_id)
def get_agent_signals(self) -> AgentSignals:
return self._island_api_client.get_agent_signals()

def get_agent_configuration_schema(self) -> Dict[str, Any]:
return self._island_api_client.get_agent_configuration_schema()
Expand All @@ -72,8 +71,8 @@ def register_agent(self, agent_registration_data: AgentRegistrationData):
def send_events(self, events: Sequence[AbstractAgentEvent]):
return self._island_api_client.send_events(events)

def send_heartbeat(self, agent_id: AgentID, timestamp: float):
return self._island_api_client.send_heartbeat(agent_id, timestamp)
def send_heartbeat(self, timestamp: float):
return self._island_api_client.send_heartbeat(timestamp)

def send_log(self, agent_id: AgentID, log_contents: str):
return self._island_api_client.send_log(agent_id, log_contents)
def send_log(self, log_contents: str):
return self._island_api_client.send_log(log_contents)
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,11 @@ def __init__(
self,
agent_event_serializer_registry: AgentEventSerializerRegistry,
http_client: HTTPClient,
agent_id: AgentID,
):
self._agent_event_serializer_registry = agent_event_serializer_registry
self._http_client = http_client
self._agent_id = agent_id

@handle_response_parsing_errors
def login(self, otp: OTP):
Expand Down Expand Up @@ -101,9 +103,9 @@ def get_agent_plugin_manifest(
return AgentPluginManifest(**response.json())

@handle_response_parsing_errors
def get_agent_signals(self, agent_id: str) -> AgentSignals:
def get_agent_signals(self) -> AgentSignals:
response = self._http_client.get(
f"/agent-signals/{agent_id}", timeout=SHORT_REQUEST_TIMEOUT
f"/agent-signals/{self._agent_id}", timeout=SHORT_REQUEST_TIMEOUT
)

return AgentSignals(**response.json())
Expand Down Expand Up @@ -154,12 +156,12 @@ def _serialize_events(self, events: Sequence[AbstractAgentEvent]) -> JSONSeriali

return serialized_events

def send_heartbeat(self, agent_id: AgentID, timestamp: float):
def send_heartbeat(self, timestamp: float):
data = AgentHeartbeat(timestamp=timestamp).dict(simplify=True)
self._http_client.post(f"/agent/{agent_id}/heartbeat", data)
self._http_client.post(f"/agent/{self._agent_id}/heartbeat", data)

def send_log(self, agent_id: AgentID, log_contents: str):
def send_log(self, log_contents: str):
self._http_client.put(
f"/agent-logs/{agent_id}",
f"/agent-logs/{self._agent_id}",
log_contents,
)
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from common.agent_event_serializers import AgentEventSerializerRegistry
from common.types import SocketAddress
from common.types import AgentID, SocketAddress

from . import (
AbstractIslandAPIClientFactory,
Expand All @@ -11,12 +11,17 @@


class HTTPIslandAPIClientFactory(AbstractIslandAPIClientFactory):
def __init__(self, agent_event_serializer_registry: AgentEventSerializerRegistry):
def __init__(
self, agent_event_serializer_registry: AgentEventSerializerRegistry, agent_id: AgentID
):
self._agent_event_serializer_registry = agent_event_serializer_registry
self._agent_id = agent_id

def create_island_api_client(self, server: SocketAddress) -> IIslandAPIClient:
return ConfigurationValidatorDecorator(
HTTPIslandAPIClient(
self._agent_event_serializer_registry, HTTPClient(f"https://{server}/api")
self._agent_event_serializer_registry,
HTTPClient(f"https://{server}/api"),
self._agent_id,
)
)
12 changes: 4 additions & 8 deletions monkey/infection_monkey/island_api_client/i_island_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from common.agent_events import AbstractAgentEvent
from common.agent_plugins import AgentPlugin, AgentPluginManifest, AgentPluginType
from common.credentials import Credentials
from common.types import AgentID


class IIslandAPIClient(ABC):
Expand Down Expand Up @@ -105,11 +104,10 @@ def get_agent_plugin_manifest(
"""

@abstractmethod
def get_agent_signals(self, agent_id: AgentID) -> AgentSignals:
def get_agent_signals(self) -> AgentSignals:
"""
Gets an agent's signals from the island
Gets the agent's signals from the island

:param agent_id: ID of the agent whose signals should be retrieved
:raises IslandAPIAuthenticationError: If the client is not authorized to access this
endpoint
:raises IslandAPIConnectionError: If the client could not connect to the island
Expand Down Expand Up @@ -197,11 +195,10 @@ def send_events(self, events: Sequence[AbstractAgentEvent]):
"""

@abstractmethod
def send_heartbeat(self, agent_id: AgentID, timestamp: float):
def send_heartbeat(self, timestamp: float):
"""
Send a "heartbeat" to the Island to indicate that the agent is still alive

:param agent_id: The ID of the agent who is sending a heartbeat
:param timestamp: The timestamp of the agent's heartbeat
:raises IslandAPIAuthenticationError: If the client is not authorized to access this
endpoint
Expand All @@ -216,11 +213,10 @@ def send_heartbeat(self, agent_id: AgentID, timestamp: float):
"""

@abstractmethod
def send_log(self, agent_id: AgentID, log_contents: str):
def send_log(self, log_contents: str):
"""
Send the contents of the agent's log to the island

:param agent_id: The ID of the agent whose logs are being sent
:param log_contents: The contents of the agent's log
:raises IslandAPIAuthenticationError: If the client is not authorized to access this
endpoint
Expand Down
6 changes: 2 additions & 4 deletions monkey/infection_monkey/master/control_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from common.agent_configuration import AgentConfiguration
from common.credentials import Credentials
from common.types import AgentID
from infection_monkey.i_control_channel import IControlChannel, IslandCommunicationError
from infection_monkey.island_api_client import IIslandAPIClient, IslandAPIError

Expand All @@ -27,8 +26,7 @@ def wrapper(*args, **kwargs):


class ControlChannel(IControlChannel):
def __init__(self, server: str, agent_id: AgentID, api_client: IIslandAPIClient):
self._agent_id = agent_id
def __init__(self, server: str, api_client: IIslandAPIClient):
self._control_channel_server = server
self._island_api_client = api_client

Expand All @@ -37,7 +35,7 @@ def should_agent_stop(self) -> bool:
if not self._control_channel_server:
logger.error("Agent should stop because it can't connect to the C&C server.")
return True
agent_signals = self._island_api_client.get_agent_signals(self._agent_id)
agent_signals = self._island_api_client.get_agent_signals()
return agent_signals.terminate is not None

@handle_island_api_errors
Expand Down
8 changes: 3 additions & 5 deletions monkey/infection_monkey/monkey.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,7 @@ def __init__(self, args, ipc_logger_queue: multiprocessing.Queue, log_path: Path
self._island_address, self._island_api_client = self._connect_to_island_api()
self._register_agent()

self._control_channel = ControlChannel(
str(self._island_address), self._agent_id, self._island_api_client
)
self._control_channel = ControlChannel(str(self._island_address), self._island_api_client)
self._legacy_propagation_credentials_repository = (
AggregatingPropagationCredentialsRepository(self._control_channel)
)
Expand Down Expand Up @@ -202,7 +200,7 @@ def _connect_to_island_api(self) -> Tuple[SocketAddress, IIslandAPIClient]:
)

http_island_api_client_factory = HTTPIslandAPIClientFactory(
self._agent_event_serializer_registry
self._agent_event_serializer_registry, self._agent_id
)

server, island_api_client = self._select_server(
Expand Down Expand Up @@ -553,7 +551,7 @@ def _send_log(self):
except FileNotFoundError:
logger.exception(f"Log file {self._log_path} is not found.")

self._island_api_client.send_log(self._agent_id, log_contents)
self._island_api_client.send_log(log_contents)

def _delete_plugin_dir(self):
if not self._plugin_dir.exists():
Expand Down
37 changes: 19 additions & 18 deletions monkey/tests/unit_tests/infection_monkey/base_island_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,48 +5,49 @@
from common.agent_events import AbstractAgentEvent
from common.agent_plugins import AgentPlugin, AgentPluginManifest, AgentPluginType
from common.credentials import Credentials
from common.types import AgentID
from infection_monkey.island_api_client import IIslandAPIClient


class BaseIslandAPIClient(IIslandAPIClient):
def login(self, otp: str):
pass
return

def get_agent_binary(self, operating_system: OperatingSystem) -> bytes:
pass
return b""

def get_agent_plugin(self, plugin_type: AgentPluginType, plugin_name: str) -> AgentPlugin:
pass
def get_agent_plugin(
self, operating_system: OperatingSystem, plugin_type: AgentPluginType, plugin_name: str
) -> AgentPlugin:
return AgentPlugin()

def get_otp(self):
pass
return

def get_agent_plugin_manifest(
self, plugin_type: AgentPluginType, plugin_name: str
) -> AgentPluginManifest:
pass
return AgentPluginManifest()

def get_agent_signals(self, agent_id: str) -> AgentSignals:
pass
def get_agent_signals(self) -> AgentSignals:
return AgentSignals()

def get_agent_configuration_schema(self) -> Dict[str, Any]:
pass
return {}

def get_config(self) -> AgentConfiguration:
pass
return AgentConfiguration()

def get_credentials_for_propagation(self) -> Sequence[Credentials]:
pass
return []

def register_agent(self, agent_registration_data: AgentRegistrationData):
pass
return

def send_events(self, events: Sequence[AbstractAgentEvent]):
pass
return

def send_heartbeat(self, agent: AgentID, timestamp: float):
pass
def send_heartbeat(self, timestamp: float):
return

def send_log(self, agent_id: AgentID, log_contents: str):
pass
def send_log(self, log_contents: str):
return
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def agent_event_serializer_registry():


def build_api_client(http_client):
return HTTPIslandAPIClient(agent_event_serializer_registry(), http_client)
return HTTPIslandAPIClient(agent_event_serializer_registry(), http_client, AGENT_ID)


def _build_client_with_json_response(response):
Expand Down Expand Up @@ -205,7 +205,7 @@ def test_island_api_client__unhandled_exceptions():
api_client = build_api_client(http_client_stub)

with pytest.raises(OSError):
api_client.get_agent_signals(agent_id=AGENT_ID)
api_client.get_agent_signals()


def test_island_api_client_get_otp():
Expand All @@ -229,7 +229,7 @@ def test_island_api_client__handled_exceptions():
api_client = build_api_client(http_client_stub)

with pytest.raises(IslandAPIResponseParsingError):
api_client.get_agent_signals(agent_id=AGENT_ID)
api_client.get_agent_signals()


def test_island_api_client_get_agent_plugin_manifest():
Expand Down Expand Up @@ -259,7 +259,7 @@ def test_island_api_client_get_agent_signals(timestamp):
expected_agent_signals = AgentSignals(terminate=timestamp)
api_client = _build_client_with_json_response({"terminate": timestamp})

actual_agent_signals = api_client.get_agent_signals(agent_id=AGENT_ID)
actual_agent_signals = api_client.get_agent_signals()

assert actual_agent_signals == expected_agent_signals

Expand All @@ -269,7 +269,7 @@ def test_island_api_client_get_agent_signals__bad_json(timestamp):
api_client = _build_client_with_json_response({"terminate": timestamp, "discombobulate": 20})

with pytest.raises(IslandAPIResponseParsingError):
api_client.get_agent_signals(agent_id=AGENT_ID)
api_client.get_agent_signals()


def test_island_api_client_get_agent_configuration_schema():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def island_api_client() -> IIslandAPIClient:

@pytest.fixture
def control_channel(island_api_client) -> ControlChannel:
return ControlChannel(SERVER, AGENT_ID, island_api_client)
return ControlChannel(SERVER, island_api_client)


@pytest.mark.parametrize("signal_time,expected_should_stop", [(1663950115, True), (None, False)])
Expand Down