From 7a3ed57dcfb09ce790681cad570173e43e408fc5 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Thu, 11 Jul 2024 12:51:45 -0500 Subject: [PATCH] Separate get_repo_blob and get_repo_tree --- course/content.py | 56 ++++++++++++------- course/validation.py | 22 +++++--- course/versioning.py | 1 - relate/utils.py | 6 +- tests/base_test_mixins.py | 5 +- tests/test_content.py | 19 ++++--- tests/test_pages/test_inline.py | 4 +- .../test_validate_course_content.py | 2 +- 8 files changed, 69 insertions(+), 46 deletions(-) diff --git a/course/content.py b/course/content.py index 388226f48..5295f4624 100644 --- a/course/content.py +++ b/course/content.py @@ -42,6 +42,7 @@ from yaml import safe_load as load_yaml from course.constants import ATTRIBUTES_FILENAME +from course.validation import Blob_ish, Tree_ish from relate.utils import Struct, SubdirRepoWrapper, dict_to_struct @@ -627,7 +628,7 @@ def get_course_repo_path(course: Course) -> str: def get_course_repo(course: Course) -> Repo_ish: - from dulwich.repo.repo import Repo + from dulwich.repo import Repo repo = Repo(get_course_repo_path(course)) if course.course_root_path: @@ -694,8 +695,7 @@ def look_up_git_object(repo: dulwich.repo.Repo, return cur_lookup -def get_repo_blob(repo: Repo_ish, full_name: str, commit_sha: bytes, - allow_tree: bool = True) -> dulwich.objects.Blob: +def get_repo_tree(repo: Repo_ish, full_name: str, commit_sha: bytes) -> Tree_ish: """ :arg full_name: A Unicode string indicating the file name. :arg commit_sha: A byte string containing the commit hash @@ -713,23 +713,44 @@ def get_repo_blob(repo: Repo_ish, full_name: str, commit_sha: bytes, git_obj = look_up_git_object( dul_repo, root_tree=dul_repo[tree_sha], full_name=full_name) - from dulwich.objects import Blob, Tree + from dulwich.objects import Tree - from course.validation import FileSystemFakeRepoFile, FileSystemFakeRepoTree + from course.validation import FileSystemFakeRepoTree msg_full_name = full_name if full_name else _("(repo root)") if isinstance(git_obj, (Tree, FileSystemFakeRepoTree)): - if allow_tree: - return git_obj - else: - raise ObjectDoesNotExist( - _("resource '%s' is a directory, not a file") % msg_full_name) + return git_obj + else: + raise ObjectDoesNotExist(_("resource '%s' is not a tree") % msg_full_name) + + +def get_repo_blob(repo: Repo_ish, full_name: str, commit_sha: bytes) -> Blob_ish: + """ + :arg full_name: A Unicode string indicating the file name. + :arg commit_sha: A byte string containing the commit hash + :arg allow_tree: Allow the resulting object to be a directory + """ + + dul_repo, full_name = get_true_repo_and_path(repo, full_name) + + try: + tree_sha = dul_repo[commit_sha].tree + except KeyError: + raise ObjectDoesNotExist( + _("commit sha '%s' not found") % commit_sha.decode()) + + git_obj = look_up_git_object( + dul_repo, root_tree=dul_repo[tree_sha], full_name=full_name) + + from dulwich.objects import Blob + + from course.validation import FileSystemFakeRepoFile if isinstance(git_obj, (Blob, FileSystemFakeRepoFile)): return git_obj else: - raise ObjectDoesNotExist(_("resource '%s' is not a file") % msg_full_name) + raise ObjectDoesNotExist(_("resource '%s' is not a file") % full_name) def get_repo_blob_data_cached( @@ -757,8 +778,7 @@ def get_repo_blob_data_cached( result: bytes | None = None if cache_key is None: - result = get_repo_blob(repo, full_name, commit_sha, - allow_tree=False).data + result = get_repo_blob(repo, full_name, commit_sha).data assert isinstance(result, bytes) return result @@ -777,8 +797,7 @@ def get_repo_blob_data_cached( assert isinstance(result, bytes), cache_key return result - result = get_repo_blob(repo, full_name, commit_sha, - allow_tree=False).data + result = get_repo_blob(repo, full_name, commit_sha).data assert result is not None if len(result) <= getattr(settings, "RELATE_CACHE_MAX_BYTES", 0): @@ -1023,8 +1042,7 @@ def get_raw_yaml_from_repo( yaml_str = expand_yaml_macros( repo, commit_sha, - get_repo_blob(repo, full_name, commit_sha, - allow_tree=False).data) + get_repo_blob(repo, full_name, commit_sha).data) result = load_yaml(yaml_str) # type: ignore @@ -1071,7 +1089,7 @@ def get_yaml_from_repo( return result yaml_bytestream = get_repo_blob( - repo, full_name, commit_sha, allow_tree=False).data + repo, full_name, commit_sha).data yaml_text = yaml_bytestream.decode("utf-8") if not tolerate_tabs and LINE_HAS_INDENTING_TABS_RE.search(yaml_text): @@ -2004,7 +2022,7 @@ def is_commit_sha_valid(repo: Repo_ish, commit_sha: str) -> bool: def list_flow_ids(repo: Repo_ish, commit_sha: bytes) -> list[str]: flow_ids = [] try: - flows_tree = get_repo_blob(repo, "flows", commit_sha) + flows_tree = get_repo_tree(repo, "flows", commit_sha) except ObjectDoesNotExist: # That's OK--no flows yet. pass diff --git a/course/validation.py b/course/validation.py index 3efcbaf18..d3562f4cf 100644 --- a/course/validation.py +++ b/course/validation.py @@ -26,8 +26,9 @@ import datetime import re import sys -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Union +import dulwich.objects from django.core.exceptions import ObjectDoesNotExist from django.utils.html import escape from django.utils.translation import gettext, gettext_lazy as _ @@ -38,7 +39,6 @@ FLOW_SESSION_EXPIRATION_MODE_CHOICES, participation_permission as pperm, ) -from course.content import get_repo_blob from relate.utils import Struct, string_concat @@ -1465,13 +1465,15 @@ def validate_course_content(repo, course_file, events_file, else: access_kinds = DEFAULT_ACCESS_KINDS + from course.content import get_repo_tree + check_attributes_yml( vctx, repo, "", - get_repo_blob(repo, "", validate_sha), + get_repo_tree(repo, "", validate_sha), access_kinds) try: - flows_tree = get_repo_blob(repo, "media", validate_sha) + get_repo_tree(repo, "media", validate_sha) except ObjectDoesNotExist: # That's great--no media directory. pass @@ -1485,7 +1487,7 @@ def validate_course_content(repo, course_file, events_file, # {{{ flows try: - flows_tree = get_repo_blob(repo, "flows", validate_sha) + flows_tree = get_repo_tree(repo, "flows", validate_sha) except ObjectDoesNotExist: # That's OK--no flows yet. pass @@ -1538,6 +1540,8 @@ def validate_course_content(repo, course_file, events_file, # }}} + from course.content import get_repo_blob + # {{{ static pages try: @@ -1591,7 +1595,7 @@ def tree(self): class FileSystemFakeRepoTreeEntry: # pragma: no cover - def __init__(self, path, mode): + def __init__(self, path: bytes, mode: int) -> None: self.path = path self.mode = mode @@ -1620,7 +1624,7 @@ def __getitem__(self, name): else: return stat_result.st_mode, FileSystemFakeRepoFile(name) - def items(self): + def items(self) -> list[FileSystemFakeRepoTreeEntry]: import os return [ FileSystemFakeRepoTreeEntry( @@ -1639,6 +1643,10 @@ def data(self): return inf.read() +Blob_ish = Union[dulwich.objects.Blob, FileSystemFakeRepoFile] +Tree_ish = Union[dulwich.objects.Tree, FileSystemFakeRepoTree] + + def validate_course_on_filesystem( root, course_file, events_file): # pragma: no cover fake_repo = FileSystemFakeRepo(root.encode("utf-8")) diff --git a/course/versioning.py b/course/versioning.py index 48a77eb1e..9d2ef11d7 100644 --- a/course/versioning.py +++ b/course/versioning.py @@ -35,7 +35,6 @@ ) import django.forms as forms -import dulwich.blob import dulwich.client import paramiko import paramiko.client diff --git a/relate/utils.py b/relate/utils.py index f1a2637d8..f532980eb 100644 --- a/relate/utils.py +++ b/relate/utils.py @@ -26,7 +26,6 @@ import datetime from typing import ( - TYPE_CHECKING, Any, Mapping, Union, @@ -34,14 +33,11 @@ import django.forms as forms import dulwich.repo +from django.http import HttpRequest from django.utils.text import format_lazy from django.utils.translation import gettext_lazy as _ -if TYPE_CHECKING: - from django.http import HttpRequest - - def string_concat(*strings: Any) -> str: return format_lazy("{}" * len(strings), *strings) diff --git a/tests/base_test_mixins.py b/tests/base_test_mixins.py index 37cd0d43e..27bdcc8d8 100644 --- a/tests/base_test_mixins.py +++ b/tests/base_test_mixins.py @@ -2837,7 +2837,7 @@ def __init__(self, yaml_file_name): data = f.read() self.data = data - def get_repo_side_effect(repo, full_name, commit_sha, allow_tree=True): + def get_repo_side_effect(repo, full_name, commit_sha): commit_sha_path_maps = COMMIT_SHA_MAP.get(full_name) if commit_sha_path_maps: assert isinstance(commit_sha_path_maps, list) @@ -2846,8 +2846,7 @@ def get_repo_side_effect(repo, full_name, commit_sha, allow_tree=True): path = cs_map[commit_sha.decode()]["path"] return Blob(path) - return get_repo_blob(repo, full_name, repo[b"HEAD"].id, - allow_tree=allow_tree) + return get_repo_blob(repo, full_name, repo[b"HEAD"].id) cls.batch_fake_get_repo_blob = mock.patch(cls.get_repo_blob_patching_path) cls.mock_get_repo_blob = cls.batch_fake_get_repo_blob.start() diff --git a/tests/test_content.py b/tests/test_content.py index ea2dfcfc3..e52ae457a 100644 --- a/tests/test_content.py +++ b/tests/test_content.py @@ -290,8 +290,7 @@ def test_repo_root_not_allow_tree_key_error(self): with self.pctx.repo as repo: with self.assertRaises(ObjectDoesNotExist) as cm: content.get_repo_blob( - repo, "", self.course.active_git_commit_sha.encode(), - allow_tree=False) + repo, "", self.course.active_git_commit_sha.encode()) expected_error_msg = "resource '(repo root)' is a directory, not a file" self.assertIn(expected_error_msg, str(cm.exception)) @@ -300,9 +299,8 @@ def test_access_directory_content_type_error(self): full_name = os.path.join(*path_parts) with self.pctx.repo as repo: with self.assertRaises(ObjectDoesNotExist) as cm: - content.get_repo_blob( - repo, full_name, self.course.active_git_commit_sha.encode(), - allow_tree=True) + content.get_repo_tree( + repo, full_name, self.course.active_git_commit_sha.encode()) expected_error_msg = ( "'%s' is not a directory, cannot lookup nested names" % path_parts[0]) @@ -313,8 +311,7 @@ def test_resource_is_a_directory_error(self): with self.pctx.repo as repo: with self.assertRaises(ObjectDoesNotExist) as cm: content.get_repo_blob( - repo, full_name, self.course.active_git_commit_sha.encode(), - allow_tree=False) + repo, full_name, self.course.active_git_commit_sha.encode()) expected_error_msg = ( "resource '%s' is a directory, not a file" % full_name) self.assertIn(expected_error_msg, str(cm.exception)) @@ -935,6 +932,12 @@ def setUp(self): self.repo = mock.MagicMock() self.commit_sha = mock.MagicMock() + fake_get_repo_tree = mock.patch("course.content.get_repo_tree") + self.mock_get_repo_tree = fake_get_repo_tree.start() + self.addCleanup(fake_get_repo_tree.stop) + self.repo = mock.MagicMock() + self.commit_sha = mock.MagicMock() + def test_object_does_not_exist(self): self.mock_get_repo_blob.side_effect = ObjectDoesNotExist() self.assertEqual(content.list_flow_ids(self.repo, self.commit_sha), []) @@ -947,7 +950,7 @@ def test_result(self): tree.add(b"flow_c.yml", stat.S_IFREG, b"flow_c content") tree.add(b"temp_dir", stat.S_IFDIR, b"a temp dir") - self.mock_get_repo_blob.return_value = tree + self.mock_get_repo_tree.return_value = tree self.assertEqual(content.list_flow_ids( self.repo, self.commit_sha), ["flow_a", "flow_b", "flow_c"]) diff --git a/tests/test_pages/test_inline.py b/tests/test_pages/test_inline.py index 81102db85..689aad238 100644 --- a/tests/test_pages/test_inline.py +++ b/tests/test_pages/test_inline.py @@ -664,11 +664,11 @@ """ -def get_repo_blob_side_effect(repo, full_name, commit_sha, allow_tree=True): +def get_repo_blob_side_effect(repo, full_name, commit_sha): # Fake the inline multiple question yaml for specific commit if not (full_name == "questions/multi-question-example.yml" and commit_sha == b"ec41a2de73a99e6022060518cb5c5c162b88cdf5"): - return get_repo_blob(repo, full_name, commit_sha, allow_tree) + return get_repo_blob(repo, full_name, commit_sha) else: class Blob: pass diff --git a/tests/test_validation/test_validate_course_content.py b/tests/test_validation/test_validate_course_content.py index 6bd0228e5..6846bd322 100644 --- a/tests/test_validation/test_validate_course_content.py +++ b/tests/test_validation/test_validate_course_content.py @@ -365,7 +365,7 @@ def setUp(self): self.addCleanup(fake_check_grade_identifier_link.stop) fake_get_repo_blob = ( - mock.patch("course.validation.get_repo_blob")) + mock.patch("course.content.get_repo_blob")) self.mock_get_repo_blob = fake_get_repo_blob.start() self.mock_get_repo_blob.side_effect = get_repo_blob_side_effect self.addCleanup(fake_get_repo_blob.stop)