diff --git a/reorder_python_imports.py b/reorder_python_imports.py index 8c5cf07..0a35a0a 100644 --- a/reorder_python_imports.py +++ b/reorder_python_imports.py @@ -153,17 +153,9 @@ def _inner() -> Generator[CodePartition, None, None]: def add_imports( - partitions: Iterable[CodePartition], + partitions: list[CodePartition], to_add: tuple[str, ...] = (), ) -> list[CodePartition]: - partitions = list(partitions) - if not _partitions_to_src(partitions).strip(): - return partitions - - # If we don't have a trailing newline, this refactor is wrong - if not partitions[-1].src.endswith('\n'): - partitions[-1] = partitions[-1]._replace(src=partitions[-1].src + '\n') - return partitions + [ CodePartition(CodeType.IMPORT, imp_statement.strip() + '\n') for imp_statement in to_add @@ -400,7 +392,11 @@ def fix_file_contents( ) -> str: # internally use `'\n` as the newline and normalize at the very end nl = _most_common_line_ending(contents) - contents = contents.replace('\r\n', '\n').replace('\r', '\n') + contents = contents.replace('\r\n', '\n').replace('\r', '\n').rstrip() + if contents: + contents += '\n' + else: + return '' partitioned = partition_source(contents) partitioned = combine_trailing_code_chunks(partitioned) diff --git a/tests/reorder_python_imports_test.py b/tests/reorder_python_imports_test.py index 5b290c0..aec6ca1 100644 --- a/tests/reorder_python_imports_test.py +++ b/tests/reorder_python_imports_test.py @@ -346,6 +346,15 @@ def test_add_import_trivial(): ) == '' +def test_add_imports_empty_file(): + assert fix_file_contents( + '\n\n', + to_add=('from __future__ import absolute_import',), + to_remove=set(), + to_replace=Replacements.make([]), + ) == '' + + def test_add_import_import_already_there(): assert fix_file_contents( 'from __future__ import absolute_import\n',