diff --git a/smartsim/_core/entrypoints/dragon.py b/smartsim/_core/entrypoints/dragon.py index 2d1fd1c52..3c3b44104 100644 --- a/smartsim/_core/entrypoints/dragon.py +++ b/smartsim/_core/entrypoints/dragon.py @@ -55,7 +55,7 @@ def handle_signal(signo: int, _frame: t.Optional[FrameType]) -> None: if not signo: logger.info("Received signal with no signo") else: - logger.info(f"Received {signo}") + logger.info(f"Received signal {signo}") cleanup() @@ -63,8 +63,6 @@ def handle_signal(signo: int, _frame: t.Optional[FrameType]) -> None: Dragon server entrypoint script """ -DBPID: t.Optional[int] = None - def print_summary(network_interface: str, ip_address: str) -> None: zmq_config = {"interface": network_interface, "address": ip_address} @@ -86,10 +84,9 @@ def print_summary(network_interface: str, ip_address: str) -> None: def run( + zmq_context: "zmq.Context[t.Any]", dragon_head_address: str, dragon_pid: int, - zmq_context: zmq.Context[t.Any], - zmq_authenticator: zmq.auth.thread.ThreadAuthenticator, ) -> None: logger.debug(f"Opening socket {dragon_head_address}") @@ -98,9 +95,7 @@ def run( zmq_context.setsockopt(zmq.REQ_CORRELATE, 1) zmq_context.setsockopt(zmq.REQ_RELAXED, 1) - dragon_head_socket, zmq_authenticator = dragonSockets.get_secure_socket( - context, zmq.REP, True, zmq_authenticator - ) + dragon_head_socket = dragonSockets.get_secure_socket(zmq_context, zmq.REP, True) dragon_head_socket.bind(dragon_head_address) dragon_backend = DragonBackend(pid=dragon_pid) @@ -133,7 +128,7 @@ def run( break -def main(args: argparse.Namespace, zmq_context: zmq.Context[t.Any]) -> int: +def main(args: argparse.Namespace) -> int: if_config = get_best_interface_and_address() interface = if_config.interface address = if_config.address @@ -142,39 +137,50 @@ def main(args: argparse.Namespace, zmq_context: zmq.Context[t.Any]) -> int: dragon_head_address = f"tcp://{address}" if args.launching_address: + zmq_context = zmq.Context() + if str(args.launching_address).split(":", maxsplit=1)[0] == dragon_head_address: address = "localhost" dragon_head_address = "tcp://localhost:5555" else: dragon_head_address += ":5555" - launcher_socket, authenticator = dragonSockets.get_secure_socket( - context, zmq.REQ, False - ) + zmq_authenticator = dragonSockets.get_authenticator(zmq_context) + + logger.debug("Getting launcher socket") + launcher_socket = dragonSockets.get_secure_socket(zmq_context, zmq.REQ, False) + + logger.debug(f"Connecting launcher socket to: {args.launching_address}") launcher_socket.connect(args.launching_address) client = dragonSockets.as_client(launcher_socket) + logger.debug( + f"Sending bootstrap request to launcher_socket with {dragon_head_address}" + ) client.send(DragonBootstrapRequest(address=dragon_head_address)) response = client.recv() + + logger.debug(f"Received bootstrap response: {response}") if not isinstance(response, DragonBootstrapResponse): raise ValueError( "Could not receive connection confirmation from launcher. Aborting." ) print_summary(interface, dragon_head_address) + try: + logger.debug("Executing event loop") run( + zmq_context=zmq_context, dragon_head_address=dragon_head_address, dragon_pid=response.dragon_pid, - zmq_context=zmq_context, - zmq_authenticator=authenticator, ) except Exception as e: logger.error(f"Dragon server failed with {e}", exc_info=True) return os.EX_SOFTWARE finally: - if authenticator.is_alive(): - authenticator.stop() + if zmq_authenticator is not None and zmq_authenticator.is_alive(): + zmq_authenticator.stop() logger.info("Shutting down! Bye bye!") return 0 @@ -210,6 +216,4 @@ def cleanup() -> None: for sig in SIGNALS: signal.signal(sig, handle_signal) - context = zmq.Context() - - sys.exit(main(args_, context)) + sys.exit(main(args_)) diff --git a/smartsim/_core/launcher/dragon/dragonBackend.py b/smartsim/_core/launcher/dragon/dragonBackend.py index f04f9e39e..278e3c9be 100644 --- a/smartsim/_core/launcher/dragon/dragonBackend.py +++ b/smartsim/_core/launcher/dragon/dragonBackend.py @@ -23,7 +23,6 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - import collections import functools import typing as t @@ -33,12 +32,10 @@ # pylint: disable=import-error # isort: off from dragon.infrastructure.policy import Policy -from dragon.native.process import Process, TemplateProcess +from dragon.native.process import Process, ProcessTemplate from dragon.native.process_group import ( ProcessGroup, DragonProcessGroupError, - Error, - Running, ) from dragon.native.machine import System, Node @@ -61,8 +58,8 @@ from smartsim._core.utils.helpers import create_short_id_str from smartsim.status import TERMINAL_STATUSES, SmartSimStatus -DRG_ERROR_STATUS = str(Error()) -DRG_RUNNING_STATUS = str(Running()) +DRG_ERROR_STATUS = "Error" +DRG_RUNNING_STATUS = "Running" @dataclass @@ -206,7 +203,7 @@ def update(self) -> None: local_policy = Policy( placement=Policy.Placement.HOST_NAME, host_name=node_name ) - tmp_proc = TemplateProcess( + tmp_proc = ProcessTemplate( target=request.exe, args=request.exe_args, cwd=request.path, diff --git a/smartsim/_core/launcher/dragon/dragonLauncher.py b/smartsim/_core/launcher/dragon/dragonLauncher.py index d6f2d5901..f84c2f593 100644 --- a/smartsim/_core/launcher/dragon/dragonLauncher.py +++ b/smartsim/_core/launcher/dragon/dragonLauncher.py @@ -76,7 +76,6 @@ _SchemaT = t.TypeVar("_SchemaT", bound=t.Union[DragonRequest, DragonResponse]) DRG_LOCK = RLock() -DRG_CTX = zmq.Context() class DragonLauncher(WLMLauncher): @@ -92,12 +91,10 @@ class DragonLauncher(WLMLauncher): def __init__(self) -> None: super().__init__() - self._context = DRG_CTX + self._context = zmq.Context() self._timeout = CONFIG.dragon_server_timeout self._reconnect_timeout = CONFIG.dragon_server_reconnect_timeout self._startup_timeout = CONFIG.dragon_server_startup_timeout - self._context.setsockopt(zmq.SNDTIMEO, value=self._timeout) - self._context.setsockopt(zmq.RCVTIMEO, value=self._timeout) self._dragon_head_socket: t.Optional[zmq.Socket[t.Any]] = None self._dragon_head_process: t.Optional[subprocess.Popen[bytes]] = None # Returned by dragon head, useful if shutdown is to be requested @@ -105,13 +102,15 @@ def __init__(self) -> None: self._dragon_head_pid: t.Optional[int] = None self._authenticator: t.Optional[zmq.auth.thread.ThreadAuthenticator] = None + self._set_timeout(self._timeout) + @property def is_connected(self) -> bool: return self._dragon_head_socket is not None def _handshake(self, address: str) -> None: - self._dragon_head_socket, self._authenticator = dragonSockets.get_secure_socket( - self._context, zmq.REQ, False, self._authenticator + self._dragon_head_socket = dragonSockets.get_secure_socket( + self._context, zmq.REQ, False ) self._dragon_head_socket.connect(address) try: @@ -149,6 +148,8 @@ def _connect_to_dragon(self, path: t.Union[str, "os.PathLike[str]"]) -> None: path = _resolve_dragon_path(path) dragon_config_log = path / CONFIG.dragon_log_filename + self._authenticator = dragonSockets.get_authenticator(self._context) + if dragon_config_log.is_file(): dragon_confs = self._parse_launched_dragon_server_info_from_files( [dragon_config_log] @@ -180,14 +181,19 @@ def _connect_to_dragon(self, path: t.Union[str, "os.PathLike[str]"]) -> None: "smartsim._core.entrypoints.dragon", ] + # Optionally configure the dragon host for debug output + if CONFIG.log_level in ["debug", "developer"]: + cmd.insert(1, "DEBUG") + cmd.insert(1, "-l") + address = get_best_interface_and_address().address socket_addr = "" launcher_socket: t.Optional[zmq.Socket[t.Any]] = None if address is not None: self._set_timeout(self._startup_timeout) - launcher_socket, self._authenticator = dragonSockets.get_secure_socket( - self._context, zmq.REP, True, self._authenticator + launcher_socket = dragonSockets.get_secure_socket( + self._context, zmq.REP, True ) # find first available port >= 5995 @@ -206,6 +212,8 @@ def _connect_to_dragon(self, path: t.Union[str, "os.PathLike[str]"]) -> None: ) as dragon_err: current_env = os.environ.copy() current_env.update({"PYTHONUNBUFFERED": "1"}) + logger.debug(f"Starting Dragon environment: {' '.join(cmd)}") + # pylint: disable-next=consider-using-with self._dragon_head_process = subprocess.Popen( args=cmd, @@ -258,6 +266,7 @@ def log_dragon_outputs() -> None: server_socket = self._dragon_head_socket server_process_pid = self._dragon_head_process.pid + # avoid registering the same cleanup more than once if server_socket is not None and self._dragon_head_process is not None: atexit.register( _dragon_cleanup, @@ -271,11 +280,15 @@ def log_dragon_outputs() -> None: raise LauncherError("Could not receive address of Dragon head process") def cleanup(self) -> None: + logger.debug("Starting Dragon launcher cleanup") _dragon_cleanup( server_socket=self._dragon_head_socket, server_process_pid=self._dragon_head_pid, server_authenticator=self._authenticator, ) + self._dragon_head_socket = None + self._dragon_head_pid = 0 + self._authenticator = None # RunSettings types supported by this launcher @property @@ -492,27 +505,34 @@ def _dragon_cleanup( try: if server_socket is not None: + print("Sending shutdown request to dragon environment") DragonLauncher.send_req_with_socket(server_socket, DragonShutdownRequest()) except zmq.error.ZMQError as e: # Can't use the logger as I/O file may be closed print("Could not send shutdown request to dragon server") print(f"ZMQ error: {e}", flush=True) finally: + print("Sending shutdown request is complete") time.sleep(1) try: if server_process_pid and psutil.pid_exists(server_process_pid): + print("Sending SIGINT to dragon server") os.kill(server_process_pid, signal.SIGINT) - print("Sent SIGINT to dragon server") except ProcessLookupError: # Can't use the logger as I/O file may be closed print("Dragon server is not running.", flush=True) + finally: + print("Dragon server process shutdown is complete") try: - if server_authenticator is not None: + if server_authenticator is not None and server_authenticator.is_alive(): + print("Shutting down ZMQ authenticator") server_authenticator.stop() except Exception: print("Authenticator shutdown error") + finally: + print("Authenticator shutdown is complete") def _resolve_dragon_path(fallback: t.Union[str, "os.PathLike[str]"]) -> Path: diff --git a/smartsim/_core/launcher/dragon/dragonSockets.py b/smartsim/_core/launcher/dragon/dragonSockets.py index 83023b917..ca693428d 100644 --- a/smartsim/_core/launcher/dragon/dragonSockets.py +++ b/smartsim/_core/launcher/dragon/dragonSockets.py @@ -33,12 +33,16 @@ from smartsim._core.schemas import dragonResponses as _dragonResponses from smartsim._core.schemas import utils as _utils from smartsim._core.utils.security import KeyManager +from smartsim.log import get_logger if t.TYPE_CHECKING: from zmq import Context from zmq.sugar.socket import Socket +logger = get_logger(__name__) + + def as_server( socket: "Socket[t.Any]", ) -> _utils.SocketSchemaTranslator[ @@ -62,11 +66,10 @@ def as_client( def get_secure_socket( - context: "Context[t.Any]", + context: "zmq.Context[t.Any]", socket_type: int, is_server: bool, - authenticator: t.Optional[zmq.auth.thread.ThreadAuthenticator] = None, -) -> "t.Tuple[Socket[t.Any], zmq.auth.thread.ThreadAuthenticator]": +) -> "Socket[t.Any]": """Create secured socket that consumes & produces encrypted messages :param context: ZMQ context object @@ -76,29 +79,18 @@ def get_secure_socket( :param is_server: Pass `True` to secure the socket as server. Pass `False` to secure the socket as a client. :type is_server: bool - :param authenticator: (optional) An existing authenticator that will be used - to authenticate secure communications. - :type authenticator: Optional[zmq.auth.thread.ThreadAuthenticator] - :returns: the secured socket prepared for sending encrypted messages and - an active authenticator if one was not supplied as a parameter. - :rtype: Tuple[zmq.Socket, zmq.auth.thread.ThreadAuthenticator]""" + :returns: the secured socket prepared for sending encrypted messages + :rtype: zmq.Socket""" config = get_config() - socket = context.socket(socket_type) + socket: "Socket[t.Any]" = context.socket(socket_type) key_manager = KeyManager(config, as_server=is_server, as_client=not is_server) server_keys, client_keys = key_manager.get_keys() - - # start an auth thread to provide encryption services on the socket - if authenticator is None: - authenticator = zmq.auth.thread.ThreadAuthenticator(context) - - # allow all keys in the client key directory to connect - authenticator.configure_curve(domain="*", location=key_manager.client_keys_dir) - - if not authenticator.is_alive(): - authenticator.start() + logger.debug(f"Applying keys to socket: {server_keys}, {client_keys}") if is_server: + logger.debug("Configuring socket as server") + # configure the server keys on the socket socket.curve_secretkey = server_keys.private socket.curve_publickey = client_keys.public @@ -110,5 +102,32 @@ def get_secure_socket( # set the server public key for decrypting incoming messages socket.curve_serverkey = server_keys.public + return socket + + +def get_authenticator( + context: "zmq.Context[t.Any]", +) -> "zmq.auth.thread.ThreadAuthenticator": + """Create an authenticator to handle encryption of ZMQ communications + + :param context: ZMQ context object + :type context: zmq.Context + :returns: the activated `Authenticator` + :rtype: zmq.auth.thread.ThreadAuthenticator""" + config = get_config() + + key_manager = KeyManager(config, as_client=True) + server_keys, client_keys = key_manager.get_keys() + logger.debug(f"Applying keys to authenticator: {server_keys}, {client_keys}") + + authenticator = zmq.auth.thread.ThreadAuthenticator(context) + + # allow all keys in the client key directory to connect + logger.debug(f"Securing with client keys in {key_manager.client_keys_dir}") + authenticator.configure_curve(domain="*", location=key_manager.client_keys_dir) + + if not authenticator.is_alive(): + logger.debug("Starting authenticator") + authenticator.start() - return socket, authenticator + return authenticator diff --git a/smartsim/_core/utils/security.py b/smartsim/_core/utils/security.py index 91a5bd1ab..f6607a2e5 100644 --- a/smartsim/_core/utils/security.py +++ b/smartsim/_core/utils/security.py @@ -35,17 +35,33 @@ import zmq.auth from smartsim._core.config.config import Config +from smartsim.log import get_logger + +logger = get_logger(__name__) class _KeyPermissions(IntEnum): """Permissions used by KeyManager""" - OWNER_RW = stat.S_IRUSR | stat.S_IWUSR - """Permissions allowing owner to r/w""" - OWNER_FULL = stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR - """permissions allowing owner to r/w/x""" - WORLD_R = stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR | stat.S_IROTH | stat.S_IRGRP - """permissions allowing world to read""" + PRIVATE_KEY = stat.S_IRUSR | stat.S_IWUSR + """Permissions only allowing an owner to read and write the file""" + PUBLIC_KEY = stat.S_IRUSR | stat.S_IWUSR | stat.S_IROTH | stat.S_IRGRP + """Permissions allowing an owner, others, and the group to read a file""" + + PRIVATE_DIR = ( + stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR | stat.S_IXOTH | stat.S_IXGRP + ) + """Permissions allowing only owners to read, write and traverse a directory""" + PUBLIC_DIR = ( + stat.S_IRUSR + | stat.S_IWUSR + | stat.S_IXUSR + | stat.S_IROTH + | stat.S_IXOTH + | stat.S_IRGRP + | stat.S_IXGRP + ) + """Permissions allowing non-owners to traverse a directory""" @dataclasses.dataclass(frozen=True) @@ -183,10 +199,14 @@ def create_directories(self) -> None: the public and private key pairs for servers & clients""" for locator in [self._server_locator, self._client_locator]: if not locator.public_dir.exists(): - locator.public_dir.mkdir(parents=True, mode=_KeyPermissions.WORLD_R) + permission = _KeyPermissions.PUBLIC_DIR + logger.debug(f"Creating key dir: {locator.public_dir}, {permission}") + locator.public_dir.mkdir(parents=True, mode=permission) if not locator.private_dir.exists(): - locator.private_dir.mkdir(parents=True, mode=_KeyPermissions.OWNER_FULL) + permission = _KeyPermissions.PRIVATE_DIR + logger.debug(f"Creating key dir: {locator.private_dir}, {permission}") + locator.private_dir.mkdir(parents=True, mode=permission) @classmethod def _load_keypair(cls, locator: _KeyLocator, in_context: bool) -> KeyPair: @@ -203,7 +223,15 @@ def _load_keypair(cls, locator: _KeyLocator, in_context: bool) -> KeyPair: """ # private keys contain public & private key parts key_path = locator.private if in_context else locator.public - pub_key, priv_key = zmq.auth.load_certificate(key_path) + + pub_key: bytes = b"" + priv_key: t.Optional[bytes] = b"" + + if key_path.exists(): + logger.debug(f"Existing key files located at {key_path}") + pub_key, priv_key = zmq.auth.load_certificate(key_path) + else: + logger.debug(f"No key files found at {key_path}") # avoid a `None` value in the private key when it isn't loaded return KeyPair(pub_key, priv_key or b"") @@ -221,7 +249,7 @@ def _load_keys(self) -> t.Tuple[KeyPair, KeyPair]: return server_keys, client_keys except (ValueError, OSError): # expected if no keys could be loaded from disk - ... + logger.warning("Loading key pairs failed.", exc_info=True) return KeyPair(), KeyPair() @@ -235,6 +263,7 @@ def _move_public_key(cls, locator: _KeyLocator) -> None: :type locator: KeyLocator""" new_path = locator.private.with_suffix(locator.public.suffix) if new_path != locator.public: + logger.debug(f"Moving key file from {locator.public} to {new_path}") new_path.rename(locator.public) def _create_keys(self) -> None: @@ -247,8 +276,8 @@ def _create_keys(self) -> None: self._move_public_key(locator) # and ensure correct r/w/x permissions on each file. - locator.private.chmod(_KeyPermissions.OWNER_RW) - locator.public.chmod(_KeyPermissions.WORLD_R) + locator.private.chmod(_KeyPermissions.PRIVATE_KEY) + locator.public.chmod(_KeyPermissions.PUBLIC_KEY) def get_keys(self, create: bool = True) -> t.Tuple[KeyPair, KeyPair]: """Use ZMQ auth to generate a public/private key pair for the server @@ -259,6 +288,7 @@ def get_keys(self, create: bool = True) -> t.Tuple[KeyPair, KeyPair]: :returns: 2-tuple of `KeyPair` (server_keypair, client_keypair) :rtype: Tuple[KeyPair, KeyPair] """ + logger.debug(f"Loading keys, creation {'is' if create else 'not'} allowed") server_keys, client_keys = self._load_keys() # check if we received "empty keys" @@ -267,6 +297,7 @@ def get_keys(self, create: bool = True) -> t.Tuple[KeyPair, KeyPair]: if not create: # if directed not to create new keys, return "empty keys" + logger.debug("Returning empty key pairs") return KeyPair(), KeyPair() self.create_directories() diff --git a/tests/test_dragon_launcher.py b/tests/test_dragon_launcher.py index cd757e302..22191351a 100644 --- a/tests/test_dragon_launcher.py +++ b/tests/test_dragon_launcher.py @@ -24,8 +24,10 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import logging import multiprocessing as mp import os +import sys import typing as t import pytest @@ -33,7 +35,10 @@ from smartsim._core.config.config import get_config from smartsim._core.launcher.dragon.dragonLauncher import DragonLauncher -from smartsim._core.launcher.dragon.dragonSockets import get_secure_socket +from smartsim._core.launcher.dragon.dragonSockets import ( + get_authenticator, + get_secure_socket, +) from smartsim._core.schemas.dragonRequests import DragonBootstrapRequest from smartsim._core.schemas.dragonResponses import DragonHandshakeResponse from smartsim._core.utils.network import IFConfig, find_free_port @@ -43,6 +48,9 @@ pytestmark = pytest.mark.group_a +is_mac = sys.platform == "darwin" + + class MockPopen: def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: ... @@ -108,19 +116,18 @@ def is_alive(self) -> bool: def mock_dragon_env(test_dir, *args, **kwargs): """Create a mock dragon environment that can talk to the launcher through ZMQ""" + logger = logging.getLogger(__name__) + logging.basicConfig(level=logging.DEBUG) + try: - context = zmq.Context() addr = "127.0.0.1" callback_port = kwargs["port"] head_port = find_free_port(start=callback_port + 1) + context = zmq.Context.instance() + authenticator = get_authenticator(context) - callback_socket, dragon_authenticator = get_secure_socket( - context, zmq.REQ, False - ) - - dragon_head_socket, dragon_authenticator = get_secure_socket( - context, zmq.REP, True, dragon_authenticator - ) + callback_socket = get_secure_socket(context, zmq.REQ, False) + dragon_head_socket = get_secure_socket(context, zmq.REP, True) full_addr = f"{addr}:{callback_port}" callback_socket.connect(f"tcp://{full_addr}") @@ -132,26 +139,42 @@ def mock_dragon_env(test_dir, *args, **kwargs): msg_sent = False while not msg_sent: + logger.info("Sending bootstrap request to callback socket") callback_socket.send_string("bootstrap|" + req.json()) # hold until bootstrap response is received + logger.info("Receiving bootstrap response from callback socket") _ = callback_socket.recv() msg_sent = True hand_shaken = False while not hand_shaken: # other side should set up a socket and push me a `HandshakeRequest` + logger.info("Receiving handshake request through dragon head socket") _ = dragon_head_socket.recv() # acknowledge handshake success w/DragonHandshakeResponse + logger.info("Sending handshake response through dragon head socket") handshake_ack = DragonHandshakeResponse(dragon_pid=os.getpid()) dragon_head_socket.send_string(f"handshake|{handshake_ack.json()}") hand_shaken = True + + shutting_down = False + while not shutting_down: + logger.info("Waiting for shutdown request through dragon head socket") + # any incoming request at this point in test is my shutdown... + try: + message = dragon_head_socket.recv() + logger.info(f"Received final message {message}") + finally: + shutting_down = True + try: + logger.info("Handshake complete. Shutting down mock dragon env.") + authenticator.stop() + finally: + logger.info("Dragon mock env exiting...") + except Exception as ex: - print(f"exception occurred while configuring mock handshaker: {ex}") - finally: - dragon_authenticator.stop() - callback_socket.close() - dragon_head_socket.close() + logger.info(f"exception occurred while configuring mock handshaker: {ex}") def test_dragon_connect_bind_address(monkeypatch: pytest.MonkeyPatch, test_dir: str): @@ -202,7 +225,6 @@ def test_secure_socket_authenticator_setup( ): """Ensure the authenticator created by the secure socket factory method is fully configured and started when returned to a client""" - context = zmq.Context() with monkeypatch.context() as ctx: # look at test dir for dragon config @@ -210,7 +232,7 @@ def test_secure_socket_authenticator_setup( # avoid starting a real authenticator thread ctx.setattr("zmq.auth.thread.ThreadAuthenticator", MockAuthenticator) - _, authenticator = get_secure_socket(context, socket_type, is_server=is_server) + authenticator = get_authenticator(zmq.Context.instance()) km = KeyManager(get_config(), as_server=is_server) @@ -220,7 +242,7 @@ def test_secure_socket_authenticator_setup( assert authenticator.num_configure_curves > 0 # ensure authenticator was started assert authenticator.num_starts > 0 - assert authenticator.context == context + assert authenticator.context == zmq.Context.instance() # ensure authenticator will accept any secured connection assert authenticator.cfg_kwargs.get("domain", "") == "*" # ensure authenticator is using the expected set of keys @@ -239,7 +261,6 @@ def test_secure_socket_setup( ): """Ensure the authenticator created by the secure socket factory method is fully configured and started when returned to a client""" - context = zmq.Context() with monkeypatch.context() as ctx: # look at test dir for dragon config @@ -247,7 +268,9 @@ def test_secure_socket_setup( # avoid starting a real authenticator thread ctx.setattr("zmq.auth.thread.ThreadAuthenticator", MockAuthenticator) - socket, _ = get_secure_socket(context, zmq.REP, as_server) + context = zmq.Context.instance() + + socket = get_secure_socket(context, zmq.REP, as_server) # verify the socket is correctly configured to use curve authentication assert bool(socket.CURVE_SERVER) == as_server @@ -263,18 +286,16 @@ def test_secure_socket(test_dir: str, monkeypatch: pytest.MonkeyPatch): with monkeypatch.context() as ctx: # make sure we don't touch "real keys" during a test ctx.setenv("SMARTSIM_KEY_PATH", test_dir) - - context = zmq.Context() - server, authenticator = get_secure_socket(context, zmq.REP, True) + context = zmq.Context.instance() + authenticator = get_authenticator(context) + server = get_secure_socket(context, zmq.REP, True) ip, port = "127.0.0.1", find_free_port(start=9999) try: server.bind(f"tcp://*:{port}") - client, authenticator = get_secure_socket( - context, zmq.REQ, False, authenticator - ) + client = get_secure_socket(context, zmq.REQ, False) client.connect(f"tcp://{ip}:{port}") @@ -293,43 +314,45 @@ def test_secure_socket(test_dir: str, monkeypatch: pytest.MonkeyPatch): server.close() -# def test_dragon_launcher_handshake(monkeypatch: pytest.MonkeyPatch, test_dir: str): -# """Test that a real handshake between a launcher & dragon environment -# completes successfully using secure sockets""" -# context = zmq.Context() -# addr = "127.0.0.1" -# bootstrap_port = find_free_port(start=5995) - -# with monkeypatch.context() as ctx: -# # make sure we don't touch "real keys" during a test -# ctx.setenv("SMARTSIM_KEY_PATH", test_dir) - -# # look at test dir for dragon config -# ctx.setenv("SMARTSIM_DRAGON_SERVER_PATH", test_dir) -# # avoid finding real interface since we may not be on a super -# ctx.setattr( -# "smartsim._core.launcher.dragon.dragonLauncher.get_best_interface_and_address", -# lambda: IFConfig("faux_interface", addr), -# ) - -# # start up a faux dragon env that knows how to do the handshake process -# # but uses secure sockets for all communication. -# mock_dragon = mp.Process( -# target=mock_dragon_env, -# daemon=True, -# kwargs={"port": bootstrap_port, "test_dir": test_dir}, -# ) - -# def fn(*args, **kwargs): -# mock_dragon.start() -# return mock_dragon - -# ctx.setattr("subprocess.Popen", fn) - -# launcher = DragonLauncher() - -# try: -# # connect executes the complete handshake and raises an exception if comms fails -# launcher.connect_to_dragon(test_dir) -# finally: -# launcher.cleanup() +@pytest.mark.skipif(is_mac, reason="unsupported on MacOSX") +def test_dragon_launcher_handshake(monkeypatch: pytest.MonkeyPatch, test_dir: str): + """Test that a real handshake between a launcher & dragon environment + completes successfully using secure sockets""" + context = zmq.Context() + addr = "127.0.0.1" + bootstrap_port = find_free_port(start=5995) + + with monkeypatch.context() as ctx: + # make sure we don't touch "real keys" during a test + ctx.setenv("SMARTSIM_KEY_PATH", test_dir) + + # look at test dir for dragon config + ctx.setenv("SMARTSIM_DRAGON_SERVER_PATH", test_dir) + # avoid finding real interface since we may not be on a super + ctx.setattr( + "smartsim._core.launcher.dragon.dragonLauncher.get_best_interface_and_address", + lambda: IFConfig("faux_interface", addr), + ) + + # start up a faux dragon env that knows how to do the handshake process + # but uses secure sockets for all communication. + mock_dragon = mp.Process( + target=mock_dragon_env, + daemon=True, + kwargs={"port": bootstrap_port, "test_dir": test_dir}, + ) + + def fn(*args, **kwargs): + mock_dragon.start() + return mock_dragon + + ctx.setattr("subprocess.Popen", fn) + + launcher = DragonLauncher() + + try: + # connect executes the complete handshake and raises an exception if comms fails + launcher.connect_to_dragon(test_dir) + finally: + launcher.cleanup() + ... diff --git a/tests/utils/test_security.py b/tests/utils/test_security.py index 79170ab7d..1a7a9586b 100644 --- a/tests/utils/test_security.py +++ b/tests/utils/test_security.py @@ -209,26 +209,26 @@ def test_key_manager_applied_permissions( s_pub_stat = km._server_locator.public_dir.stat() c_pub_stat = km._client_locator.public_dir.stat() - assert stat.S_IMODE(s_pub_stat.st_mode) == _KeyPermissions.WORLD_R - assert stat.S_IMODE(c_pub_stat.st_mode) == _KeyPermissions.WORLD_R + assert stat.S_IMODE(s_pub_stat.st_mode) == _KeyPermissions.PUBLIC_DIR + assert stat.S_IMODE(c_pub_stat.st_mode) == _KeyPermissions.PUBLIC_DIR # ensure private dirs are open only to owner s_priv_stat = km._server_locator.private_dir.stat() c_priv_stat = km._client_locator.private_dir.stat() - assert stat.S_IMODE(s_priv_stat.st_mode) == _KeyPermissions.OWNER_FULL - assert stat.S_IMODE(c_priv_stat.st_mode) == _KeyPermissions.OWNER_FULL + assert stat.S_IMODE(s_priv_stat.st_mode) == _KeyPermissions.PRIVATE_DIR + assert stat.S_IMODE(c_priv_stat.st_mode) == _KeyPermissions.PRIVATE_DIR # ensure public files are open for reading by others s_pub_stat = km._server_locator.public.stat() c_pub_stat = km._client_locator.public.stat() - assert stat.S_IMODE(s_pub_stat.st_mode) == _KeyPermissions.WORLD_R - assert stat.S_IMODE(c_pub_stat.st_mode) == _KeyPermissions.WORLD_R + assert stat.S_IMODE(s_pub_stat.st_mode) == _KeyPermissions.PUBLIC_KEY + assert stat.S_IMODE(c_pub_stat.st_mode) == _KeyPermissions.PUBLIC_KEY # ensure private files are read-only for owner s_priv_stat = km._server_locator.private.stat() c_priv_stat = km._client_locator.private.stat() - assert stat.S_IMODE(s_priv_stat.st_mode) == _KeyPermissions.OWNER_RW - assert stat.S_IMODE(c_priv_stat.st_mode) == _KeyPermissions.OWNER_RW + assert stat.S_IMODE(s_priv_stat.st_mode) == _KeyPermissions.PRIVATE_KEY + assert stat.S_IMODE(c_priv_stat.st_mode) == _KeyPermissions.PRIVATE_KEY