diff --git a/rapids_build_backend/config.py b/rapids_build_backend/config.py index 3968c5c..a044197 100644 --- a/rapids_build_backend/config.py +++ b/rapids_build_backend/config.py @@ -12,13 +12,12 @@ class Config: # required) and whether it may be overridden by an environment variable or a config # setting. config_options = { - "build-backend": (None, False, None), - "commit-file": ("", False, None), - "commit-file-type": ("python", False, {"python", "raw"}), - "disable-cuda-suffix": (False, True, None), - "only-release-deps": (False, True, None), - "require-cuda": (True, True, None), - "requires": ([], False, None), + "build-backend": (None, False), + "commit-file": ("", False), + "disable-cuda-suffix": (False, True), + "only-release-deps": (False, True), + "require-cuda": (True, True), + "requires": ([], False), } def __init__(self, dirname=".", config_settings=None): @@ -32,9 +31,7 @@ def __init__(self, dirname=".", config_settings=None): def __getattr__(self, name): config_name = name.replace("_", "-") if config_name in Config.config_options: - default_value, allows_override, allowed_values = Config.config_options[ - config_name - ] + default_value, allows_override = Config.config_options[config_name] # If overrides are allowed environment variables take precedence over the # config_settings dict. @@ -50,21 +47,15 @@ def __getattr__(self, name): f"{env_var} must be 'true' or 'false', not {str_val}" ) return str_val == "true" - return self._check_value( - env_var, os.environ[env_var], allowed_values - ) + return os.environ[env_var] if config_name in self.config_settings: if isinstance(default_value, bool): return self.config_settings[config_name] == "true" - return self._check_value( - config_name, self.config_settings[config_name], allowed_values - ) + return self.config_settings[config_name] try: - return self._check_value( - config_name, self.config[config_name], allowed_values - ) + return self.config[config_name] except KeyError: if default_value is not None: return default_value @@ -72,9 +63,3 @@ def __getattr__(self, name): raise AttributeError(f"Config is missing required attribute {name}") else: raise AttributeError(f"Attempted to access unknown option {name}") - - def _check_value(self, name, value, allowed_values): - if allowed_values is not None and value not in allowed_values: - formatted_list = "\n".join(f" - {v}" for v in allowed_values) - raise ValueError(f"{name} must be one of:\n{formatted_list}") - return value diff --git a/rapids_build_backend/impls.py b/rapids_build_backend/impls.py index 93ca8b6..1ff3dec 100644 --- a/rapids_build_backend/impls.py +++ b/rapids_build_backend/impls.py @@ -235,7 +235,6 @@ def _edit_git_commit(config): at build time. """ commit_file = config.commit_file - commit_file_type = config.commit_file_type commit = _get_git_commit() if commit_file != "" and commit is not None: @@ -244,32 +243,13 @@ def _edit_git_commit(config): f".{os.path.basename(commit_file)}.rapids-build-backend.bak", ) try: - if commit_file_type == "python": - with open(commit_file) as f: - lines = f.readlines() - + try: shutil.move(commit_file, bkp_commit_file) + except FileNotFoundError: + bkp_commit_file = None - with open(commit_file, "w") as f: - wrote = False - for line in lines: - if "__git_commit__" in line: - f.write(f'__git_commit__ = "{commit}"\n') - wrote = True - else: - f.write(line) - # If no git commit line was found, write it at the end of the file. - if not wrote: - f.write(f'__git_commit__ = "{commit}"\n') - - elif commit_file_type == "raw": - try: - shutil.move(commit_file, bkp_commit_file) - except FileNotFoundError: - bkp_commit_file = None - - with open(commit_file, "w") as f: - f.write(f"{commit}\n") + with open(commit_file, "w") as f: + f.write(f"{commit}\n") yield finally: diff --git a/tests/test_impls.py b/tests/test_impls.py index 9f30641..4d8cb75 100644 --- a/tests/test_impls.py +++ b/tests/test_impls.py @@ -10,20 +10,15 @@ @pytest.mark.parametrize( - ["commit_file_type", "initial_contents", "expected_contents"], + ["initial_contents"], [ - ( - "python", - '# Begin Python file\n__git_commit__ = ""\n# End Python file\n', - '# Begin Python file\n__git_commit__ = "abc123"\n# End Python file\n', - ), - ("python", None, FileNotFoundError), - ("raw", "def456\n", "abc123\n"), - ("raw", None, "abc123\n"), + ("def456\n",), + ("",), + (None,), ], ) @patch("rapids_build_backend.impls._get_git_commit", Mock(return_value="abc123")) -def test_edit_git_commit(commit_file_type, initial_contents, expected_contents): +def test_edit_git_commit(initial_contents): def check_initial_contents(filename): if initial_contents is not None: with open(filename) as f: @@ -39,20 +34,12 @@ def check_initial_contents(filename): config = Mock( commit_file=commit_file, - commit_file_type=commit_file_type, ) - if isinstance(expected_contents, type) and issubclass( - expected_contents, Exception - ): - with pytest.raises(expected_contents): - with _edit_git_commit(config): - pass - else: - with _edit_git_commit(config): - with open(commit_file) as f: - assert f.read() == expected_contents - check_initial_contents( - os.path.join(d, ".commit-file.rapids-build-backend.bak") - ) + with _edit_git_commit(config): + with open(commit_file) as f: + assert f.read() == "abc123\n" + check_initial_contents( + os.path.join(d, ".commit-file.rapids-build-backend.bak") + ) check_initial_contents(commit_file)