Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Azure Llama client support #872

Merged
merged 4 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ get-hashes = 'codemodder.scripts.get_hashes:main'

[project.optional-dependencies]
test = [
"azure-ai-inference>=1.0.0b1,<2.0",
"coverage>=7.6,<7.7",
"coverage-threshold~=0.4",
"defusedxml==0.7.1",
Expand Down Expand Up @@ -86,6 +87,10 @@ complexity = [
openai = [
"openai>=1.50,<1.52",
]
azure = [
"azure-ai-inference>=1.0.0b1,<2.0",
]

all = [
"codemodder[test]",
"codemodder[complexity]",
Expand Down
5 changes: 3 additions & 2 deletions src/codemodder/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
build_failed_dependency_notification,
)
from codemodder.file_context import FileContext
from codemodder.llm import setup_llm_client
from codemodder.llm import setup_azure_llama_llm_client, setup_openai_llm_client
from codemodder.logging import log_list, logger
from codemodder.project_analysis.file_parsers.package_store import PackageStore
from codemodder.project_analysis.python_repo_manager import PythonRepoManager
Expand Down Expand Up @@ -82,7 +82,8 @@ def __init__(
self.max_workers = max_workers
self.tool_result_files_map = tool_result_files_map or {}
self.semgrep_prefilter_results = None
self.llm_client = setup_llm_client()
self.openai_llm_client = setup_openai_llm_client()
self.azure_llama_llm_client = setup_azure_llama_llm_client()

def add_changesets(self, codemod_name: str, change_sets: List[ChangeSet]):
self._changesets_by_codemod.setdefault(codemod_name, []).extend(change_sets)
Expand Down
36 changes: 34 additions & 2 deletions src/codemodder/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,24 @@
OpenAI = None
AzureOpenAI = None

try:
from azure.ai.inference import ChatCompletionsClient
from azure.core.credentials import AzureKeyCredential
except ImportError:
ChatCompletionsClient = None
AzureKeyCredential = None

if TYPE_CHECKING:
from openai import OpenAI
from azure.ai.inference import ChatCompletionsClient
from azure.core.credentials import AzureKeyCredential

from codemodder.logging import logger

__all__ = [
"MODELS",
"setup_llm_client",
"setup_openai_llm_client",
"setup_azure_llama_llm_client",
"MisconfiguredAIClient",
]

Expand Down Expand Up @@ -46,7 +55,8 @@ def __getattr__(self, name):
MODELS = ModelRegistry(models)


def setup_llm_client() -> OpenAI | None:
def setup_openai_llm_client() -> OpenAI | None:
"""Configure either the Azure OpenAI LLM client or the OpenAI client, in that order."""
if not AzureOpenAI:
logger.info("Azure OpenAI API client not available")
return None
Expand Down Expand Up @@ -81,5 +91,27 @@ def setup_llm_client() -> OpenAI | None:
return OpenAI(api_key=api_key)


def setup_azure_llama_llm_client() -> ChatCompletionsClient | None:
"""Configure the Azure Llama LLM client."""
if not ChatCompletionsClient:
logger.info("Azure API client not available")
return None

azure_llama_key = os.getenv("CODEMODDER_AZURE_LLAMA_API_KEY")
azure_llama_endpoint = os.getenv("CODEMODDER_AZURE_LLAMA_ENDPOINT")
if bool(azure_llama_key) ^ bool(azure_llama_endpoint):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is my fault because I did this originally but != is definitely a lot clearer here 😅

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I kinda agree but I also got very used to reading this and it's nice :)

raise MisconfiguredAIClient(
"Azure Llama API key and endpoint must both be set or unset"
)

if azure_llama_key and azure_llama_endpoint:
logger.info("Using Azure Llama API client")
return ChatCompletionsClient(
credential=AzureKeyCredential(azure_llama_key),
endpoint=azure_llama_endpoint,
)
return None


class MisconfiguredAIClient(ValueError):
pass
78 changes: 71 additions & 7 deletions tests/test_context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

import pytest
from azure.ai.inference import ChatCompletionsClient
from openai import AzureOpenAI, OpenAI

from codemodder.context import CodemodExecutionContext as Context
Expand Down Expand Up @@ -90,7 +91,7 @@ def test_failed_dependency_description(self, mocker):
in description
)

def test_setup_llm_client_no_env_vars(self, mocker):
def test_setup_llm_clients_no_env_vars(self, mocker):
mocker.patch.dict(os.environ, clear=True)
context = Context(
mocker.Mock(),
Expand All @@ -102,7 +103,8 @@ def test_setup_llm_client_no_env_vars(self, mocker):
[],
[],
)
assert context.llm_client is None
assert context.openai_llm_client is None
assert context.azure_llama_llm_client is None

def test_setup_openai_llm_client(self, mocker):
mocker.patch.dict(os.environ, {"CODEMODDER_OPENAI_API_KEY": "test"})
Expand All @@ -116,7 +118,29 @@ def test_setup_openai_llm_client(self, mocker):
[],
[],
)
assert isinstance(context.llm_client, OpenAI)
assert isinstance(context.openai_llm_client, OpenAI)

def test_setup_both_llm_clients(self, mocker):
mocker.patch.dict(
os.environ,
{
"CODEMODDER_OPENAI_API_KEY": "test",
"CODEMODDER_AZURE_LLAMA_API_KEY": "test",
"CODEMODDER_AZURE_LLAMA_ENDPOINT": "test",
},
)
context = Context(
mocker.Mock(),
True,
False,
load_registered_codemods(),
None,
PythonRepoManager(mocker.Mock()),
[],
[],
)
assert isinstance(context.openai_llm_client, OpenAI)
assert isinstance(context.azure_llama_llm_client, ChatCompletionsClient)

def test_setup_azure_llm_client(self, mocker):
mocker.patch.dict(
Expand All @@ -136,8 +160,10 @@ def test_setup_azure_llm_client(self, mocker):
[],
[],
)
assert isinstance(context.llm_client, AzureOpenAI)
assert context.llm_client._api_version == DEFAULT_AZURE_OPENAI_API_VERSION
assert isinstance(context.openai_llm_client, AzureOpenAI)
assert (
context.openai_llm_client._api_version == DEFAULT_AZURE_OPENAI_API_VERSION
)

@pytest.mark.parametrize(
"env_var",
Expand All @@ -157,6 +183,44 @@ def test_setup_azure_llm_client_missing_one(self, mocker, env_var):
[],
)

def test_setup_azure_llama_llm_client(self, mocker):
mocker.patch.dict(
os.environ,
{
"CODEMODDER_AZURE_LLAMA_API_KEY": "test",
"CODEMODDER_AZURE_LLAMA_ENDPOINT": "test",
},
)
context = Context(
mocker.Mock(),
True,
False,
load_registered_codemods(),
None,
PythonRepoManager(mocker.Mock()),
[],
[],
)
assert isinstance(context.azure_llama_llm_client, ChatCompletionsClient)

@pytest.mark.parametrize(
"env_var",
["CODEMODDER_AZURE_LLAMA_API_KEY", "CODEMODDER_AZURE_LLAMA_ENDPOINT"],
)
def test_setup_azure_llama_llm_client_missing_one(self, mocker, env_var):
mocker.patch.dict(os.environ, {env_var: "test"})
with pytest.raises(MisconfiguredAIClient):
Context(
mocker.Mock(),
True,
False,
load_registered_codemods(),
None,
PythonRepoManager(mocker.Mock()),
[],
[],
)

def test_get_api_version_from_env(self, mocker):
version = "fake-version"
mocker.patch.dict(
Expand All @@ -177,5 +241,5 @@ def test_get_api_version_from_env(self, mocker):
[],
[],
)
assert isinstance(context.llm_client, AzureOpenAI)
assert context.llm_client._api_version == version
assert isinstance(context.openai_llm_client, AzureOpenAI)
assert context.openai_llm_client._api_version == version
Loading