Skip to content

Commit

Permalink
[DOP-13852] - upgrade MSSQL packages
Browse files Browse the repository at this point in the history
  • Loading branch information
maxim-lixakov committed Apr 18, 2024
1 parent 8e0dc89 commit 4a5b97b
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 28 deletions.
1 change: 1 addition & 0 deletions docs/changelog/next_release/254.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
:class:`MSSQL` connection now uses Microsoft SQL Server JDBC driver ``12.7.0``, upgraded from ``12.2.0``, and supports passing custom versions: ``MSSQL.get_packages(java_version=..., package_version=...)``.
33 changes: 22 additions & 11 deletions onetl/connection/db_connection/mssql/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,43 +149,54 @@ class MSSQL(JDBCConnection):
@classmethod
def get_packages(
cls,
java_version: str | Version | None = None,
java_version: str | None = None,
package_version: str | None = None,
) -> list[str]:
"""
Get package names to be downloaded by Spark. |support_hooks|
Get package names to be downloaded by Spark. Allows specifying custom JDBC driver versions for MSSQL. |support_hooks|
Parameters
----------
java_version : str, default ``8``
Java major version.
java_version : str, optional
Java major version, defaults to ``8``. Must be ``8`` or ``11``.
package_version : str, optional
Specifies the version of the MSSQL JDBC driver to use. Defaults to ``12.7.0``.
Examples
--------
.. code:: python
from onetl.connection import MSSQL
MSSQL.get_packages()
MSSQL.get_packages(java_version="8")
# specify Java and package versions
MSSQL.get_packages(java_version="8", package_version="12.6.1.jre11")
"""
if java_version is None:
java_version = "8"
default_java_version = "8"
default_package_version = "12.7.0"

java_ver = Version(java_version)
java_ver = Version(java_version or default_java_version)
if java_ver.major < 8:
raise ValueError(f"Java version must be at least 8, got {java_ver}")

jre_ver = "8" if java_ver.major < 11 else "11"
return [f"com.microsoft.sqlserver:mssql-jdbc:12.2.0.jre{jre_ver}"]
full_package_version = Version(package_version or default_package_version).min_digits(3)

# check if a JRE suffix is already included
if ".jre" in str(full_package_version):
jdbc_version = full_package_version
else:
jdbc_version = Version(f"{full_package_version}.jre{jre_ver}")

return [f"com.microsoft.sqlserver:mssql-jdbc:{jdbc_version}"]

@classproperty
def package(cls) -> str:
"""Get package name to be downloaded by Spark."""
msg = "`MSSQL.package` will be removed in 1.0.0, use `MSSQL.get_packages()` instead"
warnings.warn(msg, UserWarning, stacklevel=3)
return "com.microsoft.sqlserver:mssql-jdbc:12.2.0.jre8"
return "com.microsoft.sqlserver:mssql-jdbc:12.7.0.jre8"

@property
def jdbc_url(self) -> str:
Expand Down
49 changes: 32 additions & 17 deletions tests/tests_unit/tests_db_connection_unit/test_mssql_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,46 @@ def test_mssql_class_attributes():
def test_mssql_package():
warning_msg = re.escape("will be removed in 1.0.0, use `MSSQL.get_packages()` instead")
with pytest.warns(UserWarning, match=warning_msg):
assert MSSQL.package == "com.microsoft.sqlserver:mssql-jdbc:12.2.0.jre8"
assert MSSQL.package == "com.microsoft.sqlserver:mssql-jdbc:12.7.0.jre8"


def test_mssql_get_packages_no_input():
assert MSSQL.get_packages() == ["com.microsoft.sqlserver:mssql-jdbc:12.2.0.jre8"]


@pytest.mark.parametrize("java_version", ["7", "6"])
def test_mssql_get_packages_java_version_not_supported(java_version):
with pytest.raises(ValueError, match=f"Java version must be at least 8, got {java_version}"):
MSSQL.get_packages(java_version=java_version)
@pytest.mark.parametrize(
"java_version, package_version, expected_packages",
[
(None, None, ["com.microsoft.sqlserver:mssql-jdbc:12.7.0.jre8"]),
("8", None, ["com.microsoft.sqlserver:mssql-jdbc:12.7.0.jre8"]),
("11", None, ["com.microsoft.sqlserver:mssql-jdbc:12.7.0.jre11"]),
("8", "12.6.1.jre8", ["com.microsoft.sqlserver:mssql-jdbc:12.6.1.jre8"]),
("11", "12.6.1.jre11", ["com.microsoft.sqlserver:mssql-jdbc:12.6.1.jre11"]),
("11", "12.7.0.jre11-preview", ["com.microsoft.sqlserver:mssql-jdbc:12.7.0.jre11-preview"]),
("8", "12.7.0.jre8-preview", ["com.microsoft.sqlserver:mssql-jdbc:12.7.0.jre8-preview"]),
("8", "12.7.0", ["com.microsoft.sqlserver:mssql-jdbc:12.7.0.jre8"]),
("11", "12.7.0", ["com.microsoft.sqlserver:mssql-jdbc:12.7.0.jre11"]),
],
)
def test_mssql_get_packages(java_version, package_version, expected_packages):
assert MSSQL.get_packages(java_version=java_version, package_version=package_version) == expected_packages


@pytest.mark.parametrize(
"java_version, package",
"package_version",
[
("8", "com.microsoft.sqlserver:mssql-jdbc:12.2.0.jre8"),
("9", "com.microsoft.sqlserver:mssql-jdbc:12.2.0.jre8"),
("11", "com.microsoft.sqlserver:mssql-jdbc:12.2.0.jre11"),
("17", "com.microsoft.sqlserver:mssql-jdbc:12.2.0.jre11"),
("20", "com.microsoft.sqlserver:mssql-jdbc:12.2.0.jre11"),
"12.7",
"abc",
],
)
def test_mssql_get_packages(java_version, package):
assert MSSQL.get_packages(java_version=java_version) == [package]
def test_mssql_get_packages_invalid_version(package_version):
with pytest.raises(
ValueError,
match=rf"Version '{package_version}' does not have enough numeric components for requested format \(expected at least 3\).",
):
MSSQL.get_packages(package_version=package_version)


@pytest.mark.parametrize("java_version", ["7", "6"])
def test_mssql_get_packages_java_version_not_supported(java_version):
with pytest.raises(ValueError, match=f"Java version must be at least 8, got {java_version}"):
MSSQL.get_packages(java_version=java_version)


def test_mssql_missing_package(spark_no_packages):
Expand Down

0 comments on commit 4a5b97b

Please sign in to comment.