Skip to content

Commit

Permalink
feat(eap-api): support conditional aggregation for TimeSeriesEndpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
Rachel Chen authored and Rachel Chen committed Feb 20, 2025
1 parent 671f1f2 commit 11057ce
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 4 deletions.
30 changes: 30 additions & 0 deletions snuba/web/rpc/v1/endpoint_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import uuid
from typing import Type

from sentry_protos.snuba.v1.attribute_conditional_aggregation_pb2 import (
AttributeConditionalAggregation,
)
from sentry_protos.snuba.v1.endpoint_time_series_pb2 import (
Expression,
TimeSeriesRequest,
Expand Down Expand Up @@ -89,6 +92,32 @@ def _convert_aggregations_to_expressions(
return request


def _convert_to_conditional_aggregation(in_msg: TimeSeriesRequest) -> None:
def _add_conditional_aggregation(
input: Expression,
) -> None:
aggregation = input.aggregation
input.ClearField("aggregation")
input.conditional_aggregation.CopyFrom(
AttributeConditionalAggregation(
aggregate=aggregation.aggregate,
key=aggregation.key,
label=aggregation.label,
extrapolation_mode=aggregation.extrapolation_mode,
)
)

def _convert(input: Expression) -> None:
if input.HasField("aggregation"):
_add_conditional_aggregation(input)
if input.HasField("formula"):
_convert(input.formula.left)
_convert(input.formula.right)

for expression in in_msg.expressions:
_convert(expression)


class EndpointTimeSeries(RPCEndpoint[TimeSeriesRequest, TimeSeriesResponse]):
@classmethod
def version(cls) -> str:
Expand Down Expand Up @@ -122,5 +151,6 @@ def _execute(self, in_msg: TimeSeriesRequest) -> TimeSeriesResponse:
"This endpoint requires meta.trace_item_type to be set (are you requesting spans? logs?)"
)
in_msg = _convert_aggregations_to_expressions(in_msg)
_convert_to_conditional_aggregation(in_msg)
resolver = self.get_resolver(in_msg.meta.trace_item_type)
return resolver.resolve(in_msg)
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,9 @@ def _get_reliability_context_columns(

aggregates = []
for e in expressions:
if e.WhichOneof("expression") == "aggregation":
if e.WhichOneof("expression") == "conditional_aggregation":
# ignore formulas
aggregates.append(e.aggregation)
aggregates.append(e.conditional_aggregation)

for aggregation in aggregates:
if (
Expand Down Expand Up @@ -225,8 +225,8 @@ def _get_reliability_context_columns(

def _proto_expression_to_ast_expression(expr: ProtoExpression) -> Expression:
match expr.WhichOneof("expression"):
case "aggregation":
return aggregation_to_expression(expr.aggregation)
case "conditional_aggregation":
return aggregation_to_expression(expr.conditional_aggregation)
case "formula":
formula_expr = OP_TO_EXPR[expr.formula.op](
_proto_expression_to_ast_expression(expr.formula.left),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import pytest
from clickhouse_driver.errors import ServerException
from google.protobuf.timestamp_pb2 import Timestamp
from sentry_protos.snuba.v1.attribute_conditional_aggregation_pb2 import (
AttributeConditionalAggregation,
)
from sentry_protos.snuba.v1.endpoint_time_series_pb2 import (
DataPoint,
Expression,
Expand Down Expand Up @@ -237,6 +240,109 @@ def test_fails_for_logs(self) -> None:
error.ParseFromString(response.data)
assert response.status_code == 400, (error.message, error.details)

def test_rachel(self) -> None:
# store a test metric with a value of 1, for ever even second of one hour
granularity_secs = 300
query_duration = 60 * 30
store_spans_timeseries(
BASE_TIME,
1,
3600,
metrics=[DummyMetric("test_metric", get_value=lambda x: int(x % 2 == 0))],
)

test_metric_attribute_key = AttributeKey(
type=AttributeKey.TYPE_FLOAT, name="test_metric"
)
test_metric_is_one_filter = TraceItemFilter(
comparison_filter=ComparisonFilter(
key=test_metric_attribute_key,
op=ComparisonFilter.OP_EQUALS,
value=AttributeValue(val_int=1),
)
)

message = TimeSeriesRequest(
meta=RequestMeta(
project_ids=[1, 2, 3],
organization_id=1,
cogs_category="something",
referrer="something",
start_timestamp=Timestamp(seconds=int(BASE_TIME.timestamp())),
end_timestamp=Timestamp(
seconds=int(BASE_TIME.timestamp() + query_duration)
),
trace_item_type=TraceItemType.TRACE_ITEM_TYPE_SPAN,
),
expressions=[
Expression(
formula=Expression.BinaryFormula(
op=Expression.BinaryFormula.OP_ADD,
left=Expression(
conditional_aggregation=AttributeConditionalAggregation(
aggregate=Function.FUNCTION_SUM,
key=test_metric_attribute_key,
label="sum",
extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE,
filter=test_metric_is_one_filter,
)
),
right=Expression(
conditional_aggregation=AttributeConditionalAggregation(
aggregate=Function.FUNCTION_AVG,
key=test_metric_attribute_key,
label="avg",
extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE,
filter=test_metric_is_one_filter,
)
),
),
label="sum + avg",
),
],
granularity_secs=granularity_secs,
)
response = EndpointTimeSeries().execute(message)
expected_buckets = [
Timestamp(seconds=int(BASE_TIME.timestamp()) + secs)
for secs in range(0, query_duration, granularity_secs)
]

expected_avg_timeseries = TimeSeries(
label="avg",
buckets=expected_buckets,
data_points=[
DataPoint(data=1, data_present=True, sample_count=150)
for _ in range(len(expected_buckets))
],
)
expected_sum_timeseries = TimeSeries(
label="sum",
buckets=expected_buckets,
data_points=[
DataPoint(data=150, data_present=True)
for _ in range(len(expected_buckets))
],
)
expected_formula_timeseries = TimeSeries(
label="sum + avg",
buckets=expected_buckets,
data_points=[
DataPoint(
data=sum_datapoint.data + avg_datapoint.data,
data_present=True,
sample_count=sum_datapoint.sample_count,
)
for sum_datapoint, avg_datapoint in zip(
expected_sum_timeseries.data_points,
expected_avg_timeseries.data_points,
)
],
)
assert sorted(response.result_timeseries, key=lambda x: x.label) == [
expected_formula_timeseries
]

def test_sum(self) -> None:
# store a a test metric with a value of 1, every second of one hour
granularity_secs = 300
Expand Down

0 comments on commit 11057ce

Please sign in to comment.