Skip to content

Commit

Permalink
feat: Add support for Sum and Avg aggregation query
Browse files Browse the repository at this point in the history
Add .sum() and .avg() functions to aggregation

Refactor limit to be passed in to the nested query's limit

Unit tests
  • Loading branch information
Mariatta committed May 9, 2023
1 parent 8036071 commit db1f820
Show file tree
Hide file tree
Showing 3 changed files with 264 additions and 13 deletions.
85 changes: 82 additions & 3 deletions google/cloud/datastore/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ class BaseAggregation(ABC):
Base class representing an Aggregation operation in Datastore
"""

def __init__(self, alias=None):
self.alias = alias

@abc.abstractmethod
def _to_pb(self):
"""
Expand All @@ -59,7 +62,7 @@ class CountAggregation(BaseAggregation):
"""

def __init__(self, alias=None):
self.alias = alias
super(CountAggregation, self).__init__(alias=alias)

def _to_pb(self):
"""
Expand All @@ -71,6 +74,61 @@ def _to_pb(self):
return aggregation_pb


class SumAggregation(BaseAggregation):
"""
Representation of a "Sum" aggregation query.
:type property_ref: str
:param property_ref: The property_ref for the aggregation.
:type value: int
:param value: The resulting value from the aggregation.
"""

def __init__(self, property_ref, alias=None):
self.property_ref = property_ref
super(SumAggregation, self).__init__(alias=alias)

def _to_pb(self):
"""
Convert this instance to the protobuf representation
"""
aggregation_pb = query_pb2.AggregationQuery.Aggregation()
aggregation_pb.alias = self.alias
aggregation_pb.sum_ = query_pb2.AggregationQuery.Aggregation.Sum()
aggregation_pb.sum_.property.name = self.property_ref
aggregation_pb.alias = self.alias
return aggregation_pb


class AvgAggregation(BaseAggregation):
"""
Representation of a "Avg" aggregation query.
:type property_ref: str
:param property_ref: The property_ref for the aggregation.
:type value: int
:param value: The resulting value from the aggregation.
"""

def __init__(self, property_ref, alias=None):
self.property_ref = property_ref
super(AvgAggregation, self).__init__(alias=alias)

def _to_pb(self):
"""
Convert this instance to the protobuf representation
"""
aggregation_pb = query_pb2.AggregationQuery.Aggregation()
aggregation_pb.avg = query_pb2.AggregationQuery.Aggregation.Avg()
aggregation_pb.avg.property.name = self.property_ref
aggregation_pb.alias = self.alias
return aggregation_pb


class AggregationResult(object):
"""
A class representing result from Aggregation Query
Expand Down Expand Up @@ -154,6 +212,28 @@ def count(self, alias=None):
self._aggregations.append(count_aggregation)
return self

def sum(self, property_ref, alias=None):
"""
Adds a sum over the nested query
:type property_ref: str
:param property_ref: The property_ref for the sum
"""
sum_aggregation = SumAggregation(property_ref=property_ref, alias=alias)
self._aggregations.append(sum_aggregation)
return self

def avg(self, property_ref, alias=None):
"""
Adds a avg over the nested query
:type property_ref: str
:param property_ref: The property_ref for the sum
"""
avg_aggregation = AvgAggregation(property_ref=property_ref, alias=alias)
self._aggregations.append(avg_aggregation)
return self

def add_aggregation(self, aggregation):
"""
Adds an aggregation operation to the nested query
Expand Down Expand Up @@ -327,8 +407,7 @@ def _build_protobuf(self):
"""
pb = self._aggregation_query._to_pb()
if self._limit is not None and self._limit > 0:
for aggregation in pb.aggregations:
aggregation.count.up_to = self._limit
pb.nested_query.limit = self._limit
return pb

def _process_query_results(self, response_pb):
Expand Down
85 changes: 78 additions & 7 deletions tests/system/test_aggregation_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,54 @@ def test_aggregation_query_with_alias(aggregation_query_client, nested_query):
assert r.value > 0


def test_sum_query_default(aggregation_query_client, nested_query):
query = nested_query

aggregation_query = aggregation_query_client.aggregation_query(query)
aggregation_query.sum("person")
result = _do_fetch(aggregation_query)
assert len(result) == 1
for r in result[0]:
assert r.alias == "property_1"
assert r.value == 8


def test_sum_query_with_alias(aggregation_query_client, nested_query):
query = nested_query

aggregation_query = aggregation_query_client.aggregation_query(query)
aggregation_query.sum("person", alias="sum_person")
result = _do_fetch(aggregation_query)
assert len(result) == 1
for r in result[0]:
assert r.alias == "sum_person"
assert r.value > 0


def test_avg_query_default(aggregation_query_client, nested_query):
query = nested_query

aggregation_query = aggregation_query_client.aggregation_query(query)
aggregation_query.avg("person")
result = _do_fetch(aggregation_query)
assert len(result) == 1
for r in result[0]:
assert r.alias == "property_1"
assert r.value == 8


def test_avg_query_with_alias(aggregation_query_client, nested_query):
query = nested_query

aggregation_query = aggregation_query_client.aggregation_query(query)
aggregation_query.avg("person", alias="avg_person")
result = _do_fetch(aggregation_query)
assert len(result) == 1
for r in result[0]:
assert r.alias == "avg_person"
assert r.value > 0


def test_aggregation_query_with_limit(aggregation_query_client, nested_query):
query = nested_query

Expand Down Expand Up @@ -121,41 +169,60 @@ def test_aggregation_query_multiple_aggregations(
aggregation_query = aggregation_query_client.aggregation_query(query)
aggregation_query.count(alias="total")
aggregation_query.count(alias="all")
aggregation_query.sum("person", alias="sum_person")
aggregation_query.avg("person", alias="avg_person")
result = _do_fetch(aggregation_query)
assert len(result) == 1
for r in result[0]:
assert r.alias in ["all", "total"]
assert r.alias in ["all", "total", "sum_person", "avg_person"]
assert r.value > 0


def test_aggregation_query_add_aggregation(aggregation_query_client, nested_query):
from google.cloud.datastore.aggregation import CountAggregation
from google.cloud.datastore.aggregation import SumAggregation
from google.cloud.datastore.aggregation import AvgAggregation

query = nested_query

aggregation_query = aggregation_query_client.aggregation_query(query)
count_aggregation = CountAggregation(alias="total")
aggregation_query.add_aggregation(count_aggregation)

sum_aggregation = SumAggregation("person", alias="sum_person")
aggregation_query.add_aggregation(sum_aggregation)

avg_aggregation = AvgAggregation("person", alias="avg_person")
aggregation_query.add_aggregation(avg_aggregation)

result = _do_fetch(aggregation_query)
assert len(result) == 1
for r in result[0]:
assert r.alias == "total"
assert r.alias in ["total", "sum_person", "avg_person"]
assert r.value > 0


def test_aggregation_query_add_aggregations(aggregation_query_client, nested_query):
from google.cloud.datastore.aggregation import CountAggregation
from google.cloud.datastore.aggregation import (
CountAggregation,
SumAggregation,
AvgAggregation,
)

query = nested_query

aggregation_query = aggregation_query_client.aggregation_query(query)
count_aggregation_1 = CountAggregation(alias="total")
count_aggregation_2 = CountAggregation(alias="all")
aggregation_query.add_aggregations([count_aggregation_1, count_aggregation_2])
sum_aggregation = SumAggregation("person", alias="sum_person")
avg_aggregation = AvgAggregation("person", alias="avg_person")
aggregation_query.add_aggregations(
[count_aggregation_1, count_aggregation_2, sum_aggregation, avg_aggregation]
)
result = _do_fetch(aggregation_query)
assert len(result) == 1
for r in result[0]:
assert r.alias in ["total", "all"]
assert r.alias in ["total", "all", "sum_person", "avg_person"]
assert r.value > 0


Expand Down Expand Up @@ -202,11 +269,13 @@ def test_aggregation_query_with_nested_query_filtered(

aggregation_query = aggregation_query_client.aggregation_query(query)
aggregation_query.count(alias="total")
aggregation_query.sum("person", alias="sum_person")
aggregation_query.avg("person", alias="avg_person")
result = _do_fetch(aggregation_query)
assert len(result) == 1

for r in result[0]:
assert r.alias == "total"
assert r.alias in ["total", "sum_person", "avg_person"]
assert r.value == 6


Expand All @@ -226,9 +295,11 @@ def test_aggregation_query_with_nested_query_multiple_filters(

aggregation_query = aggregation_query_client.aggregation_query(query)
aggregation_query.count(alias="total")
aggregation_query.sum("person", alias="sum_person")
aggregation_query.avg("person", alias="avg_person")
result = _do_fetch(aggregation_query)
assert len(result) == 1

for r in result[0]:
assert r.alias == "total"
assert r.alias in ["total", "sum_person", "avg_person"]
assert r.value == 4
Loading

0 comments on commit db1f820

Please sign in to comment.