Skip to content

Commit

Permalink
add initial tests for dragon entrypoint methods
Browse files Browse the repository at this point in the history
  • Loading branch information
ankona committed Apr 17, 2024
1 parent 472715c commit 175c2ca
Show file tree
Hide file tree
Showing 2 changed files with 197 additions and 19 deletions.
28 changes: 18 additions & 10 deletions smartsim/_core/entrypoints/dragon.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ class DragonEntrypointArgs:
launching_address: str
interface: str

def handle_signal(signo: int, _frame: t.Optional[FrameType]) -> None:

def handle_signal(signo: int, _frame: t.Optional[FrameType] = None) -> None:
if not signo:
logger.info("Received signal with no signo")
else:
Expand Down Expand Up @@ -142,7 +142,7 @@ def run(
break


def execute_entrypoint(args: DragonEntrypointArgs) -> int:
def execute_entrypoint(args: DragonEntrypointArgs) -> int:
if_config = get_best_interface_and_address()
interface = if_config.interface
address = if_config.address
Expand Down Expand Up @@ -232,32 +232,40 @@ def parse_arguments(args: t.List[str]) -> DragonEntrypointArgs:
"+launching_address",
type=str,
help="Address of launching process if a ZMQ connection can be established",
required=False,
required=True,
)
parser.add_argument(
"+interface", type=str, help="Network Interface name", required=False
"+interface",
type=str,
help="Network Interface name",
required=False,
)
args_ = parser.parse_args(args)


if not args_.launching_address:
raise ValueError("Empty launching address supplied.")

return DragonEntrypointArgs(args_.launching_address, args_.interface)


def main(args_: t.List[str]):
"""Execute the dragon entrypoint as a module"""
os.environ["PYTHONUNBUFFERED"] = "1"
logger.info("Dragon server started")

args = parse_arguments(args_)
register_signal_handlers()

try:
return_code = execute_entrypoint(args)
return return_code
except:
logger.error("An unexpected error occurred in the Dragon entrypoint.", exc_info=True)
except Exception:
logger.error(
"An unexpected error occurred in the Dragon entrypoint.", exc_info=True
)
finally:
cleanup()

return -1


Expand Down
188 changes: 179 additions & 9 deletions tests/on_wlm/test_dragon_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,18 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import os
import pathlib
import pytest
import typing as t

import pytest

import smartsim._core.entrypoints.dragon as drg
from smartsim._core.entrypoints.dragon import (
cleanup,
get_log_path,
handle_signal,
main,
parse_arguments,
print_summary,
register_signal_handlers,
remove_config_log,
)
Expand All @@ -42,8 +46,8 @@
def mock_argv() -> t.List[str]:
"""Fixture for returning valid arguments to the entrypoint"""
return ["+launching_address", "mock-addr", "+interface", "mock-interface"]


def test_file_removal(test_dir: str, monkeypatch: pytest.MonkeyPatch):
"""Verify that the log file is removed when expected"""
mock_file_name = "mocked_file_name.txt"
Expand Down Expand Up @@ -81,7 +85,9 @@ def test_file_removal_on_bad_path(test_dir: str, monkeypatch: pytest.MonkeyPatch
assert False


def test_dragon_failure(mock_argv: t.List[str], test_dir: str, monkeypatch: pytest.MonkeyPatch):
def test_dragon_failure(
mock_argv: t.List[str], test_dir: str, monkeypatch: pytest.MonkeyPatch
):
"""Verify that the expected cleanup actions are taken when the dragon
entrypoint exits"""
mock_file_name = "mocked_file_name.txt"
Expand All @@ -93,7 +99,7 @@ def test_dragon_failure(mock_argv: t.List[str], test_dir: str, monkeypatch: pyte
ctx.setattr(
"smartsim._core.entrypoints.dragon.get_log_path", lambda: str(expected_path)
)

def raiser(args_) -> int:
raise Exception("Something bad...")

Expand All @@ -102,11 +108,13 @@ def raiser(args_) -> int:

return_code = main(mock_argv)

# ensure our exception error code is returned
# ensure our exception error code is returned
assert return_code == -1


def test_dragon_main(mock_argv: t.List[str], test_dir: str, monkeypatch: pytest.MonkeyPatch):
def test_dragon_main(
mock_argv: t.List[str], test_dir: str, monkeypatch: pytest.MonkeyPatch
):
"""Verify that the expected startup & cleanup actions are taken when the dragon
entrypoint exits"""
mock_file_name = "mocked_file_name.txt"
Expand All @@ -119,13 +127,175 @@ def test_dragon_main(mock_argv: t.List[str], test_dir: str, monkeypatch: pytest.
"smartsim._core.entrypoints.dragon.get_log_path", lambda: str(expected_path)
)
# we don't need to execute the actual entrypoint...
ctx.setattr("smartsim._core.entrypoints.dragon.execute_entrypoint", lambda args_: 0)
ctx.setattr(
"smartsim._core.entrypoints.dragon.execute_entrypoint", lambda args_: 0
)

return_code = main(mock_argv)

# execute_entrypoint should return 0 from our mock
# execute_entrypoint should return 0 from our mock
assert return_code == 0
# the cleanup should remove our config file
assert not expected_path.exists(), "Dragon config file was not removed!"
# the environment should be set as expected
assert os.environ.get("PYTHONUNBUFFERED", None) == "1"


@pytest.mark.parametrize(
"signal_num",
[
pytest.param(0, id="non-truthy signal"),
pytest.param(-1, id="negative signal"),
pytest.param(1, id="positive signal"),
],
)
def test_signal_handler(signal_num: int, monkeypatch: pytest.MonkeyPatch):
"""Verify that the signal handler performs expected actions"""
counter: int = 0

def increment_counter(*args, **kwargs):
nonlocal counter
counter += 1

with monkeypatch.context() as ctx:
ctx.setattr("smartsim._core.entrypoints.dragon.cleanup", increment_counter)
ctx.setattr("smartsim._core.entrypoints.dragon.logger.info", increment_counter)

handle_signal(signal_num, None)

# show that we log informational message & do cleanup (take 2 actions)
assert counter == 2


def test_log_path(monkeypatch: pytest.MonkeyPatch):
"""Verify that the log path is loaded & returned as expected"""

with monkeypatch.context() as ctx:
expected_filename = "foo.log"
ctx.setattr(
"smartsim._core.config.config.Config.dragon_log_filename", expected_filename
)

log_path = get_log_path()

assert expected_filename in log_path


def test_summary(test_dir: str, monkeypatch: pytest.MonkeyPatch):
"""Verify that the summary is written to expected location w/expected information"""

with monkeypatch.context() as ctx:
expected_ip = "127.0.0.111"
expected_interface = "mock_int0"
summary_file = pathlib.Path(test_dir) / "foo.log"
expected_hostname = "mockhostname"

ctx.setattr(
"smartsim._core.config.config.Config.dragon_log_filename",
str(summary_file),
)
ctx.setattr(
"smartsim._core.entrypoints.dragon.socket.gethostname",
lambda: expected_hostname,
)

print_summary(expected_interface, expected_ip)

summary = summary_file.read_text()

assert expected_ip in summary
assert expected_interface in summary
assert expected_hostname in summary


def test_cleanup(monkeypatch: pytest.MonkeyPatch):
"""Verify that the cleanup function attempts to remove the log file"""
counter: int = 0

def increment_counter(*args, **kwargs):
nonlocal counter
counter += 1

with monkeypatch.context() as ctx:
ctx.setattr(
"smartsim._core.entrypoints.dragon.remove_config_log", increment_counter
)

# ensure shutdown isn't initially true
assert not drg.SHUTDOWN_INITIATED

cleanup()

# show that cleanup removes config
assert counter == 1
# show that cleanup alters the flag to enable shutdown
assert drg.SHUTDOWN_INITIATED


def test_signal_handler_registration(test_dir: str, monkeypatch: pytest.MonkeyPatch):
"""Verify that signal handlers are registered for all expected signals"""
sig_nums: t.List[int] = []

def track_args(*args, **kwargs):
nonlocal sig_nums
sig_nums.append(args[0])

with monkeypatch.context() as ctx:
ctx.setattr("smartsim._core.entrypoints.dragon.signal.signal", track_args)

# ensure valid start point
assert not sig_nums

register_signal_handlers()

# ensure all expected handlers are registered
assert set(sig_nums) == set(drg.SIGNALS)


def test_arg_parser__no_args():
"""Verify arg parser fails when no args are not supplied"""
args_list = []

with pytest.raises(SystemExit) as ex:
# ensure that parser complains about missing required arguments
parse_arguments(args_list)


def test_arg_parser__invalid_launch_addr():
"""Verify arg parser fails with empty launch_address"""
addr_flag = "+launching_address"
addr_value = ""

args_list = [addr_flag, addr_value]

with pytest.raises(ValueError) as ex:
args = parse_arguments(args_list)


def test_arg_parser__required_only():
"""Verify arg parser succeeds when optional args are omitted"""
addr_flag = "+launching_address"
addr_value = "mock-address"

args_list = [addr_flag, addr_value]

args = parse_arguments(args_list)

assert args.launching_address == addr_value
assert not args.interface


def test_arg_parser__with_optionals():
"""Verify arg parser succeeds when optional args are included"""
addr_flag = "+launching_address"
addr_value = "mock-address"

interface_flag = "+interface"
interface_value = "mock-int"

args_list = [interface_flag, interface_value, addr_flag, addr_value]

args = parse_arguments(args_list)

assert args.launching_address == addr_value
assert args.interface == interface_value

0 comments on commit 175c2ca

Please sign in to comment.