Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add future-compatible mongo Hook typing #31289

Merged
merged 1 commit into from
May 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion airflow/providers/amazon/aws/transfers/mongo_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from typing import TYPE_CHECKING, Any, Iterable, Sequence, cast

from bson import json_util
from pymongo.command_cursor import CommandCursor
from pymongo.cursor import Cursor

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
Expand Down Expand Up @@ -96,7 +98,7 @@ def execute(self, context: Context):

# Grab collection and execute query according to whether or not it is a pipeline
if self.is_pipeline:
results = MongoHook(self.mongo_conn_id).aggregate(
results: CommandCursor[Any] | Cursor = MongoHook(self.mongo_conn_id).aggregate(
mongo_collection=self.mongo_collection,
aggregate_query=cast(list, self.mongo_query),
mongo_db=self.mongo_db,
Expand All @@ -109,6 +111,7 @@ def execute(self, context: Context):
query=cast(dict, self.mongo_query),
projection=self.mongo_projection,
mongo_db=self.mongo_db,
find_one=False,
)

# Performs transform then stringifies the docs results into json format
Expand Down
30 changes: 28 additions & 2 deletions airflow/providers/mongo/hooks/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@

from ssl import CERT_NONE
from types import TracebackType
from typing import Any, overload
from urllib.parse import quote_plus, urlunsplit

import pymongo
from pymongo import MongoClient, ReplaceOne

from airflow.hooks.base import BaseHook
from airflow.typing_compat import Literal


class MongoHook(BaseHook):
Expand Down Expand Up @@ -56,7 +58,7 @@ def __init__(self, conn_id: str = default_conn_name, *args, **kwargs) -> None:
self.mongo_conn_id = conn_id
self.connection = self.get_connection(conn_id)
self.extras = self.connection.extra_dejson.copy()
self.client = None
self.client: MongoClient | None = None
self.uri = self._create_uri()

def __enter__(self):
Expand Down Expand Up @@ -134,15 +136,39 @@ def aggregate(

return collection.aggregate(aggregate_query, **kwargs)

@overload
def find(
self,
mongo_collection: str,
query: dict,
find_one: bool = False,
find_one: Literal[False],
mongo_db: str | None = None,
projection: list | dict | None = None,
**kwargs,
) -> pymongo.cursor.Cursor:
...

@overload
def find(
self,
mongo_collection: str,
query: dict,
find_one: Literal[True],
mongo_db: str | None = None,
projection: list | dict | None = None,
**kwargs,
) -> Any | None:
...

def find(
self,
mongo_collection: str,
query: dict,
find_one: bool = False,
mongo_db: str | None = None,
projection: list | dict | None = None,
**kwargs,
) -> pymongo.cursor.Cursor | Any | None:
"""
Runs a mongo find query and returns the results
https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.find
Expand Down
12 changes: 10 additions & 2 deletions tests/providers/amazon/aws/transfers/test_mongo_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,11 @@ def test_execute(self, mock_s3_hook, mock_mongo_hook):
operator.execute(None)

mock_mongo_hook.return_value.find.assert_called_once_with(
mongo_collection=MONGO_COLLECTION, query=MONGO_QUERY, mongo_db=None, projection=None
mongo_collection=MONGO_COLLECTION,
query=MONGO_QUERY,
find_one=False,
mongo_db=None,
projection=None,
)

op_stringify = self.mock_operator._stringify
Expand All @@ -119,7 +123,11 @@ def test_execute_compress(self, mock_s3_hook, mock_mongo_hook):
operator.execute(None)

mock_mongo_hook.return_value.find.assert_called_once_with(
mongo_collection=MONGO_COLLECTION, query=MONGO_QUERY, mongo_db=None, projection=None
mongo_collection=MONGO_COLLECTION,
query=MONGO_QUERY,
find_one=False,
mongo_db=None,
projection=None,
)

op_stringify = self.mock_operator._stringify
Expand Down