Skip to content

Commit

Permalink
Set ML models and scripts from driver scripts (#185)
Browse files Browse the repository at this point in the history
Adds utilities to set ML models and ML scripts directly from driver scripts, as opposed to only from application code.

[ committed by @al-rigazzi ]
[ reviewed by @Spartee ]
  • Loading branch information
al-rigazzi authored May 11, 2022
1 parent 3f9b583 commit 8798c28
Show file tree
Hide file tree
Showing 17 changed files with 1,738 additions and 30 deletions.
54 changes: 51 additions & 3 deletions smartsim/_core/control/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@
import threading
import time

from ..._core.utils.redis import db_is_active, set_ml_model, set_script
from ...database import Orchestrator
from ...entity import DBNode, EntityList, SmartSimEntity
from ...entity import DBNode, DBModel, DBObject, DBScript, EntityList, SmartSimEntity
from ...error import LauncherError, SmartSimError, SSInternalError, SSUnsupportedError
from ...log import get_logger
from ...status import STATUS_RUNNING, TERMINAL_STATUSES
Expand All @@ -40,6 +41,10 @@
from ..utils import check_cluster_status, create_cluster
from .jobmanager import JobManager

from smartredis import Client
from smartredis.error import RedisConnectionError, RedisReplyError


logger = get_logger(__name__)

# job manager lock
Expand Down Expand Up @@ -286,9 +291,12 @@ def _launch(self, manifest):
raise SmartSimError(msg)
self._launch_orchestrator(orchestrator)

for rc in manifest.ray_clusters:
for rc in manifest.ray_clusters: # cov-wlm
rc._update_workers()

if self.orchestrator_active:
self._set_dbobjects(manifest)

# create all steps prior to launch
steps = []
all_entity_lists = manifest.ensembles + manifest.ray_clusters
Expand All @@ -297,7 +305,7 @@ def _launch(self, manifest):
batch_step = self._create_batch_job_step(elist)
steps.append((batch_step, elist))
else:
# if ensemble is to be run as seperate job steps, aka not in a batch
# if ensemble is to be run as separate job steps, aka not in a batch
job_steps = [(self._create_job_step(e), e) for e in elist.entities]
steps.extend(job_steps)

Expand Down Expand Up @@ -586,3 +594,43 @@ def reload_saved_db(self, checkpoint_file):
finally:
JM_LOCK.release()


def _set_dbobjects(self, manifest):
if not manifest.has_db_objects:
return

db_addresses = self._jobs.get_db_host_addresses()

hosts = list(set([address.split(":")[0] for address in db_addresses]))
ports = list(set([address.split(":")[-1] for address in db_addresses]))

if not db_is_active(hosts=hosts,
ports=ports,
num_shards=len(db_addresses)):
raise SSInternalError("Cannot set DB Objects, DB is not running")

client = Client(address=db_addresses[0], cluster=len(db_addresses) > 1)

for model in manifest.models:
if not model.colocated:
for db_model in model._db_models:
set_ml_model(db_model, client)
for db_script in model._db_scripts:
set_script(db_script, client)

for ensemble in manifest.ensembles:
for db_model in ensemble._db_models:
set_ml_model(db_model, client)
for db_script in ensemble._db_scripts:
set_script(db_script, client)
for entity in ensemble:
if not entity.colocated:
# Set models which could belong only
# to the entities and not to the ensemble
# but avoid duplicates
for db_model in entity._db_models:
if db_model not in ensemble._db_models:
set_ml_model(db_model, client)
for db_script in entity._db_scripts:
if db_script not in ensemble._db_scripts:
set_script(db_script, client)
44 changes: 44 additions & 0 deletions smartsim/_core/control/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,47 @@ def __str__(self):

s += "\n"
return s

@property
def has_db_objects(self):
"""Check if any entity has DBObjects to set
"""

def has_db_models(entity):
if hasattr(entity, "_db_models"):
return len(entity._db_models) > 0
def has_db_scripts(entity):
if hasattr(entity, "_db_scripts"):
return len(entity._db_scripts) > 0


has_db_objects = False
for model in self.models:
has_db_objects |= hasattr(model, "_db_models")

# Check if any model has either a DBModel or a DBScript
# we update has_db_objects so that as soon as one check
# returns True, we can exit
has_db_objects |= any([has_db_models(model) | has_db_scripts(model) for model in self.models])
if has_db_objects:
return True

# If there are no ensembles, there can be no outstanding model
# to check for DBObjects, return current value of DBObjects, which
# should be False
ensembles = self.ensembles
if not ensembles:
return has_db_objects

# First check if there is any ensemble DBObject, if so, return True
has_db_objects |= any([has_db_models(ensemble) | has_db_scripts(ensemble) for ensemble in ensembles])
if has_db_objects:
return True
for ensemble in ensembles:
# Last case, check if any model within an ensemble has DBObjects attached
has_db_objects |= any([has_db_models(model) | has_db_scripts(model) for model in ensemble])
if has_db_objects:
return True

# `has_db_objects` should be False here
return has_db_objects
124 changes: 122 additions & 2 deletions smartsim/_core/entrypoints/colocated.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
from pathlib import Path
from subprocess import PIPE, STDOUT

from smartredis import Client
from smartredis.error import RedisConnectionError, RedisReplyError
from smartsim._core.utils.network import current_ip
from smartsim.error import SSInternalError
from smartsim.log import get_logger
Expand All @@ -55,8 +57,107 @@
def handle_signal(signo, frame):
cleanup()

def launch_db_model(client: Client, db_model: List[str]):
"""Parse options to launch model on local cluster
def main(network_interface: str, db_cpus: int, command: List[str]):
:param client: SmartRedis client connected to local DB
:type client: Client
:param db_model: List of arguments defining the model
:type db_model: List[str]
:return: Name of model
:rtype: str
"""
parser = argparse.ArgumentParser("Set ML model on DB")
parser.add_argument("--name", type=str)
parser.add_argument("--file", type=str)
parser.add_argument("--backend", type=str)
parser.add_argument("--device", type=str)
parser.add_argument("--devices_per_node", type=int)
parser.add_argument("--batch_size", type=int, default=0)
parser.add_argument("--min_batch_size", type=int, default=0)
parser.add_argument("--tag", type=str, default="")
parser.add_argument("--inputs", nargs="+", default=None)
parser.add_argument("--outputs", nargs="+", default=None)

# Unused if we use SmartRedis
parser.add_argument("--min_batch_timeout", type=int, default=None)
args = parser.parse_args(db_model)

if args.inputs:
inputs = list(args.inputs)
if args.outputs:
outputs = list(args.outputs)

if args.devices_per_node == 1:
client.set_model_from_file(args.name,
args.file,
args.backend,
args.device,
args.batch_size,
args.min_batch_size,
args.tag,
inputs,
outputs)
else:
for device_num in range(args.devices_per_node):
client.set_model_from_file(args.name,
args.file,
args.backend,
args.device+f":{device_num}",
args.batch_size,
args.min_batch_size,
args.tag,
inputs,
outputs)

return args.name

def launch_db_script(client: Client, db_script: List[str]):
"""Parse options to launch script on local cluster
:param client: SmartRedis client connected to local DB
:type client: Client
:param db_model: List of arguments defining the script
:type db_model: List[str]
:return: Name of model
:rtype: str
"""
parser = argparse.ArgumentParser("Set script on DB")
parser.add_argument("--name", type=str)
parser.add_argument("--func", type=str)
parser.add_argument("--file", type=str)
parser.add_argument("--backend", type=str)
parser.add_argument("--device", type=str)
parser.add_argument("--devices_per_node", type=int)
args = parser.parse_args(db_script)
if args.func:
func = args.func.replace("\\n", "\n")

if args.devices_per_node == 1:
client.set_script(args.name,
func,
args.device)
else:
for device_num in range(args.devices_per_node):
client.set_script(args.name,
func,
args.device+f":{device_num}")
elif args.file:
if args.devices_per_node == 1:
client.set_script_from_file(args.name,
args.file,
args.device)
else:
for device_num in range(args.devices_per_node):
client.set_script_from_file(args.name,
args.file,
args.device+f":{device_num}")


return args.name


def main(network_interface: str, db_cpus: int, command: List[str], db_models: List[List[str]], db_scripts: List[List[str]]):
global DBPID

try:
Expand Down Expand Up @@ -102,6 +203,23 @@ def main(network_interface: str, db_cpus: int, command: List[str]):
f"\tCommand: {' '.join(cmd)}\n\n"
)))

if db_models or db_scripts:
try:
client = Client(cluster=False)
for i, db_model in enumerate(db_models):
logger.debug("Uploading model")
model_name = launch_db_model(client, db_model)
logger.debug(f"Added model {model_name} ({i+1}/{len(db_models)})")
for i, db_script in enumerate(db_scripts):
logger.debug("Uploading script")
script_name = launch_db_script(client, db_script)
logger.debug(f"Added script {script_name} ({i+1}/{len(db_scripts)})")
# Make sure we don't keep this around
del client
except (RedisConnectionError, RedisReplyError):
raise SSInternalError("Failed to set model or script, could not connect to database")


for line in iter(p.stdout.readline, b""):
print(line.decode("utf-8").rstrip(), flush=True)

Expand Down Expand Up @@ -144,6 +262,8 @@ def cleanup():
parser.add_argument("+lockfile", type=str, help="Filename to create for single proc per host")
parser.add_argument("+db_cpus", type=int, default=2, help="Number of CPUs to use for DB")
parser.add_argument("+command", nargs="+", help="Command to run")
parser.add_argument("+db_model", nargs="+", action="append", default=[], help="Model to set on DB")
parser.add_argument("+db_script", nargs="+", action="append", default=[], help="Script to set on DB")
args = parser.parse_args()

tmp_lockfile = Path(tempfile.gettempdir()) / args.lockfile
Expand All @@ -160,7 +280,7 @@ def cleanup():
for sig in SIGNALS:
signal.signal(sig, handle_signal)

main(args.ifname, args.db_cpus, args.command)
main(args.ifname, args.db_cpus, args.command, args.db_model, args.db_script)

# gracefully exit the processes in the distributed application that
# we do not want to have start a colocated process. Only one process
Expand Down
Loading

0 comments on commit 8798c28

Please sign in to comment.