Skip to content

Commit

Permalink
Remove deprecations from Weaviate Provider (apache#44745)
Browse files Browse the repository at this point in the history
* remove weaviate deprecations

* update provider name in change log

* fix docs
  • Loading branch information
vatsrahul1001 authored and Lefteris Gilmaz committed Jan 5, 2025
1 parent 130d82e commit 327a1c4
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 44 deletions.
9 changes: 9 additions & 0 deletions providers/src/airflow/providers/weaviate/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@
Changelog
---------

main
....

.. warning::
All deprecated classes, parameters and features have been removed from the weaviate provider package.
The following breaking changes were introduced:

* Removed deprecated ``input_json`` parameter from ``WeaviateIngestOperator``. Use ``input_data`` instead.

2.1.0
.....

Expand Down
20 changes: 3 additions & 17 deletions providers/src/airflow/providers/weaviate/operators/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@

from __future__ import annotations

import warnings
from collections.abc import Sequence
from functools import cached_property
from typing import TYPE_CHECKING, Any

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.models import BaseOperator
from airflow.providers.weaviate.hooks.weaviate import WeaviateHook

Expand Down Expand Up @@ -51,11 +49,9 @@ class WeaviateIngestOperator(BaseOperator):
:param vector_col: key/column name in which the vectors are stored.
:param hook_params: Optional config params to be passed to the underlying hook.
Should match the desired hook constructor params.
:param input_json: (Deprecated) The JSON representing Weaviate data objects to generate embeddings on
(or provides custom vectors) and store them in the Weaviate class.
"""

template_fields: Sequence[str] = ("input_json", "input_data")
template_fields: Sequence[str] = ("input_data",)

def __init__(
self,
Expand All @@ -66,29 +62,19 @@ def __init__(
uuid_column: str = "id",
tenant: str | None = None,
hook_params: dict | None = None,
input_json: list[dict[str, Any]] | pd.DataFrame | None = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.collection_name = collection_name
self.conn_id = conn_id
self.vector_col = vector_col
self.input_json = input_json
self.uuid_column = uuid_column
self.tenant = tenant
self.input_data = input_data
self.hook_params = hook_params or {}

if (self.input_data is None) and (input_json is not None):
warnings.warn(
"Passing 'input_json' to WeaviateIngestOperator is deprecated and"
" you should use 'input_data' instead",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
self.input_data = input_json
elif self.input_data is None and input_json is None:
raise TypeError("Either input_json or input_data is required")
if self.input_data is None:
raise TypeError("input_data is required")

@cached_property
def hook(self) -> WeaviateHook:
Expand Down
27 changes: 0 additions & 27 deletions providers/tests/weaviate/operators/test_weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import pytest

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.utils.task_instance_session import set_current_task_instance_session

pytest.importorskip("weaviate")
Expand All @@ -47,30 +46,6 @@ def test_constructor(self, operator):
assert operator.input_data == [{"data": "sample_data"}]
assert operator.hook_params == {}

@patch("airflow.providers.weaviate.operators.weaviate.WeaviateIngestOperator.log")
def test_execute_with_input_json(self, mock_log, operator):
with pytest.warns(
AirflowProviderDeprecationWarning,
match="Passing 'input_json' to WeaviateIngestOperator is deprecated and you should use 'input_data' instead",
):
operator = WeaviateIngestOperator(
task_id="weaviate_task",
conn_id="weaviate_conn",
collection_name="my_collection",
input_json=[{"data": "sample_data"}],
)
operator.hook.batch_data = MagicMock()

operator.execute(context=None)

operator.hook.batch_data.assert_called_once_with(
collection_name="my_collection",
data=[{"data": "sample_data"}],
vector_col="Vector",
uuid_col="id",
)
mock_log.debug.assert_called_once_with("Input data: %s", [{"data": "sample_data"}])

@patch("airflow.providers.weaviate.operators.weaviate.WeaviateIngestOperator.log")
def test_execute_with_input_data(self, mock_log, operator):
operator.hook.batch_data = MagicMock()
Expand All @@ -94,12 +69,10 @@ def test_templates(self, create_task_instance_of_operator):
task_id="task-id",
conn_id="weaviate_conn",
collection_name="my_collection",
input_json="{{ dag.dag_id }}",
input_data="{{ dag.dag_id }}",
)
ti.render_templates()

assert dag_id == ti.task.input_json
assert dag_id == ti.task.input_data

@pytest.mark.db_test
Expand Down

0 comments on commit 327a1c4

Please sign in to comment.