From a1c57b7068a6fb5eb526d6b4edbce6a16e178070 Mon Sep 17 00:00:00 2001 From: Jib Date: Tue, 14 Nov 2023 15:35:06 -0500 Subject: [PATCH] MOTOR-1209: Motor's DriverInfo should not be overwritten (#233) --- motor/core.py | 29 +++++++++++++++++--- test/tornado_tests/test_motor_client.py | 35 +++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 4 deletions(-) diff --git a/motor/core.py b/motor/core.py index 3f2133c1..7f22af5a 100644 --- a/motor/core.py +++ b/motor/core.py @@ -140,9 +140,25 @@ def __init__(self, *args, **kwargs): self._io_loop = io_loop kwargs.setdefault("connect", False) - kwargs.setdefault( - "driver", DriverInfo("Motor", motor_version, self._framework.platform_info()) - ) + + driver_info = DriverInfo("Motor", motor_version, self._framework.platform_info()) + + if kwargs.get("driver"): + provided_info = kwargs.get("driver") + if not isinstance(provided_info, DriverInfo): + raise TypeError( + f"Incorrect type for `driver` {type(provided_info)};" + " expected value of type pymongo.driver_info.DriverInfo" + ) + added_version = f"|{provided_info.version}" if provided_info.version else "" + added_platform = f"|{provided_info.platform}" if provided_info.platform else "" + driver_info = DriverInfo( + f"{driver_info.name}|{provided_info.name}", + f"{driver_info.version}{added_version}", + f"{driver_info.platform}{added_platform}", + ) + + kwargs["driver"] = driver_info delegate = self.__delegate_class__(*args, **kwargs) super().__init__(delegate) @@ -1650,7 +1666,12 @@ def to_list(self, length): else: the_list = [] self._framework.add_future( - self.get_io_loop(), self._get_more(), self._to_list, length, the_list, future + self.get_io_loop(), + self._get_more(), + self._to_list, + length, + the_list, + future, ) return future diff --git a/test/tornado_tests/test_motor_client.py b/test/tornado_tests/test_motor_client.py index 8529ed45..8964c7c1 100644 --- a/test/tornado_tests/test_motor_client.py +++ b/test/tornado_tests/test_motor_client.py @@ -27,6 +27,7 @@ from bson import CodecOptions from mockupdb import OpQuery from pymongo import CursorType, ReadPreference, WriteConcern +from pymongo.driver_info import DriverInfo from pymongo.errors import ConnectionFailure, OperationFailure from tornado import gen from tornado.testing import gen_test @@ -305,6 +306,40 @@ async def test_handshake(self): except Exception: pass + @gen_test + async def test_driver_info(self): + server = self.server() + driver_info = DriverInfo(name="Foo", version="1.1.1", platform="FooPlat") + client = motor.MotorClient(server.uri, driver=driver_info) + + # Trigger connection. + future = client.db.command("ping") + handshake = await self.run_thread(server.receives, "ismaster") + meta = handshake.doc["client"] + self.assertEqual(f"PyMongo|Motor|{driver_info.name}", meta["driver"]["name"]) + self.assertIn("Tornado", meta["platform"]) + self.assertIn(f"|{driver_info.platform}", meta["platform"]) + self.assertTrue( + meta["driver"]["version"].endswith(f"{motor.version}|{driver_info.version}"), + "Version in handshake [%s] doesn't end with MotorVersion|Test version [%s]" + % (meta["driver"]["version"], f"{motor.version}|{driver_info.version}"), + ) + + handshake.ok() + server.stop() + client.close() + try: + await future + except Exception: + pass + + def test_incorrect_driver_info(self): + with self.assertRaises( + TypeError, + msg="Allowed invalid type parameter str, driver should only be of DriverInfo", + ): + motor.MotorClient(driver="string") + if __name__ == "__main__": unittest.main()