diff --git a/docs/changelog/next_release/254.feature.rst b/docs/changelog/next_release/254.feature.rst new file mode 100644 index 000000000..0a8aff089 --- /dev/null +++ b/docs/changelog/next_release/254.feature.rst @@ -0,0 +1 @@ +:class:`MSSQL` connection now uses Microsoft SQL Server JDBC driver ``12.6.1``, upgraded from ``12.2.0``, and supports passing custom versions: ``MSSQL.get_packages(java_version=..., package_version=...)``. diff --git a/onetl/connection/db_connection/mssql/connection.py b/onetl/connection/db_connection/mssql/connection.py index 1756146fa..8fa91b47a 100644 --- a/onetl/connection/db_connection/mssql/connection.py +++ b/onetl/connection/db_connection/mssql/connection.py @@ -149,43 +149,54 @@ class MSSQL(JDBCConnection): @classmethod def get_packages( cls, - java_version: str | Version | None = None, + java_version: str | None = None, + package_version: str | None = None, ) -> list[str]: """ - Get package names to be downloaded by Spark. |support_hooks| + Get package names to be downloaded by Spark. Allows specifying custom JDBC driver versions for MSSQL. |support_hooks| Parameters ---------- - java_version : str, default ``8`` - Java major version. + java_version : str, optional + Java major version, defaults to ``8``. Must be ``8`` or ``11``. + package_version : str, optional + Specifies the version of the MSSQL JDBC driver to use. Defaults to ``12.6.1.``. Examples -------- - .. code:: python from onetl.connection import MSSQL MSSQL.get_packages() - MSSQL.get_packages(java_version="8") + # specify Java and package versions + MSSQL.get_packages(java_version="8", package_version="12.6.1.jre11") """ - if java_version is None: - java_version = "8" + default_java_version = "8" + default_package_version = "12.6.1" - java_ver = Version(java_version) + java_ver = Version(java_version or default_java_version) if java_ver.major < 8: raise ValueError(f"Java version must be at least 8, got {java_ver}") jre_ver = "8" if java_ver.major < 11 else "11" - return [f"com.microsoft.sqlserver:mssql-jdbc:12.2.0.jre{jre_ver}"] + full_package_version = Version(package_version or default_package_version).min_digits(3) + + # check if a JRE suffix is already included + if ".jre" in str(full_package_version): + jdbc_version = full_package_version + else: + jdbc_version = Version(f"{full_package_version}.jre{jre_ver}") + + return [f"com.microsoft.sqlserver:mssql-jdbc:{jdbc_version}"] @classproperty def package(cls) -> str: """Get package name to be downloaded by Spark.""" msg = "`MSSQL.package` will be removed in 1.0.0, use `MSSQL.get_packages()` instead" warnings.warn(msg, UserWarning, stacklevel=3) - return "com.microsoft.sqlserver:mssql-jdbc:12.2.0.jre8" + return "com.microsoft.sqlserver:mssql-jdbc:12.6.1.jre8" @property def jdbc_url(self) -> str: diff --git a/tests/tests_unit/tests_db_connection_unit/test_mssql_unit.py b/tests/tests_unit/tests_db_connection_unit/test_mssql_unit.py index 7b0328ca9..51a548166 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_mssql_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_mssql_unit.py @@ -14,31 +14,48 @@ def test_mssql_class_attributes(): def test_mssql_package(): warning_msg = re.escape("will be removed in 1.0.0, use `MSSQL.get_packages()` instead") with pytest.warns(UserWarning, match=warning_msg): - assert MSSQL.package == "com.microsoft.sqlserver:mssql-jdbc:12.2.0.jre8" + assert MSSQL.package == "com.microsoft.sqlserver:mssql-jdbc:12.6.1.jre8" -def test_mssql_get_packages_no_input(): - assert MSSQL.get_packages() == ["com.microsoft.sqlserver:mssql-jdbc:12.2.0.jre8"] - - -@pytest.mark.parametrize("java_version", ["7", "6"]) -def test_mssql_get_packages_java_version_not_supported(java_version): - with pytest.raises(ValueError, match=f"Java version must be at least 8, got {java_version}"): - MSSQL.get_packages(java_version=java_version) +@pytest.mark.parametrize( + "java_version, package_version, expected_packages", + [ + (None, None, ["com.microsoft.sqlserver:mssql-jdbc:12.6.1.jre8"]), + ("8", None, ["com.microsoft.sqlserver:mssql-jdbc:12.6.1.jre8"]), + ("9", None, ["com.microsoft.sqlserver:mssql-jdbc:12.6.1.jre8"]), + ("11", None, ["com.microsoft.sqlserver:mssql-jdbc:12.6.1.jre11"]), + ("20", None, ["com.microsoft.sqlserver:mssql-jdbc:12.6.1.jre11"]), + ("8", "12.6.1.jre8", ["com.microsoft.sqlserver:mssql-jdbc:12.6.1.jre8"]), + ("11", "12.6.1.jre11", ["com.microsoft.sqlserver:mssql-jdbc:12.6.1.jre11"]), + ("11", "12.7.0.jre11-preview", ["com.microsoft.sqlserver:mssql-jdbc:12.7.0.jre11-preview"]), + ("8", "12.7.0.jre8-preview", ["com.microsoft.sqlserver:mssql-jdbc:12.7.0.jre8-preview"]), + ("8", "12.6.1", ["com.microsoft.sqlserver:mssql-jdbc:12.6.1.jre8"]), + ("11", "12.6.1", ["com.microsoft.sqlserver:mssql-jdbc:12.6.1.jre11"]), + ], +) +def test_mssql_get_packages(java_version, package_version, expected_packages): + assert MSSQL.get_packages(java_version=java_version, package_version=package_version) == expected_packages @pytest.mark.parametrize( - "java_version, package", + "package_version", [ - ("8", "com.microsoft.sqlserver:mssql-jdbc:12.2.0.jre8"), - ("9", "com.microsoft.sqlserver:mssql-jdbc:12.2.0.jre8"), - ("11", "com.microsoft.sqlserver:mssql-jdbc:12.2.0.jre11"), - ("17", "com.microsoft.sqlserver:mssql-jdbc:12.2.0.jre11"), - ("20", "com.microsoft.sqlserver:mssql-jdbc:12.2.0.jre11"), + "12.7", + "abc", ], ) -def test_mssql_get_packages(java_version, package): - assert MSSQL.get_packages(java_version=java_version) == [package] +def test_mssql_get_packages_invalid_version(package_version): + with pytest.raises( + ValueError, + match=rf"Version '{package_version}' does not have enough numeric components for requested format \(expected at least 3\).", + ): + MSSQL.get_packages(package_version=package_version) + + +@pytest.mark.parametrize("java_version", ["7", "6"]) +def test_mssql_get_packages_java_version_not_supported(java_version): + with pytest.raises(ValueError, match=f"Java version must be at least 8, got {java_version}"): + MSSQL.get_packages(java_version=java_version) def test_mssql_missing_package(spark_no_packages):