diff --git a/lib/ramble/ramble/language/application_language.py b/lib/ramble/ramble/language/application_language.py index fd2c7b5b5..7c6514999 100644 --- a/lib/ramble/ramble/language/application_language.py +++ b/lib/ramble/ramble/language/application_language.py @@ -75,13 +75,42 @@ def _execute_workload(app): executable, executables, app.executables, "executable", "executables", "workload" ) - all_inputs = ramble.language.language_helpers.merge_definitions(input, inputs, app.inputs) + all_inputs = ramble.language.language_helpers.merge_definitions( + input, inputs, app.inputs, "input", "inputs", "workload" + ) app.workloads[name] = ramble.workload.Workload(name, all_execs, all_inputs, tags) return _execute_workload +@application_directive("workload_groups") +def workload_group(name, workloads=[], mode=None, **kwargs): + """Adds a workload group to this application + + Defines a new workload group that can be used within the context of its + application. + + Args: + name: The name of the group + workloads: A list of workloads to be grouped + """ + + def _execute_workload_groups(app): + if mode == "append": + app.workload_groups[name].update(set(workloads)) + else: + app.workload_groups[name] = set(workloads) + + # Apply any existing variables in the group to the workload + for workload in workloads: + if name in app.workload_group_vars: + for var in app.workload_group_vars[name]: + app.workloads[workload].add_variable(var) + + return _execute_workload_groups + + @application_directive("executables") def executable(name, template, **kwargs): """Adds an executable to this application @@ -156,7 +185,7 @@ def _execute_input_file(app): return _execute_input_file -@application_directive(dicts=()) +@application_directive("workload_group_vars") def workload_variable( name, default, @@ -164,6 +193,7 @@ def workload_variable( values=None, workload=None, workloads=None, + workload_group=None, expandable=True, **kwargs, ): @@ -177,20 +207,33 @@ def workload_variable( """ def _execute_workload_variable(app): - all_workloads = ramble.language.language_helpers.require_definition( + # Always apply passes workload/workloads + all_workloads = ramble.language.language_helpers.merge_definitions( workload, workloads, app.workloads, "workload", "workloads", "workload_variable" ) + workload_var = ramble.workload.WorkloadVariable( + name, default=default, description=description, values=values, expandable=expandable + ) + for wl_name in all_workloads: - app.workloads[wl_name].add_variable( - ramble.workload.WorkloadVariable( - name, - default=default, - description=description, - values=values, - expandable=expandable, - ) - ) + app.workloads[wl_name].add_variable(workload_var.copy()) + + if workload_group is not None: + workload_group_list = app.workload_groups[workload_group] + + if workload_group not in app.workload_group_vars: + app.workload_group_vars[workload_group] = [] + + # Track which vars we add to, to allow us to re-apply during inheritance + app.workload_group_vars[workload_group].append(workload_var.copy()) + + for wl_name in workload_group_list: + # Apply the variable + app.workloads[wl_name].add_variable(workload_var.copy()) + + if not all_workloads and workload_group is None: + raise DirectiveError("A workload or workload group is required") return _execute_workload_variable diff --git a/lib/ramble/ramble/language/language_helpers.py b/lib/ramble/ramble/language/language_helpers.py index 364ad8cdd..e247c0ba6 100644 --- a/lib/ramble/ramble/language/language_helpers.py +++ b/lib/ramble/ramble/language/language_helpers.py @@ -12,7 +12,46 @@ from ramble.language.language_base import DirectiveError -def merge_definitions(single_type, multiple_type, multiple_pattern_match): +def check_definition( + single_type, multiple_type, single_arg_name, multiple_arg_name, directive_name +): + """ + Sanity check definitions before merging or require + + Args: + single_type: Single string for type name + multiple_type: List of strings for type names, may contain wildcards + multiple_pattern_match: List of strings to match against patterns in multiple_type + single_arg_name: String name of the single_type argument in the directive + multiple_arg_name: String name of the multiple_type argument in the directive + directive_name: Name of the directive requiring a type + + Returns: + List of all type names (Merged if both single_type and multiple_type definitions are valid) + """ + if single_type and not isinstance(single_type, six.string_types): + raise DirectiveError( + f"Directive {directive_name} was given an invalid type " + f"for the {single_arg_name} argument. " + f"Type was {type(single_type)}" + ) + + if multiple_type and not isinstance(multiple_type, list): + raise DirectiveError( + f"Directive {directive_name} was given an invalid type " + f"for the {multiple_arg_name} argument. " + f"Type was {type(multiple_type)}" + ) + + +def merge_definitions( + single_type, + multiple_type, + multiple_pattern_match, + single_arg_name, + multiple_arg_name, + directive_name, +): """Merge definitions of a type This method will merge two optional definitions of single_type and @@ -22,11 +61,18 @@ def merge_definitions(single_type, multiple_type, multiple_pattern_match): single_type: Single string for type name multiple_type: List of strings for type names, may contain wildcards multiple_pattern_match: List of strings to match against patterns in multiple_type + single_arg_name: String name of the single_type argument in the directive + multiple_arg_name: String name of the multiple_type argument in the directive + directive_name: Name of the directive requiring a type Returns: List of all type names (Merged if both single_type and multiple_type definitions are valid) """ + check_definition( + single_type, multiple_type, single_arg_name, multiple_arg_name, directive_name + ) + all_types = [] if single_type: @@ -72,21 +118,14 @@ def require_definition( f"{single_arg_name} or {multiple_arg_name} to be defined." ) - if single_type and not isinstance(single_type, six.string_types): - raise DirectiveError( - f"Directive {directive_name} was given an invalid type " - f"for the {single_arg_name} argument. " - f"Type was {type(single_type)}" - ) - - if multiple_type and not isinstance(multiple_type, list): - raise DirectiveError( - f"Directive {directive_name} was given an invalid type " - f"for the {multiple_arg_name} argument. " - f"Type was {type(multiple_type)}" - ) - - return merge_definitions(single_type, multiple_type, multiple_pattern_match) + return merge_definitions( + single_type, + multiple_type, + multiple_pattern_match, + single_arg_name, + multiple_arg_name, + directive_name, + ) def expand_patterns(multiple_type: list, multiple_pattern_match: list): diff --git a/lib/ramble/ramble/test/application_tests.py b/lib/ramble/ramble/test/application_tests.py index 5018f773a..cba645985 100644 --- a/lib/ramble/ramble/test/application_tests.py +++ b/lib/ramble/ramble/test/application_tests.py @@ -22,7 +22,9 @@ ) def test_app_features(mutable_mock_apps_repo, app): app_inst = mutable_mock_apps_repo.get(app) + assert hasattr(app_inst, "workloads") + assert hasattr(app_inst, "workload_groups") assert hasattr(app_inst, "executables") assert hasattr(app_inst, "figures_of_merit") assert hasattr(app_inst, "inputs") @@ -498,3 +500,46 @@ def test_class_attributes(mutable_mock_apps_repo): assert "added_workload" in basic_copy.workloads assert "added_workload" not in basic_inst.workloads + + +def test_workload_groups(mutable_mock_apps_repo): + workload_group_inst = mutable_mock_apps_repo.get("workload-groups") + + assert "test_wl" in workload_group_inst.workloads + + assert "empty" in workload_group_inst.workload_groups + assert "test_wlg" in workload_group_inst.workload_groups + + my_var = workload_group_inst.workloads["test_wl"].find_variable("test_var") + assert my_var is not None + assert my_var.default == "2.0" + assert my_var.description == "Test workload vars and groups" + + my_mixed_var_wl = workload_group_inst.workloads["test_wl"].find_variable("test_var_mixed") + assert my_mixed_var_wl is not None + assert my_mixed_var_wl.default == "3.0" + assert my_mixed_var_wl.description == "Test vars for workload and groups" + + +def test_workload_groups_inherited(mutable_mock_apps_repo): + wlgi_inst = mutable_mock_apps_repo.get("workload-groups-inherited") + + assert "test_wl" in wlgi_inst.workloads + assert "test_wl3" in wlgi_inst.workloads + + # check we inherit groups we don't touch + assert "empty" in wlgi_inst.workload_groups + assert "test_wlg" in wlgi_inst.workload_groups + + assert "test_wl" in wlgi_inst.workload_groups["test_wlg"] + + # Ensure a new workload can obtain the parent level vars via groups + my_var = wlgi_inst.workloads["test_wl3"].find_variable("test_var") + assert my_var is not None + assert my_var.default == "2.0" + assert my_var.description == "Test workload vars and groups" + + for wl in ["test_wl", "test_wl3"]: + my_mixed_var_wl = wlgi_inst.workloads[wl].find_variable("test_var_mixed") + assert my_mixed_var_wl is not None + assert my_mixed_var_wl.default == "3.0" diff --git a/lib/ramble/ramble/workload.py b/lib/ramble/ramble/workload.py index 0517b3167..5011dc1ee 100644 --- a/lib/ramble/ramble/workload.py +++ b/lib/ramble/ramble/workload.py @@ -7,6 +7,8 @@ # except according to those terms. from typing import List +import copy + import ramble.util.colors as rucolor @@ -62,6 +64,9 @@ def as_str(self, n_indent: int = 0): out_str += f'{indentation} {name}: {str(attr_val).replace("@", "@@")}\n' return out_str + def copy(self): + return copy.deepcopy(self) + class WorkloadEnvironmentVariable(object): """Class representing an environment variable in a workload""" @@ -99,6 +104,9 @@ def as_str(self, n_indent: int = 0): out_str += f'{indentation} {name}: {attr_val.replace("@", "@@")}\n' return out_str + def copy(self): + return copy.deepcopy(self) + class Workload(object): """Class representing a single workload""" diff --git a/var/ramble/repos/builtin.mock/applications/workload-groups-inherited/application.py b/var/ramble/repos/builtin.mock/applications/workload-groups-inherited/application.py new file mode 100644 index 000000000..0771ea808 --- /dev/null +++ b/var/ramble/repos/builtin.mock/applications/workload-groups-inherited/application.py @@ -0,0 +1,22 @@ +# Copyright 2022-2024 The Ramble Authors +# +# Licensed under the Apache License, Version 2.0 or the MIT license +# , at your +# option. This file may not be copied, modified, or distributed +# except according to those terms. + +from ramble.appkit import * + +from ramble.app.builtin.mock.workload_groups import WorkloadGroups + + +class WorkloadGroupsInherited(WorkloadGroups): + name = "workload-groups-inherited" + + workload('test_wl3', executable='baz') + + # Test populated group applies existing vars to new workload + workload_group('test_wlg', + workloads=['test_wl3'], + mode='append') diff --git a/var/ramble/repos/builtin.mock/applications/workload-groups/application.py b/var/ramble/repos/builtin.mock/applications/workload-groups/application.py new file mode 100644 index 000000000..43cb03caa --- /dev/null +++ b/var/ramble/repos/builtin.mock/applications/workload-groups/application.py @@ -0,0 +1,38 @@ +# Copyright 2022-2024 The Ramble Authors +# +# Licensed under the Apache License, Version 2.0 or the MIT license +# , at your +# option. This file may not be copied, modified, or distributed +# except according to those terms. + +from ramble.appkit import * + + +class WorkloadGroups(ExecutableApplication): + name = "workload-groups" + + executable('foo', 'echo "bar"', use_mpi=False) + executable('bar', 'echo "baz"', use_mpi=False) + + workload('test_wl', executable='foo') + workload('test_wl2', executable='bar') + + # Test empty group + workload_group('empty', + workloads=[]) + + # Test populated group + workload_group('test_wlg', + workloads=['test_wl', 'test_wl2']) + + # Test workload_variable that uses a group + workload_variable('test_var', default='2.0', + description='Test workload vars and groups', + workload_group='test_wlg') + + # Test passing both groups an explicit workloads + workload_variable('test_var_mixed', default='3.0', + description='Test vars for workload and groups', + workload='test_wl', + workload_group='test_wlg')