diff --git a/docs/changelog/next_release/164.bugfix.rst b/docs/changelog/next_release/164.bugfix.rst new file mode 100644 index 000000000..e9c591100 --- /dev/null +++ b/docs/changelog/next_release/164.bugfix.rst @@ -0,0 +1 @@ +Fix ``Hive.check()`` behavior when Hive Metastore is not available. diff --git a/docs/changelog/next_release/164.improvement.rst b/docs/changelog/next_release/164.improvement.rst new file mode 100644 index 000000000..09799b8d2 --- /dev/null +++ b/docs/changelog/next_release/164.improvement.rst @@ -0,0 +1 @@ +Add check to all DB and FileDF connections that Spark session is alive. diff --git a/onetl/connection/db_connection/db_connection/connection.py b/onetl/connection/db_connection/db_connection/connection.py index 315f5b17c..223c7c63c 100644 --- a/onetl/connection/db_connection/db_connection/connection.py +++ b/onetl/connection/db_connection/db_connection/connection.py @@ -17,7 +17,7 @@ from logging import getLogger from typing import TYPE_CHECKING -from pydantic import Field +from pydantic import Field, validator from onetl._util.spark import try_import_pyspark from onetl.base import BaseDBConnection @@ -48,6 +48,16 @@ def _forward_refs(cls) -> dict[str, type]: refs["SparkSession"] = SparkSession return refs + @validator("spark") + def _check_spark_session_alive(cls, spark): + try: + spark.sql("SELECT 1").collect() + except Exception as e: + msg = "Spark session is stopped. Please recreate Spark session." + raise ValueError(msg) from e + + return spark + def _log_parameters(self): log.info("|%s| Using connection parameters:", self.__class__.__name__) parameters = self.dict(exclude_none=True, exclude={"spark"}) diff --git a/onetl/connection/db_connection/hive/connection.py b/onetl/connection/db_connection/hive/connection.py index 740d09b44..97cc034f5 100644 --- a/onetl/connection/db_connection/hive/connection.py +++ b/onetl/connection/db_connection/hive/connection.py @@ -146,7 +146,7 @@ class Hive(DBConnection): # TODO: remove in v1.0.0 slots = HiveSlots - _CHECK_QUERY: ClassVar[str] = "SELECT 1" + _CHECK_QUERY: ClassVar[str] = "SHOW DATABASES" @slot @classmethod @@ -207,7 +207,7 @@ def check(self): log_lines(log, self._CHECK_QUERY, level=logging.DEBUG) try: - self._execute_sql(self._CHECK_QUERY) + self._execute_sql(self._CHECK_QUERY).limit(1).collect() log.info("|%s| Connection is available.", self.__class__.__name__) except Exception as e: log.exception("|%s| Connection is unavailable", self.__class__.__name__) diff --git a/onetl/connection/db_connection/mongodb/connection.py b/onetl/connection/db_connection/mongodb/connection.py index 860f7b215..280596d5d 100644 --- a/onetl/connection/db_connection/mongodb/connection.py +++ b/onetl/connection/db_connection/mongodb/connection.py @@ -507,6 +507,7 @@ def write_df_to_target( ) if self._collection_exists(target): + # MongoDB connector does not support mode=ignore and mode=error if write_options.if_exists == MongoDBCollectionExistBehavior.ERROR: raise ValueError("Operation stopped due to MongoDB.WriteOptions(if_exists='error')") elif write_options.if_exists == MongoDBCollectionExistBehavior.IGNORE: diff --git a/onetl/connection/file_df_connection/spark_file_df_connection.py b/onetl/connection/file_df_connection/spark_file_df_connection.py index 7c1994182..e355f5c4b 100644 --- a/onetl/connection/file_df_connection/spark_file_df_connection.py +++ b/onetl/connection/file_df_connection/spark_file_df_connection.py @@ -19,7 +19,7 @@ from logging import getLogger from typing import TYPE_CHECKING -from pydantic import Field +from pydantic import Field, validator from onetl._util.hadoop import get_hadoop_config from onetl._util.spark import try_import_pyspark @@ -182,6 +182,16 @@ def _forward_refs(cls) -> dict[str, type]: refs["SparkSession"] = SparkSession return refs + @validator("spark") + def _check_spark_session_alive(cls, spark): + try: + spark.sql("SELECT 1").collect() + except Exception as e: + msg = "Spark session is stopped. Please recreate Spark session." + raise ValueError(msg) from e + + return spark + def _log_parameters(self): log.info("|%s| Using connection parameters:", self.__class__.__name__) parameters = self.dict(exclude_none=True, exclude={"spark"}) diff --git a/tests/tests_unit/tests_db_connection_unit/test_hive_unit.py b/tests/tests_unit/tests_db_connection_unit/test_hive_unit.py index 6469b10c8..a9505eed1 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_hive_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_hive_unit.py @@ -60,8 +60,6 @@ def normalize_cluster_name(cluster: str) -> str: def test_hive_known_get_current_cluster_hook(request, spark_mock, mocker): - mocker.patch.object(Hive, "_execute_sql", return_value=None) - # no exception Hive(cluster="rnd-prod", spark=spark_mock).check() Hive(cluster="rnd-dwh", spark=spark_mock).check()