Skip to content

Commit

Permalink
Only write "raw" git commit file
Browse files Browse the repository at this point in the history
  • Loading branch information
KyleFromNVIDIA committed Mar 14, 2024
1 parent 61e3e20 commit 39a7f53
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 74 deletions.
35 changes: 10 additions & 25 deletions rapids_build_backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@ class Config:
# required) and whether it may be overridden by an environment variable or a config
# setting.
config_options = {
"build-backend": (None, False, None),
"commit-file": ("", False, None),
"commit-file-type": ("python", False, {"python", "raw"}),
"disable-cuda-suffix": (False, True, None),
"only-release-deps": (False, True, None),
"require-cuda": (True, True, None),
"requires": ([], False, None),
"build-backend": (None, False),
"commit-file": ("", False),
"disable-cuda-suffix": (False, True),
"only-release-deps": (False, True),
"require-cuda": (True, True),
"requires": ([], False),
}

def __init__(self, dirname=".", config_settings=None):
Expand All @@ -32,9 +31,7 @@ def __init__(self, dirname=".", config_settings=None):
def __getattr__(self, name):
config_name = name.replace("_", "-")
if config_name in Config.config_options:
default_value, allows_override, allowed_values = Config.config_options[
config_name
]
default_value, allows_override = Config.config_options[config_name]

# If overrides are allowed environment variables take precedence over the
# config_settings dict.
Expand All @@ -50,31 +47,19 @@ def __getattr__(self, name):
f"{env_var} must be 'true' or 'false', not {str_val}"
)
return str_val == "true"
return self._check_value(
env_var, os.environ[env_var], allowed_values
)
return os.environ[env_var]

if config_name in self.config_settings:
if isinstance(default_value, bool):
return self.config_settings[config_name] == "true"
return self._check_value(
config_name, self.config_settings[config_name], allowed_values
)
return self.config_settings[config_name]

try:
return self._check_value(
config_name, self.config[config_name], allowed_values
)
return self.config[config_name]
except KeyError:
if default_value is not None:
return default_value

raise AttributeError(f"Config is missing required attribute {name}")
else:
raise AttributeError(f"Attempted to access unknown option {name}")

def _check_value(self, name, value, allowed_values):
if allowed_values is not None and value not in allowed_values:
formatted_list = "\n".join(f" - {v}" for v in allowed_values)
raise ValueError(f"{name} must be one of:\n{formatted_list}")
return value
30 changes: 5 additions & 25 deletions rapids_build_backend/impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,6 @@ def _edit_git_commit(config):
at build time.
"""
commit_file = config.commit_file
commit_file_type = config.commit_file_type
commit = _get_git_commit()

if commit_file != "" and commit is not None:
Expand All @@ -244,32 +243,13 @@ def _edit_git_commit(config):
f".{os.path.basename(commit_file)}.rapids-build-backend.bak",
)
try:
if commit_file_type == "python":
with open(commit_file) as f:
lines = f.readlines()

try:
shutil.move(commit_file, bkp_commit_file)
except FileNotFoundError:
bkp_commit_file = None

with open(commit_file, "w") as f:
wrote = False
for line in lines:
if "__git_commit__" in line:
f.write(f'__git_commit__ = "{commit}"\n')
wrote = True
else:
f.write(line)
# If no git commit line was found, write it at the end of the file.
if not wrote:
f.write(f'__git_commit__ = "{commit}"\n')

elif commit_file_type == "raw":
try:
shutil.move(commit_file, bkp_commit_file)
except FileNotFoundError:
bkp_commit_file = None

with open(commit_file, "w") as f:
f.write(f"{commit}\n")
with open(commit_file, "w") as f:
f.write(f"{commit}\n")

yield
finally:
Expand Down
35 changes: 11 additions & 24 deletions tests/test_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,15 @@


@pytest.mark.parametrize(
["commit_file_type", "initial_contents", "expected_contents"],
["initial_contents"],
[
(
"python",
'# Begin Python file\n__git_commit__ = ""\n# End Python file\n',
'# Begin Python file\n__git_commit__ = "abc123"\n# End Python file\n',
),
("python", None, FileNotFoundError),
("raw", "def456\n", "abc123\n"),
("raw", None, "abc123\n"),
("def456\n",),
("",),
(None,),
],
)
@patch("rapids_build_backend.impls._get_git_commit", Mock(return_value="abc123"))
def test_edit_git_commit(commit_file_type, initial_contents, expected_contents):
def test_edit_git_commit(initial_contents):
def check_initial_contents(filename):
if initial_contents is not None:
with open(filename) as f:
Expand All @@ -39,20 +34,12 @@ def check_initial_contents(filename):

config = Mock(
commit_file=commit_file,
commit_file_type=commit_file_type,
)
if isinstance(expected_contents, type) and issubclass(
expected_contents, Exception
):
with pytest.raises(expected_contents):
with _edit_git_commit(config):
pass
else:
with _edit_git_commit(config):
with open(commit_file) as f:
assert f.read() == expected_contents
check_initial_contents(
os.path.join(d, ".commit-file.rapids-build-backend.bak")
)
with _edit_git_commit(config):
with open(commit_file) as f:
assert f.read() == "abc123\n"
check_initial_contents(
os.path.join(d, ".commit-file.rapids-build-backend.bak")
)

check_initial_contents(commit_file)

0 comments on commit 39a7f53

Please sign in to comment.