From cc5d4504d06ac050c2c2c7b51edf9bd60d5337b6 Mon Sep 17 00:00:00 2001 From: Michael Aquilina Date: Fri, 16 Aug 2024 14:41:47 +0100 Subject: [PATCH] Add support for sections in requirements-fixer --- pre_commit_hooks/requirements_txt_fixer.py | 73 +++++++++++++--------- tests/requirements_txt_fixer_test.py | 19 ++++++ 2 files changed, 61 insertions(+), 31 deletions(-) diff --git a/pre_commit_hooks/requirements_txt_fixer.py b/pre_commit_hooks/requirements_txt_fixer.py index 07b57e18..aa5d4687 100644 --- a/pre_commit_hooks/requirements_txt_fixer.py +++ b/pre_commit_hooks/requirements_txt_fixer.py @@ -66,10 +66,7 @@ def append_value(self, value: bytes) -> None: def fix_requirements(f: IO[bytes]) -> int: - requirements: list[Requirement] = [] before = list(f) - after: list[bytes] = [] - before_string = b''.join(before) # adds new line in case one is missing @@ -81,17 +78,21 @@ def fix_requirements(f: IO[bytes]) -> int: if before_string.strip() == b'': return PASS + # split requirements into sections based on newlines + sections: list[list[Requirement]] = [[]] + for line in before: + requirements = sections[-1] + # If the most recent requirement object has a value, then it's # time to start building the next requirement object. - if not len(requirements) or requirements[-1].is_complete(): requirements.append(Requirement()) requirement = requirements[-1] # If we see a newline before any requirements, then this is a - # top of file comment. + # top of section comment. if len(requirements) == 1 and line.strip() == b'': if ( len(requirement.comments) and @@ -100,38 +101,48 @@ def fix_requirements(f: IO[bytes]) -> int: requirement.value = b'\n' else: requirement.comments.append(line) + elif line.lstrip() == b'': + sections.append([]) elif line.lstrip().startswith(b'#') or line.strip() == b'': requirement.comments.append(line) else: requirement.append_value(line) - # if a file ends in a comment, preserve it at the end - if requirements[-1].value is None: - rest = requirements.pop().comments - else: - rest = [] - - # find and remove pkg-resources==0.0.0 - # which is automatically added by broken pip package under Debian - requirements = [ - req for req in requirements - if req.value not in [ - b'pkg-resources==0.0.0\n', - b'pkg_resources==0.0.0\n', + after_string = b'' + + for requirements in sections: + after: list[bytes] = [] + + if len(after_string) > 0: + after_string += b'\n' + + # if a section ends in a comment, preserve it at the end + if requirements[-1].value is None: + rest = requirements.pop().comments + else: + rest = [] + + # find and remove pkg-resources==0.0.0 + # which is automatically added by broken pip package under Debian + requirements = [ + req for req in requirements + if req.value not in [ + b'pkg-resources==0.0.0\n', + b'pkg_resources==0.0.0\n', + ] ] - ] - - # sort the requirements and remove duplicates - prev = None - for requirement in sorted(requirements): - after.extend(requirement.comments) - assert requirement.value, requirement.value - if prev is None or requirement.value != prev.value: - after.append(requirement.value) - prev = requirement - after.extend(rest) - - after_string = b''.join(after) + + # sort the requirements and remove duplicates + prev = None + for requirement in sorted(requirements): + after.extend(requirement.comments) + assert requirement.value, requirement.value + if prev is None or requirement.value != prev.value: + after.append(requirement.value) + prev = requirement + after.extend(rest) + + after_string += b''.join(after) if before_string == after_string: return PASS diff --git a/tests/requirements_txt_fixer_test.py b/tests/requirements_txt_fixer_test.py index c0d2c65d..318f6241 100644 --- a/tests/requirements_txt_fixer_test.py +++ b/tests/requirements_txt_fixer_test.py @@ -107,6 +107,25 @@ PASS, b'a=2.0.0 \\\n --hash=sha256:abcd\nb==1.0.0\n', ), + ( + ( + b'-c external.txt\n' + b'\n' + b'Peach\n' + b'apple\n' + b'banana\n' + b'pear\n' + ), + FAIL, + ( + b'-c external.txt\n' + b'\n' + b'apple\n' + b'banana\n' + b'Peach\n' + b'pear\n' + ), + ), ), ) def test_integration(input_s, expected_retval, output, tmpdir):