Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove deprecations from Weaviate Provider #44745

Merged
merged 3 commits into from
Dec 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions providers/src/airflow/providers/weaviate/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@
Changelog
---------

main
....

.. warning::
All deprecated classes, parameters and features have been removed from the weaviate provider package.
The following breaking changes were introduced:

* Removed deprecated ``input_json`` parameter from ``WeaviateIngestOperator``. Use ``input_data`` instead.

2.1.0
.....

Expand Down
20 changes: 3 additions & 17 deletions providers/src/airflow/providers/weaviate/operators/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@

from __future__ import annotations

import warnings
from collections.abc import Sequence
from functools import cached_property
from typing import TYPE_CHECKING, Any

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.models import BaseOperator
from airflow.providers.weaviate.hooks.weaviate import WeaviateHook

Expand Down Expand Up @@ -51,11 +49,9 @@ class WeaviateIngestOperator(BaseOperator):
:param vector_col: key/column name in which the vectors are stored.
:param hook_params: Optional config params to be passed to the underlying hook.
Should match the desired hook constructor params.
:param input_json: (Deprecated) The JSON representing Weaviate data objects to generate embeddings on
(or provides custom vectors) and store them in the Weaviate class.
"""

template_fields: Sequence[str] = ("input_json", "input_data")
template_fields: Sequence[str] = ("input_data",)

def __init__(
self,
Expand All @@ -66,29 +62,19 @@ def __init__(
uuid_column: str = "id",
tenant: str | None = None,
hook_params: dict | None = None,
input_json: list[dict[str, Any]] | pd.DataFrame | None = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.collection_name = collection_name
self.conn_id = conn_id
self.vector_col = vector_col
self.input_json = input_json
self.uuid_column = uuid_column
self.tenant = tenant
self.input_data = input_data
self.hook_params = hook_params or {}

if (self.input_data is None) and (input_json is not None):
warnings.warn(
"Passing 'input_json' to WeaviateIngestOperator is deprecated and"
" you should use 'input_data' instead",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
self.input_data = input_json
elif self.input_data is None and input_json is None:
raise TypeError("Either input_json or input_data is required")
if self.input_data is None:
raise TypeError("input_data is required")

@cached_property
def hook(self) -> WeaviateHook:
Expand Down
27 changes: 0 additions & 27 deletions providers/tests/weaviate/operators/test_weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import pytest

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.utils.task_instance_session import set_current_task_instance_session

pytest.importorskip("weaviate")
Expand All @@ -47,30 +46,6 @@ def test_constructor(self, operator):
assert operator.input_data == [{"data": "sample_data"}]
assert operator.hook_params == {}

@patch("airflow.providers.weaviate.operators.weaviate.WeaviateIngestOperator.log")
def test_execute_with_input_json(self, mock_log, operator):
with pytest.warns(
AirflowProviderDeprecationWarning,
match="Passing 'input_json' to WeaviateIngestOperator is deprecated and you should use 'input_data' instead",
):
operator = WeaviateIngestOperator(
task_id="weaviate_task",
conn_id="weaviate_conn",
collection_name="my_collection",
input_json=[{"data": "sample_data"}],
)
operator.hook.batch_data = MagicMock()

operator.execute(context=None)

operator.hook.batch_data.assert_called_once_with(
collection_name="my_collection",
data=[{"data": "sample_data"}],
vector_col="Vector",
uuid_col="id",
)
mock_log.debug.assert_called_once_with("Input data: %s", [{"data": "sample_data"}])

@patch("airflow.providers.weaviate.operators.weaviate.WeaviateIngestOperator.log")
def test_execute_with_input_data(self, mock_log, operator):
operator.hook.batch_data = MagicMock()
Expand All @@ -94,12 +69,10 @@ def test_templates(self, create_task_instance_of_operator):
task_id="task-id",
conn_id="weaviate_conn",
collection_name="my_collection",
input_json="{{ dag.dag_id }}",
input_data="{{ dag.dag_id }}",
)
ti.render_templates()

assert dag_id == ti.task.input_json
assert dag_id == ti.task.input_data

@pytest.mark.db_test
Expand Down