Skip to content

Commit

Permalink
[CAPPL-270/271] Fix Consensus bugs (#934)
Browse files Browse the repository at this point in the history
- Fix "result is not a pointer error" in the reduce aggregator
- Continue rather than error if we encounter an aggregation error
  • Loading branch information
cedric-cordenier authored Nov 13, 2024
1 parent 8a7a997 commit cb37b93
Show file tree
Hide file tree
Showing 4 changed files with 235 additions and 28 deletions.
10 changes: 5 additions & 5 deletions pkg/capabilities/consensus/ocr3/aggregators/reduce_aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,15 +190,15 @@ func (a *reduceAggregator) initializeCurrentState(lggr logger.Logger, previousOu

if previousOutcome != nil {
pb := &pb.Map{}
proto.Unmarshal(previousOutcome.Metadata, pb)
mv, err := values.FromMapValueProto(pb)
err := proto.Unmarshal(previousOutcome.Metadata, pb)
if err != nil {
return nil, fmt.Errorf("initializeCurrentState FromMapValueProto error: %s", err.Error())
return nil, fmt.Errorf("initializeCurrentState Unmarshal error: %w", err)
}
err = mv.UnwrapTo(currentState)
mv, err := values.FromMapValueProto(pb)
if err != nil {
return nil, fmt.Errorf("initializeCurrentState FromMapValueProto error: %s", err.Error())
return nil, fmt.Errorf("initializeCurrentState FromMapValueProto error: %w", err)
}
currentState = mv.Underlying
}

zeroValue := values.NewDecimal(decimal.Zero)
Expand Down
85 changes: 67 additions & 18 deletions pkg/capabilities/consensus/ocr3/aggregators/reduce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ func TestReduceAggregator_Aggregate(t *testing.T) {
shouldReport bool
expectedState any
expectedOutcome map[string]any
previousOutcome func(t *testing.T) *types.AggregationOutcome
}{
{
name: "aggregate on int64 median",
Expand Down Expand Up @@ -211,6 +212,63 @@ func TestReduceAggregator_Aggregate(t *testing.T) {
"Price": big.NewInt(100),
},
},
{
name: "aggregate with previous outcome",
fields: []aggregators.AggregationField{
{
InputKey: "FeedID",
OutputKey: "FeedID",
Method: "mode",
},
{
InputKey: "BenchmarkPrice",
OutputKey: "Price",
Method: "median",
DeviationString: "10",
DeviationType: "percent",
},
{
InputKey: "Timestamp",
OutputKey: "Timestamp",
Method: "median",
DeviationString: "100",
DeviationType: "absolute",
},
},
extraConfig: map[string]any{},
observationsFactory: func() map[commontypes.OracleID][]values.Value {
mockValue, err := values.WrapMap(map[string]any{
"FeedID": idABytes[:],
"BenchmarkPrice": int64(100),
"Timestamp": 12341414929,
})
require.NoError(t, err)
return map[commontypes.OracleID][]values.Value{1: {mockValue}, 2: {mockValue}, 3: {mockValue}}
},
shouldReport: true,
expectedOutcome: map[string]any{
"Reports": []any{
map[string]any{
"FeedID": idABytes[:],
"Timestamp": int64(12341414929),
"Price": int64(100),
},
},
},
expectedState: map[string]any{
"FeedID": idABytes[:],
"Timestamp": int64(12341414929),
"Price": int64(100),
},
previousOutcome: func(t *testing.T) *types.AggregationOutcome {
m, err := values.NewMap(map[string]any{})
require.NoError(t, err)
pm := values.Proto(m)
bm, err := proto.Marshal(pm)
require.NoError(t, err)
return &types.AggregationOutcome{Metadata: bm}
},
},
{
name: "aggregate on bytes mode",
fields: []aggregators.AggregationField{
Expand Down Expand Up @@ -468,12 +526,19 @@ func TestReduceAggregator_Aggregate(t *testing.T) {
require.NoError(t, err)

pb := &pb.Map{}
outcome, err := agg.Aggregate(logger.Nop(), nil, tt.observationsFactory(), 1)

var po *types.AggregationOutcome
if tt.previousOutcome != nil {
po = tt.previousOutcome(t)
}

outcome, err := agg.Aggregate(logger.Nop(), po, tt.observationsFactory(), 1)
require.NoError(t, err)
require.Equal(t, tt.shouldReport, outcome.ShouldReport)

// validate metadata
proto.Unmarshal(outcome.Metadata, pb)
err = proto.Unmarshal(outcome.Metadata, pb)
require.NoError(t, err)
vmap, err := values.FromMapValueProto(pb)
require.NoError(t, err)
state, err := vmap.Unwrap()
Expand Down Expand Up @@ -517,22 +582,6 @@ func TestReduceAggregator_Aggregate(t *testing.T) {
return map[commontypes.OracleID][]values.Value{}
},
},
{
name: "empty previous outcome",
previousOutcome: &types.AggregationOutcome{},
fields: []aggregators.AggregationField{
{
Method: "median",
OutputKey: "Price",
},
},
extraConfig: map[string]any{},
observationsFactory: func() map[commontypes.OracleID][]values.Value {
mockValue, err := values.Wrap(int64(100))
require.NoError(t, err)
return map[commontypes.OracleID][]values.Value{1: {mockValue}, 2: {mockValue}, 3: {mockValue}}
},
},
{
name: "invalid previous outcome not pb",
previousOutcome: &types.AggregationOutcome{
Expand Down
2 changes: 1 addition & 1 deletion pkg/capabilities/consensus/ocr3/reporting_plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ func (r *reportingPlugin) Outcome(ctx context.Context, outctx ocr3types.OutcomeC
outcome, err2 := agg.Aggregate(lggr, workflowOutcome, obs, r.config.F)
if err2 != nil {
lggr.Errorw("error aggregating outcome", "error", err2)
return nil, err
continue
}

// Only if the previous outcome exists:
Expand Down
166 changes: 162 additions & 4 deletions pkg/capabilities/consensus/ocr3/reporting_plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package ocr3

import (
"context"
"errors"
"sort"
"testing"
"time"
Expand Down Expand Up @@ -75,7 +76,7 @@ func TestReportingPlugin_Query(t *testing.T) {

type mockCapability struct {
t *testing.T
aggregator *aggregator
aggregator pbtypes.Aggregator
encoder *enc
registeredWorkflows map[string]bool
expectedEncoderName string
Expand All @@ -102,6 +103,20 @@ func (a *aggregator) Aggregate(lggr logger.Logger, pout *pbtypes.AggregationOutc
return a.outcome, nil
}

type erroringAggregator struct {
aggregator
count int
}

func (a *erroringAggregator) Aggregate(lggr logger.Logger, pout *pbtypes.AggregationOutcome, observations map[commontypes.OracleID][]values.Value, i int) (*pbtypes.AggregationOutcome, error) {
defer func() { a.count += 1 }()
if a.count == 0 {
return nil, errors.New("failed to aggregate")
}

return a.aggregator.Aggregate(lggr, pout, observations, i)
}

type enc struct {
gotInput values.Map
}
Expand Down Expand Up @@ -258,8 +273,9 @@ func TestReportingPlugin_Observation_NoResults(t *testing.T) {
func TestReportingPlugin_Outcome(t *testing.T) {
lggr := logger.Test(t)
s := requests.NewStore()
aggregator := &aggregator{}
mcap := &mockCapability{
aggregator: &aggregator{},
aggregator: aggregator,
encoder: &enc{},
}
rp, err := newReportingPlugin(s, mcap, defaultBatchSize, ocr3types.ReportingPluginConfig{}, defaultOutcomePruningThreshold, lggr)
Expand Down Expand Up @@ -310,8 +326,83 @@ func TestReportingPlugin_Outcome(t *testing.T) {

cr := opb.CurrentReports[0]
assert.EqualExportedValues(t, cr.Id, id)
assert.EqualExportedValues(t, cr.Outcome, mcap.aggregator.outcome)
assert.EqualExportedValues(t, opb.Outcomes[workflowTestID], mcap.aggregator.outcome)
assert.EqualExportedValues(t, cr.Outcome, aggregator.outcome)
assert.EqualExportedValues(t, opb.Outcomes[workflowTestID], aggregator.outcome)
}

func TestReportingPlugin_Outcome_AggregatorErrorDoesntInterruptOtherWorkflows(t *testing.T) {
lggr := logger.Test(t)
s := requests.NewStore()
aggregator := &erroringAggregator{}
mcap := &mockCapability{
aggregator: aggregator,
encoder: &enc{},
}
rp, err := newReportingPlugin(s, mcap, defaultBatchSize, ocr3types.ReportingPluginConfig{}, defaultOutcomePruningThreshold, lggr)
require.NoError(t, err)

weid := uuid.New().String()
wowner := uuid.New().String()
id := &pbtypes.Id{
WorkflowExecutionId: weid,
WorkflowId: workflowTestID,
WorkflowOwner: wowner,
WorkflowName: workflowTestName,
ReportId: reportTestID,
}
weid2 := uuid.New().String()
id2 := &pbtypes.Id{
WorkflowExecutionId: weid2,
WorkflowId: workflowTestID,
WorkflowOwner: wowner,
WorkflowName: workflowTestName,
ReportId: reportTestID,
}
q := &pbtypes.Query{
Ids: []*pbtypes.Id{id, id2},
}
qb, err := proto.Marshal(q)
require.NoError(t, err)
o, err := values.NewList([]any{"hello"})
require.NoError(t, err)

o2, err := values.NewList([]any{"world"})
require.NoError(t, err)
obs := &pbtypes.Observations{
Observations: []*pbtypes.Observation{
{
Id: id,
Observations: values.Proto(o).GetListValue(),
},
{
Id: id2,
Observations: values.Proto(o2).GetListValue(),
},
},
}

rawObs, err := proto.Marshal(obs)
require.NoError(t, err)
aos := []types.AttributedObservation{
{
Observation: rawObs,
Observer: commontypes.OracleID(1),
},
}

outcome, err := rp.Outcome(tests.Context(t), ocr3types.OutcomeContext{}, qb, aos)
require.NoError(t, err)

opb := &pbtypes.Outcome{}
err = proto.Unmarshal(outcome, opb)
require.NoError(t, err)

assert.Len(t, opb.CurrentReports, 1)

cr := opb.CurrentReports[0]
assert.EqualExportedValues(t, cr.Id, id2)
assert.EqualExportedValues(t, cr.Outcome, aggregator.outcome)
assert.EqualExportedValues(t, opb.Outcomes[workflowTestID], aggregator.outcome)
}

func TestReportingPlugin_Outcome_NilDerefs(t *testing.T) {
Expand Down Expand Up @@ -372,6 +463,73 @@ func TestReportingPlugin_Outcome_NilDerefs(t *testing.T) {
require.NoError(t, err)
}

func TestReportingPlugin_Outcome_AggregatorErrorDoesntInterruptOtherIDs(t *testing.T) {
ctx := tests.Context(t)
lggr := logger.Test(t)
s := requests.NewStore()
mcap := &mockCapability{
aggregator: &aggregator{},
encoder: &enc{},
}
rp, err := newReportingPlugin(s, mcap, defaultBatchSize, ocr3types.ReportingPluginConfig{}, defaultOutcomePruningThreshold, lggr)
require.NoError(t, err)

weid := uuid.New().String()
wowner := uuid.New().String()
id1 := &pbtypes.Id{
WorkflowExecutionId: weid,
WorkflowId: workflowTestID,
WorkflowOwner: wowner,
WorkflowName: workflowTestName,
ReportId: reportTestID,
}

weid2 := uuid.New().String()
id2 := &pbtypes.Id{
WorkflowExecutionId: weid2,
WorkflowId: workflowTestID,
WorkflowOwner: wowner,
WorkflowName: workflowTestName,
ReportId: reportTestID,
}
q := &pbtypes.Query{
Ids: []*pbtypes.Id{
id1,
id2,
},
}
qb, err := proto.Marshal(q)
require.NoError(t, err)
aos := []types.AttributedObservation{
{
Observer: commontypes.OracleID(1),
},
{},
}

_, err = rp.Outcome(ctx, ocr3types.OutcomeContext{}, qb, aos)
require.NoError(t, err)

obs := &pbtypes.Observations{
Observations: []*pbtypes.Observation{
nil,
{},
},
RegisteredWorkflowIds: nil,
}
obsb, err := proto.Marshal(obs)
require.NoError(t, err)

aos = []types.AttributedObservation{
{
Observation: obsb,
Observer: commontypes.OracleID(1),
},
}
_, err = rp.Outcome(ctx, ocr3types.OutcomeContext{}, qb, aos)
require.NoError(t, err)
}

func TestReportingPlugin_Reports_ShouldReportFalse(t *testing.T) {
lggr := logger.Test(t)
s := requests.NewStore()
Expand Down

0 comments on commit cb37b93

Please sign in to comment.