diff --git a/pkg/capabilities/consensus/ocr3/aggregators/reduce_aggregator.go b/pkg/capabilities/consensus/ocr3/aggregators/reduce_aggregator.go index 58946f461..8afca2725 100644 --- a/pkg/capabilities/consensus/ocr3/aggregators/reduce_aggregator.go +++ b/pkg/capabilities/consensus/ocr3/aggregators/reduce_aggregator.go @@ -190,15 +190,15 @@ func (a *reduceAggregator) initializeCurrentState(lggr logger.Logger, previousOu if previousOutcome != nil { pb := &pb.Map{} - proto.Unmarshal(previousOutcome.Metadata, pb) - mv, err := values.FromMapValueProto(pb) + err := proto.Unmarshal(previousOutcome.Metadata, pb) if err != nil { - return nil, fmt.Errorf("initializeCurrentState FromMapValueProto error: %s", err.Error()) + return nil, fmt.Errorf("initializeCurrentState Unmarshal error: %w", err) } - err = mv.UnwrapTo(currentState) + mv, err := values.FromMapValueProto(pb) if err != nil { - return nil, fmt.Errorf("initializeCurrentState FromMapValueProto error: %s", err.Error()) + return nil, fmt.Errorf("initializeCurrentState FromMapValueProto error: %w", err) } + currentState = mv.Underlying } zeroValue := values.NewDecimal(decimal.Zero) diff --git a/pkg/capabilities/consensus/ocr3/aggregators/reduce_test.go b/pkg/capabilities/consensus/ocr3/aggregators/reduce_test.go index 66467dd62..5efedafbf 100644 --- a/pkg/capabilities/consensus/ocr3/aggregators/reduce_test.go +++ b/pkg/capabilities/consensus/ocr3/aggregators/reduce_test.go @@ -37,6 +37,7 @@ func TestReduceAggregator_Aggregate(t *testing.T) { shouldReport bool expectedState any expectedOutcome map[string]any + previousOutcome func(t *testing.T) *types.AggregationOutcome }{ { name: "aggregate on int64 median", @@ -211,6 +212,63 @@ func TestReduceAggregator_Aggregate(t *testing.T) { "Price": big.NewInt(100), }, }, + { + name: "aggregate with previous outcome", + fields: []aggregators.AggregationField{ + { + InputKey: "FeedID", + OutputKey: "FeedID", + Method: "mode", + }, + { + InputKey: "BenchmarkPrice", + OutputKey: "Price", + Method: "median", + DeviationString: "10", + DeviationType: "percent", + }, + { + InputKey: "Timestamp", + OutputKey: "Timestamp", + Method: "median", + DeviationString: "100", + DeviationType: "absolute", + }, + }, + extraConfig: map[string]any{}, + observationsFactory: func() map[commontypes.OracleID][]values.Value { + mockValue, err := values.WrapMap(map[string]any{ + "FeedID": idABytes[:], + "BenchmarkPrice": int64(100), + "Timestamp": 12341414929, + }) + require.NoError(t, err) + return map[commontypes.OracleID][]values.Value{1: {mockValue}, 2: {mockValue}, 3: {mockValue}} + }, + shouldReport: true, + expectedOutcome: map[string]any{ + "Reports": []any{ + map[string]any{ + "FeedID": idABytes[:], + "Timestamp": int64(12341414929), + "Price": int64(100), + }, + }, + }, + expectedState: map[string]any{ + "FeedID": idABytes[:], + "Timestamp": int64(12341414929), + "Price": int64(100), + }, + previousOutcome: func(t *testing.T) *types.AggregationOutcome { + m, err := values.NewMap(map[string]any{}) + require.NoError(t, err) + pm := values.Proto(m) + bm, err := proto.Marshal(pm) + require.NoError(t, err) + return &types.AggregationOutcome{Metadata: bm} + }, + }, { name: "aggregate on bytes mode", fields: []aggregators.AggregationField{ @@ -468,12 +526,19 @@ func TestReduceAggregator_Aggregate(t *testing.T) { require.NoError(t, err) pb := &pb.Map{} - outcome, err := agg.Aggregate(logger.Nop(), nil, tt.observationsFactory(), 1) + + var po *types.AggregationOutcome + if tt.previousOutcome != nil { + po = tt.previousOutcome(t) + } + + outcome, err := agg.Aggregate(logger.Nop(), po, tt.observationsFactory(), 1) require.NoError(t, err) require.Equal(t, tt.shouldReport, outcome.ShouldReport) // validate metadata - proto.Unmarshal(outcome.Metadata, pb) + err = proto.Unmarshal(outcome.Metadata, pb) + require.NoError(t, err) vmap, err := values.FromMapValueProto(pb) require.NoError(t, err) state, err := vmap.Unwrap() @@ -517,22 +582,6 @@ func TestReduceAggregator_Aggregate(t *testing.T) { return map[commontypes.OracleID][]values.Value{} }, }, - { - name: "empty previous outcome", - previousOutcome: &types.AggregationOutcome{}, - fields: []aggregators.AggregationField{ - { - Method: "median", - OutputKey: "Price", - }, - }, - extraConfig: map[string]any{}, - observationsFactory: func() map[commontypes.OracleID][]values.Value { - mockValue, err := values.Wrap(int64(100)) - require.NoError(t, err) - return map[commontypes.OracleID][]values.Value{1: {mockValue}, 2: {mockValue}, 3: {mockValue}} - }, - }, { name: "invalid previous outcome not pb", previousOutcome: &types.AggregationOutcome{ diff --git a/pkg/capabilities/consensus/ocr3/reporting_plugin.go b/pkg/capabilities/consensus/ocr3/reporting_plugin.go index 923fdb67b..e0eb72f53 100644 --- a/pkg/capabilities/consensus/ocr3/reporting_plugin.go +++ b/pkg/capabilities/consensus/ocr3/reporting_plugin.go @@ -339,7 +339,7 @@ func (r *reportingPlugin) Outcome(ctx context.Context, outctx ocr3types.OutcomeC outcome, err2 := agg.Aggregate(lggr, workflowOutcome, obs, r.config.F) if err2 != nil { lggr.Errorw("error aggregating outcome", "error", err2) - return nil, err + continue } // Only if the previous outcome exists: diff --git a/pkg/capabilities/consensus/ocr3/reporting_plugin_test.go b/pkg/capabilities/consensus/ocr3/reporting_plugin_test.go index d6f8326a8..18cc26e53 100644 --- a/pkg/capabilities/consensus/ocr3/reporting_plugin_test.go +++ b/pkg/capabilities/consensus/ocr3/reporting_plugin_test.go @@ -2,6 +2,7 @@ package ocr3 import ( "context" + "errors" "sort" "testing" "time" @@ -75,7 +76,7 @@ func TestReportingPlugin_Query(t *testing.T) { type mockCapability struct { t *testing.T - aggregator *aggregator + aggregator pbtypes.Aggregator encoder *enc registeredWorkflows map[string]bool expectedEncoderName string @@ -102,6 +103,20 @@ func (a *aggregator) Aggregate(lggr logger.Logger, pout *pbtypes.AggregationOutc return a.outcome, nil } +type erroringAggregator struct { + aggregator + count int +} + +func (a *erroringAggregator) Aggregate(lggr logger.Logger, pout *pbtypes.AggregationOutcome, observations map[commontypes.OracleID][]values.Value, i int) (*pbtypes.AggregationOutcome, error) { + defer func() { a.count += 1 }() + if a.count == 0 { + return nil, errors.New("failed to aggregate") + } + + return a.aggregator.Aggregate(lggr, pout, observations, i) +} + type enc struct { gotInput values.Map } @@ -258,8 +273,9 @@ func TestReportingPlugin_Observation_NoResults(t *testing.T) { func TestReportingPlugin_Outcome(t *testing.T) { lggr := logger.Test(t) s := requests.NewStore() + aggregator := &aggregator{} mcap := &mockCapability{ - aggregator: &aggregator{}, + aggregator: aggregator, encoder: &enc{}, } rp, err := newReportingPlugin(s, mcap, defaultBatchSize, ocr3types.ReportingPluginConfig{}, defaultOutcomePruningThreshold, lggr) @@ -310,8 +326,83 @@ func TestReportingPlugin_Outcome(t *testing.T) { cr := opb.CurrentReports[0] assert.EqualExportedValues(t, cr.Id, id) - assert.EqualExportedValues(t, cr.Outcome, mcap.aggregator.outcome) - assert.EqualExportedValues(t, opb.Outcomes[workflowTestID], mcap.aggregator.outcome) + assert.EqualExportedValues(t, cr.Outcome, aggregator.outcome) + assert.EqualExportedValues(t, opb.Outcomes[workflowTestID], aggregator.outcome) +} + +func TestReportingPlugin_Outcome_AggregatorErrorDoesntInterruptOtherWorkflows(t *testing.T) { + lggr := logger.Test(t) + s := requests.NewStore() + aggregator := &erroringAggregator{} + mcap := &mockCapability{ + aggregator: aggregator, + encoder: &enc{}, + } + rp, err := newReportingPlugin(s, mcap, defaultBatchSize, ocr3types.ReportingPluginConfig{}, defaultOutcomePruningThreshold, lggr) + require.NoError(t, err) + + weid := uuid.New().String() + wowner := uuid.New().String() + id := &pbtypes.Id{ + WorkflowExecutionId: weid, + WorkflowId: workflowTestID, + WorkflowOwner: wowner, + WorkflowName: workflowTestName, + ReportId: reportTestID, + } + weid2 := uuid.New().String() + id2 := &pbtypes.Id{ + WorkflowExecutionId: weid2, + WorkflowId: workflowTestID, + WorkflowOwner: wowner, + WorkflowName: workflowTestName, + ReportId: reportTestID, + } + q := &pbtypes.Query{ + Ids: []*pbtypes.Id{id, id2}, + } + qb, err := proto.Marshal(q) + require.NoError(t, err) + o, err := values.NewList([]any{"hello"}) + require.NoError(t, err) + + o2, err := values.NewList([]any{"world"}) + require.NoError(t, err) + obs := &pbtypes.Observations{ + Observations: []*pbtypes.Observation{ + { + Id: id, + Observations: values.Proto(o).GetListValue(), + }, + { + Id: id2, + Observations: values.Proto(o2).GetListValue(), + }, + }, + } + + rawObs, err := proto.Marshal(obs) + require.NoError(t, err) + aos := []types.AttributedObservation{ + { + Observation: rawObs, + Observer: commontypes.OracleID(1), + }, + } + + outcome, err := rp.Outcome(tests.Context(t), ocr3types.OutcomeContext{}, qb, aos) + require.NoError(t, err) + + opb := &pbtypes.Outcome{} + err = proto.Unmarshal(outcome, opb) + require.NoError(t, err) + + assert.Len(t, opb.CurrentReports, 1) + + cr := opb.CurrentReports[0] + assert.EqualExportedValues(t, cr.Id, id2) + assert.EqualExportedValues(t, cr.Outcome, aggregator.outcome) + assert.EqualExportedValues(t, opb.Outcomes[workflowTestID], aggregator.outcome) } func TestReportingPlugin_Outcome_NilDerefs(t *testing.T) { @@ -372,6 +463,73 @@ func TestReportingPlugin_Outcome_NilDerefs(t *testing.T) { require.NoError(t, err) } +func TestReportingPlugin_Outcome_AggregatorErrorDoesntInterruptOtherIDs(t *testing.T) { + ctx := tests.Context(t) + lggr := logger.Test(t) + s := requests.NewStore() + mcap := &mockCapability{ + aggregator: &aggregator{}, + encoder: &enc{}, + } + rp, err := newReportingPlugin(s, mcap, defaultBatchSize, ocr3types.ReportingPluginConfig{}, defaultOutcomePruningThreshold, lggr) + require.NoError(t, err) + + weid := uuid.New().String() + wowner := uuid.New().String() + id1 := &pbtypes.Id{ + WorkflowExecutionId: weid, + WorkflowId: workflowTestID, + WorkflowOwner: wowner, + WorkflowName: workflowTestName, + ReportId: reportTestID, + } + + weid2 := uuid.New().String() + id2 := &pbtypes.Id{ + WorkflowExecutionId: weid2, + WorkflowId: workflowTestID, + WorkflowOwner: wowner, + WorkflowName: workflowTestName, + ReportId: reportTestID, + } + q := &pbtypes.Query{ + Ids: []*pbtypes.Id{ + id1, + id2, + }, + } + qb, err := proto.Marshal(q) + require.NoError(t, err) + aos := []types.AttributedObservation{ + { + Observer: commontypes.OracleID(1), + }, + {}, + } + + _, err = rp.Outcome(ctx, ocr3types.OutcomeContext{}, qb, aos) + require.NoError(t, err) + + obs := &pbtypes.Observations{ + Observations: []*pbtypes.Observation{ + nil, + {}, + }, + RegisteredWorkflowIds: nil, + } + obsb, err := proto.Marshal(obs) + require.NoError(t, err) + + aos = []types.AttributedObservation{ + { + Observation: obsb, + Observer: commontypes.OracleID(1), + }, + } + _, err = rp.Outcome(ctx, ocr3types.OutcomeContext{}, qb, aos) + require.NoError(t, err) +} + func TestReportingPlugin_Reports_ShouldReportFalse(t *testing.T) { lggr := logger.Test(t) s := requests.NewStore()