Skip to content

Commit

Permalink
Merge pull request #827 from linsword13/template-dest-path
Browse files Browse the repository at this point in the history
Ensure the object template vars point to dest path correctly
  • Loading branch information
douglasjacobsen authored Jan 16, 2025
2 parents 2a2c9c6 + 53b8647 commit f74a69c
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 38 deletions.
42 changes: 21 additions & 21 deletions lib/ramble/ramble/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,7 +1184,7 @@ def add_expand_vars(self, workspace):
self._set_input_path()

self._derive_variables_for_template_path(workspace)
self._define_object_template_vars()
self._define_object_template_vars(workspace)
self._vars_are_expanded = True

def _inputs_and_fetchers(self, workload=None):
Expand Down Expand Up @@ -1390,7 +1390,7 @@ def _make_experiments(self, workspace, app_inst=None):
)
os.chmod(expand_path, _DEFAULT_CONTENT_PERM)

self._render_object_templates(exec_vars, replacement_vars=workspace.workspace_paths())
self._render_object_templates(exec_vars, workspace)

experiment_script = workspace.experiments_script
experiment_script.write(self.expander.expand_var("{batch_submit}\n"))
Expand Down Expand Up @@ -2291,12 +2291,14 @@ def evaluate_success(self):

return True

def _object_templates(self):
def _object_templates(self, workspace):
"""Return templates defined from different objects associated with the app_inst"""
run_dir = self.expander.experiment_run_dir
replacements = workspace.workspace_paths()

def _get_template_config(obj, tpl_config, obj_type):
# Search up the object chain to resolve source path
found = False
# Search up the object chain
object_paths = [e[1] for e in ramble.repository.list_object_files(obj, obj_type)]
src_name = tpl_config["src_name"]
for obj_path in object_paths:
Expand All @@ -2306,18 +2308,22 @@ def _get_template_config(obj, tpl_config, obj_type):
break
if not found:
raise ApplicationError(f"Object {obj.name} is missing template file at {src_path}")
return (obj, {**tpl_config, "src_path": src_path})

# Resolve the destination path
dest_path = ramble.util.path.substitute_path_variables(
tpl_config["dest_path"], local_replacements=replacements
)
if not os.path.isabs(dest_path):
dest_path = os.path.join(run_dir, dest_path)

return (obj, {**tpl_config, "src_path": src_path, "dest_path": dest_path})

for obj_type, obj in self._objects():
for tpl_conf in obj.templates.values():
yield _get_template_config(obj, tpl_conf, obj_type=obj_type)

def _render_object_templates(self, extra_vars, replacement_vars=None):
run_dir = self.expander.experiment_run_dir
replacements = {}
if replacement_vars is not None:
replacements = replacement_vars
for obj, tpl_config in self._object_templates():
def _render_object_templates(self, extra_vars, workspace):
for obj, tpl_config in self._object_templates(workspace):
src_path = tpl_config["src_path"]
with open(src_path) as f_in:
content = f_in.read()
Expand All @@ -2329,24 +2335,18 @@ def _render_object_templates(self, extra_vars, replacement_vars=None):
extra_vars_func = getattr(obj, extra_vars_func_name)
extra_vars.update(extra_vars_func())
rendered = self.expander.expand_var(content, extra_vars=extra_vars)
out_path = ramble.util.path.substitute_path_variables(
tpl_config["dest_name"], local_replacements=replacements
)
if not os.path.isabs(out_path):
out_path = os.path.join(run_dir, out_path)
out_path = tpl_config["dest_path"]
perm = tpl_config.get("content_perm", _DEFAULT_CONTENT_PERM)
with open(out_path, "w+") as f_out:
f_out.write(rendered)
f_out.write("\n")
os.chmod(out_path, perm)

def _define_object_template_vars(self):
run_dir = self.expander.experiment_run_dir
for _, tpl_config in self._object_templates():
def _define_object_template_vars(self, workspace):
for _, tpl_config in self._object_templates(workspace):
var_name = tpl_config["var_name"]
if var_name is not None:
path = os.path.join(run_dir, tpl_config["dest_name"])
self.variables[var_name] = path
self.variables[var_name] = tpl_config["dest_path"]

def _objects(self):
"""Return a tuple for each object instance associated with the app_inst.
Expand Down
12 changes: 7 additions & 5 deletions lib/ramble/ramble/language/shared_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,15 +484,15 @@ def _execute_target_shells(obj):
def register_template(
name: str,
src_name: str,
dest_name: str,
dest_path: str,
define_var: bool = True,
extra_vars: Optional[dict] = None,
extra_vars_func: Optional[str] = None,
output_perm=None,
):
"""Directive to define an object-specific template to be rendered into experiment run_dir.
For instance, `register_template(name="foo", src_name="foo.tpl", dest_name="foo.sh")`
For instance, `register_template(name="foo", src_name="foo.tpl", dest_path="foo.sh")`
expects a "foo.tpl" template defined alongside the object source, and uses that to
render a file under "{experiment_run_dir}/foo.sh". The rendered path can also be
referenced with the `foo` variable name.
Expand All @@ -503,8 +503,10 @@ def register_template(
`define_var` is true.
src_name: The leaf name of the template. This is used to locate the
the template under the containing directory of the object.
dest_name: The leaf name of the rendered output under the experiment
run directory.
dest_path: The location of the rendered output. It can either point
to an absolute or a relative path. It knows how to resolve
workspace paths such as `$workspace_shared`. A relative path
is relative to the `experiment_run_dir`.
define_var: Controls if a variable named `name` should be defined.
extra_vars: If present, the variable dict is used as extra variables to
render the template.
Expand All @@ -520,7 +522,7 @@ def _define_template(obj):
extra_vars_func_name = f"_{extra_vars_func}" if extra_vars_func is not None else None
obj.templates[name] = {
"src_name": src_name,
"dest_name": dest_name,
"dest_path": dest_path,
"var_name": var_name,
"extra_vars": extra_vars,
"extra_vars_func_name": extra_vars_func_name,
Expand Down
8 changes: 5 additions & 3 deletions lib/ramble/ramble/test/end_to_end/test_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,14 @@ def test_template():
assert "echo hello santa" in content
assert "echo not_exist" not in content
execute_path = os.path.join(run_dir, "execute_experiment")
script2_path = os.path.join(ws.shared_dir, "script.sh")
assert os.path.isfile(script2_path)
with open(execute_path) as f:
content = f.read()
assert script_path in content

script_path = os.path.join(ws.shared_dir, "script.sh")
assert os.path.isfile(script_path)
# The workspace path should be expanded
assert "$workspace_shared" not in content
assert script2_path in content


def test_template_inherited():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Template(ExecutableApplication):

name = "template"

executable("foo", template=["bash {bar}"])
executable("foo", template=["bash {bar}", "echo {test}"])

workload("test_template", executable="foo")

Expand All @@ -28,7 +28,7 @@ class Template(ExecutableApplication):
register_template(
name="bar",
src_name="bar.tpl",
dest_name="bar.sh",
dest_path="bar.sh",
# The `dynamic_hello_world` will be overridden by `_bar_vars`
extra_vars={
"dynamic_var1": "foobar",
Expand All @@ -45,6 +45,6 @@ def _bar_vars(self):
register_template(
name="test",
src_name="script.sh",
dest_name="$workspace_shared/script.sh",
dest_path="$workspace_shared/script.sh",
output_perm="755",
)
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,6 @@ class Hpcg(ExecutableApplication):
register_template(
name="hpcg_dat",
src_name="hpcg.dat.tpl",
dest_name="hpcg.dat",
dest_path="hpcg.dat",
define_var=False,
)
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def _isqrt(self, n):
register_template(
"hpl_dat",
src_name="HPL.dat.tpl",
dest_name="HPL.dat",
dest_path="HPL.dat",
define_var=False,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,26 +86,26 @@ def __init__(self, file_path):
register_template(
name="batch_submit",
src_name="batch_submit.tpl",
dest_name="batch_submit",
dest_path="batch_submit",
)

register_template(
name="batch_query",
src_name="batch_query.tpl",
dest_name="batch_query",
dest_path="batch_query",
extra_vars={"declare_status_map": _declare_status_map()},
)

register_template(
name="batch_cancel",
src_name="batch_cancel.tpl",
dest_name="batch_cancel",
dest_path="batch_cancel",
)

register_template(
name="slurm_execute_experiment",
src_name="slurm_execute_experiment.tpl",
dest_name="slurm_execute_experiment",
dest_path="slurm_execute_experiment",
extra_vars_func="execute_vars",
)

Expand Down

0 comments on commit f74a69c

Please sign in to comment.