Skip to content

Commit

Permalink
Define enum for stage
Browse files Browse the repository at this point in the history
  • Loading branch information
jelmer committed Jul 20, 2023
1 parent 52f323d commit f33ecb1
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 58 deletions.
117 changes: 60 additions & 57 deletions dulwich/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import stat
import struct
import sys
from enum import Enum
from typing import (
Any,
BinaryIO,
Expand Down Expand Up @@ -90,19 +91,20 @@
DEFAULT_VERSION = 2


class Stage(Enum):
NORMAL = 0
MERGE_CONFLICT_ANCESTOR = 1
MERGE_CONFLICT_THIS = 2
MERGE_CONFLICT_OTHER = 3


class UnmergedEntriesInIndexEx(Exception):
def __init__(self, message):
super().__init__(message)


def read_stage(entry: IndexEntry) -> int:
"""Stage of an Entry
0 - normal
1 - merge conflict 'ancestor' entry
2 - merge conflict 'this' entry
3 - merge conflict 'other' entry
"""
return (entry.flags & FLAG_STAGEMASK) >> FLAG_STAGESHIFT
def read_stage(entry: IndexEntry) -> Stage:
return Stage((entry.flags & FLAG_STAGEMASK) >> FLAG_STAGESHIFT)


def pathsplit(path: bytes) -> Tuple[bytes, bytes]:
Expand Down Expand Up @@ -155,7 +157,7 @@ def write_cache_time(f, t):
f.write(struct.pack(">LL", *t))


def read_cache_entry(f, version: int) -> Tuple[str, IndexEntry]:
def read_cache_entry(f, version: int) -> Tuple[bytes, IndexEntry]:
"""Read an entry from a cache file.
Args:
Expand Down Expand Up @@ -260,12 +262,12 @@ def read_index(f: BinaryIO):
yield read_cache_entry(f, version)


def read_index_dict(f) -> Dict[Tuple[bytes, int], IndexEntry]:
def read_index_dict(f) -> Dict[Tuple[bytes, Stage], IndexEntry]:
"""Read an index file and return it as a dictionary.
Dict Key is tuple of path and stage number, as
path alone is not unique
Args:
f: File object to read fromls
f: File object to read fromls.
"""
ret = {}
for name, entry in read_index(f):
Expand All @@ -292,19 +294,19 @@ def write_index(f: BinaryIO, entries: List[Tuple[bytes, IndexEntry]], version: O

def write_index_dict(
f: BinaryIO,
entries: Dict[Tuple[bytes, int], IndexEntry],
entries: Dict[Tuple[bytes, Stage] | bytes, IndexEntry],
version: Optional[int] = None,
) -> None:
"""Write an index file based on the contents of a dictionary.
being careful to sort by path and then by stage
being careful to sort by path and then by stage.
"""
entries_list = []
for key in sorted(entries):
if isinstance(key, tuple):
name, stage = key
else:
name = key
stage = 0
stage = Stage.NORMAL
entries_list.append((name, entries[(name, stage)]))
write_index(f, entries_list, version=version)

Expand Down Expand Up @@ -335,6 +337,8 @@ def cleanup_mode(mode: int) -> int:
class Index:
"""A Git Index file."""

_bynamestage: Dict[Tuple[bytes, Stage], IndexEntry]

def __init__(self, filename: Union[bytes, str], read=True) -> None:
"""Create an index object associated with the given filename.
Expand Down Expand Up @@ -385,40 +389,40 @@ def __len__(self) -> int:
"""Number of entries in this index file."""
return len(self._bynamestage)

def __getitem__(self, key: Union[Tuple[bytes, int], bytes]) -> IndexEntry:
def __getitem__(self, key: Union[Tuple[bytes, Stage], bytes]) -> IndexEntry:
"""Retrieve entry by relative path and stage.
Returns: tuple with (ctime, mtime, dev, ino, mode, uid, gid, size, sha,
flags)
"""
if isinstance(key, tuple):
return self._bynamestage[key]
if (key, 0) in self._bynamestage:
return self._bynamestage[(key, 0)]
if (key, Stage.NORMAL) in self._bynamestage:
return self._bynamestage[(key, Stage.NORMAL)]
# there is a conflict return 'this' entry
return self._bynamestage[(key, 2)]
return self._bynamestage[(key, Stage.MERGE_CONFLICT_THIS)]

def __iter__(self) -> Iterator[bytes]:
"""Iterate over the paths and stages in this index."""
for (name, stage) in self._bynamestage:
if stage == 1 or stage == 3:
if stage == Stage.MERGE_CONFLICT_ANCESTOR or stage == Stage.MERGE_CONFLICT_OTHER:
continue
yield name

def __contains__(self, key):
if isinstance(key, tuple):
return key in self._bynamestage
if (key, 0) in self._bynamestage:
if (key, Stage.NORMAL) in self._bynamestage:
return True
if (key, 2) in self._bynamestage:
if (key, Stage.MERGE_CONFLICT_THIS) in self._bynamestage:
return True
return False
def get_sha1(self, path: bytes, stage: int = 0) -> bytes:

def get_sha1(self, path: bytes, stage: Stage = Stage.NORMAL) -> bytes:
"""Return the (git object) SHA1 for the object at a path."""
return self[(path, stage)].sha

def get_mode(self, path: bytes, stage: int = 0) -> int:
def get_mode(self, path: bytes, stage: Stage = Stage.NORMAL) -> int:
"""Return the POSIX file mode for the object at a path."""
return self[(path, stage)].mode

Expand All @@ -428,15 +432,15 @@ def iterobjects(self) -> Iterable[Tuple[bytes, bytes, int]]:
entry = self[path]
yield path, entry.sha, cleanup_mode(entry.mode)

def iterconflicts(self) -> Iterable[Tuple[int, bytes, int, bytes]]:
def iterconflicts(self) -> Iterable[Tuple[int, bytes, Stage, bytes]]:
"""Iterate over path, sha, mode tuples for use with commit_tree."""
for (name, stage), entry in self._bynamestage.items():
if stage > 0:
if stage != Stage.NORMAL:
yield cleanup_mode(entry.mode), entry.sha, stage, name

def has_conflicts(self):
for (name, stage) in self._bynamestage.keys():
if stage > 0:
if stage != Stage.NORMAL:
return True
return False

Expand All @@ -452,66 +456,66 @@ def set_merge_conflict(self, apath, stage, mode, sha, time):
sha,
stage << FLAG_STAGESHIFT,
0)
if (apath, 0) in self._bynamestage:
del self._bynamestage[(apath, 0)]
if (apath, Stage.NORMAL) in self._bynamestage:
del self._bynamestage[(apath, Stage.NORMAL)]
self._bynamestage[(apath, stage)] = entry

def clear(self):
"""Remove all contents from this index."""
self._bynamestage = {}

def __setitem__(self, key: Union[Tuple[bytes, int], bytes], x: IndexEntry) -> None:
def __setitem__(self, key: Union[Tuple[bytes, Stage], bytes], x: IndexEntry) -> None:
assert len(x) == len(IndexEntry._fields)
if isinstance(key, tuple):
name, stage = key
else:
name = key
stage = 0 # default when stage not explicitly specified
stage = Stage.NORMAL # default when stage not explicitly specified
assert isinstance(name, bytes)
# Remove merge conflict entries if new entry is stage 0
# Remove stage 0 entry if new entry has conflicts (stage > 0)
if stage == 0:
if (name, 1) in self._bynamestage:
del self._bynamestage[(name, 1)]
if (name, 2) in self._bynamestage:
del self._bynamestage[(name, 2)]
if (name, 3) in self._bynamestage:
del self._bynamestage[(name, 3)]
if stage > 0 and (name, 0) in self._bynamestage:
del self._bynamestage[(name, 0)]
# Remove normal stage entry if new entry has conflicts (stage > 0)
if stage == Stage.NORMAL:
if (name, Stage.MERGE_CONFLICT_ANCESTOR) in self._bynamestage:
del self._bynamestage[(name, Stage.MERGE_CONFLICT_ANCESTOR)]
if (name, Stage.MERGE_CONFLICT_THIS) in self._bynamestage:
del self._bynamestage[(name, Stage.MERGE_CONFLICT_THIS)]
if (name, Stage.MERGE_CONFLICT_OTHER) in self._bynamestage:
del self._bynamestage[(name, Stage.MERGE_CONFLICT_OTHER)]
if stage != Stage.NORMAL and (name, Stage.NORMAL) in self._bynamestage:
del self._bynamestage[(name, Stage.NORMAL)]
self._bynamestage[(name, stage)] = IndexEntry(*x)

def __delitem__(self, key: Union[Tuple[bytes, int], bytes]) -> None:
def __delitem__(self, key: Union[Tuple[bytes, Stage], bytes]) -> None:
if isinstance(key, tuple):
del self._bynamestage[key]
return
name = key
assert isinstance(name, bytes)
if (name, 0) in self._bynamestage:
del self._bynamestage[(name, 0)]
if (name, 1) in self._bynamestage:
del self._bynamestage[(name, 1)]
if (name, 2) in self._bynamestage:
del self._bynamestage[(name, 2)]
if (name, 3) in self._bynamestage:
del self._bynamestage[(name, 3)]
if (name, Stage.NORMAL) in self._bynamestage:
del self._bynamestage[(name, Stage.NORMAL)]
if (name, Stage.MERGE_CONFLICT_ANCESTOR) in self._bynamestage:
del self._bynamestage[(name, Stage.MERGE_CONFLICT_ANCESTOR)]
if (name, Stage.MERGE_CONFLICT_THIS) in self._bynamestage:
del self._bynamestage[(name, Stage.MERGE_CONFLICT_THIS)]
if (name, Stage.MERGE_CONFLICT_OTHER) in self._bynamestage:
del self._bynamestage[(name, Stage.MERGE_CONFLICT_OTHER)]

def iteritems(self) -> Iterator[Tuple[bytes, IndexEntry]]:
for (name, stage), entry in self._bynamestage.items():
yield name, entry

def items(self) -> Iterator[Tuple[Tuple[bytes, int], IndexEntry]]:
def items(self) -> Iterator[Tuple[Tuple[bytes, Stage], IndexEntry]]:
return self._bynamestage.items()

def update(self, entries: Dict[Tuple[bytes, int], IndexEntry]):
def update(self, entries: Dict[Tuple[bytes, Stage], IndexEntry]):
for key, value in entries.items():
self[key] = value

def paths(self):
for (name, stage) in self._bynamestage.keys():
if stage == 0 or stage == 2: # normal or conflict 'this'
if stage == Stage.NORMAL or stage == Stage.MERGE_CONFLICT_THIS:
yield name

def changes_from_tree(
self, object_store, tree: ObjectID, want_unchanged: bool = False):
"""Find the differences between the contents of this index and a tree.
Expand Down Expand Up @@ -561,7 +565,6 @@ def commit_tree(
Returns:
SHA1 of the created tree.
"""

trees: Dict[bytes, Any] = {b"": {}}

def add_tree(path):
Expand Down Expand Up @@ -855,7 +858,7 @@ def build_index_from_tree(
st = st.__class__(st_tuple)
# default to a stage 0 index entry (normal)
# when reading from the filesystem
index[(entry.path, 0)] = index_entry_from_stat(st, entry.sha, 0)
index[(entry.path, Stage.NORMAL)] = index_entry_from_stat(st, entry.sha, 0)

index.write()

Expand Down Expand Up @@ -960,7 +963,7 @@ def get_unstaged_changes(
for tree_path, entry in index.iteritems():
full_path = _tree_to_fs_path(root_path, tree_path)
stage = read_stage(entry)
if stage == 1 or stage == 3:
if stage == Stage.MERGE_CONFLICT_ANCESTOR or stage == Stage.MERGE_CONFLICT_OTHER:
continue
try:
st = os.lstat(full_path)
Expand Down
3 changes: 2 additions & 1 deletion dulwich/tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from ..index import (
Index,
IndexEntry,
Stage,
_fs_to_tree_path,
_tree_to_fs_path,
build_index_from_tree,
Expand Down Expand Up @@ -168,7 +169,7 @@ def tearDown(self):

def test_simple_write(self):
entries = {
(b"barbla", 0): IndexEntry(
(b"barbla", Stage.NORMAL): IndexEntry(
(1230680220, 0),
(1230680220, 0),
2050,
Expand Down

0 comments on commit f33ecb1

Please sign in to comment.