Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure the object template vars point to dest path correctly #827

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions lib/ramble/ramble/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,7 +1184,7 @@ def add_expand_vars(self, workspace):
self._set_input_path()

self._derive_variables_for_template_path(workspace)
self._define_object_template_vars()
self._define_object_template_vars(workspace)
self._vars_are_expanded = True

def _inputs_and_fetchers(self, workload=None):
Expand Down Expand Up @@ -1390,7 +1390,7 @@ def _make_experiments(self, workspace, app_inst=None):
)
os.chmod(expand_path, _DEFAULT_CONTENT_PERM)

self._render_object_templates(exec_vars, replacement_vars=workspace.workspace_paths())
self._render_object_templates(exec_vars, workspace)

experiment_script = workspace.experiments_script
experiment_script.write(self.expander.expand_var("{batch_submit}\n"))
Expand Down Expand Up @@ -2291,12 +2291,14 @@ def evaluate_success(self):

return True

def _object_templates(self):
def _object_templates(self, workspace):
"""Return templates defined from different objects associated with the app_inst"""
run_dir = self.expander.experiment_run_dir
replacements = workspace.workspace_paths()

def _get_template_config(obj, tpl_config, obj_type):
# Search up the object chain to resolve source path
found = False
# Search up the object chain
object_paths = [e[1] for e in ramble.repository.list_object_files(obj, obj_type)]
src_name = tpl_config["src_name"]
for obj_path in object_paths:
Expand All @@ -2306,18 +2308,22 @@ def _get_template_config(obj, tpl_config, obj_type):
break
if not found:
raise ApplicationError(f"Object {obj.name} is missing template file at {src_path}")
return (obj, {**tpl_config, "src_path": src_path})

# Resolve the destination path
dest_path = ramble.util.path.substitute_path_variables(
tpl_config["dest_path"], local_replacements=replacements
)
if not os.path.isabs(dest_path):
dest_path = os.path.join(run_dir, dest_path)

return (obj, {**tpl_config, "src_path": src_path, "dest_path": dest_path})

for obj_type, obj in self._objects():
for tpl_conf in obj.templates.values():
yield _get_template_config(obj, tpl_conf, obj_type=obj_type)

def _render_object_templates(self, extra_vars, replacement_vars=None):
run_dir = self.expander.experiment_run_dir
replacements = {}
if replacement_vars is not None:
replacements = replacement_vars
for obj, tpl_config in self._object_templates():
def _render_object_templates(self, extra_vars, workspace):
for obj, tpl_config in self._object_templates(workspace):
src_path = tpl_config["src_path"]
with open(src_path) as f_in:
content = f_in.read()
Expand All @@ -2329,24 +2335,18 @@ def _render_object_templates(self, extra_vars, replacement_vars=None):
extra_vars_func = getattr(obj, extra_vars_func_name)
extra_vars.update(extra_vars_func())
rendered = self.expander.expand_var(content, extra_vars=extra_vars)
out_path = ramble.util.path.substitute_path_variables(
tpl_config["dest_name"], local_replacements=replacements
)
if not os.path.isabs(out_path):
out_path = os.path.join(run_dir, out_path)
out_path = tpl_config["dest_path"]
perm = tpl_config.get("content_perm", _DEFAULT_CONTENT_PERM)
with open(out_path, "w+") as f_out:
f_out.write(rendered)
f_out.write("\n")
os.chmod(out_path, perm)

def _define_object_template_vars(self):
run_dir = self.expander.experiment_run_dir
for _, tpl_config in self._object_templates():
def _define_object_template_vars(self, workspace):
for _, tpl_config in self._object_templates(workspace):
var_name = tpl_config["var_name"]
if var_name is not None:
path = os.path.join(run_dir, tpl_config["dest_name"])
self.variables[var_name] = path
self.variables[var_name] = tpl_config["dest_path"]

def _objects(self):
"""Return a tuple for each object instance associated with the app_inst.
Expand Down
12 changes: 7 additions & 5 deletions lib/ramble/ramble/language/shared_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,15 +484,15 @@ def _execute_target_shells(obj):
def register_template(
name: str,
src_name: str,
dest_name: str,
dest_path: str,
define_var: bool = True,
extra_vars: Optional[dict] = None,
extra_vars_func: Optional[str] = None,
output_perm=None,
):
"""Directive to define an object-specific template to be rendered into experiment run_dir.

For instance, `register_template(name="foo", src_name="foo.tpl", dest_name="foo.sh")`
For instance, `register_template(name="foo", src_name="foo.tpl", dest_path="foo.sh")`
expects a "foo.tpl" template defined alongside the object source, and uses that to
render a file under "{experiment_run_dir}/foo.sh". The rendered path can also be
referenced with the `foo` variable name.
Expand All @@ -503,8 +503,10 @@ def register_template(
`define_var` is true.
src_name: The leaf name of the template. This is used to locate the
the template under the containing directory of the object.
dest_name: The leaf name of the rendered output under the experiment
run directory.
dest_path: The location of the rendered output. It can either point
to an absolute or a relative path. It knows how to resolve
workspace paths such as `$workspace_shared`. A relative path
is relative to the `experiment_run_dir`.
define_var: Controls if a variable named `name` should be defined.
extra_vars: If present, the variable dict is used as extra variables to
render the template.
Expand All @@ -520,7 +522,7 @@ def _define_template(obj):
extra_vars_func_name = f"_{extra_vars_func}" if extra_vars_func is not None else None
obj.templates[name] = {
"src_name": src_name,
"dest_name": dest_name,
"dest_path": dest_path,
"var_name": var_name,
"extra_vars": extra_vars,
"extra_vars_func_name": extra_vars_func_name,
Expand Down
8 changes: 5 additions & 3 deletions lib/ramble/ramble/test/end_to_end/test_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,14 @@ def test_template():
assert "echo hello santa" in content
assert "echo not_exist" not in content
execute_path = os.path.join(run_dir, "execute_experiment")
script2_path = os.path.join(ws.shared_dir, "script.sh")
assert os.path.isfile(script2_path)
with open(execute_path) as f:
content = f.read()
assert script_path in content

script_path = os.path.join(ws.shared_dir, "script.sh")
assert os.path.isfile(script_path)
# The workspace path should be expanded
assert "$workspace_shared" not in content
assert script2_path in content


def test_template_inherited():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Template(ExecutableApplication):

name = "template"

executable("foo", template=["bash {bar}"])
executable("foo", template=["bash {bar}", "echo {test}"])

workload("test_template", executable="foo")

Expand All @@ -28,7 +28,7 @@ class Template(ExecutableApplication):
register_template(
name="bar",
src_name="bar.tpl",
dest_name="bar.sh",
dest_path="bar.sh",
# The `dynamic_hello_world` will be overridden by `_bar_vars`
extra_vars={
"dynamic_var1": "foobar",
Expand All @@ -45,6 +45,6 @@ def _bar_vars(self):
register_template(
name="test",
src_name="script.sh",
dest_name="$workspace_shared/script.sh",
dest_path="$workspace_shared/script.sh",
output_perm="755",
)
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,6 @@ class Hpcg(ExecutableApplication):
register_template(
name="hpcg_dat",
src_name="hpcg.dat.tpl",
dest_name="hpcg.dat",
dest_path="hpcg.dat",
define_var=False,
)
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def _isqrt(self, n):
register_template(
"hpl_dat",
src_name="HPL.dat.tpl",
dest_name="HPL.dat",
dest_path="HPL.dat",
define_var=False,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,26 +81,26 @@ def __init__(self, file_path):
register_template(
name="batch_submit",
src_name="batch_submit.tpl",
dest_name="batch_submit",
dest_path="batch_submit",
)

register_template(
name="batch_query",
src_name="batch_query.tpl",
dest_name="batch_query",
dest_path="batch_query",
extra_vars={"declare_status_map": _declare_status_map()},
)

register_template(
name="batch_cancel",
src_name="batch_cancel.tpl",
dest_name="batch_cancel",
dest_path="batch_cancel",
)

register_template(
name="slurm_execute_experiment",
src_name="slurm_execute_experiment.tpl",
dest_name="slurm_execute_experiment",
dest_path="slurm_execute_experiment",
extra_vars_func="execute_vars",
)

Expand Down
Loading