Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include module-level python statements chunk #4

Merged
merged 1 commit into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,7 @@ MCPunk is at a minimum usable state right now.
chunk types
- When listing chunks in a file, include chunk type, chunk id, numer of
characters in chunk content
- Add a `__main__` chunk for Python, plus a chunk for "any not yet accounted for
module-level statements"
- Include module-level comments when extracting python module-level statements.
- Caching of a project, so it doesn't need to re-parse all files every time you
restart MCP client
- Handle changed files sensibly, so you don't need to restart MCP client
Expand Down
3 changes: 2 additions & 1 deletion mcpunk/file_chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@ class ChunkCategory(enum.StrEnum):
callable = "callable"
markdown_section = "markdown section"
imports = "imports"
module_level = "module_level"


# Seems if you annotate a FastMCP tool function with an enum it totally
# crashes claude desktop. So define an equivalent Literal type here.
ChunkCategoryLiteral = Literal["callable", "markdown section", "imports"]
ChunkCategoryLiteral = Literal["callable", "markdown section", "imports", "module_level"]
assert set(get_args(ChunkCategoryLiteral)) == set(ChunkCategory.__members__.values())


Expand Down
9 changes: 8 additions & 1 deletion mcpunk/file_chunkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pathlib import Path

from mcpunk.file_chunk import Chunk, ChunkCategory
from mcpunk.python_file_analysis import Callable, extract_imports
from mcpunk.python_file_analysis import Callable, extract_imports, extract_module_statements


class BaseChunker:
Expand Down Expand Up @@ -40,8 +40,15 @@ def can_chunk(source_code: str, file_path: Path) -> bool: # noqa: ARG004
def chunk_file(self) -> list[Chunk]:
callables = Callable.from_source_code(self.source_code)
imports = "\n".join(extract_imports(self.source_code))
module_level_statements = "\n".join(extract_module_statements(self.source_code))
chunks: list[Chunk] = [
Chunk(category=ChunkCategory.imports, name="<imports>", line=None, content=imports),
Chunk(
category=ChunkCategory.module_level,
name="<module_level_statements>",
line=None,
content=module_level_statements,
),
]
chunks.extend(
Chunk(
Expand Down
29 changes: 29 additions & 0 deletions mcpunk/python_file_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,32 @@ def extract_imports(source_code: str) -> list[str]:
imports.append(import_code)

return imports


def extract_module_statements(source_code: str) -> list[str]:
"""Extract all module-level statements from source code.

function/class definitions are inserted like `def func1...` or `class MyClass...`
to provide context around where the module-level statements are defined.

Takes source code as input and returns a list of statement strings.
"""
# TODO: include comments

atok = asttokens.ASTTokens(source_code, parse=True)
statements: list[str] = []

for node in atok.tree.body: # type: ignore[union-attr]
if isinstance(node, ast.FunctionDef):
statements.append(f"def {node.name}...")
elif isinstance(node, ast.AsyncFunctionDef):
statements.append(f"async def {node.name}...")
elif isinstance(node, ast.ClassDef):
statements.append(f"class {node.name}...")
elif isinstance(node, ast.Import | ast.ImportFrom):
pass
else:
statement_code = atok.get_text(node)
statements.append(statement_code)

return statements
50 changes: 49 additions & 1 deletion tests/test_python_file_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import deepdiff
import pytest

from mcpunk.python_file_analysis import Callable, extract_imports
from mcpunk.python_file_analysis import Callable, extract_imports, extract_module_statements

# This makes pytest print nice diffs when asserts fail
# within this function.
Expand Down Expand Up @@ -769,3 +769,51 @@ def method(self):

result = extract_imports(source)
assert result == ["import os", "import json", "from typing import List"]


def test_extract_module_statements() -> None:
source = """\
import os
x = 1

def func1():
y = 2
return y

CONSTANT = "test"
import json

class MyClass:
z = 3

def method(self):
pass

final_var = True

a = (
1,
# Internal comment
2,
)

if __name__ == "__main__":
print("hey!")

# A comment!

async def func2():
pass
"""

result = extract_module_statements(source)
assert result == [
"x = 1",
"def func1...",
'CONSTANT = "test"',
"class MyClass...",
"final_var = True",
"a = (\n 1,\n # Internal comment\n 2,\n)",
'if __name__ == "__main__":\n print("hey!")',
"async def func2...",
]
Loading