Skip to content

Commit

Permalink
Add Azure Llama client support (#872)
Browse files Browse the repository at this point in the history
* breaking api: rename to openai_llm_client

* setup azure llama client

* export func

* fix pyproject.toml
  • Loading branch information
clavedeluna authored Oct 10, 2024
1 parent c0c3563 commit 7a41a84
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 11 deletions.
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ get-hashes = 'codemodder.scripts.get_hashes:main'

[project.optional-dependencies]
test = [
"azure-ai-inference>=1.0.0b1,<2.0",
"coverage>=7.6,<7.7",
"coverage-threshold~=0.4",
"defusedxml==0.7.1",
Expand Down Expand Up @@ -86,6 +87,10 @@ complexity = [
openai = [
"openai>=1.50,<1.52",
]
azure = [
"azure-ai-inference>=1.0.0b1,<2.0",
]

all = [
"codemodder[test]",
"codemodder[complexity]",
Expand Down
5 changes: 3 additions & 2 deletions src/codemodder/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
build_failed_dependency_notification,
)
from codemodder.file_context import FileContext
from codemodder.llm import setup_llm_client
from codemodder.llm import setup_azure_llama_llm_client, setup_openai_llm_client
from codemodder.logging import log_list, logger
from codemodder.project_analysis.file_parsers.package_store import PackageStore
from codemodder.project_analysis.python_repo_manager import PythonRepoManager
Expand Down Expand Up @@ -82,7 +82,8 @@ def __init__(
self.max_workers = max_workers
self.tool_result_files_map = tool_result_files_map or {}
self.semgrep_prefilter_results = None
self.llm_client = setup_llm_client()
self.openai_llm_client = setup_openai_llm_client()
self.azure_llama_llm_client = setup_azure_llama_llm_client()

def add_changesets(self, codemod_name: str, change_sets: List[ChangeSet]):
self._changesets_by_codemod.setdefault(codemod_name, []).extend(change_sets)
Expand Down
36 changes: 34 additions & 2 deletions src/codemodder/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,24 @@
OpenAI = None
AzureOpenAI = None

try:
from azure.ai.inference import ChatCompletionsClient
from azure.core.credentials import AzureKeyCredential
except ImportError:
ChatCompletionsClient = None
AzureKeyCredential = None

if TYPE_CHECKING:
from openai import OpenAI
from azure.ai.inference import ChatCompletionsClient
from azure.core.credentials import AzureKeyCredential

from codemodder.logging import logger

__all__ = [
"MODELS",
"setup_llm_client",
"setup_openai_llm_client",
"setup_azure_llama_llm_client",
"MisconfiguredAIClient",
]

Expand Down Expand Up @@ -46,7 +55,8 @@ def __getattr__(self, name):
MODELS = ModelRegistry(models)


def setup_llm_client() -> OpenAI | None:
def setup_openai_llm_client() -> OpenAI | None:
"""Configure either the Azure OpenAI LLM client or the OpenAI client, in that order."""
if not AzureOpenAI:
logger.info("Azure OpenAI API client not available")
return None
Expand Down Expand Up @@ -81,5 +91,27 @@ def setup_llm_client() -> OpenAI | None:
return OpenAI(api_key=api_key)


def setup_azure_llama_llm_client() -> ChatCompletionsClient | None:
"""Configure the Azure Llama LLM client."""
if not ChatCompletionsClient:
logger.info("Azure API client not available")
return None

azure_llama_key = os.getenv("CODEMODDER_AZURE_LLAMA_API_KEY")
azure_llama_endpoint = os.getenv("CODEMODDER_AZURE_LLAMA_ENDPOINT")
if bool(azure_llama_key) ^ bool(azure_llama_endpoint):
raise MisconfiguredAIClient(
"Azure Llama API key and endpoint must both be set or unset"
)

if azure_llama_key and azure_llama_endpoint:
logger.info("Using Azure Llama API client")
return ChatCompletionsClient(
credential=AzureKeyCredential(azure_llama_key),
endpoint=azure_llama_endpoint,
)
return None


class MisconfiguredAIClient(ValueError):
pass
78 changes: 71 additions & 7 deletions tests/test_context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

import pytest
from azure.ai.inference import ChatCompletionsClient
from openai import AzureOpenAI, OpenAI

from codemodder.context import CodemodExecutionContext as Context
Expand Down Expand Up @@ -90,7 +91,7 @@ def test_failed_dependency_description(self, mocker):
in description
)

def test_setup_llm_client_no_env_vars(self, mocker):
def test_setup_llm_clients_no_env_vars(self, mocker):
mocker.patch.dict(os.environ, clear=True)
context = Context(
mocker.Mock(),
Expand All @@ -102,7 +103,8 @@ def test_setup_llm_client_no_env_vars(self, mocker):
[],
[],
)
assert context.llm_client is None
assert context.openai_llm_client is None
assert context.azure_llama_llm_client is None

def test_setup_openai_llm_client(self, mocker):
mocker.patch.dict(os.environ, {"CODEMODDER_OPENAI_API_KEY": "test"})
Expand All @@ -116,7 +118,29 @@ def test_setup_openai_llm_client(self, mocker):
[],
[],
)
assert isinstance(context.llm_client, OpenAI)
assert isinstance(context.openai_llm_client, OpenAI)

def test_setup_both_llm_clients(self, mocker):
mocker.patch.dict(
os.environ,
{
"CODEMODDER_OPENAI_API_KEY": "test",
"CODEMODDER_AZURE_LLAMA_API_KEY": "test",
"CODEMODDER_AZURE_LLAMA_ENDPOINT": "test",
},
)
context = Context(
mocker.Mock(),
True,
False,
load_registered_codemods(),
None,
PythonRepoManager(mocker.Mock()),
[],
[],
)
assert isinstance(context.openai_llm_client, OpenAI)
assert isinstance(context.azure_llama_llm_client, ChatCompletionsClient)

def test_setup_azure_llm_client(self, mocker):
mocker.patch.dict(
Expand All @@ -136,8 +160,10 @@ def test_setup_azure_llm_client(self, mocker):
[],
[],
)
assert isinstance(context.llm_client, AzureOpenAI)
assert context.llm_client._api_version == DEFAULT_AZURE_OPENAI_API_VERSION
assert isinstance(context.openai_llm_client, AzureOpenAI)
assert (
context.openai_llm_client._api_version == DEFAULT_AZURE_OPENAI_API_VERSION
)

@pytest.mark.parametrize(
"env_var",
Expand All @@ -157,6 +183,44 @@ def test_setup_azure_llm_client_missing_one(self, mocker, env_var):
[],
)

def test_setup_azure_llama_llm_client(self, mocker):
mocker.patch.dict(
os.environ,
{
"CODEMODDER_AZURE_LLAMA_API_KEY": "test",
"CODEMODDER_AZURE_LLAMA_ENDPOINT": "test",
},
)
context = Context(
mocker.Mock(),
True,
False,
load_registered_codemods(),
None,
PythonRepoManager(mocker.Mock()),
[],
[],
)
assert isinstance(context.azure_llama_llm_client, ChatCompletionsClient)

@pytest.mark.parametrize(
"env_var",
["CODEMODDER_AZURE_LLAMA_API_KEY", "CODEMODDER_AZURE_LLAMA_ENDPOINT"],
)
def test_setup_azure_llama_llm_client_missing_one(self, mocker, env_var):
mocker.patch.dict(os.environ, {env_var: "test"})
with pytest.raises(MisconfiguredAIClient):
Context(
mocker.Mock(),
True,
False,
load_registered_codemods(),
None,
PythonRepoManager(mocker.Mock()),
[],
[],
)

def test_get_api_version_from_env(self, mocker):
version = "fake-version"
mocker.patch.dict(
Expand All @@ -177,5 +241,5 @@ def test_get_api_version_from_env(self, mocker):
[],
[],
)
assert isinstance(context.llm_client, AzureOpenAI)
assert context.llm_client._api_version == version
assert isinstance(context.openai_llm_client, AzureOpenAI)
assert context.openai_llm_client._api_version == version

0 comments on commit 7a41a84

Please sign in to comment.