Skip to content

Commit

Permalink
Add Mypy type checking (#13)
Browse files Browse the repository at this point in the history
* Add Mypy type checking

Fixes #12

Add Mypy type checking to the project.

* Add `mypy` as a development dependency in `pyproject.toml`.
* Add `[tool.mypy]` section in `pyproject.toml` to configure mypy to check all code files.
* Add type hints to all function and method signatures in `intentguard/cache.py`, `intentguard/intentguard_options.py`, and `intentguard/intentguard.py`.
* Update `.github/workflows/main.yml` to include a step to run `mypy` checks after the `ruff` checks.
* Update `generate_cache_key` in `intentguard/cache.py` to accept `IntentGuardOptions` object.
* Update `_format_code_objects` in `intentguard/intentguard.py` to handle specific types expected by `inspect.getsource`.
  • Loading branch information
kdunee authored Nov 11, 2024
1 parent 2423a8e commit ac25db9
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 30 deletions.
9 changes: 9 additions & 0 deletions .cdigestignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
.ruff_cache
dist
.venv
design
assets
poetry.lock
.cdigestignore
LICENSE
._codebase_digest.txt
3 changes: 3 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ jobs:
- name: Run ruff format check
run: |
poetry run ruff format --check .
- name: Run mypy check
run: |
poetry run mypy .
- name: Run tests
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
Expand Down
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,8 @@
/.venv
__pycache__
.intentguard
.mypy_cache
/design
._codebase_digest.txt
/.idea
*.iml
16 changes: 10 additions & 6 deletions intentguard/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
import os
import hashlib
import json
from typing import Optional, Dict, Any
from intentguard.intentguard_options import IntentGuardOptions

# Directory where cache files will be stored
CACHE_DIR = ".intentguard"


def ensure_cache_dir_exists():
def ensure_cache_dir_exists() -> None:
"""
Creates the cache directory if it doesn't exist.
This is called before any cache operations to ensure the cache directory is available.
Expand All @@ -20,7 +22,9 @@ def ensure_cache_dir_exists():
os.makedirs(CACHE_DIR)


def generate_cache_key(expectation: str, objects_text: str, options) -> str:
def generate_cache_key(
expectation: str, objects_text: str, options: IntentGuardOptions
) -> str:
"""
Generates a unique cache key based on the input parameters and model configuration.
Expand All @@ -42,7 +46,7 @@ def generate_cache_key(expectation: str, objects_text: str, options) -> str:
return hashlib.sha256(key_string.encode()).hexdigest()


def read_cache(cache_key: str):
def read_cache(cache_key: str) -> Optional[Dict[str, Any]]:
"""
Retrieves cached results for a given cache key.
Expand All @@ -60,7 +64,7 @@ def read_cache(cache_key: str):
return None


def write_cache(cache_key: str, result):
def write_cache(cache_key: str, result: Dict[str, Any]) -> None:
"""
Stores a result in the cache using the provided cache key.
Expand Down Expand Up @@ -93,7 +97,7 @@ def __init__(self, result: bool, explanation: str = ""):
self.result = result
self.explanation = explanation

def to_dict(self):
def to_dict(self) -> Dict[str, Any]:
"""
Converts the CachedResult instance to a dictionary for JSON serialization.
Expand All @@ -103,7 +107,7 @@ def to_dict(self):
return {"result": self.result, "explanation": self.explanation}

@classmethod
def from_dict(cls, data: dict):
def from_dict(cls, data: Dict[str, Any]) -> "CachedResult":
"""
Creates a CachedResult instance from a dictionary.
Expand Down
43 changes: 24 additions & 19 deletions intentguard/intentguard.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Dict, List, Any, Optional
from typing import Dict, List, Any, Optional, cast
import inspect
import json
from collections import Counter
from dataclasses import dataclass

from litellm import completion
from litellm.main import ModelResponse
from litellm.types.utils import Choices

from intentguard.intentguard_options import IntentGuardOptions
from intentguard.prompts import system_prompt, reponse_schema, explanation_prompt
Expand All @@ -31,14 +32,14 @@ class IntentGuard:
customizable options for the assertion process.
"""

def __init__(self, options: Optional[IntentGuardOptions] = None):
def __init__(self, options: Optional[IntentGuardOptions] = None) -> None:
"""
Initialize the IntentGuard instance.
Args:
options: Configuration options for assertions. Uses default options if None.
"""
self.options = options or IntentGuardOptions()
self.options: IntentGuardOptions = options or IntentGuardOptions()

def assert_code(
self,
Expand All @@ -63,12 +64,12 @@ def assert_code(
options = options or self.options

# Prepare evaluation context
objects_text = self._format_code_objects(params)
prompt = self._create_evaluation_prompt(objects_text, expectation)
cache_key = generate_cache_key(expectation, objects_text, options)
objects_text: str = self._format_code_objects(params)
prompt: str = self._create_evaluation_prompt(objects_text, expectation)
cache_key: str = generate_cache_key(expectation, objects_text, options)

# Check cache or perform evaluation
final_result = self._get_cached_or_evaluate(
final_result: CachedResult = self._get_cached_or_evaluate(
cache_key=cache_key, prompt=prompt, options=options
)

Expand All @@ -79,7 +80,7 @@ def assert_code(
f"Explanation: {final_result.explanation}"
)

def _format_code_objects(self, params: Dict[str, object]) -> str:
def _format_code_objects(self, params: Dict[str, Any]) -> str:
"""
Format code objects for LLM evaluation.
Expand All @@ -91,9 +92,9 @@ def _format_code_objects(self, params: Dict[str, object]) -> str:
Returns:
Formatted string containing source code of all objects
"""
formatted_objects = []
formatted_objects: List[str] = []
for name, obj in params.items():
source = inspect.getsource(obj)
source: str = inspect.getsource(obj)
formatted_objects.append(
f"""{{{name}}}:
```py
Expand Down Expand Up @@ -138,12 +139,14 @@ def _get_cached_or_evaluate(
return CachedResult.from_dict(cached_result)

# Perform multiple evaluations for consensus
results = [
results: List[bool] = [
self._perform_single_evaluation(prompt, options)
for _ in range(options.num_evaluations)
]

final_result = CachedResult(result=self._determine_consensus(results))
final_result: CachedResult = CachedResult(
result=self._determine_consensus(results)
)

# Generate explanation for failed assertions
if not final_result.result:
Expand All @@ -167,7 +170,7 @@ def _perform_single_evaluation(
Returns:
Boolean result of evaluation
"""
request = LLMRequest(
request: LLMRequest = LLMRequest(
messages=[
{"content": system_prompt, "role": "system"},
{"content": prompt, "role": "user"},
Expand All @@ -179,8 +182,10 @@ def _perform_single_evaluation(
},
)

response = self._send_llm_request(request)
return json.loads(response.choices[0].message.content)["result"]
response: ModelResponse = self._send_llm_request(request)
return json.loads(
cast(str, cast(Choices, response.choices[0]).message.content)
)["result"]

def _generate_failure_explanation(
self, prompt: str, options: IntentGuardOptions
Expand All @@ -195,16 +200,16 @@ def _generate_failure_explanation(
Returns:
Detailed explanation of why assertion failed
"""
request = LLMRequest(
request: LLMRequest = LLMRequest(
messages=[
{"content": explanation_prompt, "role": "system"},
{"content": prompt, "role": "user"},
],
model=options.model,
)

response = self._send_llm_request(request)
return response.choices[0].message.content
response: ModelResponse = self._send_llm_request(request)
return cast(str, cast(Choices, response.choices[0]).message.content)

def _send_llm_request(self, request: LLMRequest) -> ModelResponse:
"""
Expand Down Expand Up @@ -233,5 +238,5 @@ def _determine_consensus(self, results: List[bool]) -> bool:
Returns:
Consensus result
"""
vote_count = Counter(results)
vote_count: Counter = Counter(results)
return vote_count[True] > vote_count[False]
6 changes: 3 additions & 3 deletions intentguard/intentguard_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def __init__(
self,
model: str = "gpt-4o-mini-2024-07-18",
num_evaluations: int = 1,
):
) -> None:
"""
Initialize IntentGuardOptions with the specified parameters.
Expand All @@ -21,5 +21,5 @@ def __init__(
for each assertion. The final result is determined by majority vote.
Defaults to 1.
"""
self.model = model
self.num_evaluations = num_evaluations
self.model: str = model
self.num_evaluations: int = num_evaluations
79 changes: 77 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ litellm = "^1.49.2"

[tool.poetry.group.dev.dependencies]
ruff = "^0.6.9"
mypy = "^1.13.0"

[tool.mypy]
files = "intentguard/**/*.py"
ignore_missing_imports = true

[build-system]
requires = ["poetry-core"]
Expand Down

0 comments on commit ac25db9

Please sign in to comment.