Skip to content

Commit

Permalink
Merge branch '3241-fix-masque-caching' into develop
Browse files Browse the repository at this point in the history
Issue #3241
PR #3255
  • Loading branch information
mssalvatore committed Apr 26, 2023
2 parents f5169e4 + b9efffd commit c3992ff
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 13 deletions.
1 change: 1 addition & 0 deletions monkey/common/utils/file_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
get_binary_io_sha256_hash,
get_text_file_contents,
InvalidPath,
make_fileobj_copy,
)
from .secure_directory import create_secure_directory
from .secure_file import open_new_securely_permissioned_file
23 changes: 23 additions & 0 deletions monkey/common/utils/file_utils/file_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import hashlib
import io
import logging
import os
import shutil
from pathlib import Path
from typing import BinaryIO, Iterable

Expand Down Expand Up @@ -43,3 +45,24 @@ def get_text_file_contents(file_path: Path) -> str:
with open(file_path, "rt") as f:
file_contents = f.read()
return file_contents


def make_fileobj_copy(src: BinaryIO) -> BinaryIO:
"""
Creates a file-like object that is a copy of the provided file-like object
The source file-like object is reset to position 0 and a copy is made. Both the source file and
the copy are reset to position 0 before returning.
:param src: A file-like object to copy
:return: A file-like object that is a copy of the provided file-like object
"""
dst = io.BytesIO()

src.seek(0)
shutil.copyfileobj(src, dst)

src.seek(0)
dst.seek(0)

return dst
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import io
import re
import shutil
from functools import lru_cache
from typing import BinaryIO, Optional, Sequence

from common.utils.file_utils import make_fileobj_copy

from . import IFileRepository


Expand All @@ -29,13 +29,8 @@ def save_file(self, unsafe_file_name: str, file_contents: BinaryIO):

def open_file(self, unsafe_file_name: str) -> BinaryIO:
original_file = self._open_file(unsafe_file_name)
file_copy = io.BytesIO()

shutil.copyfileobj(original_file, file_copy)
original_file.seek(0)
file_copy.seek(0)

return file_copy
return make_fileobj_copy(original_file)

@lru_cache(maxsize=16)
def _open_file(self, unsafe_file_name: str) -> BinaryIO:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import BinaryIO, Mapping, Optional

from common import OperatingSystem
from common.utils.file_utils import make_fileobj_copy

from .i_agent_binary_repository import IAgentBinaryRepository

Expand All @@ -29,8 +30,13 @@ def __init__(
self._masques = masques
self._null_bytes = b"\x00" * null_bytes_length

@lru_cache()
def get_agent_binary(self, operating_system: OperatingSystem) -> BinaryIO:
original_file = self._get_agent_binary(operating_system)

return make_fileobj_copy(original_file)

@lru_cache()
def _get_agent_binary(self, operating_system: OperatingSystem) -> BinaryIO:
agent_binary = self._agent_binary_repository.get_agent_binary(operating_system)
return self._apply_masque(operating_system, agent_binary)

Expand Down
30 changes: 29 additions & 1 deletion monkey/tests/unit_tests/common/utils/test_file_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import io
import os
import stat

import pytest
from tests.monkey_island.utils import assert_linux_permissions, assert_windows_permissions

from common.utils.environment import is_windows_os
from common.utils.file_utils import create_secure_directory, open_new_securely_permissioned_file
from common.utils.file_utils import (
create_secure_directory,
make_fileobj_copy,
open_new_securely_permissioned_file,
)
from common.utils.file_utils.secure_directory import FailedDirectoryCreationError


Expand Down Expand Up @@ -131,3 +136,26 @@ def test_open_new_securely_permissioned_file__perm_windows(test_path):
pass

assert_windows_permissions(test_path)


def test_make_fileobj_copy():
TEST_STR = b"Hello World"
with io.BytesIO(TEST_STR) as src:
dst = make_fileobj_copy(src)

# Writing the assertion this way verifies that both src and dest file handles have had
# their positions reset to 0.
assert src.read() == TEST_STR
assert dst.read() == TEST_STR


def test_make_fileobj_copy_seek_src_to_0():
TEST_STR = b"Hello World"
with io.BytesIO(TEST_STR) as src:
src.seek(int(len(TEST_STR) / 2))
dst = make_fileobj_copy(src)

# Writing the assertion this way verifies that both src and dest file handles have had
# their positions reset to 0.
assert src.read() == TEST_STR
assert dst.read() == TEST_STR
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,29 @@ def test_get_agent_binary__cached(
mock_masquerade_agent_binary_repository: MasqueradeAgentBinaryRepositoryDecorator,
operating_system: OperatingSystem,
):
actual_linux_binary = mock_masquerade_agent_binary_repository.get_agent_binary(operating_system)
actual_binary = mock_masquerade_agent_binary_repository.get_agent_binary(operating_system)
in_memory_agent_binary_repository.agent_binaries[operating_system] = b"new_binary"
cached_linux_binary = mock_masquerade_agent_binary_repository.get_agent_binary(operating_system)
cached_binary = mock_masquerade_agent_binary_repository.get_agent_binary(operating_system)

assert actual_binary.read() == cached_binary.read()


def test_get_agent_binary__cached_multiple_calls(
in_memory_agent_binary_repository: InMemoryAgentBinaryRepository,
mock_masquerade_agent_binary_repository: MasqueradeAgentBinaryRepositoryDecorator,
):
operating_system = OperatingSystem.WINDOWS

assert actual_linux_binary == cached_linux_binary
cached_binary_1 = mock_masquerade_agent_binary_repository.get_agent_binary(operating_system)
in_memory_agent_binary_repository.agent_binaries[operating_system] = b"new_binary"
cached_binary_2 = mock_masquerade_agent_binary_repository.get_agent_binary(operating_system)
cached_binary_3 = mock_masquerade_agent_binary_repository.get_agent_binary(operating_system)

# Writing the assertion this way verifies that returned files have had their positions reset to
# the beginning (i.e. seek(0)).
assert cached_binary_1.read() == MASQUED_WINDOWS_AGENT_BINARY.getvalue()
assert cached_binary_2.read() == MASQUED_WINDOWS_AGENT_BINARY.getvalue()
assert cached_binary_3.read() == MASQUED_WINDOWS_AGENT_BINARY.getvalue()


@pytest.mark.parametrize(
Expand Down

0 comments on commit c3992ff

Please sign in to comment.