From 4a42cd5da4c2673122aaf2c3c24e7d59fd5b392a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=9C=D0=B0=D1=80=D1=82=D1=8B=D0=BD=D0=BE=D0=B2=20=D0=9C?= =?UTF-8?q?=D0=B0=D0=BA=D1=81=D0=B8=D0=BC=20=D0=A1=D0=B5=D1=80=D0=B3=D0=B5?= =?UTF-8?q?=D0=B5=D0=B2=D0=B8=D1=87?= Date: Fri, 26 Apr 2024 14:41:12 +0000 Subject: [PATCH] [DOP-15564] Avoid urlencoding JDBC params --- docker-compose.yml | 2 + docs/changelog/next_release/268.feature.rst | 1 + .../db_connection/clickhouse/connection.py | 13 ++-- .../db_connection/greenplum/connection.py | 21 +++--- .../jdbc_connection/connection.py | 40 ++---------- .../db_connection/jdbc_mixin/connection.py | 27 ++++---- .../db_connection/mssql/connection.py | 11 ++-- .../db_connection/mysql/connection.py | 15 +++-- .../db_connection/oracle/connection.py | 13 ++-- .../db_connection/postgres/connection.py | 15 +++-- .../db_connection/teradata/connection.py | 21 ++++-- .../test_clickhouse_integration.py | 15 +++++ .../test_greenplum_integration.py | 15 +++++ .../test_mssql_integration.py | 15 +++++ .../test_mysql_integration.py | 15 +++++ .../test_oracle_integration.py | 16 +++++ .../test_clickhouse_unit.py | 29 ++++++++- .../test_greenplum_unit.py | 33 ++++++++-- .../test_jdbc_options_unit.py | 64 ------------------- .../test_mssql_unit.py | 33 ++++++++-- .../test_mysql_unit.py | 55 +++++++++++++--- .../test_oracle_unit.py | 28 +++++++- .../test_postgres_unit.py | 46 +++++++++++-- .../test_teradata_unit.py | 34 +++++++++- 24 files changed, 395 insertions(+), 182 deletions(-) create mode 100644 docs/changelog/next_release/268.feature.rst diff --git a/docker-compose.yml b/docker-compose.yml index 54b2af91d..5951ca926 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -31,6 +31,8 @@ services: - 5433:5432 networks: - onetl + sysctls: + - net.ipv6.conf.all.disable_ipv6=1 clickhouse: image: ${CLICKHOUSE_IMAGE:-clickhouse/clickhouse-server:latest-alpine} diff --git a/docs/changelog/next_release/268.feature.rst b/docs/changelog/next_release/268.feature.rst new file mode 100644 index 000000000..0938462ed --- /dev/null +++ b/docs/changelog/next_release/268.feature.rst @@ -0,0 +1 @@ +Allow passing JDBC connection extra params without urlencode. diff --git a/onetl/connection/db_connection/clickhouse/connection.py b/onetl/connection/db_connection/clickhouse/connection.py index 2f22de9bf..89b7ff463 100644 --- a/onetl/connection/db_connection/clickhouse/connection.py +++ b/onetl/connection/db_connection/clickhouse/connection.py @@ -162,13 +162,16 @@ def package(self) -> str: @property def jdbc_url(self) -> str: - extra = self.extra.dict(by_alias=True) - parameters = "&".join(f"{k}={v}" for k, v in sorted(extra.items())) - if self.database: - return f"jdbc:clickhouse://{self.host}:{self.port}/{self.database}?{parameters}".rstrip("?") + return f"jdbc:clickhouse://{self.host}:{self.port}/{self.database}" + + return f"jdbc:clickhouse://{self.host}:{self.port}" - return f"jdbc:clickhouse://{self.host}:{self.port}?{parameters}".rstrip("?") + @property + def jdbc_params(self) -> dict: + result = super().jdbc_params + result.update(self.extra.dict(by_alias=True)) + return result @staticmethod def _build_statement( diff --git a/onetl/connection/db_connection/greenplum/connection.py b/onetl/connection/db_connection/greenplum/connection.py index c3e6b3b9f..120d58008 100644 --- a/onetl/connection/db_connection/greenplum/connection.py +++ b/onetl/connection/db_connection/greenplum/connection.py @@ -250,15 +250,20 @@ def instance_url(self) -> str: @property def jdbc_url(self) -> str: - extra = { - key: value - for key, value in self.extra.dict(by_alias=True).items() - if not (key.startswith("server.") or key.startswith("pool.")) - } - extra["ApplicationName"] = extra.get("ApplicationName", self.spark.sparkContext.appName) + return f"jdbc:postgresql://{self.host}:{self.port}/{self.database}" - parameters = "&".join(f"{k}={v}" for k, v in sorted(extra.items())) - return f"jdbc:postgresql://{self.host}:{self.port}/{self.database}?{parameters}".rstrip("?") + @property + def jdbc_params(self) -> dict: + result = super().jdbc_params + result.update( + { + key: value + for key, value in self.extra.dict(by_alias=True).items() + if not (key.startswith("server.") or key.startswith("pool.")) + }, + ) + result["ApplicationName"] = result.get("ApplicationName", self.spark.sparkContext.appName) + return result @slot def read_source_as_df( diff --git a/onetl/connection/db_connection/jdbc_connection/connection.py b/onetl/connection/db_connection/jdbc_connection/connection.py index 3133d3671..616e5fd29 100644 --- a/onetl/connection/db_connection/jdbc_connection/connection.py +++ b/onetl/connection/db_connection/jdbc_connection/connection.py @@ -165,7 +165,7 @@ def write_df_to_target( options: JDBCWriteOptions | None = None, ) -> None: write_options = self.WriteOptions.parse(options) - jdbc_params = self.options_to_jdbc_params(write_options) + jdbc_properties = self._get_jdbc_properties(write_options, exclude={"if_exists"}, exclude_none=True) mode = ( "overwrite" @@ -173,7 +173,7 @@ def write_df_to_target( else write_options.if_exists.value ) log.info("|%s| Saving data to a table %r", self.__class__.__name__, target) - df.write.jdbc(table=target, mode=mode, **jdbc_params) + df.write.format("jdbc").mode(mode).options(dbtable=target, **jdbc_properties).save() log.info("|%s| Table %r successfully written", self.__class__.__name__, target) @slot @@ -196,38 +196,6 @@ def get_df_schema( return df.schema - def options_to_jdbc_params( - self, - options: JDBCReadOptions | JDBCWriteOptions, - ) -> dict: - # Have to replace the parameter with - # since the method takes the named parameter - # link to source below - # https://github.com/apache/spark/blob/2ef8ced27a6b0170a691722a855d3886e079f037/python/pyspark/sql/readwriter.py#L465 - - partition_column = getattr(options, "partition_column", None) - if partition_column: - options = options.copy( - update={"column": partition_column}, - exclude={"partition_column"}, - ) - - result = self._get_jdbc_properties( - options, - include=READ_TOP_LEVEL_OPTIONS | WRITE_TOP_LEVEL_OPTIONS, - exclude={"if_exists"}, - exclude_none=True, - ) - - result["properties"] = self._get_jdbc_properties( - options, - exclude=READ_TOP_LEVEL_OPTIONS | WRITE_TOP_LEVEL_OPTIONS | {"if_exists"}, - exclude_none=True, - ) - - result["properties"].pop("partitioningMode", None) - return result - @slot def get_min_max_values( self, @@ -275,8 +243,8 @@ def _query_on_executor( query: str, options: JDBCReadOptions, ) -> DataFrame: - jdbc_params = self.options_to_jdbc_params(options) - return self.spark.read.jdbc(table=f"({query}) T", **jdbc_params) + jdbc_properties = self._get_jdbc_properties(options, exclude={"partitioning_mode"}, exclude_none=True) + return self.spark.read.format("jdbc").options(dbtable=f"({query}) T", **jdbc_properties).load() def _exclude_partition_options( self, diff --git a/onetl/connection/db_connection/jdbc_mixin/connection.py b/onetl/connection/db_connection/jdbc_mixin/connection.py index 856d387cf..dae2242b5 100644 --- a/onetl/connection/db_connection/jdbc_mixin/connection.py +++ b/onetl/connection/db_connection/jdbc_mixin/connection.py @@ -76,6 +76,16 @@ class JDBCMixin(FrozenModel): def jdbc_url(self) -> str: """JDBC Connection URL""" + @property + def jdbc_params(self) -> dict: + """JDBC Connection params""" + return { + "user": self.user, + "password": self.password.get_secret_value() if self.password is not None else "", + "driver": self.DRIVER, + "url": self.jdbc_url, + } + @slot def close(self): """ @@ -312,20 +322,12 @@ def _get_jdbc_properties( self, options: JDBCMixinOptions, **kwargs, - ) -> dict: + ) -> dict[str, str]: """ Fills up human-readable Options class to a format required by Spark internal methods """ - - result = options.copy( - update={ - "user": self.user, - "password": self.password.get_secret_value() if self.password is not None else "", - "driver": self.DRIVER, - "url": self.jdbc_url, - }, - ).dict(by_alias=True, **kwargs) - + result = self.jdbc_params + result.update(options.dict(by_alias=True, **kwargs)) return stringify(result) def _options_to_connection_properties(self, options: JDBCMixinOptions): @@ -339,8 +341,7 @@ def _options_to_connection_properties(self, options: JDBCMixinOptions): * https://github.com/apache/spark/blob/v2.3.0/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala#L248-L255 """ - jdbc_properties = self._get_jdbc_properties(options, exclude_unset=True) - + jdbc_properties = self._get_jdbc_properties(options, exclude_none=True) jdbc_utils_package = self.spark._jvm.org.apache.spark.sql.execution.datasources.jdbc # type: ignore jdbc_options = jdbc_utils_package.JDBCOptions( self.jdbc_url, diff --git a/onetl/connection/db_connection/mssql/connection.py b/onetl/connection/db_connection/mssql/connection.py index 8fa91b47a..48143191d 100644 --- a/onetl/connection/db_connection/mssql/connection.py +++ b/onetl/connection/db_connection/mssql/connection.py @@ -200,11 +200,14 @@ def package(cls) -> str: @property def jdbc_url(self) -> str: - prop = self.extra.dict(by_alias=True) - prop["databaseName"] = self.database - parameters = ";".join(f"{k}={v}" for k, v in sorted(prop.items())) + return f"jdbc:sqlserver://{self.host}:{self.port}" - return f"jdbc:sqlserver://{self.host}:{self.port};{parameters}" + @property + def jdbc_params(self) -> dict: + result = super().jdbc_params + result.update(self.extra.dict(by_alias=True)) + result["databaseName"] = self.database + return result @property def instance_url(self) -> str: diff --git a/onetl/connection/db_connection/mysql/connection.py b/onetl/connection/db_connection/mysql/connection.py index da71de55b..a26f8f385 100644 --- a/onetl/connection/db_connection/mysql/connection.py +++ b/onetl/connection/db_connection/mysql/connection.py @@ -138,11 +138,14 @@ def package(cls) -> str: return "com.mysql:mysql-connector-j:8.3.0" @property - def jdbc_url(self): - prop = self.extra.dict(by_alias=True) - parameters = "&".join(f"{k}={v}" for k, v in sorted(prop.items())) - + def jdbc_url(self) -> str: if self.database: - return f"jdbc:mysql://{self.host}:{self.port}/{self.database}?{parameters}" + return f"jdbc:mysql://{self.host}:{self.port}/{self.database}" + + return f"jdbc:mysql://{self.host}:{self.port}" - return f"jdbc:mysql://{self.host}:{self.port}?{parameters}" + @property + def jdbc_params(self) -> dict: + result = super().jdbc_params + result.update(self.extra.dict(by_alias=True)) + return result diff --git a/onetl/connection/db_connection/oracle/connection.py b/onetl/connection/db_connection/oracle/connection.py index a2d8d35b9..d566fa275 100644 --- a/onetl/connection/db_connection/oracle/connection.py +++ b/onetl/connection/db_connection/oracle/connection.py @@ -226,13 +226,16 @@ def package(cls) -> str: @property def jdbc_url(self) -> str: - extra = self.extra.dict(by_alias=True) - parameters = "&".join(f"{k}={v}" for k, v in sorted(extra.items())) - if self.sid: - return f"jdbc:oracle:thin:@{self.host}:{self.port}:{self.sid}?{parameters}".rstrip("?") + return f"jdbc:oracle:thin:@{self.host}:{self.port}:{self.sid}" + + return f"jdbc:oracle:thin:@//{self.host}:{self.port}/{self.service_name}" - return f"jdbc:oracle:thin:@//{self.host}:{self.port}/{self.service_name}?{parameters}".rstrip("?") + @property + def jdbc_params(self) -> dict: + result = super().jdbc_params + result.update(self.extra.dict(by_alias=True)) + return result @property def instance_url(self) -> str: diff --git a/onetl/connection/db_connection/postgres/connection.py b/onetl/connection/db_connection/postgres/connection.py index 80cddbc11..16d317fee 100644 --- a/onetl/connection/db_connection/postgres/connection.py +++ b/onetl/connection/db_connection/postgres/connection.py @@ -20,6 +20,10 @@ class PostgresExtra(GenericOptions): # allows automatic conversion from text to target column type during write stringtype: str = "unspecified" + # avoid closing connections from server side + # while connector is moving data to executors before insert + tcpKeepAlive: str = "true" # noqa: N815 + class Config: extra = "allow" @@ -142,11 +146,14 @@ def package(cls) -> str: @property def jdbc_url(self) -> str: - extra = self.extra.dict(by_alias=True) - extra["ApplicationName"] = extra.get("ApplicationName", self.spark.sparkContext.appName) + return f"jdbc:postgresql://{self.host}:{self.port}/{self.database}" - parameters = "&".join(f"{k}={v}" for k, v in sorted(extra.items())) - return f"jdbc:postgresql://{self.host}:{self.port}/{self.database}?{parameters}".rstrip("?") + @property + def jdbc_params(self) -> dict[str, str]: + result = super().jdbc_params + result.update(self.extra.dict(by_alias=True)) + result["ApplicationName"] = result.get("ApplicationName", self.spark.sparkContext.appName) + return result @property def instance_url(self) -> str: diff --git a/onetl/connection/db_connection/teradata/connection.py b/onetl/connection/db_connection/teradata/connection.py index 93bd51468..cf135009d 100644 --- a/onetl/connection/db_connection/teradata/connection.py +++ b/onetl/connection/db_connection/teradata/connection.py @@ -5,6 +5,7 @@ import warnings from typing import ClassVar, Optional +from onetl._internal import stringify from onetl._util.classproperty import classproperty from onetl._util.version import Version from onetl.connection.db_connection.jdbc_connection import JDBCConnection @@ -162,12 +163,22 @@ def package(cls) -> str: @property def jdbc_url(self) -> str: - prop = self.extra.dict(by_alias=True) + # Teradata JDBC driver documentation specifically mentions that params from + # java.sql.DriverManager.getConnection(url, params) are used to only retrieve 'user' and 'password' values. + # Other params should be passed via url + properties = self.extra.dict(by_alias=True) if self.database: - prop["DATABASE"] = self.database + properties["DATABASE"] = self.database - prop["DBS_PORT"] = self.port + properties["DBS_PORT"] = self.port - conn = ",".join(f"{k}={v}" for k, v in sorted(prop.items())) - return f"jdbc:teradata://{self.host}/{conn}" + connection_params = [] + for key, value in sorted(properties.items()): + string_value = stringify(value) + if "," in string_value: + connection_params.append(f"{key}='{string_value}'") + else: + connection_params.append(f"{key}={string_value}") + + return f"jdbc:teradata://{self.host}/{','.join(connection_params)}" diff --git a/tests/tests_integration/tests_db_connection_integration/test_clickhouse_integration.py b/tests/tests_integration/tests_db_connection_integration/test_clickhouse_integration.py index 18047d749..78656d834 100644 --- a/tests/tests_integration/tests_db_connection_integration/test_clickhouse_integration.py +++ b/tests/tests_integration/tests_db_connection_integration/test_clickhouse_integration.py @@ -46,6 +46,21 @@ def test_clickhouse_connection_check_fail(spark): clickhouse.check() +def test_clickhouse_connection_check_extra_is_handled_by_driver(spark, processing): + clickhouse = Clickhouse( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + extra={"socket_timeout": "wrong_type"}, + ) + + with pytest.raises(RuntimeError, match="Connection is unavailable"): + clickhouse.check() + + @pytest.mark.parametrize("suffix", ["", ";"]) def test_clickhouse_connection_sql(spark, processing, load_table_data, suffix): clickhouse = Clickhouse( diff --git a/tests/tests_integration/tests_db_connection_integration/test_greenplum_integration.py b/tests/tests_integration/tests_db_connection_integration/test_greenplum_integration.py index 4514594c5..5c2d17115 100644 --- a/tests/tests_integration/tests_db_connection_integration/test_greenplum_integration.py +++ b/tests/tests_integration/tests_db_connection_integration/test_greenplum_integration.py @@ -48,6 +48,21 @@ def test_greenplum_connection_check_fail(spark): greenplum.check() +def test_greenplum_connection_check_extra_is_handled_by_driver(spark, processing): + greenplum = Greenplum( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + extra={**processing.extra, "connectTimeout": "wrong_type"}, + ) + + with pytest.raises(RuntimeError, match="Connection is unavailable"): + greenplum.check() + + @pytest.mark.parametrize("suffix", ["", ";"]) def test_greenplum_connection_fetch(spark, processing, load_table_data, suffix): greenplum = Greenplum( diff --git a/tests/tests_integration/tests_db_connection_integration/test_mssql_integration.py b/tests/tests_integration/tests_db_connection_integration/test_mssql_integration.py index 9a875671a..4fad8a754 100644 --- a/tests/tests_integration/tests_db_connection_integration/test_mssql_integration.py +++ b/tests/tests_integration/tests_db_connection_integration/test_mssql_integration.py @@ -55,6 +55,21 @@ def test_mssql_connection_check_fail(spark): mssql.check() +def test_mssql_connection_check_extra_is_handled_by_driver(spark, processing): + mssql = MSSQL( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + extra={"trustServerCertificate": "false"}, + ) + + with pytest.raises(RuntimeError, match="Connection is unavailable"): + mssql.check() + + @pytest.mark.parametrize("suffix", ["", ";"]) def test_mssql_connection_sql(spark, processing, load_table_data, suffix): mssql = MSSQL( diff --git a/tests/tests_integration/tests_db_connection_integration/test_mysql_integration.py b/tests/tests_integration/tests_db_connection_integration/test_mysql_integration.py index 72a6b3b8f..8a4840329 100644 --- a/tests/tests_integration/tests_db_connection_integration/test_mysql_integration.py +++ b/tests/tests_integration/tests_db_connection_integration/test_mysql_integration.py @@ -46,6 +46,21 @@ def test_mysql_connection_check_fail(spark): mysql.check() +def test_mysql_connection_check_extra_is_handled_by_driver(spark, processing): + mysql = MySQL( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + database=processing.database, + spark=spark, + extra={"tcpKeepAlive": "wrong_type"}, + ) + + with pytest.raises(RuntimeError, match="Connection is unavailable"): + mysql.check() + + @pytest.mark.parametrize("suffix", ["", ";"]) def test_mysql_connection_sql(spark, processing, load_table_data, suffix): mysql = MySQL( diff --git a/tests/tests_integration/tests_db_connection_integration/test_oracle_integration.py b/tests/tests_integration/tests_db_connection_integration/test_oracle_integration.py index 6bd96b259..485ca1911 100644 --- a/tests/tests_integration/tests_db_connection_integration/test_oracle_integration.py +++ b/tests/tests_integration/tests_db_connection_integration/test_oracle_integration.py @@ -55,6 +55,22 @@ def test_oracle_connection_check_fail(spark): oracle.check() +def test_oracle_connection_check_extra_is_handled_by_driver(spark, processing): + oracle = Oracle( + host=processing.host, + port=processing.port, + user=processing.user, + password=processing.password, + spark=spark, + sid=processing.sid, + service_name=processing.service_name, + extra={"defaultRowPrefetch": "wrong_type"}, + ) + + with pytest.raises(RuntimeError, match="Connection is unavailable"): + oracle.check() + + @pytest.mark.parametrize("suffix", ["", ";"]) def test_oracle_connection_sql(spark, processing, load_table_data, suffix): oracle = Oracle( diff --git a/tests/tests_unit/tests_db_connection_unit/test_clickhouse_unit.py b/tests/tests_unit/tests_db_connection_unit/test_clickhouse_unit.py index 79fc13ddc..29478b6c9 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_clickhouse_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_clickhouse_unit.py @@ -121,6 +121,12 @@ def test_clickhouse(spark_mock): assert conn.database == "database" assert conn.jdbc_url == "jdbc:clickhouse://some_host:8123/database" + assert conn.jdbc_params == { + "user": "user", + "password": "passwd", + "driver": "com.clickhouse.jdbc.ClickHouseDriver", + "url": "jdbc:clickhouse://some_host:8123/database", + } assert "password='passwd'" not in str(conn) assert "password='passwd'" not in repr(conn) @@ -144,6 +150,12 @@ def test_clickhouse_with_port(spark_mock): assert conn.database == "database" assert conn.jdbc_url == "jdbc:clickhouse://some_host:5000/database" + assert conn.jdbc_params == { + "user": "user", + "password": "passwd", + "driver": "com.clickhouse.jdbc.ClickHouseDriver", + "url": "jdbc:clickhouse://some_host:5000/database", + } def test_clickhouse_without_database(spark_mock): @@ -157,6 +169,12 @@ def test_clickhouse_without_database(spark_mock): assert not conn.database assert conn.jdbc_url == "jdbc:clickhouse://some_host:8123" + assert conn.jdbc_params == { + "user": "user", + "password": "passwd", + "driver": "com.clickhouse.jdbc.ClickHouseDriver", + "url": "jdbc:clickhouse://some_host:8123", + } def test_clickhouse_with_extra(spark_mock): @@ -165,11 +183,18 @@ def test_clickhouse_with_extra(spark_mock): user="user", password="passwd", database="database", - extra={"socket_timeout": "120000", "query": "SELECT%201%3B"}, + extra={"socket_timeout": 120000, "custom_http_params": "key1=value1,key2=value2"}, spark=spark_mock, ) - assert conn.jdbc_url == "jdbc:clickhouse://some_host:8123/database?query=SELECT%201%3B&socket_timeout=120000" + assert conn.jdbc_params == { + "user": "user", + "password": "passwd", + "driver": "com.clickhouse.jdbc.ClickHouseDriver", + "url": "jdbc:clickhouse://some_host:8123/database", + "socket_timeout": 120000, + "custom_http_params": "key1=value1,key2=value2", + } def test_clickhouse_without_mandatory_args(spark_mock): diff --git a/tests/tests_unit/tests_db_connection_unit/test_greenplum_unit.py b/tests/tests_unit/tests_db_connection_unit/test_greenplum_unit.py index de24e5ce2..5c824d127 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_greenplum_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_greenplum_unit.py @@ -119,7 +119,15 @@ def test_greenplum(spark_mock): assert conn.password.get_secret_value() == "passwd" assert conn.database == "database" - assert conn.jdbc_url == "jdbc:postgresql://some_host:5432/database?ApplicationName=abc&tcpKeepAlive=true" + assert conn.jdbc_url == "jdbc:postgresql://some_host:5432/database" + assert conn.jdbc_params == { + "user": "user", + "password": "passwd", + "driver": "org.postgresql.Driver", + "url": "jdbc:postgresql://some_host:5432/database", + "ApplicationName": "abc", + "tcpKeepAlive": "true", + } assert "password='passwd'" not in str(conn) assert "password='passwd'" not in repr(conn) @@ -135,7 +143,15 @@ def test_greenplum_with_port(spark_mock): assert conn.password.get_secret_value() == "passwd" assert conn.database == "database" - assert conn.jdbc_url == "jdbc:postgresql://some_host:5000/database?ApplicationName=abc&tcpKeepAlive=true" + assert conn.jdbc_url == "jdbc:postgresql://some_host:5000/database" + assert conn.jdbc_params == { + "user": "user", + "password": "passwd", + "driver": "org.postgresql.Driver", + "url": "jdbc:postgresql://some_host:5000/database", + "ApplicationName": "abc", + "tcpKeepAlive": "true", + } def test_greenplum_without_database_error(spark_mock): @@ -161,9 +177,16 @@ def test_greenplum_with_extra(spark_mock): # `server.*` and `pool.*` options are ignored while generating jdbc_url # they are used only in `read_source_as_df` and `write_df_to_target` - assert conn.jdbc_url == ( - "jdbc:postgresql://some_host:5432/database?ApplicationName=override&autosave=always&tcpKeepAlive=false" - ) + assert conn.jdbc_url == "jdbc:postgresql://some_host:5432/database" + assert conn.jdbc_params == { + "user": "user", + "password": "passwd", + "driver": "org.postgresql.Driver", + "url": "jdbc:postgresql://some_host:5432/database", + "ApplicationName": "override", + "tcpKeepAlive": "false", + "autosave": "always", + } def test_greenplum_without_mandatory_args(spark_mock): diff --git a/tests/tests_unit/tests_db_connection_unit/test_jdbc_options_unit.py b/tests/tests_unit/tests_db_connection_unit/test_jdbc_options_unit.py index 7c8ecfcca..47148c6b9 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_jdbc_options_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_jdbc_options_unit.py @@ -197,70 +197,6 @@ def test_jdbc_write_options_case(): assert camel_case == snake_case -def test_jdbc_read_options_to_jdbc(spark_mock): - connection = Postgres(host="local", user="admin", database="default", password="1234", spark=spark_mock) - jdbc_params = connection.options_to_jdbc_params( - options=Postgres.ReadOptions( - lowerBound=10, - upperBound=1000, - partitionColumn="some_column", - numPartitions=20, - fetchsize=1000, - sessionInitStatement="BEGIN execute immediate 'alter session set '_serial_direct_read'=true", - snake_case_option="left unchanged", - camelCaseOption="left unchanged", - CamelCaseOption="left unchanged", - ), - ) - - assert jdbc_params == { - "column": "some_column", - "lowerBound": "10", - "numPartitions": "20", - "properties": { - "driver": "org.postgresql.Driver", - "fetchsize": "1000", - "password": "1234", - "sessionInitStatement": "BEGIN execute immediate 'alter session set '_serial_direct_read'=true", - "user": "admin", - "snake_case_option": "left unchanged", - "camelCaseOption": "left unchanged", - "CamelCaseOption": "left unchanged", - }, - "upperBound": "1000", - "url": "jdbc:postgresql://local:5432/default?ApplicationName=abc&stringtype=unspecified", - } - - -def test_jdbc_write_options_to_jdbc(spark_mock): - connection = Postgres(host="local", user="admin", database="default", password="1234", spark=spark_mock) - jdbc_params = connection.options_to_jdbc_params( - options=Postgres.WriteOptions( - batchsize=1000, - truncate=True, - isolation_level="NONE", - snake_case_option="left unchanged", - camelCaseOption="left unchanged", - CamelCaseOption="left unchanged", - ), - ) - - assert jdbc_params == { - "properties": { - "batchsize": "1000", - "driver": "org.postgresql.Driver", - "password": "1234", - "isolationLevel": "NONE", - "truncate": "true", - "user": "admin", - "snake_case_option": "left unchanged", - "camelCaseOption": "left unchanged", - "CamelCaseOption": "left unchanged", - }, - "url": "jdbc:postgresql://local:5432/default?ApplicationName=abc&stringtype=unspecified", - } - - @pytest.mark.parametrize( "options, value", [ diff --git a/tests/tests_unit/tests_db_connection_unit/test_mssql_unit.py b/tests/tests_unit/tests_db_connection_unit/test_mssql_unit.py index 51a548166..e1069c20e 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_mssql_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_mssql_unit.py @@ -92,7 +92,14 @@ def test_mssql(spark_mock): assert conn.password.get_secret_value() == "passwd" assert conn.database == "database" - assert conn.jdbc_url == "jdbc:sqlserver://some_host:1433;databaseName=database" + assert conn.jdbc_url == "jdbc:sqlserver://some_host:1433" + assert conn.jdbc_params == { + "user": "user", + "password": "passwd", + "driver": "com.microsoft.sqlserver.jdbc.SQLServerDriver", + "url": "jdbc:sqlserver://some_host:1433", + "databaseName": "database", + } assert "password='passwd'" not in str(conn) assert "password='passwd'" not in repr(conn) @@ -108,7 +115,14 @@ def test_mssql_with_port(spark_mock): assert conn.password.get_secret_value() == "passwd" assert conn.database == "database" - assert conn.jdbc_url == "jdbc:sqlserver://some_host:5000;databaseName=database" + assert conn.jdbc_url == "jdbc:sqlserver://some_host:5000" + assert conn.jdbc_params == { + "user": "user", + "password": "passwd", + "driver": "com.microsoft.sqlserver.jdbc.SQLServerDriver", + "url": "jdbc:sqlserver://some_host:5000", + "databaseName": "database", + } def test_mssql_without_database_error(spark_mock): @@ -118,7 +132,6 @@ def test_mssql_without_database_error(spark_mock): user="user", password="passwd", spark=spark_mock, - extra={"trustServerCertificate": "true"}, ) @@ -132,10 +145,16 @@ def test_mssql_with_extra(spark_mock): spark=spark_mock, ) - assert ( - conn.jdbc_url - == "jdbc:sqlserver://some_host:1433;characterEncoding=UTF-8;databaseName=database;trustServerCertificate=true" - ) + assert conn.jdbc_url == "jdbc:sqlserver://some_host:1433" + assert conn.jdbc_params == { + "user": "user", + "password": "passwd", + "driver": "com.microsoft.sqlserver.jdbc.SQLServerDriver", + "url": "jdbc:sqlserver://some_host:1433", + "databaseName": "database", + "characterEncoding": "UTF-8", + "trustServerCertificate": "true", + } def test_mssql_with_extra_prohibited(spark_mock): diff --git a/tests/tests_unit/tests_db_connection_unit/test_mysql_unit.py b/tests/tests_unit/tests_db_connection_unit/test_mysql_unit.py index c071e1196..da9267586 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_mysql_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_mysql_unit.py @@ -79,7 +79,15 @@ def test_mysql(spark_mock): assert conn.password.get_secret_value() == "passwd" assert conn.database == "database" - assert conn.jdbc_url == "jdbc:mysql://some_host:3306/database?characterEncoding=UTF-8&useUnicode=yes" + assert conn.jdbc_url == "jdbc:mysql://some_host:3306/database" + assert conn.jdbc_params == { + "user": "user", + "password": "passwd", + "driver": "com.mysql.cj.jdbc.Driver", + "url": "jdbc:mysql://some_host:3306/database", + "characterEncoding": "UTF-8", + "useUnicode": "yes", + } assert "password='passwd'" not in str(conn) assert "password='passwd'" not in repr(conn) @@ -95,7 +103,15 @@ def test_mysql_with_port(spark_mock): assert conn.password.get_secret_value() == "passwd" assert conn.database == "database" - assert conn.jdbc_url == "jdbc:mysql://some_host:5000/database?characterEncoding=UTF-8&useUnicode=yes" + assert conn.jdbc_url == "jdbc:mysql://some_host:5000/database" + assert conn.jdbc_params == { + "user": "user", + "password": "passwd", + "driver": "com.mysql.cj.jdbc.Driver", + "url": "jdbc:mysql://some_host:5000/database", + "characterEncoding": "UTF-8", + "useUnicode": "yes", + } def test_mysql_without_database(spark_mock): @@ -108,7 +124,15 @@ def test_mysql_without_database(spark_mock): assert conn.password.get_secret_value() == "passwd" assert not conn.database - assert conn.jdbc_url == "jdbc:mysql://some_host:3306?characterEncoding=UTF-8&useUnicode=yes" + assert conn.jdbc_url == "jdbc:mysql://some_host:3306" + assert conn.jdbc_params == { + "user": "user", + "password": "passwd", + "driver": "com.mysql.cj.jdbc.Driver", + "url": "jdbc:mysql://some_host:3306", + "characterEncoding": "UTF-8", + "useUnicode": "yes", + } def test_mysql_with_extra(spark_mock): @@ -121,10 +145,17 @@ def test_mysql_with_extra(spark_mock): spark=spark_mock, ) - assert conn.jdbc_url == ( - "jdbc:mysql://some_host:3306/database?allowMultiQueries=true&characterEncoding=UTF-8&" - "requireSSL=true&useUnicode=yes" - ) + assert conn.jdbc_url == "jdbc:mysql://some_host:3306/database" + assert conn.jdbc_params == { + "user": "user", + "password": "passwd", + "driver": "com.mysql.cj.jdbc.Driver", + "url": "jdbc:mysql://some_host:3306/database", + "characterEncoding": "UTF-8", + "useUnicode": "yes", + "allowMultiQueries": "true", + "requireSSL": "true", + } conn = MySQL( host="some_host", @@ -135,7 +166,15 @@ def test_mysql_with_extra(spark_mock): spark=spark_mock, ) - assert conn.jdbc_url == ("jdbc:mysql://some_host:3306/database?characterEncoding=CP-1251&useUnicode=no") + assert conn.jdbc_url == "jdbc:mysql://some_host:3306/database" + assert conn.jdbc_params == { + "user": "user", + "password": "passwd", + "driver": "com.mysql.cj.jdbc.Driver", + "url": "jdbc:mysql://some_host:3306/database", + "characterEncoding": "CP-1251", + "useUnicode": "no", + } def test_mysql_without_mandatory_args(spark_mock): diff --git a/tests/tests_unit/tests_db_connection_unit/test_oracle_unit.py b/tests/tests_unit/tests_db_connection_unit/test_oracle_unit.py index cb2b9dc7b..d4db6940e 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_oracle_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_oracle_unit.py @@ -103,6 +103,12 @@ def test_oracle(spark_mock): assert conn.sid == "sid" assert conn.jdbc_url == "jdbc:oracle:thin:@some_host:1521:sid" + assert conn.jdbc_params == { + "user": "user", + "password": "passwd", + "driver": "oracle.jdbc.driver.OracleDriver", + "url": "jdbc:oracle:thin:@some_host:1521:sid", + } assert "password='passwd'" not in str(conn) assert "password='passwd'" not in repr(conn) @@ -119,12 +125,24 @@ def test_oracle_with_port(spark_mock): assert conn.sid == "sid" assert conn.jdbc_url == "jdbc:oracle:thin:@some_host:5000:sid" + assert conn.jdbc_params == { + "user": "user", + "password": "passwd", + "driver": "oracle.jdbc.driver.OracleDriver", + "url": "jdbc:oracle:thin:@some_host:5000:sid", + } def test_oracle_uri_with_service_name(spark_mock): conn = Oracle(host="some_host", user="user", password="passwd", service_name="service", spark=spark_mock) assert conn.jdbc_url == "jdbc:oracle:thin:@//some_host:1521/service" + assert conn.jdbc_params == { + "user": "user", + "password": "passwd", + "driver": "oracle.jdbc.driver.OracleDriver", + "url": "jdbc:oracle:thin:@//some_host:1521/service", + } def test_oracle_without_sid_and_service_name(spark_mock): @@ -167,7 +185,15 @@ def test_oracle_with_extra(spark_mock): spark=spark_mock, ) - assert conn.jdbc_url == "jdbc:oracle:thin:@some_host:1521:sid?connectTimeout=10&tcpKeepAlive=false" + assert conn.jdbc_url == "jdbc:oracle:thin:@some_host:1521:sid" + assert conn.jdbc_params == { + "user": "user", + "password": "passwd", + "driver": "oracle.jdbc.driver.OracleDriver", + "url": "jdbc:oracle:thin:@some_host:1521:sid", + "tcpKeepAlive": "false", + "connectTimeout": "10", + } def test_oracle_without_mandatory_args(spark_mock): diff --git a/tests/tests_unit/tests_db_connection_unit/test_postgres_unit.py b/tests/tests_unit/tests_db_connection_unit/test_postgres_unit.py index f4c00f30f..268525220 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_postgres_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_postgres_unit.py @@ -79,7 +79,16 @@ def test_postgres(spark_mock): assert conn.password.get_secret_value() == "passwd" assert conn.database == "database" - assert conn.jdbc_url == "jdbc:postgresql://some_host:5432/database?ApplicationName=abc&stringtype=unspecified" + assert conn.jdbc_url == "jdbc:postgresql://some_host:5432/database" + assert conn.jdbc_params == { + "user": "user", + "password": "passwd", + "driver": "org.postgresql.Driver", + "url": "jdbc:postgresql://some_host:5432/database", + "ApplicationName": "abc", + "tcpKeepAlive": "true", + "stringtype": "unspecified", + } assert "password='passwd'" not in str(conn) assert "password='passwd'" not in repr(conn) @@ -95,7 +104,16 @@ def test_postgres_with_port(spark_mock): assert conn.password.get_secret_value() == "passwd" assert conn.database == "database" - assert conn.jdbc_url == "jdbc:postgresql://some_host:5000/database?ApplicationName=abc&stringtype=unspecified" + assert conn.jdbc_url == "jdbc:postgresql://some_host:5000/database" + assert conn.jdbc_params == { + "user": "user", + "password": "passwd", + "driver": "org.postgresql.Driver", + "url": "jdbc:postgresql://some_host:5000/database", + "ApplicationName": "abc", + "tcpKeepAlive": "true", + "stringtype": "unspecified", + } def test_postgres_without_database_error(spark_mock): @@ -109,14 +127,28 @@ def test_postgres_with_extra(spark_mock): user="user", password="passwd", database="database", - extra={"ssl": "true", "autosave": "always"}, + extra={ + "stringtype": "VARCHAR", + "autosave": "always", + "tcpKeepAlive": "false", + "ApplicationName": "override", + "ssl": "true", + }, spark=spark_mock, ) - assert ( - conn.jdbc_url - == "jdbc:postgresql://some_host:5432/database?ApplicationName=abc&autosave=always&ssl=true&stringtype=unspecified" - ) + assert conn.jdbc_url == "jdbc:postgresql://some_host:5432/database" + assert conn.jdbc_params == { + "user": "user", + "password": "passwd", + "driver": "org.postgresql.Driver", + "url": "jdbc:postgresql://some_host:5432/database", + "stringtype": "VARCHAR", + "autosave": "always", + "tcpKeepAlive": "false", + "ApplicationName": "override", + "ssl": "true", + } def test_postgres_without_mandatory_args(spark_mock): diff --git a/tests/tests_unit/tests_db_connection_unit/test_teradata_unit.py b/tests/tests_unit/tests_db_connection_unit/test_teradata_unit.py index fd90d31d4..b71d7e8d1 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_teradata_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_teradata_unit.py @@ -82,6 +82,12 @@ def test_teradata(spark_mock): "jdbc:teradata://some_host/CHARSET=UTF8,COLUMN_NAME=ON,DATABASE=database," "DBS_PORT=1025,FLATTEN=ON,MAYBENULL=ON,STRICT_NAMES=OFF" ) + assert conn.jdbc_params == { + "user": "user", + "password": "passwd", + "driver": "com.teradata.jdbc.TeraDriver", + "url": conn.jdbc_url, + } assert "password='passwd'" not in str(conn) assert "password='passwd'" not in repr(conn) @@ -101,6 +107,12 @@ def test_teradata_with_port(spark_mock): "jdbc:teradata://some_host/CHARSET=UTF8,COLUMN_NAME=ON,DATABASE=database," "DBS_PORT=5000,FLATTEN=ON,MAYBENULL=ON,STRICT_NAMES=OFF" ) + assert conn.jdbc_params == { + "user": "user", + "password": "passwd", + "driver": "com.teradata.jdbc.TeraDriver", + "url": conn.jdbc_url, + } def test_teradata_without_database(spark_mock): @@ -117,6 +129,12 @@ def test_teradata_without_database(spark_mock): "jdbc:teradata://some_host/CHARSET=UTF8,COLUMN_NAME=ON," "DBS_PORT=1025,FLATTEN=ON,MAYBENULL=ON,STRICT_NAMES=OFF" ) + assert conn.jdbc_params == { + "user": "user", + "password": "passwd", + "driver": "com.teradata.jdbc.TeraDriver", + "url": conn.jdbc_url, + } def test_teradata_with_extra(spark_mock): @@ -125,14 +143,20 @@ def test_teradata_with_extra(spark_mock): user="user", password="passwd", database="database", - extra={"TMODE": "TERA", "LOGMECH": "LDAP"}, + extra={"TMODE": "TERA", "LOGMECH": "LDAP", "PARAM_WITH_COMMA": "some,value"}, spark=spark_mock, ) assert conn.jdbc_url == ( "jdbc:teradata://some_host/CHARSET=UTF8,COLUMN_NAME=ON,DATABASE=database," - "DBS_PORT=1025,FLATTEN=ON,LOGMECH=LDAP,MAYBENULL=ON,STRICT_NAMES=OFF,TMODE=TERA" + "DBS_PORT=1025,FLATTEN=ON,LOGMECH=LDAP,MAYBENULL=ON,PARAM_WITH_COMMA='some,value',STRICT_NAMES=OFF,TMODE=TERA" ) + assert conn.jdbc_params == { + "user": "user", + "password": "passwd", + "driver": "com.teradata.jdbc.TeraDriver", + "url": conn.jdbc_url, + } conn = Teradata( host="some_host", @@ -147,6 +171,12 @@ def test_teradata_with_extra(spark_mock): "jdbc:teradata://some_host/CHARSET=CP-1251,COLUMN_NAME=OFF,DATABASE=database," "DBS_PORT=1025,FLATTEN=OFF,MAYBENULL=OFF,STRICT_NAMES=ON" ) + assert conn.jdbc_params == { + "user": "user", + "password": "passwd", + "driver": "com.teradata.jdbc.TeraDriver", + "url": conn.jdbc_url, + } def test_teradata_with_extra_prohibited(spark_mock):