diff --git a/airflow/providers/amazon/aws/transfers/mongo_to_s3.py b/airflow/providers/amazon/aws/transfers/mongo_to_s3.py index bf0297397a7a1..64390d9e351ab 100644 --- a/airflow/providers/amazon/aws/transfers/mongo_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/mongo_to_s3.py @@ -21,6 +21,8 @@ from typing import TYPE_CHECKING, Any, Iterable, Sequence, cast from bson import json_util +from pymongo.command_cursor import CommandCursor +from pymongo.cursor import Cursor from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook @@ -96,7 +98,7 @@ def execute(self, context: Context): # Grab collection and execute query according to whether or not it is a pipeline if self.is_pipeline: - results = MongoHook(self.mongo_conn_id).aggregate( + results: CommandCursor[Any] | Cursor = MongoHook(self.mongo_conn_id).aggregate( mongo_collection=self.mongo_collection, aggregate_query=cast(list, self.mongo_query), mongo_db=self.mongo_db, @@ -109,6 +111,7 @@ def execute(self, context: Context): query=cast(dict, self.mongo_query), projection=self.mongo_projection, mongo_db=self.mongo_db, + find_one=False, ) # Performs transform then stringifies the docs results into json format diff --git a/airflow/providers/mongo/hooks/mongo.py b/airflow/providers/mongo/hooks/mongo.py index b37a76dcf3800..96b5a1eb39a23 100644 --- a/airflow/providers/mongo/hooks/mongo.py +++ b/airflow/providers/mongo/hooks/mongo.py @@ -20,12 +20,14 @@ from ssl import CERT_NONE from types import TracebackType +from typing import Any, overload from urllib.parse import quote_plus, urlunsplit import pymongo from pymongo import MongoClient, ReplaceOne from airflow.hooks.base import BaseHook +from airflow.typing_compat import Literal class MongoHook(BaseHook): @@ -56,7 +58,7 @@ def __init__(self, conn_id: str = default_conn_name, *args, **kwargs) -> None: self.mongo_conn_id = conn_id self.connection = self.get_connection(conn_id) self.extras = self.connection.extra_dejson.copy() - self.client = None + self.client: MongoClient | None = None self.uri = self._create_uri() def __enter__(self): @@ -134,15 +136,39 @@ def aggregate( return collection.aggregate(aggregate_query, **kwargs) + @overload def find( self, mongo_collection: str, query: dict, - find_one: bool = False, + find_one: Literal[False], mongo_db: str | None = None, projection: list | dict | None = None, **kwargs, ) -> pymongo.cursor.Cursor: + ... + + @overload + def find( + self, + mongo_collection: str, + query: dict, + find_one: Literal[True], + mongo_db: str | None = None, + projection: list | dict | None = None, + **kwargs, + ) -> Any | None: + ... + + def find( + self, + mongo_collection: str, + query: dict, + find_one: bool = False, + mongo_db: str | None = None, + projection: list | dict | None = None, + **kwargs, + ) -> pymongo.cursor.Cursor | Any | None: """ Runs a mongo find query and returns the results https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.find diff --git a/tests/providers/amazon/aws/transfers/test_mongo_to_s3.py b/tests/providers/amazon/aws/transfers/test_mongo_to_s3.py index f2bd53318c3e8..ed07fdfab23bb 100644 --- a/tests/providers/amazon/aws/transfers/test_mongo_to_s3.py +++ b/tests/providers/amazon/aws/transfers/test_mongo_to_s3.py @@ -96,7 +96,11 @@ def test_execute(self, mock_s3_hook, mock_mongo_hook): operator.execute(None) mock_mongo_hook.return_value.find.assert_called_once_with( - mongo_collection=MONGO_COLLECTION, query=MONGO_QUERY, mongo_db=None, projection=None + mongo_collection=MONGO_COLLECTION, + query=MONGO_QUERY, + find_one=False, + mongo_db=None, + projection=None, ) op_stringify = self.mock_operator._stringify @@ -119,7 +123,11 @@ def test_execute_compress(self, mock_s3_hook, mock_mongo_hook): operator.execute(None) mock_mongo_hook.return_value.find.assert_called_once_with( - mongo_collection=MONGO_COLLECTION, query=MONGO_QUERY, mongo_db=None, projection=None + mongo_collection=MONGO_COLLECTION, + query=MONGO_QUERY, + find_one=False, + mongo_db=None, + projection=None, ) op_stringify = self.mock_operator._stringify