From ac25db99be24c40b713597b5e88f4f164013dc63 Mon Sep 17 00:00:00 2001
From: Kosma Dunikowski <kosmadunikowski@gmail.com>
Date: Mon, 11 Nov 2024 14:04:49 +0100
Subject: [PATCH] Add Mypy type checking (#13)

* Add Mypy type checking

Fixes #12

Add Mypy type checking to the project.

* Add `mypy` as a development dependency in `pyproject.toml`.
* Add `[tool.mypy]` section in `pyproject.toml` to configure mypy to check all code files.
* Add type hints to all function and method signatures in `intentguard/cache.py`, `intentguard/intentguard_options.py`, and `intentguard/intentguard.py`.
* Update `.github/workflows/main.yml` to include a step to run `mypy` checks after the `ruff` checks.
* Update `generate_cache_key` in `intentguard/cache.py` to accept `IntentGuardOptions` object.
* Update `_format_code_objects` in `intentguard/intentguard.py` to handle specific types expected by `inspect.getsource`.
---
 .cdigestignore                     |  9 ++++
 .github/workflows/main.yml         |  3 ++
 .gitignore                         |  5 ++
 intentguard/cache.py               | 16 +++---
 intentguard/intentguard.py         | 43 +++++++++-------
 intentguard/intentguard_options.py |  6 +--
 poetry.lock                        | 79 +++++++++++++++++++++++++++++-
 pyproject.toml                     |  5 ++
 8 files changed, 136 insertions(+), 30 deletions(-)
 create mode 100644 .cdigestignore

diff --git a/.cdigestignore b/.cdigestignore
new file mode 100644
index 0000000..e71454d
--- /dev/null
+++ b/.cdigestignore
@@ -0,0 +1,9 @@
+.ruff_cache
+dist
+.venv
+design
+assets
+poetry.lock
+.cdigestignore
+LICENSE
+._codebase_digest.txt
\ No newline at end of file
diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index 624b54c..43cc607 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -40,6 +40,9 @@ jobs:
       - name: Run ruff format check
         run: |
           poetry run ruff format --check .
+      - name: Run mypy check
+        run: |
+          poetry run mypy .
       - name: Run tests
         env:
           OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
diff --git a/.gitignore b/.gitignore
index f7bdbab..f8fe946 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,3 +2,8 @@
 /.venv
 __pycache__
 .intentguard
+.mypy_cache
+/design
+._codebase_digest.txt
+/.idea
+*.iml
\ No newline at end of file
diff --git a/intentguard/cache.py b/intentguard/cache.py
index 3b32480..0a1850c 100644
--- a/intentguard/cache.py
+++ b/intentguard/cache.py
@@ -6,12 +6,14 @@
 import os
 import hashlib
 import json
+from typing import Optional, Dict, Any
+from intentguard.intentguard_options import IntentGuardOptions
 
 # Directory where cache files will be stored
 CACHE_DIR = ".intentguard"
 
 
-def ensure_cache_dir_exists():
+def ensure_cache_dir_exists() -> None:
     """
     Creates the cache directory if it doesn't exist.
     This is called before any cache operations to ensure the cache directory is available.
@@ -20,7 +22,9 @@ def ensure_cache_dir_exists():
         os.makedirs(CACHE_DIR)
 
 
-def generate_cache_key(expectation: str, objects_text: str, options) -> str:
+def generate_cache_key(
+    expectation: str, objects_text: str, options: IntentGuardOptions
+) -> str:
     """
     Generates a unique cache key based on the input parameters and model configuration.
 
@@ -42,7 +46,7 @@ def generate_cache_key(expectation: str, objects_text: str, options) -> str:
     return hashlib.sha256(key_string.encode()).hexdigest()
 
 
-def read_cache(cache_key: str):
+def read_cache(cache_key: str) -> Optional[Dict[str, Any]]:
     """
     Retrieves cached results for a given cache key.
 
@@ -60,7 +64,7 @@ def read_cache(cache_key: str):
     return None
 
 
-def write_cache(cache_key: str, result):
+def write_cache(cache_key: str, result: Dict[str, Any]) -> None:
     """
     Stores a result in the cache using the provided cache key.
 
@@ -93,7 +97,7 @@ def __init__(self, result: bool, explanation: str = ""):
         self.result = result
         self.explanation = explanation
 
-    def to_dict(self):
+    def to_dict(self) -> Dict[str, Any]:
         """
         Converts the CachedResult instance to a dictionary for JSON serialization.
 
@@ -103,7 +107,7 @@ def to_dict(self):
         return {"result": self.result, "explanation": self.explanation}
 
     @classmethod
-    def from_dict(cls, data: dict):
+    def from_dict(cls, data: Dict[str, Any]) -> "CachedResult":
         """
         Creates a CachedResult instance from a dictionary.
 
diff --git a/intentguard/intentguard.py b/intentguard/intentguard.py
index a99f780..33219bd 100644
--- a/intentguard/intentguard.py
+++ b/intentguard/intentguard.py
@@ -1,4 +1,4 @@
-from typing import Dict, List, Any, Optional
+from typing import Dict, List, Any, Optional, cast
 import inspect
 import json
 from collections import Counter
@@ -6,6 +6,7 @@
 
 from litellm import completion
 from litellm.main import ModelResponse
+from litellm.types.utils import Choices
 
 from intentguard.intentguard_options import IntentGuardOptions
 from intentguard.prompts import system_prompt, reponse_schema, explanation_prompt
@@ -31,14 +32,14 @@ class IntentGuard:
     customizable options for the assertion process.
     """
 
-    def __init__(self, options: Optional[IntentGuardOptions] = None):
+    def __init__(self, options: Optional[IntentGuardOptions] = None) -> None:
         """
         Initialize the IntentGuard instance.
 
         Args:
             options: Configuration options for assertions. Uses default options if None.
         """
-        self.options = options or IntentGuardOptions()
+        self.options: IntentGuardOptions = options or IntentGuardOptions()
 
     def assert_code(
         self,
@@ -63,12 +64,12 @@ def assert_code(
         options = options or self.options
 
         # Prepare evaluation context
-        objects_text = self._format_code_objects(params)
-        prompt = self._create_evaluation_prompt(objects_text, expectation)
-        cache_key = generate_cache_key(expectation, objects_text, options)
+        objects_text: str = self._format_code_objects(params)
+        prompt: str = self._create_evaluation_prompt(objects_text, expectation)
+        cache_key: str = generate_cache_key(expectation, objects_text, options)
 
         # Check cache or perform evaluation
-        final_result = self._get_cached_or_evaluate(
+        final_result: CachedResult = self._get_cached_or_evaluate(
             cache_key=cache_key, prompt=prompt, options=options
         )
 
@@ -79,7 +80,7 @@ def assert_code(
                 f"Explanation: {final_result.explanation}"
             )
 
-    def _format_code_objects(self, params: Dict[str, object]) -> str:
+    def _format_code_objects(self, params: Dict[str, Any]) -> str:
         """
         Format code objects for LLM evaluation.
 
@@ -91,9 +92,9 @@ def _format_code_objects(self, params: Dict[str, object]) -> str:
         Returns:
             Formatted string containing source code of all objects
         """
-        formatted_objects = []
+        formatted_objects: List[str] = []
         for name, obj in params.items():
-            source = inspect.getsource(obj)
+            source: str = inspect.getsource(obj)
             formatted_objects.append(
                 f"""{{{name}}}:
 ```py
@@ -138,12 +139,14 @@ def _get_cached_or_evaluate(
             return CachedResult.from_dict(cached_result)
 
         # Perform multiple evaluations for consensus
-        results = [
+        results: List[bool] = [
             self._perform_single_evaluation(prompt, options)
             for _ in range(options.num_evaluations)
         ]
 
-        final_result = CachedResult(result=self._determine_consensus(results))
+        final_result: CachedResult = CachedResult(
+            result=self._determine_consensus(results)
+        )
 
         # Generate explanation for failed assertions
         if not final_result.result:
@@ -167,7 +170,7 @@ def _perform_single_evaluation(
         Returns:
             Boolean result of evaluation
         """
-        request = LLMRequest(
+        request: LLMRequest = LLMRequest(
             messages=[
                 {"content": system_prompt, "role": "system"},
                 {"content": prompt, "role": "user"},
@@ -179,8 +182,10 @@ def _perform_single_evaluation(
             },
         )
 
-        response = self._send_llm_request(request)
-        return json.loads(response.choices[0].message.content)["result"]
+        response: ModelResponse = self._send_llm_request(request)
+        return json.loads(
+            cast(str, cast(Choices, response.choices[0]).message.content)
+        )["result"]
 
     def _generate_failure_explanation(
         self, prompt: str, options: IntentGuardOptions
@@ -195,7 +200,7 @@ def _generate_failure_explanation(
         Returns:
             Detailed explanation of why assertion failed
         """
-        request = LLMRequest(
+        request: LLMRequest = LLMRequest(
             messages=[
                 {"content": explanation_prompt, "role": "system"},
                 {"content": prompt, "role": "user"},
@@ -203,8 +208,8 @@ def _generate_failure_explanation(
             model=options.model,
         )
 
-        response = self._send_llm_request(request)
-        return response.choices[0].message.content
+        response: ModelResponse = self._send_llm_request(request)
+        return cast(str, cast(Choices, response.choices[0]).message.content)
 
     def _send_llm_request(self, request: LLMRequest) -> ModelResponse:
         """
@@ -233,5 +238,5 @@ def _determine_consensus(self, results: List[bool]) -> bool:
         Returns:
             Consensus result
         """
-        vote_count = Counter(results)
+        vote_count: Counter = Counter(results)
         return vote_count[True] > vote_count[False]
diff --git a/intentguard/intentguard_options.py b/intentguard/intentguard_options.py
index 761bc0e..bbf648b 100644
--- a/intentguard/intentguard_options.py
+++ b/intentguard/intentguard_options.py
@@ -10,7 +10,7 @@ def __init__(
         self,
         model: str = "gpt-4o-mini-2024-07-18",
         num_evaluations: int = 1,
-    ):
+    ) -> None:
         """
         Initialize IntentGuardOptions with the specified parameters.
 
@@ -21,5 +21,5 @@ def __init__(
                 for each assertion. The final result is determined by majority vote.
                 Defaults to 1.
         """
-        self.model = model
-        self.num_evaluations = num_evaluations
+        self.model: str = model
+        self.num_evaluations: int = num_evaluations
diff --git a/poetry.lock b/poetry.lock
index a2fb99a..4045cad 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1,4 +1,4 @@
-# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
+# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand.
 
 [[package]]
 name = "aiohappyeyeballs"
@@ -980,6 +980,70 @@ files = [
 [package.dependencies]
 typing-extensions = {version = ">=4.1.0", markers = "python_version < \"3.11\""}
 
+[[package]]
+name = "mypy"
+version = "1.13.0"
+description = "Optional static typing for Python"
+optional = false
+python-versions = ">=3.8"
+files = [
+    {file = "mypy-1.13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6607e0f1dd1fb7f0aca14d936d13fd19eba5e17e1cd2a14f808fa5f8f6d8f60a"},
+    {file = "mypy-1.13.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8a21be69bd26fa81b1f80a61ee7ab05b076c674d9b18fb56239d72e21d9f4c80"},
+    {file = "mypy-1.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7b2353a44d2179846a096e25691d54d59904559f4232519d420d64da6828a3a7"},
+    {file = "mypy-1.13.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0730d1c6a2739d4511dc4253f8274cdd140c55c32dfb0a4cf8b7a43f40abfa6f"},
+    {file = "mypy-1.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:c5fc54dbb712ff5e5a0fca797e6e0aa25726c7e72c6a5850cfd2adbc1eb0a372"},
+    {file = "mypy-1.13.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:581665e6f3a8a9078f28d5502f4c334c0c8d802ef55ea0e7276a6e409bc0d82d"},
+    {file = "mypy-1.13.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3ddb5b9bf82e05cc9a627e84707b528e5c7caaa1c55c69e175abb15a761cec2d"},
+    {file = "mypy-1.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:20c7ee0bc0d5a9595c46f38beb04201f2620065a93755704e141fcac9f59db2b"},
+    {file = "mypy-1.13.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:3790ded76f0b34bc9c8ba4def8f919dd6a46db0f5a6610fb994fe8efdd447f73"},
+    {file = "mypy-1.13.0-cp311-cp311-win_amd64.whl", hash = "sha256:51f869f4b6b538229c1d1bcc1dd7d119817206e2bc54e8e374b3dfa202defcca"},
+    {file = "mypy-1.13.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5c7051a3461ae84dfb5dd15eff5094640c61c5f22257c8b766794e6dd85e72d5"},
+    {file = "mypy-1.13.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:39bb21c69a5d6342f4ce526e4584bc5c197fd20a60d14a8624d8743fffb9472e"},
+    {file = "mypy-1.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:164f28cb9d6367439031f4c81e84d3ccaa1e19232d9d05d37cb0bd880d3f93c2"},
+    {file = "mypy-1.13.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a4c1bfcdbce96ff5d96fc9b08e3831acb30dc44ab02671eca5953eadad07d6d0"},
+    {file = "mypy-1.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:a0affb3a79a256b4183ba09811e3577c5163ed06685e4d4b46429a271ba174d2"},
+    {file = "mypy-1.13.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a7b44178c9760ce1a43f544e595d35ed61ac2c3de306599fa59b38a6048e1aa7"},
+    {file = "mypy-1.13.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5d5092efb8516d08440e36626f0153b5006d4088c1d663d88bf79625af3d1d62"},
+    {file = "mypy-1.13.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:de2904956dac40ced10931ac967ae63c5089bd498542194b436eb097a9f77bc8"},
+    {file = "mypy-1.13.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:7bfd8836970d33c2105562650656b6846149374dc8ed77d98424b40b09340ba7"},
+    {file = "mypy-1.13.0-cp313-cp313-win_amd64.whl", hash = "sha256:9f73dba9ec77acb86457a8fc04b5239822df0c14a082564737833d2963677dbc"},
+    {file = "mypy-1.13.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:100fac22ce82925f676a734af0db922ecfea991e1d7ec0ceb1e115ebe501301a"},
+    {file = "mypy-1.13.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7bcb0bb7f42a978bb323a7c88f1081d1b5dee77ca86f4100735a6f541299d8fb"},
+    {file = "mypy-1.13.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bde31fc887c213e223bbfc34328070996061b0833b0a4cfec53745ed61f3519b"},
+    {file = "mypy-1.13.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:07de989f89786f62b937851295ed62e51774722e5444a27cecca993fc3f9cd74"},
+    {file = "mypy-1.13.0-cp38-cp38-win_amd64.whl", hash = "sha256:4bde84334fbe19bad704b3f5b78c4abd35ff1026f8ba72b29de70dda0916beb6"},
+    {file = "mypy-1.13.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0246bcb1b5de7f08f2826451abd947bf656945209b140d16ed317f65a17dc7dc"},
+    {file = "mypy-1.13.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7f5b7deae912cf8b77e990b9280f170381fdfbddf61b4ef80927edd813163732"},
+    {file = "mypy-1.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7029881ec6ffb8bc233a4fa364736789582c738217b133f1b55967115288a2bc"},
+    {file = "mypy-1.13.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3e38b980e5681f28f033f3be86b099a247b13c491f14bb8b1e1e134d23bb599d"},
+    {file = "mypy-1.13.0-cp39-cp39-win_amd64.whl", hash = "sha256:a6789be98a2017c912ae6ccb77ea553bbaf13d27605d2ca20a76dfbced631b24"},
+    {file = "mypy-1.13.0-py3-none-any.whl", hash = "sha256:9c250883f9fd81d212e0952c92dbfcc96fc237f4b7c92f56ac81fd48460b3e5a"},
+    {file = "mypy-1.13.0.tar.gz", hash = "sha256:0291a61b6fbf3e6673e3405cfcc0e7650bebc7939659fdca2702958038bd835e"},
+]
+
+[package.dependencies]
+mypy-extensions = ">=1.0.0"
+tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
+typing-extensions = ">=4.6.0"
+
+[package.extras]
+dmypy = ["psutil (>=4.0)"]
+faster-cache = ["orjson"]
+install-types = ["pip"]
+mypyc = ["setuptools (>=50)"]
+reports = ["lxml"]
+
+[[package]]
+name = "mypy-extensions"
+version = "1.0.0"
+description = "Type system extensions for programs checked with the mypy type checker."
+optional = false
+python-versions = ">=3.5"
+files = [
+    {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"},
+    {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"},
+]
+
 [[package]]
 name = "openai"
 version = "1.51.2"
@@ -1775,6 +1839,17 @@ dev = ["tokenizers[testing]"]
 docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"]
 testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests", "ruff"]
 
+[[package]]
+name = "tomli"
+version = "2.0.2"
+description = "A lil' TOML parser"
+optional = false
+python-versions = ">=3.8"
+files = [
+    {file = "tomli-2.0.2-py3-none-any.whl", hash = "sha256:2ebe24485c53d303f690b0ec092806a085f07af5a5aa1464f3931eec36caaa38"},
+    {file = "tomli-2.0.2.tar.gz", hash = "sha256:d46d457a85337051c36524bc5349dd91b1877838e2979ac5ced3e710ed8a60ed"},
+]
+
 [[package]]
 name = "tqdm"
 version = "4.66.5"
@@ -1957,4 +2032,4 @@ type = ["pytest-mypy"]
 [metadata]
 lock-version = "2.0"
 python-versions = "^3.10"
-content-hash = "a8444b990d1c8d17cc073fb99977cc8b5511133f964a2b9e8278c8ecd419639f"
+content-hash = "bec0c2170be5d8562dc02173a385adc2f60c95524a4f7310070972ec3666c05d"
diff --git a/pyproject.toml b/pyproject.toml
index 891289d..890a1d9 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -19,6 +19,11 @@ litellm = "^1.49.2"
 
 [tool.poetry.group.dev.dependencies]
 ruff = "^0.6.9"
+mypy = "^1.13.0"
+
+[tool.mypy]
+files = "intentguard/**/*.py"
+ignore_missing_imports = true
 
 [build-system]
 requires = ["poetry-core"]