Skip to content

Commit

Permalink
Decouple authenticator and socket creation (#542)
Browse files Browse the repository at this point in the history
1. ZMQ authenticators appear to have clashing inproc addresses when
using the `zmq.Context.instance()` factory method. Replaced as needed.
2. Updated underlying `Dragon` library version, which included a
breaking changing causing the swap from `TemplateProcess` to
`ProcessTemplate`
3. Fixed incomplete permission set on curve key files

[ committed by @ankona]
[ reviewed by @MattToast @al-rigazzi ]
  • Loading branch information
ankona authored Apr 15, 2024
1 parent 547e20a commit 7f6ecbe
Show file tree
Hide file tree
Showing 7 changed files with 236 additions and 142 deletions.
42 changes: 23 additions & 19 deletions smartsim/_core/entrypoints/dragon.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,14 @@ def handle_signal(signo: int, _frame: t.Optional[FrameType]) -> None:
if not signo:
logger.info("Received signal with no signo")
else:
logger.info(f"Received {signo}")
logger.info(f"Received signal {signo}")
cleanup()


"""
Dragon server entrypoint script
"""

DBPID: t.Optional[int] = None


def print_summary(network_interface: str, ip_address: str) -> None:
zmq_config = {"interface": network_interface, "address": ip_address}
Expand All @@ -86,10 +84,9 @@ def print_summary(network_interface: str, ip_address: str) -> None:


def run(
zmq_context: "zmq.Context[t.Any]",
dragon_head_address: str,
dragon_pid: int,
zmq_context: zmq.Context[t.Any],
zmq_authenticator: zmq.auth.thread.ThreadAuthenticator,
) -> None:
logger.debug(f"Opening socket {dragon_head_address}")

Expand All @@ -98,9 +95,7 @@ def run(
zmq_context.setsockopt(zmq.REQ_CORRELATE, 1)
zmq_context.setsockopt(zmq.REQ_RELAXED, 1)

dragon_head_socket, zmq_authenticator = dragonSockets.get_secure_socket(
context, zmq.REP, True, zmq_authenticator
)
dragon_head_socket = dragonSockets.get_secure_socket(zmq_context, zmq.REP, True)
dragon_head_socket.bind(dragon_head_address)
dragon_backend = DragonBackend(pid=dragon_pid)

Expand Down Expand Up @@ -133,7 +128,7 @@ def run(
break


def main(args: argparse.Namespace, zmq_context: zmq.Context[t.Any]) -> int:
def main(args: argparse.Namespace) -> int:
if_config = get_best_interface_and_address()
interface = if_config.interface
address = if_config.address
Expand All @@ -142,39 +137,50 @@ def main(args: argparse.Namespace, zmq_context: zmq.Context[t.Any]) -> int:
dragon_head_address = f"tcp://{address}"

if args.launching_address:
zmq_context = zmq.Context()

if str(args.launching_address).split(":", maxsplit=1)[0] == dragon_head_address:
address = "localhost"
dragon_head_address = "tcp://localhost:5555"
else:
dragon_head_address += ":5555"

launcher_socket, authenticator = dragonSockets.get_secure_socket(
context, zmq.REQ, False
)
zmq_authenticator = dragonSockets.get_authenticator(zmq_context)

logger.debug("Getting launcher socket")
launcher_socket = dragonSockets.get_secure_socket(zmq_context, zmq.REQ, False)

logger.debug(f"Connecting launcher socket to: {args.launching_address}")
launcher_socket.connect(args.launching_address)
client = dragonSockets.as_client(launcher_socket)

logger.debug(
f"Sending bootstrap request to launcher_socket with {dragon_head_address}"
)
client.send(DragonBootstrapRequest(address=dragon_head_address))
response = client.recv()

logger.debug(f"Received bootstrap response: {response}")
if not isinstance(response, DragonBootstrapResponse):
raise ValueError(
"Could not receive connection confirmation from launcher. Aborting."
)

print_summary(interface, dragon_head_address)

try:
logger.debug("Executing event loop")
run(
zmq_context=zmq_context,
dragon_head_address=dragon_head_address,
dragon_pid=response.dragon_pid,
zmq_context=zmq_context,
zmq_authenticator=authenticator,
)
except Exception as e:
logger.error(f"Dragon server failed with {e}", exc_info=True)
return os.EX_SOFTWARE
finally:
if authenticator.is_alive():
authenticator.stop()
if zmq_authenticator is not None and zmq_authenticator.is_alive():
zmq_authenticator.stop()

logger.info("Shutting down! Bye bye!")
return 0
Expand Down Expand Up @@ -210,6 +216,4 @@ def cleanup() -> None:
for sig in SIGNALS:
signal.signal(sig, handle_signal)

context = zmq.Context()

sys.exit(main(args_, context))
sys.exit(main(args_))
11 changes: 4 additions & 7 deletions smartsim/_core/launcher/dragon/dragonBackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import collections
import functools
import typing as t
Expand All @@ -33,12 +32,10 @@
# pylint: disable=import-error
# isort: off
from dragon.infrastructure.policy import Policy
from dragon.native.process import Process, TemplateProcess
from dragon.native.process import Process, ProcessTemplate
from dragon.native.process_group import (
ProcessGroup,
DragonProcessGroupError,
Error,
Running,
)
from dragon.native.machine import System, Node

Expand All @@ -61,8 +58,8 @@
from smartsim._core.utils.helpers import create_short_id_str
from smartsim.status import TERMINAL_STATUSES, SmartSimStatus

DRG_ERROR_STATUS = str(Error())
DRG_RUNNING_STATUS = str(Running())
DRG_ERROR_STATUS = "Error"
DRG_RUNNING_STATUS = "Running"


@dataclass
Expand Down Expand Up @@ -206,7 +203,7 @@ def update(self) -> None:
local_policy = Policy(
placement=Policy.Placement.HOST_NAME, host_name=node_name
)
tmp_proc = TemplateProcess(
tmp_proc = ProcessTemplate(
target=request.exe,
args=request.exe_args,
cwd=request.path,
Expand Down
40 changes: 30 additions & 10 deletions smartsim/_core/launcher/dragon/dragonLauncher.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@
_SchemaT = t.TypeVar("_SchemaT", bound=t.Union[DragonRequest, DragonResponse])

DRG_LOCK = RLock()
DRG_CTX = zmq.Context()


class DragonLauncher(WLMLauncher):
Expand All @@ -92,26 +91,26 @@ class DragonLauncher(WLMLauncher):

def __init__(self) -> None:
super().__init__()
self._context = DRG_CTX
self._context = zmq.Context()
self._timeout = CONFIG.dragon_server_timeout
self._reconnect_timeout = CONFIG.dragon_server_reconnect_timeout
self._startup_timeout = CONFIG.dragon_server_startup_timeout
self._context.setsockopt(zmq.SNDTIMEO, value=self._timeout)
self._context.setsockopt(zmq.RCVTIMEO, value=self._timeout)
self._dragon_head_socket: t.Optional[zmq.Socket[t.Any]] = None
self._dragon_head_process: t.Optional[subprocess.Popen[bytes]] = None
# Returned by dragon head, useful if shutdown is to be requested
# but process was started by another launcher
self._dragon_head_pid: t.Optional[int] = None
self._authenticator: t.Optional[zmq.auth.thread.ThreadAuthenticator] = None

self._set_timeout(self._timeout)

@property
def is_connected(self) -> bool:
return self._dragon_head_socket is not None

def _handshake(self, address: str) -> None:
self._dragon_head_socket, self._authenticator = dragonSockets.get_secure_socket(
self._context, zmq.REQ, False, self._authenticator
self._dragon_head_socket = dragonSockets.get_secure_socket(
self._context, zmq.REQ, False
)
self._dragon_head_socket.connect(address)
try:
Expand Down Expand Up @@ -149,6 +148,8 @@ def _connect_to_dragon(self, path: t.Union[str, "os.PathLike[str]"]) -> None:
path = _resolve_dragon_path(path)
dragon_config_log = path / CONFIG.dragon_log_filename

self._authenticator = dragonSockets.get_authenticator(self._context)

if dragon_config_log.is_file():
dragon_confs = self._parse_launched_dragon_server_info_from_files(
[dragon_config_log]
Expand Down Expand Up @@ -180,14 +181,19 @@ def _connect_to_dragon(self, path: t.Union[str, "os.PathLike[str]"]) -> None:
"smartsim._core.entrypoints.dragon",
]

# Optionally configure the dragon host for debug output
if CONFIG.log_level in ["debug", "developer"]:
cmd.insert(1, "DEBUG")
cmd.insert(1, "-l")

address = get_best_interface_and_address().address
socket_addr = ""
launcher_socket: t.Optional[zmq.Socket[t.Any]] = None
if address is not None:
self._set_timeout(self._startup_timeout)

launcher_socket, self._authenticator = dragonSockets.get_secure_socket(
self._context, zmq.REP, True, self._authenticator
launcher_socket = dragonSockets.get_secure_socket(
self._context, zmq.REP, True
)

# find first available port >= 5995
Expand All @@ -206,6 +212,8 @@ def _connect_to_dragon(self, path: t.Union[str, "os.PathLike[str]"]) -> None:
) as dragon_err:
current_env = os.environ.copy()
current_env.update({"PYTHONUNBUFFERED": "1"})
logger.debug(f"Starting Dragon environment: {' '.join(cmd)}")

# pylint: disable-next=consider-using-with
self._dragon_head_process = subprocess.Popen(
args=cmd,
Expand Down Expand Up @@ -258,6 +266,7 @@ def log_dragon_outputs() -> None:
server_socket = self._dragon_head_socket
server_process_pid = self._dragon_head_process.pid

# avoid registering the same cleanup more than once
if server_socket is not None and self._dragon_head_process is not None:
atexit.register(
_dragon_cleanup,
Expand All @@ -271,11 +280,15 @@ def log_dragon_outputs() -> None:
raise LauncherError("Could not receive address of Dragon head process")

def cleanup(self) -> None:
logger.debug("Starting Dragon launcher cleanup")
_dragon_cleanup(
server_socket=self._dragon_head_socket,
server_process_pid=self._dragon_head_pid,
server_authenticator=self._authenticator,
)
self._dragon_head_socket = None
self._dragon_head_pid = 0
self._authenticator = None

# RunSettings types supported by this launcher
@property
Expand Down Expand Up @@ -492,27 +505,34 @@ def _dragon_cleanup(

try:
if server_socket is not None:
print("Sending shutdown request to dragon environment")
DragonLauncher.send_req_with_socket(server_socket, DragonShutdownRequest())
except zmq.error.ZMQError as e:
# Can't use the logger as I/O file may be closed
print("Could not send shutdown request to dragon server")
print(f"ZMQ error: {e}", flush=True)
finally:
print("Sending shutdown request is complete")
time.sleep(1)

try:
if server_process_pid and psutil.pid_exists(server_process_pid):
print("Sending SIGINT to dragon server")
os.kill(server_process_pid, signal.SIGINT)
print("Sent SIGINT to dragon server")
except ProcessLookupError:
# Can't use the logger as I/O file may be closed
print("Dragon server is not running.", flush=True)
finally:
print("Dragon server process shutdown is complete")

try:
if server_authenticator is not None:
if server_authenticator is not None and server_authenticator.is_alive():
print("Shutting down ZMQ authenticator")
server_authenticator.stop()
except Exception:
print("Authenticator shutdown error")
finally:
print("Authenticator shutdown is complete")


def _resolve_dragon_path(fallback: t.Union[str, "os.PathLike[str]"]) -> Path:
Expand Down
61 changes: 40 additions & 21 deletions smartsim/_core/launcher/dragon/dragonSockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,16 @@
from smartsim._core.schemas import dragonResponses as _dragonResponses
from smartsim._core.schemas import utils as _utils
from smartsim._core.utils.security import KeyManager
from smartsim.log import get_logger

if t.TYPE_CHECKING:
from zmq import Context
from zmq.sugar.socket import Socket


logger = get_logger(__name__)


def as_server(
socket: "Socket[t.Any]",
) -> _utils.SocketSchemaTranslator[
Expand All @@ -62,11 +66,10 @@ def as_client(


def get_secure_socket(
context: "Context[t.Any]",
context: "zmq.Context[t.Any]",
socket_type: int,
is_server: bool,
authenticator: t.Optional[zmq.auth.thread.ThreadAuthenticator] = None,
) -> "t.Tuple[Socket[t.Any], zmq.auth.thread.ThreadAuthenticator]":
) -> "Socket[t.Any]":
"""Create secured socket that consumes & produces encrypted messages
:param context: ZMQ context object
Expand All @@ -76,29 +79,18 @@ def get_secure_socket(
:param is_server: Pass `True` to secure the socket as server. Pass `False`
to secure the socket as a client.
:type is_server: bool
:param authenticator: (optional) An existing authenticator that will be used
to authenticate secure communications.
:type authenticator: Optional[zmq.auth.thread.ThreadAuthenticator]
:returns: the secured socket prepared for sending encrypted messages and
an active authenticator if one was not supplied as a parameter.
:rtype: Tuple[zmq.Socket, zmq.auth.thread.ThreadAuthenticator]"""
:returns: the secured socket prepared for sending encrypted messages
:rtype: zmq.Socket"""
config = get_config()
socket = context.socket(socket_type)
socket: "Socket[t.Any]" = context.socket(socket_type)

key_manager = KeyManager(config, as_server=is_server, as_client=not is_server)
server_keys, client_keys = key_manager.get_keys()

# start an auth thread to provide encryption services on the socket
if authenticator is None:
authenticator = zmq.auth.thread.ThreadAuthenticator(context)

# allow all keys in the client key directory to connect
authenticator.configure_curve(domain="*", location=key_manager.client_keys_dir)

if not authenticator.is_alive():
authenticator.start()
logger.debug(f"Applying keys to socket: {server_keys}, {client_keys}")

if is_server:
logger.debug("Configuring socket as server")

# configure the server keys on the socket
socket.curve_secretkey = server_keys.private
socket.curve_publickey = client_keys.public
Expand All @@ -110,5 +102,32 @@ def get_secure_socket(

# set the server public key for decrypting incoming messages
socket.curve_serverkey = server_keys.public
return socket


def get_authenticator(
context: "zmq.Context[t.Any]",
) -> "zmq.auth.thread.ThreadAuthenticator":
"""Create an authenticator to handle encryption of ZMQ communications
:param context: ZMQ context object
:type context: zmq.Context
:returns: the activated `Authenticator`
:rtype: zmq.auth.thread.ThreadAuthenticator"""
config = get_config()

key_manager = KeyManager(config, as_client=True)
server_keys, client_keys = key_manager.get_keys()
logger.debug(f"Applying keys to authenticator: {server_keys}, {client_keys}")

authenticator = zmq.auth.thread.ThreadAuthenticator(context)

# allow all keys in the client key directory to connect
logger.debug(f"Securing with client keys in {key_manager.client_keys_dir}")
authenticator.configure_curve(domain="*", location=key_manager.client_keys_dir)

if not authenticator.is_alive():
logger.debug("Starting authenticator")
authenticator.start()

return socket, authenticator
return authenticator
Loading

0 comments on commit 7f6ecbe

Please sign in to comment.