diff --git a/docs/changelog/0.12.0.rst b/docs/changelog/0.12.0.rst new file mode 100644 index 00000000..2a299fdc --- /dev/null +++ b/docs/changelog/0.12.0.rst @@ -0,0 +1,4 @@ +Bug Fixes +--------- + +- Fix passing ``Greenplum(extra={"options": ...)`` during read/write operations. (:github:pull:`308`) diff --git a/docs/changelog/index.rst b/docs/changelog/index.rst index 4bdac946..a5ab166b 100644 --- a/docs/changelog/index.rst +++ b/docs/changelog/index.rst @@ -3,6 +3,7 @@ :caption: Changelog DRAFT + 0.11.2 0.11.1 0.11.0 0.10.2 diff --git a/onetl/VERSION b/onetl/VERSION index af88ba82..bc859cbd 100644 --- a/onetl/VERSION +++ b/onetl/VERSION @@ -1 +1 @@ -0.11.1 +0.11.2 diff --git a/onetl/connection/db_connection/greenplum/connection.py b/onetl/connection/db_connection/greenplum/connection.py index 7ed60539..5f77e04a 100644 --- a/onetl/connection/db_connection/greenplum/connection.py +++ b/onetl/connection/db_connection/greenplum/connection.py @@ -7,6 +7,7 @@ import textwrap import warnings from typing import TYPE_CHECKING, Any, ClassVar +from urllib.parse import quote, urlencode, urlparse, urlunparse from etl_entities.instance import Host @@ -271,17 +272,20 @@ def instance_url(self) -> str: def jdbc_url(self) -> str: return f"jdbc:postgresql://{self.host}:{self.port}/{self.database}" + @property + def jdbc_custom_params(self) -> dict: + result = { + key: value + for key, value in self.extra.dict(by_alias=True).items() + if not (key.startswith("server.") or key.startswith("pool.")) + } + result["ApplicationName"] = result.get("ApplicationName", self.spark.sparkContext.appName) + return result + @property def jdbc_params(self) -> dict: result = super().jdbc_params - result.update( - { - key: value - for key, value in self.extra.dict(by_alias=True).items() - if not (key.startswith("server.") or key.startswith("pool.")) - }, - ) - result["ApplicationName"] = result.get("ApplicationName", self.spark.sparkContext.appName) + result.update(self.jdbc_custom_params) return result @slot @@ -302,7 +306,7 @@ def read_source_as_df( fake_query_for_log = self.dialect.get_sql_query(table=source, columns=columns, where=where, limit=limit) log_lines(log, fake_query_for_log) - df = self.spark.read.format("greenplum").options(**self._connector_params(source), **read_options).load() + df = self.spark.read.format("greenplum").options(**self._get_connector_params(source), **read_options).load() self._check_expected_jobs_number(df, action="read") if where: @@ -337,7 +341,7 @@ def write_df_to_target( else write_options.if_exists.value ) df.write.format("greenplum").options( - **self._connector_params(target), + **self._get_connector_params(target), **options_dict, ).mode(mode).save() @@ -422,21 +426,31 @@ def _check_java_class_imported(cls, spark): raise ValueError(msg) from e return spark - def _connector_params( + def _get_connector_params( self, table: str, ) -> dict: schema, table_name = table.split(".") # noqa: WPS414 extra = self.extra.dict(by_alias=True, exclude_none=True) - extra = {key: value for key, value in extra.items() if key.startswith("server.") or key.startswith("pool.")} + greenplum_connector_options = { + key: value for key, value in extra.items() if key.startswith("server.") or key.startswith("pool.") + } + + # Greenplum connector requires all JDBC params to be passed via JDBC URL: + # https://docs.vmware.com/en/VMware-Greenplum-Connector-for-Apache-Spark/2.3/greenplum-connector-spark/using_the_connector.html#specifying-session-parameters + parsed_jdbc_url = urlparse(self.jdbc_url) + sorted_jdbc_params = [(k, v) for k, v in sorted(self.jdbc_custom_params.items(), key=lambda x: x[0].lower())] + jdbc_url_query = urlencode(sorted_jdbc_params, quote_via=quote) + jdbc_url = urlunparse(parsed_jdbc_url._replace(query=jdbc_url_query)) + return { "driver": self.DRIVER, - "url": self.jdbc_url, + "url": jdbc_url, "user": self.user, "password": self.password.get_secret_value(), "dbschema": schema, "dbtable": table_name, - **extra, + **greenplum_connector_options, } def _options_to_connection_properties(self, options: JDBCFetchOptions | JDBCExecuteOptions): diff --git a/tests/tests_unit/tests_db_connection_unit/test_greenplum_unit.py b/tests/tests_unit/tests_db_connection_unit/test_greenplum_unit.py index 0d382d44..473562a5 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_greenplum_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_greenplum_unit.py @@ -128,6 +128,14 @@ def test_greenplum(spark_mock): "ApplicationName": "abc", "tcpKeepAlive": "true", } + assert conn._get_connector_params("some.table") == { + "user": "user", + "password": "passwd", + "driver": "org.postgresql.Driver", + "url": "jdbc:postgresql://some_host:5432/database?ApplicationName=abc&tcpKeepAlive=true", + "dbschema": "some", + "dbtable": "table", + } assert "password='passwd'" not in str(conn) assert "password='passwd'" not in repr(conn) @@ -154,6 +162,14 @@ def test_greenplum_with_port(spark_mock): "ApplicationName": "abc", "tcpKeepAlive": "true", } + assert conn._get_connector_params("some.table") == { + "user": "user", + "password": "passwd", + "driver": "org.postgresql.Driver", + "url": "jdbc:postgresql://some_host:5000/database?ApplicationName=abc&tcpKeepAlive=true", + "dbschema": "some", + "dbtable": "table", + } assert conn.instance_url == "greenplum://some_host:5000/database" @@ -173,6 +189,7 @@ def test_greenplum_with_extra(spark_mock): "autosave": "always", "tcpKeepAlive": "false", "ApplicationName": "override", + "options": "-c search_path=public", "server.port": 8000, "pool.maxSize": 40, }, @@ -190,6 +207,17 @@ def test_greenplum_with_extra(spark_mock): "ApplicationName": "override", "tcpKeepAlive": "false", "autosave": "always", + "options": "-c search_path=public", + } + assert conn._get_connector_params("some.table") == { + "user": "user", + "password": "passwd", + "driver": "org.postgresql.Driver", + "url": "jdbc:postgresql://some_host:5432/database?ApplicationName=override&autosave=always&options=-c%20search_path%3Dpublic&tcpKeepAlive=false", + "dbschema": "some", + "dbtable": "table", + "pool.maxSize": 40, + "server.port": 8000, }