diff --git a/docs/changelog/next_release/252.feature.rst b/docs/changelog/next_release/252.feature.rst new file mode 100644 index 000000000..497b3c4ae --- /dev/null +++ b/docs/changelog/next_release/252.feature.rst @@ -0,0 +1 @@ +:class:`Oracle` connection now uses Oracle JDBC driver ``23.3.0.0.23.09``, upgraded from ``23.2.0.0``, and supports passing custom versions: ``Oracle.get_packages(java_version=..., package_version=...)``. diff --git a/onetl/connection/db_connection/oracle/connection.py b/onetl/connection/db_connection/oracle/connection.py index f239b85b2..a2d8d35b9 100644 --- a/onetl/connection/db_connection/oracle/connection.py +++ b/onetl/connection/db_connection/oracle/connection.py @@ -179,15 +179,18 @@ class Oracle(JDBCConnection): @classmethod def get_packages( cls, - java_version: str | Version | None = None, + java_version: str | None = None, + package_version: str | None = None, ) -> list[str]: """ - Get package names to be downloaded by Spark. |support_hooks| + Get package names to be downloaded by Spark. Allows specifying custom JDBC driver versions for Oracle. |support_hooks| Parameters ---------- - java_version : str, default ``8`` - Java major version. + java_version : str, optional + Java major version, defaults to "8". Must be "8" or "11". + package_version : str, optional + Specifies the version of the Oracle JDBC driver to use. Defaults to "23.3.0.0.23.09". Examples -------- @@ -197,25 +200,29 @@ def get_packages( from onetl.connection import Oracle Oracle.get_packages() - Oracle.get_packages(java_version="8") + # specify Java and package versions + Oracle.get_packages(java_version="8", package_version="23.2.0.0") """ - if java_version is None: - java_version = "8" - java_ver = Version(java_version) + default_java_version = "8" + default_package_version = "23.3.0.23.09" + + java_ver = Version(java_version or default_java_version) if java_ver.major < 8: - raise ValueError(f"Java version must be at least 8, got {java_ver}") + raise ValueError(f"Java version must be at least 8, got {java_ver.major}") jre_ver = "8" if java_ver.major < 11 else "11" - return [f"com.oracle.database.jdbc:ojdbc{jre_ver}:23.2.0.0"] + jdbc_version = Version(package_version or default_package_version).min_digits(4) + + return [f"com.oracle.database.jdbc:ojdbc{jre_ver}:{jdbc_version}"] @classproperty def package(cls) -> str: """Get package name to be downloaded by Spark.""" msg = "`Oracle.package` will be removed in 1.0.0, use `Oracle.get_packages()` instead" warnings.warn(msg, UserWarning, stacklevel=3) - return "com.oracle.database.jdbc:ojdbc8:23.2.0.0" + return "com.oracle.database.jdbc:ojdbc8:23.3.0.23.09" @property def jdbc_url(self) -> str: diff --git a/tests/tests_unit/tests_db_connection_unit/test_oracle_unit.py b/tests/tests_unit/tests_db_connection_unit/test_oracle_unit.py index 6a875b8f7..cb2b9dc7b 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_oracle_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_oracle_unit.py @@ -14,11 +14,11 @@ def test_oracle_class_attributes(): def test_oracle_package(): warning_msg = re.escape("will be removed in 1.0.0, use `Oracle.get_packages()` instead") with pytest.warns(UserWarning, match=warning_msg): - assert Oracle.package == "com.oracle.database.jdbc:ojdbc8:23.2.0.0" + assert Oracle.package == "com.oracle.database.jdbc:ojdbc8:23.3.0.23.09" def test_oracle_get_packages_no_input(): - assert Oracle.get_packages() == ["com.oracle.database.jdbc:ojdbc8:23.2.0.0"] + assert Oracle.get_packages() == ["com.oracle.database.jdbc:ojdbc8:23.3.0.23.09"] @pytest.mark.parametrize("java_version", ["7", "6"]) @@ -28,17 +28,44 @@ def test_oracle_get_packages_java_version_not_supported(java_version): @pytest.mark.parametrize( - "java_version, package", + "java_version, package_version, expected_packages", [ - ("8", "com.oracle.database.jdbc:ojdbc8:23.2.0.0"), - ("9", "com.oracle.database.jdbc:ojdbc8:23.2.0.0"), - ("11", "com.oracle.database.jdbc:ojdbc11:23.2.0.0"), - ("17", "com.oracle.database.jdbc:ojdbc11:23.2.0.0"), - ("20", "com.oracle.database.jdbc:ojdbc11:23.2.0.0"), + (None, None, ["com.oracle.database.jdbc:ojdbc8:23.3.0.23.09"]), + ("8", None, ["com.oracle.database.jdbc:ojdbc8:23.3.0.23.09"]), + ("8", "23.3.0.23.09", ["com.oracle.database.jdbc:ojdbc8:23.3.0.23.09"]), + ("8", "21.13.0.0", ["com.oracle.database.jdbc:ojdbc8:21.13.0.0"]), + ("9", None, ["com.oracle.database.jdbc:ojdbc8:23.3.0.23.09"]), + ("9", "21.13.0.0", ["com.oracle.database.jdbc:ojdbc8:21.13.0.0"]), + ("11", None, ["com.oracle.database.jdbc:ojdbc11:23.3.0.23.09"]), + ("11", "21.13.0.0", ["com.oracle.database.jdbc:ojdbc11:21.13.0.0"]), + ("17", "21.13.0.0", ["com.oracle.database.jdbc:ojdbc11:21.13.0.0"]), + ("20", "23.3.0.23.09", ["com.oracle.database.jdbc:ojdbc11:23.3.0.23.09"]), ], ) -def test_oracle_get_packages(java_version, package): - assert Oracle.get_packages(java_version=java_version) == [package] +def test_oracle_get_packages(java_version, package_version, expected_packages): + assert Oracle.get_packages(java_version=java_version, package_version=package_version) == expected_packages + + +@pytest.mark.parametrize( + "java_version, package_version", + [ + ("8", "23.3.0"), + ("11", "23.3"), + ("11", "a.b.c.d"), + ], +) +def test_oracle_get_packages_invalid_version(java_version, package_version): + with pytest.raises( + ValueError, + match=rf"Version '{package_version}' does not have enough numeric components for requested format \(expected at least 4\).", + ): + Oracle.get_packages(java_version=java_version, package_version=package_version) + + +@pytest.mark.parametrize("java_version", ["7", "6"]) +def test_oracle_get_packages_java_version_not_supported(java_version): + with pytest.raises(ValueError, match=f"Java version must be at least 8, got {java_version}"): + Oracle.get_packages(java_version=java_version) def test_oracle_missing_package(spark_no_packages):