Skip to content

Commit

Permalink
🏷️ Add type annotations (#21)
Browse files Browse the repository at this point in the history
* Simplify `CaseInsensitivePathCompleter` using `expanduser`

* Simplify `parse_extensions()` function

* Add type hints
  • Loading branch information
OLILHR authored Aug 22, 2024
1 parent d3ebb08 commit 0c2ac47
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 60 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ $ chonk

and follow the prompts by providing an input directory, an output file destination and optional filters.

Alternatively, the script can also be executed using a single command with the appropriate flags:
Alternatively, the script can be executed using a single command with the appropriate flags:

```sh
$ chonk -i <input_path> -o <output_path> -f <(optional) filters>
Expand Down
19 changes: 10 additions & 9 deletions chonk/filter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
from typing import Any, Callable, List, Optional


def skip_ignore_list_comments(file_path):
ignore_list = []
def skip_ignore_list_comments(file_path: str) -> List[str]:
ignore_list: List[str] = []
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
Expand All @@ -11,21 +12,21 @@ def skip_ignore_list_comments(file_path):
return ignore_list


def read_chonkignore(project_root, extension_filter):
def read_chonkignore(project_root: str, extension_filter: Optional[List[str]]) -> Callable[[str], bool]:
"""
Excludes all files, extensions and directories specified in .chonkignore, located inside the root directory.
"""
chonkignore = os.path.join(project_root, ".chonkignore")
default_ignore_list = DEFAULT_IGNORE_LIST.copy()

ignore_list = []
ignore_list: List[str] = []
if os.path.exists(chonkignore):
with open(chonkignore, "r", encoding="utf-8") as f:
ignore_list = [line.strip() for line in f if line.strip() and not line.startswith("#")]

default_ignore_list.extend(ignore_list)

def exclude_files(file_path):
def exclude_files(file_path: str) -> bool:
file_path = file_path.replace(os.sep, "/")

if extension_filter:
Expand Down Expand Up @@ -58,7 +59,7 @@ def exclude_files(file_path):
return exclude_files


def filter_extensions(file_path, extensions):
def filter_extensions(file_path: str, extensions: Optional[List[str]]) -> bool:
"""
Optional filter to include only certain provided extensions in the consolidated markdown file. If no extensions are
provided, all files are considered except files, extensions and directories that are explicitly excluded in the
Expand All @@ -70,17 +71,17 @@ def filter_extensions(file_path, extensions):
return file_extension[1:] in extensions


def parse_extensions(_csx, _param, value):
def parse_extensions(_csx: Any, _param: Any, value: Optional[List[str]]) -> Optional[List[str]]:
"""
Converts a comma-separated string of file extensions into a list of individual extensions, which - in turn - is
parsed to the main function to filter files during the consolidation process.
"""
if not value:
return None
return [ext.strip() for item in value for ext in item.split(",")]
return [ext.strip() for item in value for ext in item.split(",")] if value else None


DEFAULT_IGNORE_LIST = [
DEFAULT_IGNORE_LIST: List[str] = [
".cache/",
".coverage",
"dist/",
Expand Down
75 changes: 34 additions & 41 deletions chonk/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import logging
import os
from dataclasses import dataclass
from typing import Any, Iterable, List, Optional

import click
from prompt_toolkit import prompt
from prompt_toolkit.completion import Completer, Completion
from prompt_toolkit.document import Document

from .filter import parse_extensions
from .utilities import NoMatchingExtensionError, consolidate
Expand All @@ -15,17 +17,17 @@
_logger = logging.getLogger(__name__)
_logger.setLevel(GLOBAL_LOG_LEVEL)

MAX_FILE_SIZE = 1024 * 1024 * 10 # 10 MB
MAX_FILE_SIZE: int = 1024 * 1024 * 10 # 10 MB


def get_project_root():
def get_project_root() -> str:
"""
Required for input/output path prompts to display the project root as default path.
"""

current_dir = os.path.abspath(os.getcwd())
current_dir: str = os.path.abspath(os.getcwd())

root_indicators = [
root_indicators: List[str] = [
".git",
"package.json",
"pdm.lock",
Expand All @@ -47,34 +49,28 @@ class CaseInsensitivePathCompleter(Completer):
only_directories: bool = False
expanduser: bool = True

def get_completions(self, document, complete_event):
text = document.text_before_cursor
if len(text) == 0:
def get_completions(self, document: Document, complete_event: Any) -> Iterable[Completion]:
text: str = os.path.expanduser(document.text_before_cursor)
if not text:
return

directory = os.path.dirname(text)
prefix = os.path.basename(text)

if os.path.isabs(text):
full_directory = os.path.abspath(directory)
else:
full_directory = os.path.abspath(os.path.join(os.getcwd(), directory))
directory: str = os.path.dirname(text)
prefix: str = os.path.basename(text)
full_directory: str = os.path.abspath(directory)

try:
suggestions = os.listdir(full_directory)
suggestions: List[str] = os.listdir(full_directory)
except OSError:
return

for suggestion in suggestions:
if suggestion.lower().startswith(prefix.lower()):
if self.only_directories and not os.path.isdir(os.path.join(full_directory, suggestion)):
continue
completion = suggestion[len(prefix) :]
display = suggestion
yield Completion(completion, start_position=0, display=display)
yield Completion(suggestion[len(prefix) :], start_position=0, display=suggestion)


def path_prompt(message, default, exists=False):
def path_prompt(message: str, default: str, exists: bool = False) -> str:
"""
Enables basic shell features, like relative path suggestion and autocompletion, for CLI prompts.
"""
Expand All @@ -84,8 +80,8 @@ def path_prompt(message, default, exists=False):
default += os.path.sep

while True:
path = prompt(f"{message} ", default=default, completer=path_completer)
full_path = os.path.abspath(os.path.expanduser(path))
path: str = prompt(f"{message} ", default=default, completer=path_completer)
full_path: str = os.path.abspath(os.path.expanduser(path))
if not exists or os.path.exists(full_path):
return full_path
print(f"🔴 {full_path} DOES NOT EXIST.")
Expand All @@ -100,26 +96,23 @@ def path_prompt(message, default, exists=False):
"extension_filter",
callback=parse_extensions,
multiple=True,
help="enables optional filtering by extensions, for instance: -f py,json", # markdown contains only .py/.json files
help="enables optional filtering by extensions, for instance: -f py,json",
)
# pylint: disable=too-many-locals
def generate_markdown(input_path, output_path, extension_filter):
no_flags_provided = input_path is None and output_path is None and not extension_filter
project_root = get_project_root()

if input_path is None:
input_path = path_prompt("📁 INPUT PATH OF YOUR TARGET DIRECTORY -", default=project_root, exists=True)
else:
input_path = os.path.abspath(input_path)

if output_path is None:
output_path = path_prompt("📁 OUTPUT PATH FOR THE MARKDOWN FILE -", default=project_root)
else:
output_path = os.path.abspath(output_path)
def generate_markdown(
input_path: Optional[str], output_path: Optional[str], extension_filter: Optional[List[str]]
) -> None:
no_flags_provided: bool = input_path is None and output_path is None and not extension_filter
project_root: str = get_project_root()

input_path = input_path or path_prompt(
"📁 INPUT PATH OF YOUR TARGET DIRECTORY -", default=project_root, exists=True
)
output_path = output_path or path_prompt("📁 OUTPUT PATH FOR THE MARKDOWN FILE -", default=project_root)

extensions = extension_filter
extensions: Optional[List[str]] = extension_filter
if no_flags_provided:
extensions_input = click.prompt(
extensions_input: str = click.prompt(
"🔎 OPTIONAL FILTER FOR SPECIFIC EXTENSIONS (COMMA-SEPARATED)",
default="",
show_default=False,
Expand All @@ -142,21 +135,21 @@ def generate_markdown(input_path, output_path, extension_filter):
_logger.error("\n" + "🔴 GENERATED CONTENT EXCEEDS 10 MB. CONSIDER ADDING LARGER FILES TO YOUR .chonkignore.")
return

chonk = os.path.join(output_path, "chonk.md")
chonk: str = os.path.join(output_path, "chonk.md")

os.makedirs(output_path, exist_ok=True)
with open(chonk, "w", encoding="utf-8") as f:
f.write(markdown_content)

chonk_size = os.path.getsize(chonk)
chonk_size: int = os.path.getsize(chonk)
if chonk_size < 1024:
file_size = f"{chonk_size} bytes"
file_size: str = f"{chonk_size} bytes"
elif chonk_size < 1024 * 1024:
file_size = f"{chonk_size / 1024:.2f} KB"
else:
file_size = f"{chonk_size / (1024 * 1024):.2f} MB"

file_type_distribution = " ".join(
file_type_distribution: str = " ".join(
f".{file_type} ({percentage:.0f}%)" for file_type, percentage in type_distribution
)

Expand Down
17 changes: 10 additions & 7 deletions chonk/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import re
from collections import Counter
from dataclasses import dataclass
from typing import List, Optional, Tuple

import tiktoken
from tqdm import tqdm
Expand All @@ -12,13 +13,13 @@
_logger = logging.getLogger(__name__)


def remove_trailing_whitespace(content):
def remove_trailing_whitespace(content: str) -> str:
content = re.sub(r"\n{3,}", "\n\n", content)
content = re.sub(r" +$", "", content, flags=re.MULTILINE)
return content


def escape_markdown_characters(file_name):
def escape_markdown_characters(file_name: str) -> str:
"""
Escapes special characters in file names such as "__init__.py"
in order to display paths correctly inside the output markdown file.
Expand All @@ -27,7 +28,7 @@ def escape_markdown_characters(file_name):
return re.sub(special_chars, r"\\\1", file_name)


def count_lines_of_code(content):
def count_lines_of_code(content: str) -> int:
"""
Counts the lines of code within each code blocks in the output markdown file.
"""
Expand All @@ -36,7 +37,7 @@ def count_lines_of_code(content):
return lines_of_code


def get_file_type_distribution(markdown_content):
def get_file_type_distribution(markdown_content: str) -> List[Tuple[str, float]]:
"""
Returns a distribution of the four most common file types in the output markdown file.
"""
Expand All @@ -56,7 +57,7 @@ def get_file_type_distribution(markdown_content):
return type_distribution


def count_tokens(text):
def count_tokens(text: str) -> int:
"""
Encoding for GPT-3.5/GPT-4.0.
"""
Expand All @@ -74,7 +75,9 @@ class NoMatchingExtensionError(Exception):


# pylint: disable=too-many-locals
def consolidate(path, extensions=None):
def consolidate(
path: str, extensions: Optional[List[str]] = None
) -> Tuple[str, int, int, int, List[Tuple[str, float]]]:
"""
Gathers and formats the content and metadata of all files inside a provided input directory,
while taking into account optional extension filters as well as .chonkignore specific exceptions.
Expand All @@ -86,7 +89,7 @@ def consolidate(path, extensions=None):
token_count = 0
lines_of_code_count = 0

matching_filter_extensions = []
matching_filter_extensions: List[str] = []
for root, dirs, files in os.walk(path):
dirs[:] = [d for d in dirs if not exclude_files(os.path.relpath(str(os.path.join(root, d)), path))]
for file in files:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_consolidate_only_specified_filters(
)


def test_extension_filter_bypasses_chonkignore(
def test_filter_bypasses_chonkignore(
project_root, mock_project, mock_operations, mock_chonkignore
): # pylint: disable=unused-argument
filtered_chonk, *_ = consolidate(project_root, extensions=["svg"])
Expand Down Expand Up @@ -60,6 +60,6 @@ def test_filter_extensions_edge_cases():


def test_parse_extensions_edge_cases():
assert parse_extensions(None, None, "") is None
assert parse_extensions(None, None, []) is None
assert parse_extensions(None, None, ["py, js, css"]) == ["py", "js", "css"]
assert parse_extensions(None, None, ["py", "js", "css"]) == ["py", "js", "css"]

0 comments on commit 0c2ac47

Please sign in to comment.