diff --git a/lib/ramble/ramble/language/application_language.py b/lib/ramble/ramble/language/application_language.py index 6e633c4dc..0ea268f38 100644 --- a/lib/ramble/ramble/language/application_language.py +++ b/lib/ramble/ramble/language/application_language.py @@ -108,6 +108,10 @@ def _execute_workload_groups(app): for var in app.workload_group_vars[name]: app.workloads[workload].add_variable(var) + if name in app.workload_group_env_vars: + for env_var in app.workload_group_env_vars[name]: + app.workloads[workload].add_environment_variable(env_var) + return _execute_workload_groups @@ -260,8 +264,10 @@ def _execute_workload_variable(app): return _execute_workload_variable -@application_directive(dicts=()) -def environment_variable(name, value, description, workload=None, workloads=None, **kwargs): +@application_directive("workload_group_env_vars") +def environment_variable( + name, value, description, workload=None, workloads=None, workload_group=None, **kwargs +): """Define an environment variable to be used in experiments Args: @@ -274,15 +280,29 @@ def environment_variable(name, value, description, workload=None, workloads=None """ def _execute_environment_variable(app): - all_workloads = ramble.language.language_helpers.require_definition( + all_workloads = ramble.language.language_helpers.merge_definitions( workload, workloads, app.workloads, "workload", "workloads", "environment_variable" ) + workload_env_var = ramble.workload.WorkloadEnvironmentVariable( + name, value=value, description=description + ) + for wl_name in all_workloads: - app.workloads[wl_name].add_environment_variable( - ramble.workload.WorkloadEnvironmentVariable( - name, value=value, description=description - ) - ) + app.workloads[wl_name].add_environment_variable(workload_env_var.copy()) + + if workload_group is not None: + workload_group_list = app.workload_groups[workload_group] + + if workload_group not in app.workload_group_env_vars: + app.workload_group_env_vars[workload_group] = [] + + app.workload_group_vars[workload_group].append(workload_env_var.copy()) + + for wl_name in workload_group_list: + app.workloads[wl_name].add_environment_variable(workload_env_var.copy()) + + if not all_workloads and workload_group is None: + raise DirectiveError("A workload or workload group is required") return _execute_environment_variable