diff --git a/lib/ramble/ramble/application.py b/lib/ramble/ramble/application.py index 5af8008af..798626043 100644 --- a/lib/ramble/ramble/application.py +++ b/lib/ramble/ramble/application.py @@ -1184,7 +1184,7 @@ def add_expand_vars(self, workspace): self._set_input_path() self._derive_variables_for_template_path(workspace) - self._define_object_template_vars() + self._define_object_template_vars(workspace) self._vars_are_expanded = True def _inputs_and_fetchers(self, workload=None): @@ -1390,7 +1390,7 @@ def _make_experiments(self, workspace, app_inst=None): ) os.chmod(expand_path, _DEFAULT_CONTENT_PERM) - self._render_object_templates(exec_vars, replacement_vars=workspace.workspace_paths()) + self._render_object_templates(exec_vars, workspace) experiment_script = workspace.experiments_script experiment_script.write(self.expander.expand_var("{batch_submit}\n")) @@ -2291,12 +2291,14 @@ def evaluate_success(self): return True - def _object_templates(self): + def _object_templates(self, workspace): """Return templates defined from different objects associated with the app_inst""" + run_dir = self.expander.experiment_run_dir + replacements = workspace.workspace_paths() def _get_template_config(obj, tpl_config, obj_type): + # Search up the object chain to resolve source path found = False - # Search up the object chain object_paths = [e[1] for e in ramble.repository.list_object_files(obj, obj_type)] src_name = tpl_config["src_name"] for obj_path in object_paths: @@ -2306,18 +2308,22 @@ def _get_template_config(obj, tpl_config, obj_type): break if not found: raise ApplicationError(f"Object {obj.name} is missing template file at {src_path}") - return (obj, {**tpl_config, "src_path": src_path}) + + # Resolve the destination path + dest_path = ramble.util.path.substitute_path_variables( + tpl_config["dest_path"], local_replacements=replacements + ) + if not os.path.isabs(dest_path): + dest_path = os.path.join(run_dir, dest_path) + + return (obj, {**tpl_config, "src_path": src_path, "dest_path": dest_path}) for obj_type, obj in self._objects(): for tpl_conf in obj.templates.values(): yield _get_template_config(obj, tpl_conf, obj_type=obj_type) - def _render_object_templates(self, extra_vars, replacement_vars=None): - run_dir = self.expander.experiment_run_dir - replacements = {} - if replacement_vars is not None: - replacements = replacement_vars - for obj, tpl_config in self._object_templates(): + def _render_object_templates(self, extra_vars, workspace): + for obj, tpl_config in self._object_templates(workspace): src_path = tpl_config["src_path"] with open(src_path) as f_in: content = f_in.read() @@ -2329,24 +2335,18 @@ def _render_object_templates(self, extra_vars, replacement_vars=None): extra_vars_func = getattr(obj, extra_vars_func_name) extra_vars.update(extra_vars_func()) rendered = self.expander.expand_var(content, extra_vars=extra_vars) - out_path = ramble.util.path.substitute_path_variables( - tpl_config["dest_name"], local_replacements=replacements - ) - if not os.path.isabs(out_path): - out_path = os.path.join(run_dir, out_path) + out_path = tpl_config["dest_path"] perm = tpl_config.get("content_perm", _DEFAULT_CONTENT_PERM) with open(out_path, "w+") as f_out: f_out.write(rendered) f_out.write("\n") os.chmod(out_path, perm) - def _define_object_template_vars(self): - run_dir = self.expander.experiment_run_dir - for _, tpl_config in self._object_templates(): + def _define_object_template_vars(self, workspace): + for _, tpl_config in self._object_templates(workspace): var_name = tpl_config["var_name"] if var_name is not None: - path = os.path.join(run_dir, tpl_config["dest_name"]) - self.variables[var_name] = path + self.variables[var_name] = tpl_config["dest_path"] def _objects(self): """Return a tuple for each object instance associated with the app_inst. diff --git a/lib/ramble/ramble/language/shared_language.py b/lib/ramble/ramble/language/shared_language.py index 57ca08aec..4b96ed82b 100644 --- a/lib/ramble/ramble/language/shared_language.py +++ b/lib/ramble/ramble/language/shared_language.py @@ -484,7 +484,7 @@ def _execute_target_shells(obj): def register_template( name: str, src_name: str, - dest_name: str, + dest_path: str, define_var: bool = True, extra_vars: Optional[dict] = None, extra_vars_func: Optional[str] = None, @@ -492,7 +492,7 @@ def register_template( ): """Directive to define an object-specific template to be rendered into experiment run_dir. - For instance, `register_template(name="foo", src_name="foo.tpl", dest_name="foo.sh")` + For instance, `register_template(name="foo", src_name="foo.tpl", dest_path="foo.sh")` expects a "foo.tpl" template defined alongside the object source, and uses that to render a file under "{experiment_run_dir}/foo.sh". The rendered path can also be referenced with the `foo` variable name. @@ -503,8 +503,10 @@ def register_template( `define_var` is true. src_name: The leaf name of the template. This is used to locate the the template under the containing directory of the object. - dest_name: The leaf name of the rendered output under the experiment - run directory. + dest_path: The location of the rendered output. It can either point + to an absolute or a relative path. It knows how to resolve + workspace paths such as `$workspace_shared`. A relative path + is relative to the `experiment_run_dir`. define_var: Controls if a variable named `name` should be defined. extra_vars: If present, the variable dict is used as extra variables to render the template. @@ -520,7 +522,7 @@ def _define_template(obj): extra_vars_func_name = f"_{extra_vars_func}" if extra_vars_func is not None else None obj.templates[name] = { "src_name": src_name, - "dest_name": dest_name, + "dest_path": dest_path, "var_name": var_name, "extra_vars": extra_vars, "extra_vars_func_name": extra_vars_func_name, diff --git a/lib/ramble/ramble/test/end_to_end/test_template.py b/lib/ramble/ramble/test/end_to_end/test_template.py index 9d9d5f31f..775bbaaa1 100644 --- a/lib/ramble/ramble/test/end_to_end/test_template.py +++ b/lib/ramble/ramble/test/end_to_end/test_template.py @@ -54,12 +54,14 @@ def test_template(): assert "echo hello santa" in content assert "echo not_exist" not in content execute_path = os.path.join(run_dir, "execute_experiment") + script2_path = os.path.join(ws.shared_dir, "script.sh") + assert os.path.isfile(script2_path) with open(execute_path) as f: content = f.read() assert script_path in content - - script_path = os.path.join(ws.shared_dir, "script.sh") - assert os.path.isfile(script_path) + # The workspace path should be expanded + assert "$workspace_shared" not in content + assert script2_path in content def test_template_inherited(): diff --git a/var/ramble/repos/builtin.mock/applications/template/application.py b/var/ramble/repos/builtin.mock/applications/template/application.py index 4ced53659..0c9112835 100644 --- a/var/ramble/repos/builtin.mock/applications/template/application.py +++ b/var/ramble/repos/builtin.mock/applications/template/application.py @@ -14,7 +14,7 @@ class Template(ExecutableApplication): name = "template" - executable("foo", template=["bash {bar}"]) + executable("foo", template=["bash {bar}", "echo {test}"]) workload("test_template", executable="foo") @@ -28,7 +28,7 @@ class Template(ExecutableApplication): register_template( name="bar", src_name="bar.tpl", - dest_name="bar.sh", + dest_path="bar.sh", # The `dynamic_hello_world` will be overridden by `_bar_vars` extra_vars={ "dynamic_var1": "foobar", @@ -45,6 +45,6 @@ def _bar_vars(self): register_template( name="test", src_name="script.sh", - dest_name="$workspace_shared/script.sh", + dest_path="$workspace_shared/script.sh", output_perm="755", ) diff --git a/var/ramble/repos/builtin/base_applications/hpcg/base_application.py b/var/ramble/repos/builtin/base_applications/hpcg/base_application.py index 5007c606a..e324bfbb0 100644 --- a/var/ramble/repos/builtin/base_applications/hpcg/base_application.py +++ b/var/ramble/repos/builtin/base_applications/hpcg/base_application.py @@ -117,6 +117,6 @@ class Hpcg(ExecutableApplication): register_template( name="hpcg_dat", src_name="hpcg.dat.tpl", - dest_name="hpcg.dat", + dest_path="hpcg.dat", define_var=False, ) diff --git a/var/ramble/repos/builtin/base_applications/hpl/base_application.py b/var/ramble/repos/builtin/base_applications/hpl/base_application.py index 8ee831f2e..5dd2957e6 100644 --- a/var/ramble/repos/builtin/base_applications/hpl/base_application.py +++ b/var/ramble/repos/builtin/base_applications/hpl/base_application.py @@ -344,7 +344,7 @@ def _isqrt(self, n): register_template( "hpl_dat", src_name="HPL.dat.tpl", - dest_name="HPL.dat", + dest_path="HPL.dat", define_var=False, ) diff --git a/var/ramble/repos/builtin/workflow_managers/slurm/workflow_manager.py b/var/ramble/repos/builtin/workflow_managers/slurm/workflow_manager.py index 9477befc0..bce03d3f5 100644 --- a/var/ramble/repos/builtin/workflow_managers/slurm/workflow_manager.py +++ b/var/ramble/repos/builtin/workflow_managers/slurm/workflow_manager.py @@ -86,26 +86,26 @@ def __init__(self, file_path): register_template( name="batch_submit", src_name="batch_submit.tpl", - dest_name="batch_submit", + dest_path="batch_submit", ) register_template( name="batch_query", src_name="batch_query.tpl", - dest_name="batch_query", + dest_path="batch_query", extra_vars={"declare_status_map": _declare_status_map()}, ) register_template( name="batch_cancel", src_name="batch_cancel.tpl", - dest_name="batch_cancel", + dest_path="batch_cancel", ) register_template( name="slurm_execute_experiment", src_name="slurm_execute_experiment.tpl", - dest_name="slurm_execute_experiment", + dest_path="slurm_execute_experiment", extra_vars_func="execute_vars", )