From ac25db99be24c40b713597b5e88f4f164013dc63 Mon Sep 17 00:00:00 2001 From: Kosma Dunikowski <kosmadunikowski@gmail.com> Date: Mon, 11 Nov 2024 14:04:49 +0100 Subject: [PATCH] Add Mypy type checking (#13) * Add Mypy type checking Fixes #12 Add Mypy type checking to the project. * Add `mypy` as a development dependency in `pyproject.toml`. * Add `[tool.mypy]` section in `pyproject.toml` to configure mypy to check all code files. * Add type hints to all function and method signatures in `intentguard/cache.py`, `intentguard/intentguard_options.py`, and `intentguard/intentguard.py`. * Update `.github/workflows/main.yml` to include a step to run `mypy` checks after the `ruff` checks. * Update `generate_cache_key` in `intentguard/cache.py` to accept `IntentGuardOptions` object. * Update `_format_code_objects` in `intentguard/intentguard.py` to handle specific types expected by `inspect.getsource`. --- .cdigestignore | 9 ++++ .github/workflows/main.yml | 3 ++ .gitignore | 5 ++ intentguard/cache.py | 16 +++--- intentguard/intentguard.py | 43 +++++++++------- intentguard/intentguard_options.py | 6 +-- poetry.lock | 79 +++++++++++++++++++++++++++++- pyproject.toml | 5 ++ 8 files changed, 136 insertions(+), 30 deletions(-) create mode 100644 .cdigestignore diff --git a/.cdigestignore b/.cdigestignore new file mode 100644 index 0000000..e71454d --- /dev/null +++ b/.cdigestignore @@ -0,0 +1,9 @@ +.ruff_cache +dist +.venv +design +assets +poetry.lock +.cdigestignore +LICENSE +._codebase_digest.txt \ No newline at end of file diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 624b54c..43cc607 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -40,6 +40,9 @@ jobs: - name: Run ruff format check run: | poetry run ruff format --check . + - name: Run mypy check + run: | + poetry run mypy . - name: Run tests env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} diff --git a/.gitignore b/.gitignore index f7bdbab..f8fe946 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,8 @@ /.venv __pycache__ .intentguard +.mypy_cache +/design +._codebase_digest.txt +/.idea +*.iml \ No newline at end of file diff --git a/intentguard/cache.py b/intentguard/cache.py index 3b32480..0a1850c 100644 --- a/intentguard/cache.py +++ b/intentguard/cache.py @@ -6,12 +6,14 @@ import os import hashlib import json +from typing import Optional, Dict, Any +from intentguard.intentguard_options import IntentGuardOptions # Directory where cache files will be stored CACHE_DIR = ".intentguard" -def ensure_cache_dir_exists(): +def ensure_cache_dir_exists() -> None: """ Creates the cache directory if it doesn't exist. This is called before any cache operations to ensure the cache directory is available. @@ -20,7 +22,9 @@ def ensure_cache_dir_exists(): os.makedirs(CACHE_DIR) -def generate_cache_key(expectation: str, objects_text: str, options) -> str: +def generate_cache_key( + expectation: str, objects_text: str, options: IntentGuardOptions +) -> str: """ Generates a unique cache key based on the input parameters and model configuration. @@ -42,7 +46,7 @@ def generate_cache_key(expectation: str, objects_text: str, options) -> str: return hashlib.sha256(key_string.encode()).hexdigest() -def read_cache(cache_key: str): +def read_cache(cache_key: str) -> Optional[Dict[str, Any]]: """ Retrieves cached results for a given cache key. @@ -60,7 +64,7 @@ def read_cache(cache_key: str): return None -def write_cache(cache_key: str, result): +def write_cache(cache_key: str, result: Dict[str, Any]) -> None: """ Stores a result in the cache using the provided cache key. @@ -93,7 +97,7 @@ def __init__(self, result: bool, explanation: str = ""): self.result = result self.explanation = explanation - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: """ Converts the CachedResult instance to a dictionary for JSON serialization. @@ -103,7 +107,7 @@ def to_dict(self): return {"result": self.result, "explanation": self.explanation} @classmethod - def from_dict(cls, data: dict): + def from_dict(cls, data: Dict[str, Any]) -> "CachedResult": """ Creates a CachedResult instance from a dictionary. diff --git a/intentguard/intentguard.py b/intentguard/intentguard.py index a99f780..33219bd 100644 --- a/intentguard/intentguard.py +++ b/intentguard/intentguard.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Any, Optional +from typing import Dict, List, Any, Optional, cast import inspect import json from collections import Counter @@ -6,6 +6,7 @@ from litellm import completion from litellm.main import ModelResponse +from litellm.types.utils import Choices from intentguard.intentguard_options import IntentGuardOptions from intentguard.prompts import system_prompt, reponse_schema, explanation_prompt @@ -31,14 +32,14 @@ class IntentGuard: customizable options for the assertion process. """ - def __init__(self, options: Optional[IntentGuardOptions] = None): + def __init__(self, options: Optional[IntentGuardOptions] = None) -> None: """ Initialize the IntentGuard instance. Args: options: Configuration options for assertions. Uses default options if None. """ - self.options = options or IntentGuardOptions() + self.options: IntentGuardOptions = options or IntentGuardOptions() def assert_code( self, @@ -63,12 +64,12 @@ def assert_code( options = options or self.options # Prepare evaluation context - objects_text = self._format_code_objects(params) - prompt = self._create_evaluation_prompt(objects_text, expectation) - cache_key = generate_cache_key(expectation, objects_text, options) + objects_text: str = self._format_code_objects(params) + prompt: str = self._create_evaluation_prompt(objects_text, expectation) + cache_key: str = generate_cache_key(expectation, objects_text, options) # Check cache or perform evaluation - final_result = self._get_cached_or_evaluate( + final_result: CachedResult = self._get_cached_or_evaluate( cache_key=cache_key, prompt=prompt, options=options ) @@ -79,7 +80,7 @@ def assert_code( f"Explanation: {final_result.explanation}" ) - def _format_code_objects(self, params: Dict[str, object]) -> str: + def _format_code_objects(self, params: Dict[str, Any]) -> str: """ Format code objects for LLM evaluation. @@ -91,9 +92,9 @@ def _format_code_objects(self, params: Dict[str, object]) -> str: Returns: Formatted string containing source code of all objects """ - formatted_objects = [] + formatted_objects: List[str] = [] for name, obj in params.items(): - source = inspect.getsource(obj) + source: str = inspect.getsource(obj) formatted_objects.append( f"""{{{name}}}: ```py @@ -138,12 +139,14 @@ def _get_cached_or_evaluate( return CachedResult.from_dict(cached_result) # Perform multiple evaluations for consensus - results = [ + results: List[bool] = [ self._perform_single_evaluation(prompt, options) for _ in range(options.num_evaluations) ] - final_result = CachedResult(result=self._determine_consensus(results)) + final_result: CachedResult = CachedResult( + result=self._determine_consensus(results) + ) # Generate explanation for failed assertions if not final_result.result: @@ -167,7 +170,7 @@ def _perform_single_evaluation( Returns: Boolean result of evaluation """ - request = LLMRequest( + request: LLMRequest = LLMRequest( messages=[ {"content": system_prompt, "role": "system"}, {"content": prompt, "role": "user"}, @@ -179,8 +182,10 @@ def _perform_single_evaluation( }, ) - response = self._send_llm_request(request) - return json.loads(response.choices[0].message.content)["result"] + response: ModelResponse = self._send_llm_request(request) + return json.loads( + cast(str, cast(Choices, response.choices[0]).message.content) + )["result"] def _generate_failure_explanation( self, prompt: str, options: IntentGuardOptions @@ -195,7 +200,7 @@ def _generate_failure_explanation( Returns: Detailed explanation of why assertion failed """ - request = LLMRequest( + request: LLMRequest = LLMRequest( messages=[ {"content": explanation_prompt, "role": "system"}, {"content": prompt, "role": "user"}, @@ -203,8 +208,8 @@ def _generate_failure_explanation( model=options.model, ) - response = self._send_llm_request(request) - return response.choices[0].message.content + response: ModelResponse = self._send_llm_request(request) + return cast(str, cast(Choices, response.choices[0]).message.content) def _send_llm_request(self, request: LLMRequest) -> ModelResponse: """ @@ -233,5 +238,5 @@ def _determine_consensus(self, results: List[bool]) -> bool: Returns: Consensus result """ - vote_count = Counter(results) + vote_count: Counter = Counter(results) return vote_count[True] > vote_count[False] diff --git a/intentguard/intentguard_options.py b/intentguard/intentguard_options.py index 761bc0e..bbf648b 100644 --- a/intentguard/intentguard_options.py +++ b/intentguard/intentguard_options.py @@ -10,7 +10,7 @@ def __init__( self, model: str = "gpt-4o-mini-2024-07-18", num_evaluations: int = 1, - ): + ) -> None: """ Initialize IntentGuardOptions with the specified parameters. @@ -21,5 +21,5 @@ def __init__( for each assertion. The final result is determined by majority vote. Defaults to 1. """ - self.model = model - self.num_evaluations = num_evaluations + self.model: str = model + self.num_evaluations: int = num_evaluations diff --git a/poetry.lock b/poetry.lock index a2fb99a..4045cad 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -980,6 +980,70 @@ files = [ [package.dependencies] typing-extensions = {version = ">=4.1.0", markers = "python_version < \"3.11\""} +[[package]] +name = "mypy" +version = "1.13.0" +description = "Optional static typing for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "mypy-1.13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6607e0f1dd1fb7f0aca14d936d13fd19eba5e17e1cd2a14f808fa5f8f6d8f60a"}, + {file = "mypy-1.13.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8a21be69bd26fa81b1f80a61ee7ab05b076c674d9b18fb56239d72e21d9f4c80"}, + {file = "mypy-1.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7b2353a44d2179846a096e25691d54d59904559f4232519d420d64da6828a3a7"}, + {file = "mypy-1.13.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0730d1c6a2739d4511dc4253f8274cdd140c55c32dfb0a4cf8b7a43f40abfa6f"}, + {file = "mypy-1.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:c5fc54dbb712ff5e5a0fca797e6e0aa25726c7e72c6a5850cfd2adbc1eb0a372"}, + {file = "mypy-1.13.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:581665e6f3a8a9078f28d5502f4c334c0c8d802ef55ea0e7276a6e409bc0d82d"}, + {file = "mypy-1.13.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3ddb5b9bf82e05cc9a627e84707b528e5c7caaa1c55c69e175abb15a761cec2d"}, + {file = "mypy-1.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:20c7ee0bc0d5a9595c46f38beb04201f2620065a93755704e141fcac9f59db2b"}, + {file = "mypy-1.13.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:3790ded76f0b34bc9c8ba4def8f919dd6a46db0f5a6610fb994fe8efdd447f73"}, + {file = "mypy-1.13.0-cp311-cp311-win_amd64.whl", hash = "sha256:51f869f4b6b538229c1d1bcc1dd7d119817206e2bc54e8e374b3dfa202defcca"}, + {file = "mypy-1.13.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5c7051a3461ae84dfb5dd15eff5094640c61c5f22257c8b766794e6dd85e72d5"}, + {file = "mypy-1.13.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:39bb21c69a5d6342f4ce526e4584bc5c197fd20a60d14a8624d8743fffb9472e"}, + {file = "mypy-1.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:164f28cb9d6367439031f4c81e84d3ccaa1e19232d9d05d37cb0bd880d3f93c2"}, + {file = "mypy-1.13.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a4c1bfcdbce96ff5d96fc9b08e3831acb30dc44ab02671eca5953eadad07d6d0"}, + {file = "mypy-1.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:a0affb3a79a256b4183ba09811e3577c5163ed06685e4d4b46429a271ba174d2"}, + {file = "mypy-1.13.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a7b44178c9760ce1a43f544e595d35ed61ac2c3de306599fa59b38a6048e1aa7"}, + {file = "mypy-1.13.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5d5092efb8516d08440e36626f0153b5006d4088c1d663d88bf79625af3d1d62"}, + {file = "mypy-1.13.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:de2904956dac40ced10931ac967ae63c5089bd498542194b436eb097a9f77bc8"}, + {file = "mypy-1.13.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:7bfd8836970d33c2105562650656b6846149374dc8ed77d98424b40b09340ba7"}, + {file = "mypy-1.13.0-cp313-cp313-win_amd64.whl", hash = "sha256:9f73dba9ec77acb86457a8fc04b5239822df0c14a082564737833d2963677dbc"}, + {file = "mypy-1.13.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:100fac22ce82925f676a734af0db922ecfea991e1d7ec0ceb1e115ebe501301a"}, + {file = "mypy-1.13.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7bcb0bb7f42a978bb323a7c88f1081d1b5dee77ca86f4100735a6f541299d8fb"}, + {file = "mypy-1.13.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bde31fc887c213e223bbfc34328070996061b0833b0a4cfec53745ed61f3519b"}, + {file = "mypy-1.13.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:07de989f89786f62b937851295ed62e51774722e5444a27cecca993fc3f9cd74"}, + {file = "mypy-1.13.0-cp38-cp38-win_amd64.whl", hash = "sha256:4bde84334fbe19bad704b3f5b78c4abd35ff1026f8ba72b29de70dda0916beb6"}, + {file = "mypy-1.13.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0246bcb1b5de7f08f2826451abd947bf656945209b140d16ed317f65a17dc7dc"}, + {file = "mypy-1.13.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7f5b7deae912cf8b77e990b9280f170381fdfbddf61b4ef80927edd813163732"}, + {file = "mypy-1.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7029881ec6ffb8bc233a4fa364736789582c738217b133f1b55967115288a2bc"}, + {file = "mypy-1.13.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3e38b980e5681f28f033f3be86b099a247b13c491f14bb8b1e1e134d23bb599d"}, + {file = "mypy-1.13.0-cp39-cp39-win_amd64.whl", hash = "sha256:a6789be98a2017c912ae6ccb77ea553bbaf13d27605d2ca20a76dfbced631b24"}, + {file = "mypy-1.13.0-py3-none-any.whl", hash = "sha256:9c250883f9fd81d212e0952c92dbfcc96fc237f4b7c92f56ac81fd48460b3e5a"}, + {file = "mypy-1.13.0.tar.gz", hash = "sha256:0291a61b6fbf3e6673e3405cfcc0e7650bebc7939659fdca2702958038bd835e"}, +] + +[package.dependencies] +mypy-extensions = ">=1.0.0" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typing-extensions = ">=4.6.0" + +[package.extras] +dmypy = ["psutil (>=4.0)"] +faster-cache = ["orjson"] +install-types = ["pip"] +mypyc = ["setuptools (>=50)"] +reports = ["lxml"] + +[[package]] +name = "mypy-extensions" +version = "1.0.0" +description = "Type system extensions for programs checked with the mypy type checker." +optional = false +python-versions = ">=3.5" +files = [ + {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, + {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, +] + [[package]] name = "openai" version = "1.51.2" @@ -1775,6 +1839,17 @@ dev = ["tokenizers[testing]"] docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"] testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests", "ruff"] +[[package]] +name = "tomli" +version = "2.0.2" +description = "A lil' TOML parser" +optional = false +python-versions = ">=3.8" +files = [ + {file = "tomli-2.0.2-py3-none-any.whl", hash = "sha256:2ebe24485c53d303f690b0ec092806a085f07af5a5aa1464f3931eec36caaa38"}, + {file = "tomli-2.0.2.tar.gz", hash = "sha256:d46d457a85337051c36524bc5349dd91b1877838e2979ac5ced3e710ed8a60ed"}, +] + [[package]] name = "tqdm" version = "4.66.5" @@ -1957,4 +2032,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "a8444b990d1c8d17cc073fb99977cc8b5511133f964a2b9e8278c8ecd419639f" +content-hash = "bec0c2170be5d8562dc02173a385adc2f60c95524a4f7310070972ec3666c05d" diff --git a/pyproject.toml b/pyproject.toml index 891289d..890a1d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,11 @@ litellm = "^1.49.2" [tool.poetry.group.dev.dependencies] ruff = "^0.6.9" +mypy = "^1.13.0" + +[tool.mypy] +files = "intentguard/**/*.py" +ignore_missing_imports = true [build-system] requires = ["poetry-core"]