diff --git a/docker/conda/environments/cuda11.8_dev.yml b/docker/conda/environments/cuda11.8_dev.yml index 600827ffa6..45ab6cc17d 100644 --- a/docker/conda/environments/cuda11.8_dev.yml +++ b/docker/conda/environments/cuda11.8_dev.yml @@ -43,6 +43,7 @@ dependencies: - dill - docker-py=5.0 - docutils + - elasticsearch==8.9.0 - faker=12.3.0 - feedparser=6.0.10 - flake8 diff --git a/docs/source/modules/core/multiplexer.md b/docs/source/modules/core/multiplexer.md deleted file mode 100644 index cd4dc64611..0000000000 --- a/docs/source/modules/core/multiplexer.md +++ /dev/null @@ -1,35 +0,0 @@ - - -## Multiplexer Module - -The multiplexer receives data packets from one or more input ports and interleaves them into a single output. - -### Configurable Parameters - -| Parameter | Type | Description | Example Value | Default Value | -|---------------------------|-----------|-----------------------------------------------------------------------------------------------------------|---------------|---------------| -| `input_ports`| list[string] | Input ports data streams to be combined | `["intput_1", "input_2"]` | `None` | -| `c` | string | Output port where the combined streams to be passed | `output` | `None` | - -### Example JSON Configuration - -```json -{ - "input_ports": ["intput_1", "input_2"], - "output_port": "output" -} diff --git a/docs/source/modules/core/write_to_elasticsearch.md b/docs/source/modules/core/write_to_elasticsearch.md new file mode 100644 index 0000000000..7689b6e3fc --- /dev/null +++ b/docs/source/modules/core/write_to_elasticsearch.md @@ -0,0 +1,45 @@ + + +## Write to Elasticsearch Module + +This module reads an input data stream, converts each row of data to a document format suitable for Elasticsearch, and writes the documents to the specified Elasticsearch index using the Elasticsearch bulk API. + +### Configurable Parameters + +| Parameter | Type | Description | Example Value | Default Value | +|-------------------------|--------------|---------------------------------------------------------------------------------------------------------|-------------------------------|---------------| +| `index` | str | Elasticsearch index. | "my_index" | `[Required]` | +| `connection_kwargs` | dict | Elasticsearch connection kwargs configuration. | {"hosts": ["host": "localhost", ...} | `[Required]` | +| `raise_on_exception` | bool | Raise or suppress exceptions when writing to Elasticsearch. | true | `false` | +| `pickled_func_config` | str | Pickled custom function configuration to update connection_kwargs as needed for the client connection. | See below | None | +| `refresh_period_secs` | int | Time in seconds to refresh the client connection. | 3600 | `2400` | + +### Example JSON Configuration + +```json +{ + "index": "test_index", + "connection_kwargs": {"hosts": [{"host": "localhost", "port": 9200, "scheme": "http"}]}, + "raise_on_exception": true, + "pickled_func_config": { + "pickled_func_str": "pickled function as a string", + "encoding": "latin1" + }, + "refresh_period_secs": 2400 +} +``` diff --git a/docs/source/modules/index.md b/docs/source/modules/index.md index 042c35104c..d80fbe367a 100644 --- a/docs/source/modules/index.md +++ b/docs/source/modules/index.md @@ -31,10 +31,10 @@ limitations under the License. ./core/filter_detections.md ./core/from_control_message.md ./core/mlflow_model_writer.md -./core/multiplexer.md ./core/payload_batcher.md ./core/serialize.md ./core/to_control_message.md +./core/write_to_elasticsearch.md ./core/write_to_file.md ``` diff --git a/morpheus/cli/commands.py b/morpheus/cli/commands.py index fae791cf0f..14687cf219 100644 --- a/morpheus/cli/commands.py +++ b/morpheus/cli/commands.py @@ -675,6 +675,9 @@ def post_pipeline(ctx: click.Context, *args, **kwargs): add_command("preprocess", "morpheus.stages.preprocess.preprocess_nlp_stage.PreprocessNLPStage", modes=NLP_ONLY) add_command("serialize", "morpheus.stages.postprocess.serialize_stage.SerializeStage", modes=ALL) add_command("timeseries", "morpheus.stages.postprocess.timeseries_stage.TimeSeriesStage", modes=AE_ONLY) +add_command("to-elasticsearch", + "morpheus.stages.output.write_to_elasticsearch_stage.WriteToElasticsearchStage", + modes=ALL) add_command("to-file", "morpheus.stages.output.write_to_file_stage.WriteToFileStage", modes=ALL) add_command("to-kafka", "morpheus.stages.output.write_to_kafka_stage.WriteToKafkaStage", modes=ALL) add_command("to-http", "morpheus.stages.output.http_client_sink_stage.HttpClientSinkStage", modes=ALL) diff --git a/morpheus/controllers/elasticsearch_controller.py b/morpheus/controllers/elasticsearch_controller.py new file mode 100644 index 0000000000..1c4bca4dfa --- /dev/null +++ b/morpheus/controllers/elasticsearch_controller.py @@ -0,0 +1,166 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time + +import pandas as pd +from elasticsearch import ConnectionError as ESConnectionError +from elasticsearch import Elasticsearch +from elasticsearch.helpers import parallel_bulk + +logger = logging.getLogger(__name__) + + +class ElasticsearchController: + """ + ElasticsearchController to perform read and write operations using Elasticsearch service. + + Parameters + ---------- + connection_kwargs : dict + Keyword arguments to configure the Elasticsearch connection. + raise_on_exception : bool, optional, default: False + Whether to raise exceptions on Elasticsearch errors. + refresh_period_secs : int, optional, default: 2400 + The refresh period in seconds for client refreshing. + """ + + def __init__(self, connection_kwargs: dict, raise_on_exception: bool = False, refresh_period_secs: int = 2400): + + self._client = None + self._last_refresh_time = None + self._raise_on_exception = raise_on_exception + self._refresh_period_secs = refresh_period_secs + + if connection_kwargs is not None and not connection_kwargs: + raise ValueError("Connection kwargs cannot be none or empty.") + + self._connection_kwargs = connection_kwargs + + logger.debug("Creating Elasticsearch client with configuration: %s", connection_kwargs) + + self.refresh_client(force=True) + + logger.debug("Elasticsearch cluster info: %s", self._client.info) + logger.debug("Creating Elasticsearch client... Done!") + + def refresh_client(self, force: bool = False) -> bool: + """ + Refresh the Elasticsearch client instance. + + Parameters + ---------- + force : bool, optional, default: False + Force a client refresh. + + Returns + ------- + bool + Returns true if client is refreshed, otherwise false. + """ + + is_refreshed = False + time_now = time.time() + if force or self._client is None or time_now - self._last_refresh_time >= self._refresh_period_secs: + if self._client: + try: + # Close the existing client + self.close_client() + except Exception as ex: + logger.warning("Ignoring client close error: %s", ex) + logger.debug("Refreshing Elasticsearch client....") + + # Create Elasticsearch client + self._client = Elasticsearch(**self._connection_kwargs) + + # Check if the client is connected + if self._client.ping(): + logger.debug("Elasticsearch client is connected.") + else: + raise ESConnectionError("Elasticsearch client is not connected.") + + logger.debug("Refreshing Elasticsearch client.... Done!") + self._last_refresh_time = time.time() + is_refreshed = True + + return is_refreshed + + def parallel_bulk_write(self, actions) -> None: + """ + Perform parallel bulk writes to Elasticsearch. + + Parameters + ---------- + actions : list + List of actions to perform in parallel. + """ + + self.refresh_client() + + for success, info in parallel_bulk(self._client, actions=actions, raise_on_exception=self._raise_on_exception): + if not success: + logger.error("Error writing to ElasticSearch: %s", str(info)) + + def search_documents(self, index: str, query: dict, **kwargs) -> dict: + """ + Search for documents in Elasticsearch based on the given query. + + Parameters + ---------- + index : str + The name of the index to search. + query : dict + The DSL query for the search. + **kwargs + Additional keyword arguments that are supported by the Elasticsearch search method. + + Returns + ------- + dict + The search result returned by Elasticsearch. + """ + + try: + self.refresh_client() + result = self._client.search(index=index, query=query, **kwargs) + return result + except Exception as exc: + logger.error("Error searching documents: %s", exc) + if self._raise_on_exception: + raise RuntimeError(f"Error searching documents: {exc}") from exc + + return {} + + def df_to_parallel_bulk_write(self, index: str, df: pd.DataFrame) -> None: + """ + Converts DataFrames to actions and parallel bulk writes to Elasticsearch. + + Parameters + ---------- + index : str + The name of the index to write. + df : pd.DataFrame + DataFrame entries that require writing to Elasticsearch. + """ + + actions = [{"_index": index, "_source": row} for row in df.to_dict("records")] + + self.parallel_bulk_write(actions) # Parallel bulk upload to Elasticsearch + + def close_client(self) -> None: + """ + Close the Elasticsearch client connection. + """ + self._client.close() diff --git a/morpheus/modules/__init__.py b/morpheus/modules/__init__.py index 8d5d36e95a..fb5bb53025 100644 --- a/morpheus/modules/__init__.py +++ b/morpheus/modules/__init__.py @@ -14,6 +14,7 @@ """ Morpheus module definitions, each module is automatically registered when imported """ +from morpheus._lib import modules # When segment modules are imported, they're added to the module registry. # To avoid flake8 warnings about unused code, the noqa flag is used during import. from morpheus.modules import file_batcher @@ -26,8 +27,8 @@ from morpheus.modules import payload_batcher from morpheus.modules import serialize from morpheus.modules import to_control_message +from morpheus.modules import write_to_elasticsearch from morpheus.modules import write_to_file -from morpheus._lib import modules __all__ = [ "file_batcher", @@ -41,5 +42,6 @@ "payload_batcher", "serialize", "to_control_message", - "write_to_file" + "write_to_file", + "write_to_elasticsearch" ] diff --git a/morpheus/modules/write_to_elasticsearch.py b/morpheus/modules/write_to_elasticsearch.py new file mode 100644 index 0000000000..9d64a5e186 --- /dev/null +++ b/morpheus/modules/write_to_elasticsearch.py @@ -0,0 +1,82 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import pickle + +import mrc +from mrc.core import operators as ops + +from morpheus.controllers.elasticsearch_controller import ElasticsearchController +from morpheus.messages import ControlMessage +from morpheus.utils.module_ids import MORPHEUS_MODULE_NAMESPACE +from morpheus.utils.module_ids import WRITE_TO_ELASTICSEARCH +from morpheus.utils.module_utils import register_module + +logger = logging.getLogger(__name__) + + +@register_module(WRITE_TO_ELASTICSEARCH, MORPHEUS_MODULE_NAMESPACE) +def write_to_elasticsearch(builder: mrc.Builder): + """ + This module reads input data stream, converts each row of data to a document format suitable for Elasticsearch, + and writes the documents to the specified Elasticsearch index using the Elasticsearch bulk API. + + Parameters + ---------- + builder : mrc.Builder + An mrc Builder object. + """ + + config = builder.get_current_module_config() + + index = config.get("index", None) + + if index is None: + raise ValueError("Index must not be None.") + + connection_kwargs = config.get("connection_kwargs") + + if not isinstance(connection_kwargs, dict): + raise ValueError(f"Expects `connection_kwargs` as a dictionary, but it is of type {type(connection_kwargs)}") + + raise_on_exception = config.get("raise_on_exception", False) + pickled_func_config = config.get("pickled_func_config", None) + refresh_period_secs = config.get("refresh_period_secs", 2400) + + if pickled_func_config: + pickled_func_str = pickled_func_config.get("pickled_func_str") + encoding = pickled_func_config.get("encoding") + + if pickled_func_str and encoding: + connection_kwargs_update_func = pickle.loads(bytes(pickled_func_str, encoding)) + connection_kwargs = connection_kwargs_update_func(connection_kwargs) + + controller = ElasticsearchController(connection_kwargs=connection_kwargs, + raise_on_exception=raise_on_exception, + refresh_period_secs=refresh_period_secs) + + def on_data(message: ControlMessage): + + df = message.payload().df.to_pandas() + + controller.df_to_parallel_bulk_write(index=index, df=df) + + return message + + node = builder.make_node(WRITE_TO_ELASTICSEARCH, ops.map(on_data), ops.on_completed(controller.close_client)) + + # Register input and output port for a module. + builder.register_module_input("input", node) + builder.register_module_output("output", node) diff --git a/morpheus/stages/output/write_to_elasticsearch_stage.py b/morpheus/stages/output/write_to_elasticsearch_stage.py new file mode 100644 index 0000000000..d0bb062d6c --- /dev/null +++ b/morpheus/stages/output/write_to_elasticsearch_stage.py @@ -0,0 +1,131 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Write to Elasticsearch stage.""" + +import logging +import typing + +import mrc +import mrc.core.operators as ops +import yaml + +import cudf + +from morpheus.cli.register_stage import register_stage +from morpheus.config import Config +from morpheus.controllers.elasticsearch_controller import ElasticsearchController +from morpheus.messages import MessageMeta +from morpheus.pipeline.single_port_stage import SinglePortStage +from morpheus.pipeline.stream_pair import StreamPair + +logger = logging.getLogger(__name__) + + +@register_stage("to-elasticsearch", ignore_args=["connection_kwargs_update_func"]) +class WriteToElasticsearchStage(SinglePortStage): + """ + + This class writes the messages as documents to Elasticsearch. + + Parameters + ---------- + config : `morpheus.config.Config` + Pipeline configuration instance. + index : str + Logical namespace that holds a collection of documents. + connection_conf_file : str + YAML configuration file for Elasticsearch connection kwargs settings. + raise_on_exception : bool, optional, default: False + Whether to raise exceptions on Elasticsearch errors. + refresh_period_secs : int, optional, default: 2400 + The refresh period in seconds for client refreshing. + connection_kwargs_update_func : typing.Callable, optional, default: None + Custom function to update connection parameters. + """ + + def __init__(self, + config: Config, + index: str, + connection_conf_file: str, + raise_on_exception: bool = False, + refresh_period_secs: int = 2400, + connection_kwargs_update_func: typing.Callable = None): + + super().__init__(config) + + self._index = index + + try: + with open(connection_conf_file, "r", encoding="utf-8") as file: + connection_kwargs = yaml.safe_load(file) + except FileNotFoundError as exc: + raise FileNotFoundError( + f"The specified connection configuration file '{connection_conf_file}' does not exist.") from exc + except Exception as exc: + raise RuntimeError(f"An error occurred while loading the configuration file: {exc}") from exc + + if connection_kwargs_update_func: + connection_kwargs = connection_kwargs_update_func(connection_kwargs) + + self._controller = ElasticsearchController(connection_kwargs=connection_kwargs, + raise_on_exception=raise_on_exception, + refresh_period_secs=refresh_period_secs) + + @property + def name(self) -> str: + """Returns the name of this stage.""" + return "to-elasticsearch" + + def accepted_types(self) -> typing.Tuple: + """ + Returns accepted input types for this stage. + + Returns + ------- + typing.Tuple(`morpheus.pipeline.messages.MessageMeta`, ) + Accepted input types. + + """ + return (MessageMeta, ) + + def supports_cpp_node(self): + """Indicates whether this stage supports a C++ node.""" + return False + + def _build_single(self, builder: mrc.Builder, input_stream: StreamPair) -> StreamPair: + + stream = input_stream[0] + + def on_data(meta: MessageMeta) -> MessageMeta: + + self._controller.refresh_client() + + df = meta.copy_dataframe() + if isinstance(df, cudf.DataFrame): + df = df.to_pandas() + logger.debug("Converted cudf of size: %s to pandas dataframe.", len(df)) + + self._controller.df_to_parallel_bulk_write(index=self._index, df=df) + + return meta + + to_elasticsearch = builder.make_node(self.unique_name, + ops.map(on_data), + ops.on_completed(self._controller.close_client)) + + builder.make_edge(stream, to_elasticsearch) + stream = to_elasticsearch + + # Return input unchanged to allow passthrough + return stream, input_stream[1] diff --git a/morpheus/utils/module_ids.py b/morpheus/utils/module_ids.py index dfb704b45e..2626859871 100644 --- a/morpheus/utils/module_ids.py +++ b/morpheus/utils/module_ids.py @@ -21,9 +21,9 @@ FILTER_DETECTIONS = "FilterDetections" FROM_CONTROL_MESSAGE = "FromControlMessage" MLFLOW_MODEL_WRITER = "MLFlowModelWriter" -MULTIPLEXER = "Multiplexer" SERIALIZE = "Serialize" TO_CONTROL_MESSAGE = "ToControlMessage" WRITE_TO_FILE = "WriteToFile" FILTER_CM_FAILED = "FilterCmFailed" PAYLOAD_BATCHER = "PayloadBatcher" +WRITE_TO_ELASTICSEARCH = "WriteToElasticsearch" diff --git a/tests/controllers/test_elasticsearch_controller.py b/tests/controllers/test_elasticsearch_controller.py new file mode 100644 index 0000000000..dac43d5aee --- /dev/null +++ b/tests/controllers/test_elasticsearch_controller.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import typing +from unittest.mock import patch + +import pandas as pd +import pytest +from elasticsearch import Elasticsearch + +from morpheus.controllers.elasticsearch_controller import ElasticsearchController + + +@pytest.fixture(scope="function", autouse=True) +def patch_elasticsearch() -> Elasticsearch: + with patch("morpheus.controllers.elasticsearch_controller.Elasticsearch", autospec=True): + yield + + +@pytest.fixture(scope="module", name="connection_kwargs") +def connection_kwargs_fixture() -> dict: + kwargs = {"hosts": [{"host": "localhost", "port": 9200, "scheme": "http"}]} + yield kwargs + + +@pytest.fixture(scope="module", name="create_controller") +def create_controller_fixture(connection_kwargs) -> typing.Callable[..., ElasticsearchController]: + + def inner_create_controller(*, connection_kwargs=connection_kwargs, refresh_period_secs=10, **controller_kwargs): + return ElasticsearchController(connection_kwargs=connection_kwargs, + refresh_period_secs=refresh_period_secs, + **controller_kwargs) + + yield inner_create_controller + + +@pytest.mark.use_python +def test_constructor(create_controller: typing.Callable[..., ElasticsearchController], connection_kwargs: dict): + assert create_controller(raise_on_exception=True)._raise_on_exception is True + assert create_controller(refresh_period_secs=1.5)._refresh_period_secs == 1.5 + assert create_controller()._connection_kwargs == connection_kwargs + + +@pytest.mark.use_python +def test_refresh_client_force(create_controller: typing.Callable[..., ElasticsearchController]): + controller = create_controller(refresh_period_secs=1) + + client = controller._client + is_refreshed = controller.refresh_client(force=True) + + controller._client.close.assert_called_once() + assert client.ping.call_count == 2 + assert is_refreshed is True + assert controller._last_refresh_time > 0 + + +@pytest.mark.use_python +def test_refresh_client_not_needed(create_controller: typing.Callable[..., ElasticsearchController]): + controller = create_controller() + client = controller._client + + # Simulate a refresh not needed scenario + is_refreshed = controller.refresh_client() + + client.close.assert_not_called() + assert client.ping.call_count == 1 + assert is_refreshed is False + + +@pytest.mark.use_python +def test_refresh_client_needed(create_controller: typing.Callable[..., ElasticsearchController]): + + # Set a 1 second refresh period + controller = create_controller(refresh_period_secs=1) + client = controller._client + + is_refreshed = False + # Now "sleep" for more than 1 second to trigger a new client + with patch("time.time", return_value=time.time() + 1): + is_refreshed = controller.refresh_client() + + client.close.assert_called_once() + assert client.ping.call_count == 2 + assert is_refreshed is True + + +@pytest.mark.use_python +@patch("morpheus.controllers.elasticsearch_controller.parallel_bulk", return_value=[(True, None)]) +def test_parallel_bulk_write(mock_parallel_bulk, create_controller: typing.Callable[..., ElasticsearchController]): + # Define your mock actions + mock_actions = [{"_index": "test_index", "_id": 1, "_source": {"field1": "value1"}}] + + create_controller().parallel_bulk_write(actions=mock_actions) + mock_parallel_bulk.assert_called_once() + + +@pytest.mark.use_python +@patch("morpheus.controllers.elasticsearch_controller.parallel_bulk", return_value=[(True, None)]) +def test_df_to_parallel_bulk_write(mock_parallel_bulk: typing.Callable, + create_controller: typing.Callable[..., ElasticsearchController]): + data = {"field1": ["value1", "value2"], "field2": ["value3", "value4"]} + df = pd.DataFrame(data) + + expected_actions = [{ + "_index": "test_index", "_source": { + "field1": "value1", "field2": "value3" + } + }, { + "_index": "test_index", "_source": { + "field1": "value2", "field2": "value4" + } + }] + + controller = create_controller() + controller.df_to_parallel_bulk_write(index="test_index", df=df) + mock_parallel_bulk.assert_called_once_with(controller._client, + actions=expected_actions, + raise_on_exception=controller._raise_on_exception) + + +def test_search_documents_success(create_controller: typing.Callable[..., ElasticsearchController]): + controller = create_controller() + controller._client.search.return_value = {"hits": {"total": 1, "hits": [{"_source": {"field1": "value1"}}]}} + query = {"match": {"field1": "value1"}} + result = controller.search_documents(index="test_index", query=query) + + assert isinstance(result, dict) + assert "hits" in result + assert "total" in result["hits"] + assert result["hits"]["total"] == 1 + + +def test_search_documents_failure_supress_errors(create_controller: typing.Callable[..., ElasticsearchController]): + controller = create_controller() + controller._client.search.side_effect = ConnectionError("Connection error") + query = {"match": {"field1": "value1"}} + result = controller.search_documents(index="test_index", query=query) + + assert isinstance(result, dict) + assert not result + + +def test_search_documents_failure_raise_error(create_controller: typing.Callable[..., ElasticsearchController]): + controller = create_controller(raise_on_exception=True) + controller._client.search.side_effect = Exception + query = {"match": {"field1": "value1"}} + + with pytest.raises(RuntimeError): + controller.search_documents(index="test_index", query=query) diff --git a/tests/test_write_to_elasticsearch_stage_pipe.py b/tests/test_write_to_elasticsearch_stage_pipe.py new file mode 100644 index 0000000000..07c9929526 --- /dev/null +++ b/tests/test_write_to_elasticsearch_stage_pipe.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing +from unittest.mock import patch + +import pandas as pd +import pytest +import yaml + +import cudf + +from morpheus.config import Config +from morpheus.pipeline.linear_pipeline import LinearPipeline +from morpheus.stages.input.in_memory_source_stage import InMemorySourceStage +from morpheus.stages.output.write_to_elasticsearch_stage import WriteToElasticsearchStage + + +def connection_kwargs_func(kwargs): + kwargs["retry_on_status"] = 3 + kwargs["retry_on_timeout"] = 3 * 10 + + return kwargs + + +@pytest.fixture(scope="function", name="connection_conf_file") +def connection_conf_file_fixture(tmp_path): + connection_kwargs = {"hosts": [{"host": "localhost", "port": 9201, "scheme": "http"}]} + + connection_conf_file = tmp_path / "connection_kwargs_conf.yaml" + with connection_conf_file.open(mode="w") as file: + yaml.dump(connection_kwargs, file) + + yield connection_conf_file + + +@pytest.mark.use_python +@pytest.mark.parametrize("conf_file, exception", [("connection_conf.yaml", FileNotFoundError), (None, Exception)]) +def test_constructor_invalid_conf_file(config: Config, + conf_file: str, + exception: typing.Union[Exception, FileNotFoundError]): + with pytest.raises(exception): + WriteToElasticsearchStage(config, index="t_index", connection_conf_file=conf_file) + + +@pytest.mark.use_python +@patch("morpheus.controllers.elasticsearch_controller.Elasticsearch") +def test_constructor_with_custom_func(config: Config, connection_conf_file: str): + expected_connection_kwargs = { + "hosts": [{ + "host": "localhost", "port": 9201, "scheme": "http" + }], "retry_on_status": 3, "retry_on_timeout": 30 + } + + stage = WriteToElasticsearchStage(config, + index="t_index", + connection_conf_file=connection_conf_file, + connection_kwargs_update_func=connection_kwargs_func) + + assert stage._controller._connection_kwargs == expected_connection_kwargs + + +@pytest.mark.use_python +@patch("morpheus.stages.output.write_to_elasticsearch_stage.ElasticsearchController") +def test_write_to_elasticsearch_stage_pipe(mock_controller: typing.Any, + connection_conf_file: str, + config: Config, + filter_probs_df: typing.Union[cudf.DataFrame, pd.DataFrame]): + mock_df_to_parallel_bulk_write = mock_controller.return_value.df_to_parallel_bulk_write + mock_refresh_client = mock_controller.return_value.refresh_client + + # Create a pipeline + pipe = LinearPipeline(config) + + # Add the source stage and the WriteToElasticsearchStage to the pipeline + pipe.set_source(InMemorySourceStage(config, [filter_probs_df])) + pipe.add_stage(WriteToElasticsearchStage(config, index="t_index", connection_conf_file=connection_conf_file)) + + # Run the pipeline + pipe.run() + + if isinstance(filter_probs_df, cudf.DataFrame): + filter_probs_df = filter_probs_df.to_pandas() + + expected_index = mock_df_to_parallel_bulk_write.call_args[1]["index"] + expected_df = mock_df_to_parallel_bulk_write.call_args[1]["df"] + + mock_refresh_client.assert_called_once() + mock_df_to_parallel_bulk_write.assert_called_once() + + assert expected_index == "t_index" + assert expected_df.equals(filter_probs_df)