Skip to content

Commit

Permalink
Cleaner and tested azure plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
satyaog committed May 22, 2024
1 parent 86598c3 commit 5b5700c
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 214 deletions.
23 changes: 14 additions & 9 deletions milabench/cli/cloud.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from copy import deepcopy
import os
import socket
import subprocess
import sys
import warnings

from coleo import Option, tooled
from omegaconf import OmegaConf
Expand All @@ -22,7 +22,7 @@
def _flatten_cli_args(**kwargs):
return sum(
(
(f"--{str(k).replace('_', '-')}", *([str(v)] if str(v) else []))
(f"--{str(k).replace('_', '-')}", *([str(v)] if v is not None else []))
for k, v in kwargs.items()
), ()
)
Expand All @@ -39,7 +39,7 @@ def manage_cloud(pack, run_on, action="setup"):
"hostname":(lambda v: ("ip",v)),
"username":(lambda v: ("user",v)),
"ssh_key_file":(lambda v: ("key",v)),
"env":(lambda v: ("env",[".", v, ";", "conda", "activate", "milabench", "&&"])),
# "env":(lambda v: ("env",[".", v, ";", "conda", "activate", "milabench", "&&"])),
}
plan_params = deepcopy(pack.config["system"]["cloud_profiles"][run_on])
run_on, *profile = run_on.split("__")
Expand All @@ -58,8 +58,9 @@ def manage_cloud(pack, run_on, action="setup"):
plan_params["state_prefix"] = plan_params.get("state_prefix", default_state_prefix)
plan_params["state_id"] = plan_params.get("state_id", default_state_id)
plan_params["cluster_size"] = max(len(pack.config["system"]["nodes"]), i + 1)
plan_params["keep_alive"] = None

import milabench.cli.covalent as cv
import milabench.scripts.covalent as cv

subprocess.run(
[
Expand Down Expand Up @@ -106,12 +107,16 @@ def manage_cloud(pack, run_on, action="setup"):
continue
try:
k, v = line_str.split("::>")
k, v = key_map[k](v)
if k == "ip" and n[k] != "1.1.1.1":
i, n = next(nodes)
n[k] = v
except ValueError:
pass
continue
try:
k, v = key_map[k](v)
except KeyError:
warnings.warn(f"Ignoring invalid key received: {k}:{v}")
continue
if k == "ip" and n[k] != "1.1.1.1":
i, n = next(nodes)
n[k] = v

_, stderr = p.communicate()
stderr = stderr.decode("utf-8").strip()
Expand Down
204 changes: 0 additions & 204 deletions milabench/cli/covalent/__main__.py

This file was deleted.

2 changes: 1 addition & 1 deletion milabench/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def _push_reports(reports_repo, runs):
"partial": "yellow",
"failure": "red",
}
import milabench.cli.badges as badges
import milabench.scripts.badges as badges

_repo = git.repo.base.Repo(ROOT_FOLDER)
try:
Expand Down
File renamed without changes.
File renamed without changes.
103 changes: 103 additions & 0 deletions milabench/scripts/covalent/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import argparse
import subprocess
import sys


def serve(*argv):
return subprocess.run([
"covalent",
*argv
]).returncode


def _get_executor_kwargs(args):
return {
**{k:v for k,v in vars(args).items() if k not in ("setup", "teardown")},
}


def executor(executor_cls, args):
import covalent as ct

return_code = 0
try:
executor:ct.executor.BaseExecutor = executor_cls(
**_get_executor_kwargs(args),
)

if args.setup:
dispatch_id = ct.dispatch(
ct.lattice(executor.get_connection_attributes), disable_run=False
)()

result = ct.get_result(dispatch_id=dispatch_id, wait=True).result

assert result and result[0]

all_connection_attributes, _ = result
for hostname, connection_attributes in all_connection_attributes.items():
print(f"hostname::>{hostname}")
for attribute,value in connection_attributes.items():
if attribute == "hostname":
continue
print(f"{attribute}::>{value}")
finally:
if args.teardown:
executor.stop_cloud_instance({})

return return_code


def main(argv=None):
if argv is None:
argv = sys.argv[1:]

try:
import covalent as ct
except (KeyError, ImportError):
from ..utils import run_in_module_venv
check_if_module = "import covalent"
return run_in_module_venv(__file__, check_if_module, argv)

parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
subparser = subparsers.add_parser("serve")
subparser.add_argument(f"argv", nargs=argparse.REMAINDER)
for p in ("azure","ec2"):
try:
config = ct.get_config(f"executors.{p}")
except KeyError:
continue
subparser = subparsers.add_parser(p)
subparser.add_argument(f"--setup", action="store_true")
subparser.add_argument(f"--teardown", action="store_true")
for param, default in config.items():
add_argument_kwargs = {}
if isinstance(default, bool):
add_argument_kwargs["action"] = "store_false" if default else "store_true"
else:
add_argument_kwargs["default"] = default
subparser.add_argument(f"--{param.replace('_', '-')}", **add_argument_kwargs)

try:
cv_argv, argv = argv[:argv.index("--")], argv[argv.index("--")+1:]
except ValueError:
cv_argv, argv = argv, []

args = parser.parse_args(cv_argv)

if cv_argv[0] == "serve":
assert not argv
return serve(*args.argv)
elif cv_argv[0] == "azure":
executor_cls = ct.executor.AzureExecutor
elif cv_argv[0] == "ec2":
executor_cls = ct.executor.EC2Executor
else:
raise

return executor(executor_cls, args)


if __name__ == "__main__":
sys.exit(main())
File renamed without changes.
File renamed without changes.

0 comments on commit 5b5700c

Please sign in to comment.