Skip to content

Commit

Permalink
[CAPPL-60] Dynamic encoder selection in OCR consensus aggregator (#780)
Browse files Browse the repository at this point in the history
Co-authored-by: Cedric <cedric.cordenier@smartcontract.com>
  • Loading branch information
bolekk and cedric-cordenier authored Sep 26, 2024
1 parent 7a9a88a commit 84ed150
Show file tree
Hide file tree
Showing 8 changed files with 167 additions and 106 deletions.
8 changes: 6 additions & 2 deletions pkg/capabilities/consensus/ocr3/capability.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func (o *capability) RegisterToWorkflow(ctx context.Context, request capabilitie
}
o.aggregators[request.Metadata.WorkflowID] = agg

encoder, err := o.encoderFactory(c.EncoderConfig)
encoder, err := o.encoderFactory(c.Encoder, c.EncoderConfig, o.lggr)
if err != nil {
return err
}
Expand All @@ -143,7 +143,7 @@ func (o *capability) getAggregator(workflowID string) (types.Aggregator, error)
return agg, nil
}

func (o *capability) getEncoder(workflowID string) (types.Encoder, error) {
func (o *capability) getEncoderByWorkflowID(workflowID string) (types.Encoder, error) {
enc, ok := o.encoders[workflowID]
if !ok {
return nil, fmt.Errorf("no aggregator found for workflowID %s", workflowID)
Expand All @@ -152,6 +152,10 @@ func (o *capability) getEncoder(workflowID string) (types.Encoder, error) {
return enc, nil
}

func (o *capability) getEncoderByName(encoderName string, config *values.Map) (types.Encoder, error) {
return o.encoderFactory(encoderName, config, o.lggr)
}

func (o *capability) getRegisteredWorkflowsIDs() []string {
o.mu.RLock()
defer o.mu.RUnlock()
Expand Down
2 changes: 1 addition & 1 deletion pkg/capabilities/consensus/ocr3/capability_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ type encoder struct {
types.Encoder
}

func mockEncoderFactory(_ *values.Map) (types.Encoder, error) {
func mockEncoderFactory(_ string, _ *values.Map, _ logger.Logger) (types.Encoder, error) {
return &encoder{}, nil
}

Expand Down
33 changes: 25 additions & 8 deletions pkg/capabilities/consensus/ocr3/reporting_plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ var _ ocr3types.ReportingPlugin[[]byte] = (*reportingPlugin)(nil)

type capabilityIface interface {
getAggregator(workflowID string) (pbtypes.Aggregator, error)
getEncoder(workflowID string) (pbtypes.Encoder, error)
getEncoderByWorkflowID(workflowID string) (pbtypes.Encoder, error)
getEncoderByName(encoderName string, config *values.Map) (pbtypes.Encoder, error)
getRegisteredWorkflowsIDs() []string
unregisterWorkflowID(workflowID string)
}
Expand Down Expand Up @@ -390,7 +391,7 @@ func (r *reportingPlugin) Reports(seqNr uint64, outcome ocr3types.Outcome) ([]oc
ShouldReport: outcome.ShouldReport,
}

var report []byte
var rawReport []byte
if info.ShouldReport {
meta := &pbtypes.Metadata{
Version: 1,
Expand All @@ -409,10 +410,26 @@ func (r *reportingPlugin) Reports(seqNr uint64, outcome ocr3types.Outcome) ([]oc
continue
}

enc, err := r.r.getEncoder(id.WorkflowId)
if err != nil {
lggr.Errorw("could not retrieve encoder for workflow", "error", err)
continue
var encoder pbtypes.Encoder
if newOutcome.EncoderName != "" {
lggr.Debugw("using encoder from outcome", "encoderName", newOutcome.EncoderName, "executionID", report.Id.WorkflowExecutionId)
encoderConfig, err2 := values.FromMapValueProto(newOutcome.EncoderConfig)
if err2 != nil {
lggr.Errorw("could not convert desired encoder config to values.Map", "error", err2, "executionID", report.Id.WorkflowExecutionId)
} else {
encoder, err2 = r.r.getEncoderByName(newOutcome.EncoderName, encoderConfig)
if err2 != nil {
lggr.Errorw("could not retrieve desired encoder, will use per-workflow default", "error", err2, "executionID", report.Id.WorkflowExecutionId)
}
}
}

if encoder == nil {
encoder, err = r.r.getEncoderByWorkflowID(id.WorkflowId)
if err != nil {
lggr.Errorw("could not retrieve encoder for workflow", "error", err)
continue
}
}

mv, err := values.FromMapValueProto(newOutcome.EncodableOutcome)
Expand All @@ -421,7 +438,7 @@ func (r *reportingPlugin) Reports(seqNr uint64, outcome ocr3types.Outcome) ([]oc
continue
}

report, err = enc.Encode(context.Background(), *mv)
rawReport, err = encoder.Encode(context.Background(), *mv)
if err != nil {
r.lggr.Errorw("could not encode report for workflow", "error", err)
continue
Expand All @@ -436,7 +453,7 @@ func (r *reportingPlugin) Reports(seqNr uint64, outcome ocr3types.Outcome) ([]oc

// Append every report, even if shouldReport = false, to let the transmitter mark the step as complete.
reports = append(reports, ocr3types.ReportWithInfo[[]byte]{
Report: report,
Report: rawReport,
Info: infob,
})
}
Expand Down
17 changes: 14 additions & 3 deletions pkg/capabilities/consensus/ocr3/reporting_plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,11 @@ func TestReportingPlugin_Query(t *testing.T) {
}

type mockCapability struct {
t *testing.T
aggregator *aggregator
encoder *enc
registeredWorkflows map[string]bool
expectedEncoderName string
}

type aggregator struct {
Expand Down Expand Up @@ -114,7 +116,12 @@ func (mc *mockCapability) getAggregator(workflowID string) (pbtypes.Aggregator,
return mc.aggregator, nil
}

func (mc *mockCapability) getEncoder(workflowID string) (pbtypes.Encoder, error) {
func (mc *mockCapability) getEncoderByWorkflowID(workflowID string) (pbtypes.Encoder, error) {
return mc.encoder, nil
}

func (mc *mockCapability) getEncoderByName(encoderName string, config *values.Map) (pbtypes.Encoder, error) {
require.Equal(mc.t, mc.expectedEncoderName, encoderName)
return mc.encoder, nil
}

Expand Down Expand Up @@ -460,10 +467,13 @@ func TestReportingPlugin_Reports_NilDerefs(t *testing.T) {

func TestReportingPlugin_Reports_ShouldReportTrue(t *testing.T) {
lggr := logger.Test(t)
dynamicEncoderName := "special_encoder"
s := requests.NewStore()
mcap := &mockCapability{
aggregator: &aggregator{},
encoder: &enc{},
t: t,
aggregator: &aggregator{},
encoder: &enc{},
expectedEncoderName: dynamicEncoderName,
}
rp, err := newReportingPlugin(s, mcap, defaultBatchSize, ocr3types.ReportingPluginConfig{}, defaultOutcomePruningThreshold, lggr)
require.NoError(t, err)
Expand Down Expand Up @@ -494,6 +504,7 @@ func TestReportingPlugin_Reports_ShouldReportTrue(t *testing.T) {
Outcome: &pbtypes.AggregationOutcome{
EncodableOutcome: nmp,
ShouldReport: true,
EncoderName: dynamicEncoderName,
},
},
},
Expand Down
8 changes: 6 additions & 2 deletions pkg/capabilities/consensus/ocr3/transmitter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ func TestTransmitter(t *testing.T) {
clockwork.NewFakeClock(),
10*time.Second,
mockAggregatorFactory,
func(config *values.Map) (pbtypes.Encoder, error) { return &encoder{}, nil },
func(_ string, _ *values.Map, _ logger.Logger) (pbtypes.Encoder, error) {
return &encoder{}, nil
},
lggr,
10,
)
Expand Down Expand Up @@ -127,7 +129,9 @@ func TestTransmitter_ShouldReportFalse(t *testing.T) {
clockwork.NewFakeClock(),
10*time.Second,
mockAggregatorFactory,
func(config *values.Map) (pbtypes.Encoder, error) { return &encoder{}, nil },
func(_ string, _ *values.Map, _ logger.Logger) (pbtypes.Encoder, error) {
return &encoder{}, nil
},
lggr,
10,
)
Expand Down
3 changes: 2 additions & 1 deletion pkg/capabilities/consensus/ocr3/types/encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ package types
import (
"context"

"github.com/smartcontractkit/chainlink-common/pkg/logger"
"github.com/smartcontractkit/chainlink-common/pkg/values"
)

type Encoder interface {
Encode(ctx context.Context, input values.Map) ([]byte, error)
}

type EncoderFactory func(config *values.Map) (Encoder, error)
type EncoderFactory func(name string, config *values.Map, lggr logger.Logger) (Encoder, error)

type SignedReport struct {
Report []byte
Expand Down
Loading

0 comments on commit 84ed150

Please sign in to comment.