From bb60ce0dd52c370e3a1253ffc4333af6ae359902 Mon Sep 17 00:00:00 2001 From: Linchin Date: Tue, 6 Aug 2024 16:53:33 -0700 Subject: [PATCH] aggregation related type annotation --- google/cloud/firestore_v1/aggregation.py | 6 +++--- google/cloud/firestore_v1/base_aggregation.py | 19 +++++++++++++------ 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/google/cloud/firestore_v1/aggregation.py b/google/cloud/firestore_v1/aggregation.py index 183c869a7..d3182d21a 100644 --- a/google/cloud/firestore_v1/aggregation.py +++ b/google/cloud/firestore_v1/aggregation.py @@ -20,7 +20,7 @@ """ from __future__ import annotations -from typing import TYPE_CHECKING, Generator, Optional, Tuple, Union +from typing import TYPE_CHECKING, Generator, List, Optional, Tuple, Union from google.api_core import exceptions, gapic_v1 from google.api_core import retry as retries @@ -133,7 +133,7 @@ def _make_stream( retry: Optional[retries.Retry] = gapic_v1.method.DEFAULT, timeout: Optional[float] = None, explain_options: Optional[ExplainOptions] = None, - ) -> Generator[Tuple[Optional[DocumentSnapshot], Optional[ExplainMetrics]]]: + ) -> Generator[Tuple[Optional[List[AggregationResult]], Optional[ExplainMetrics]]]: """Internal method for stream(). Runs the aggregation query. This sends a ``RunAggregationQuery`` RPC and then returns a generator @@ -158,7 +158,7 @@ def _make_stream( explain_metrics will be available on the returned generator. Yields: - Tuple[Optional[DocumentSnapshot], Optional[ExplainMetrics]]: + Tuple[Optional[List[AggregationResult]], Optional[ExplainMetrics]]: The result of aggregations of this query. """ diff --git a/google/cloud/firestore_v1/base_aggregation.py b/google/cloud/firestore_v1/base_aggregation.py index 277addf37..e414a153e 100644 --- a/google/cloud/firestore_v1/base_aggregation.py +++ b/google/cloud/firestore_v1/base_aggregation.py @@ -27,9 +27,7 @@ from typing import ( TYPE_CHECKING, Any, - AsyncGenerator, Coroutine, - Generator, List, Optional, Tuple, @@ -49,7 +47,12 @@ # Types needed only for Type Hints if TYPE_CHECKING: # pragma: NO COVER from google.cloud.firestore_v1 import transaction + from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator from google.cloud.firestore_v1.query_profile import ExplainOptions + from google.cloud.firestore_v1.stream_generator import ( + QueryResultsList, + StreamGenerator, + ) class AggregationResult(object): @@ -236,7 +239,10 @@ def get( timeout: float | None = None, *, explain_options: Optional[ExplainOptions] = None, - ) -> List[AggregationResult] | Coroutine[Any, Any, List[AggregationResult]]: + ) -> ( + QueryResultsList[AggregationResult] + | Coroutine[Any, Any, List[AggregationResult]] + ): """Runs the aggregation query. This sends a ``RunAggregationQuery`` RPC and returns a list of aggregation results in the stream of ``RunAggregationQueryResponse`` messages. @@ -258,7 +264,8 @@ def get( explain_metrics will be available on the returned generator. Returns: - list: The aggregation query results + QueryResultsList[AggregationResult] | Coroutine[Any, Any, List[AggregationResult]]: + The aggregation query results. """ @@ -271,8 +278,8 @@ def stream( *, explain_options: Optional[ExplainOptions] = None, ) -> ( - Generator[List[AggregationResult], Any, None] - | AsyncGenerator[List[AggregationResult], None] + StreamGenerator[List[AggregationResult], Any, None] + | AsyncStreamGenerator[List[AggregationResult], Any, None] ): """Runs the aggregation query.