diff --git a/pyproject.toml b/pyproject.toml index f1160926..e792c222 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ get-hashes = 'codemodder.scripts.get_hashes:main' [project.optional-dependencies] test = [ + "azure-ai-inference>=1.0.0b1,<2.0", "coverage>=7.6,<7.7", "coverage-threshold~=0.4", "defusedxml==0.7.1", @@ -86,6 +87,10 @@ complexity = [ openai = [ "openai>=1.50,<1.52", ] +azure = [ + "azure-ai-inference>=1.0.0b1,<2.0", +] + all = [ "codemodder[test]", "codemodder[complexity]", diff --git a/src/codemodder/context.py b/src/codemodder/context.py index 8d70947a..0c28e51a 100644 --- a/src/codemodder/context.py +++ b/src/codemodder/context.py @@ -17,7 +17,7 @@ build_failed_dependency_notification, ) from codemodder.file_context import FileContext -from codemodder.llm import setup_llm_client +from codemodder.llm import setup_azure_llama_llm_client, setup_openai_llm_client from codemodder.logging import log_list, logger from codemodder.project_analysis.file_parsers.package_store import PackageStore from codemodder.project_analysis.python_repo_manager import PythonRepoManager @@ -82,7 +82,8 @@ def __init__( self.max_workers = max_workers self.tool_result_files_map = tool_result_files_map or {} self.semgrep_prefilter_results = None - self.llm_client = setup_llm_client() + self.openai_llm_client = setup_openai_llm_client() + self.azure_llama_llm_client = setup_azure_llama_llm_client() def add_changesets(self, codemod_name: str, change_sets: List[ChangeSet]): self._changesets_by_codemod.setdefault(codemod_name, []).extend(change_sets) diff --git a/src/codemodder/llm.py b/src/codemodder/llm.py index db615e92..cc566b2a 100644 --- a/src/codemodder/llm.py +++ b/src/codemodder/llm.py @@ -9,15 +9,24 @@ OpenAI = None AzureOpenAI = None +try: + from azure.ai.inference import ChatCompletionsClient + from azure.core.credentials import AzureKeyCredential +except ImportError: + ChatCompletionsClient = None + AzureKeyCredential = None if TYPE_CHECKING: from openai import OpenAI + from azure.ai.inference import ChatCompletionsClient + from azure.core.credentials import AzureKeyCredential from codemodder.logging import logger __all__ = [ "MODELS", - "setup_llm_client", + "setup_openai_llm_client", + "setup_azure_llama_llm_client", "MisconfiguredAIClient", ] @@ -46,7 +55,8 @@ def __getattr__(self, name): MODELS = ModelRegistry(models) -def setup_llm_client() -> OpenAI | None: +def setup_openai_llm_client() -> OpenAI | None: + """Configure either the Azure OpenAI LLM client or the OpenAI client, in that order.""" if not AzureOpenAI: logger.info("Azure OpenAI API client not available") return None @@ -81,5 +91,27 @@ def setup_llm_client() -> OpenAI | None: return OpenAI(api_key=api_key) +def setup_azure_llama_llm_client() -> ChatCompletionsClient | None: + """Configure the Azure Llama LLM client.""" + if not ChatCompletionsClient: + logger.info("Azure API client not available") + return None + + azure_llama_key = os.getenv("CODEMODDER_AZURE_LLAMA_API_KEY") + azure_llama_endpoint = os.getenv("CODEMODDER_AZURE_LLAMA_ENDPOINT") + if bool(azure_llama_key) ^ bool(azure_llama_endpoint): + raise MisconfiguredAIClient( + "Azure Llama API key and endpoint must both be set or unset" + ) + + if azure_llama_key and azure_llama_endpoint: + logger.info("Using Azure Llama API client") + return ChatCompletionsClient( + credential=AzureKeyCredential(azure_llama_key), + endpoint=azure_llama_endpoint, + ) + return None + + class MisconfiguredAIClient(ValueError): pass diff --git a/tests/test_context.py b/tests/test_context.py index 6a482ecb..d80b699c 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -1,6 +1,7 @@ import os import pytest +from azure.ai.inference import ChatCompletionsClient from openai import AzureOpenAI, OpenAI from codemodder.context import CodemodExecutionContext as Context @@ -90,7 +91,7 @@ def test_failed_dependency_description(self, mocker): in description ) - def test_setup_llm_client_no_env_vars(self, mocker): + def test_setup_llm_clients_no_env_vars(self, mocker): mocker.patch.dict(os.environ, clear=True) context = Context( mocker.Mock(), @@ -102,7 +103,8 @@ def test_setup_llm_client_no_env_vars(self, mocker): [], [], ) - assert context.llm_client is None + assert context.openai_llm_client is None + assert context.azure_llama_llm_client is None def test_setup_openai_llm_client(self, mocker): mocker.patch.dict(os.environ, {"CODEMODDER_OPENAI_API_KEY": "test"}) @@ -116,7 +118,29 @@ def test_setup_openai_llm_client(self, mocker): [], [], ) - assert isinstance(context.llm_client, OpenAI) + assert isinstance(context.openai_llm_client, OpenAI) + + def test_setup_both_llm_clients(self, mocker): + mocker.patch.dict( + os.environ, + { + "CODEMODDER_OPENAI_API_KEY": "test", + "CODEMODDER_AZURE_LLAMA_API_KEY": "test", + "CODEMODDER_AZURE_LLAMA_ENDPOINT": "test", + }, + ) + context = Context( + mocker.Mock(), + True, + False, + load_registered_codemods(), + None, + PythonRepoManager(mocker.Mock()), + [], + [], + ) + assert isinstance(context.openai_llm_client, OpenAI) + assert isinstance(context.azure_llama_llm_client, ChatCompletionsClient) def test_setup_azure_llm_client(self, mocker): mocker.patch.dict( @@ -136,8 +160,10 @@ def test_setup_azure_llm_client(self, mocker): [], [], ) - assert isinstance(context.llm_client, AzureOpenAI) - assert context.llm_client._api_version == DEFAULT_AZURE_OPENAI_API_VERSION + assert isinstance(context.openai_llm_client, AzureOpenAI) + assert ( + context.openai_llm_client._api_version == DEFAULT_AZURE_OPENAI_API_VERSION + ) @pytest.mark.parametrize( "env_var", @@ -157,6 +183,44 @@ def test_setup_azure_llm_client_missing_one(self, mocker, env_var): [], ) + def test_setup_azure_llama_llm_client(self, mocker): + mocker.patch.dict( + os.environ, + { + "CODEMODDER_AZURE_LLAMA_API_KEY": "test", + "CODEMODDER_AZURE_LLAMA_ENDPOINT": "test", + }, + ) + context = Context( + mocker.Mock(), + True, + False, + load_registered_codemods(), + None, + PythonRepoManager(mocker.Mock()), + [], + [], + ) + assert isinstance(context.azure_llama_llm_client, ChatCompletionsClient) + + @pytest.mark.parametrize( + "env_var", + ["CODEMODDER_AZURE_LLAMA_API_KEY", "CODEMODDER_AZURE_LLAMA_ENDPOINT"], + ) + def test_setup_azure_llama_llm_client_missing_one(self, mocker, env_var): + mocker.patch.dict(os.environ, {env_var: "test"}) + with pytest.raises(MisconfiguredAIClient): + Context( + mocker.Mock(), + True, + False, + load_registered_codemods(), + None, + PythonRepoManager(mocker.Mock()), + [], + [], + ) + def test_get_api_version_from_env(self, mocker): version = "fake-version" mocker.patch.dict( @@ -177,5 +241,5 @@ def test_get_api_version_from_env(self, mocker): [], [], ) - assert isinstance(context.llm_client, AzureOpenAI) - assert context.llm_client._api_version == version + assert isinstance(context.openai_llm_client, AzureOpenAI) + assert context.openai_llm_client._api_version == version