Skip to content

Commit

Permalink
feat: add default_headers for Azure embedders (#8699)
Browse files Browse the repository at this point in the history
* Add default_headers param to azure embedders
  • Loading branch information
Amnah199 authored Jan 12, 2025
1 parent 4f73b19 commit db76ae2
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 0 deletions.
6 changes: 6 additions & 0 deletions haystack/components/embedders/azure_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def __init__( # noqa: PLR0913 (too-many-arguments) # pylint: disable=too-many-p
embedding_separator: str = "\n",
timeout: Optional[float] = None,
max_retries: Optional[int] = None,
*,
default_headers: Optional[Dict[str, str]] = None,
):
"""
Creates an AzureOpenAIDocumentEmbedder component.
Expand Down Expand Up @@ -95,6 +97,7 @@ def __init__( # noqa: PLR0913 (too-many-arguments) # pylint: disable=too-many-p
`OPENAI_TIMEOUT` environment variable, or 30 seconds.
:param max_retries: Maximum number of retries to contact AzureOpenAI after an internal error.
If not set, defaults to either the `OPENAI_MAX_RETRIES` environment variable or to 5 retries.
:param default_headers: Default headers to send to the AzureOpenAI client.
"""
# if not provided as a parameter, azure_endpoint is read from the env var AZURE_OPENAI_ENDPOINT
azure_endpoint = azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT")
Expand All @@ -119,6 +122,7 @@ def __init__( # noqa: PLR0913 (too-many-arguments) # pylint: disable=too-many-p
self.embedding_separator = embedding_separator
self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
self.default_headers = default_headers or {}

self._client = AzureOpenAI(
api_version=api_version,
Expand All @@ -129,6 +133,7 @@ def __init__( # noqa: PLR0913 (too-many-arguments) # pylint: disable=too-many-p
organization=organization,
timeout=self.timeout,
max_retries=self.max_retries,
default_headers=self.default_headers,
)

def _get_telemetry_data(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -161,6 +166,7 @@ def to_dict(self) -> Dict[str, Any]:
azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
timeout=self.timeout,
max_retries=self.max_retries,
default_headers=self.default_headers,
)

@classmethod
Expand Down
6 changes: 6 additions & 0 deletions haystack/components/embedders/azure_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def __init__( # pylint: disable=too-many-positional-arguments
max_retries: Optional[int] = None,
prefix: str = "",
suffix: str = "",
*,
default_headers: Optional[Dict[str, str]] = None,
):
"""
Creates an AzureOpenAITextEmbedder component.
Expand Down Expand Up @@ -82,6 +84,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
A string to add at the beginning of each text.
:param suffix:
A string to add at the end of each text.
:param default_headers: Default headers to send to the AzureOpenAI client.
"""
# Why is this here?
# AzureOpenAI init is forcing us to use an init method that takes either base_url or azure_endpoint as not
Expand All @@ -105,6 +108,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
self.prefix = prefix
self.suffix = suffix
self.default_headers = default_headers or {}

self._client = AzureOpenAI(
api_version=api_version,
Expand All @@ -115,6 +119,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
organization=organization,
timeout=self.timeout,
max_retries=self.max_retries,
default_headers=self.default_headers,
)

def _get_telemetry_data(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -143,6 +148,7 @@ def to_dict(self) -> Dict[str, Any]:
azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
timeout=self.timeout,
max_retries=self.max_retries,
default_headers=self.default_headers,
)

@classmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
enhancements:
- |
Added `default_headers` parameter to `AzureOpenAIDocumentEmbedder` and `AzureOpenAITextEmbedder`.
35 changes: 35 additions & 0 deletions test/components/embedders/test_azure_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def test_init_default(self, monkeypatch):
assert embedder.progress_bar is True
assert embedder.meta_fields_to_embed == []
assert embedder.embedding_separator == "\n"
assert embedder.default_headers == {}

def test_to_dict(self, monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
Expand All @@ -45,9 +46,43 @@ def test_to_dict(self, monkeypatch):
"embedding_separator": "\n",
"max_retries": 5,
"timeout": 30.0,
"default_headers": {},
},
}

def test_from_dict(self, monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
data = {
"type": "haystack.components.embedders.azure_document_embedder.AzureOpenAIDocumentEmbedder",
"init_parameters": {
"api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"},
"azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"},
"api_version": "2023-05-15",
"azure_deployment": "text-embedding-ada-002",
"dimensions": None,
"azure_endpoint": "https://example-resource.azure.openai.com/",
"organization": None,
"prefix": "",
"suffix": "",
"batch_size": 32,
"progress_bar": True,
"meta_fields_to_embed": [],
"embedding_separator": "\n",
"max_retries": 5,
"timeout": 30.0,
"default_headers": {},
},
}
component = AzureOpenAIDocumentEmbedder.from_dict(data)
assert component.azure_deployment == "text-embedding-ada-002"
assert component.azure_endpoint == "https://example-resource.azure.openai.com/"
assert component.api_version == "2023-05-15"
assert component.max_retries == 5
assert component.timeout == 30.0
assert component.prefix == ""
assert component.suffix == ""
assert component.default_headers == {}

@pytest.mark.integration
@pytest.mark.skipif(
not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),
Expand Down
31 changes: 31 additions & 0 deletions test/components/embedders/test_azure_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def test_init_default(self, monkeypatch):
assert embedder.organization is None
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.default_headers == {}

def test_to_dict(self, monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
Expand All @@ -38,9 +39,39 @@ def test_to_dict(self, monkeypatch):
"timeout": 30.0,
"prefix": "",
"suffix": "",
"default_headers": {},
},
}

def test_from_dict(self, monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
data = {
"type": "haystack.components.embedders.azure_text_embedder.AzureOpenAITextEmbedder",
"init_parameters": {
"api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"},
"azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"},
"azure_deployment": "text-embedding-ada-002",
"dimensions": None,
"organization": None,
"azure_endpoint": "https://example-resource.azure.openai.com/",
"api_version": "2023-05-15",
"max_retries": 5,
"timeout": 30.0,
"prefix": "",
"suffix": "",
"default_headers": {},
},
}
component = AzureOpenAITextEmbedder.from_dict(data)
assert component.azure_deployment == "text-embedding-ada-002"
assert component.azure_endpoint == "https://example-resource.azure.openai.com/"
assert component.api_version == "2023-05-15"
assert component.max_retries == 5
assert component.timeout == 30.0
assert component.prefix == ""
assert component.suffix == ""
assert component.default_headers == {}

@pytest.mark.integration
@pytest.mark.skipif(
not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),
Expand Down

0 comments on commit db76ae2

Please sign in to comment.