From 11b8f1108c8ba36011dc2a7db17eee417b6054b3 Mon Sep 17 00:00:00 2001 From: testinginprod <98415576+testinginprod@users.noreply.github.com> Date: Thu, 20 Jun 2024 15:45:26 +0200 Subject: [PATCH] feat(stf): change router service to extract the router at runtime rather than build time (#20724) Co-authored-by: unknown unknown --- runtime/v2/builder.go | 22 +++--- runtime/v2/module.go | 4 +- server/v2/stf/core_branch_service_test.go | 13 ++-- server/v2/stf/core_router_service.go | 92 ++++++----------------- server/v2/stf/stf.go | 57 +++++++++----- server/v2/stf/stf_router.go | 55 ++++++++++---- server/v2/stf/stf_test.go | 57 +++++++++++--- store/rootmulti/store_test.go | 1 - 8 files changed, 164 insertions(+), 137 deletions(-) diff --git a/runtime/v2/builder.go b/runtime/v2/builder.go index 0898f29bed40..e3728a2ec0cb 100644 --- a/runtime/v2/builder.go +++ b/runtime/v2/builder.go @@ -96,21 +96,12 @@ func (a *AppBuilder) Build(opts ...AppBuilderOption) (*App, error) { return nil, err } - stfMsgHandler, err := a.app.msgRouterBuilder.Build() - if err != nil { - return nil, fmt.Errorf("failed to build STF message handler: %w", err) - } - - stfQueryHandler, err := a.app.queryRouterBuilder.Build() - if err != nil { - return nil, fmt.Errorf("failed to build query handler: %w", err) - } - endBlocker, valUpdate := a.app.moduleManager.EndBlock() - a.app.stf = stf.NewSTF[transaction.Tx]( - stfMsgHandler, - stfQueryHandler, + stf, err := stf.NewSTF[transaction.Tx]( + a.app.logger.With("module", "stf"), + a.app.msgRouterBuilder, + a.app.queryRouterBuilder, a.app.moduleManager.PreBlocker(), a.app.moduleManager.BeginBlock(), endBlocker, @@ -119,6 +110,11 @@ func (a *AppBuilder) Build(opts ...AppBuilderOption) (*App, error) { a.postTxExec, a.branch, ) + if err != nil { + return nil, fmt.Errorf("failed to create STF: %w", err) + } + + a.app.stf = stf rs, err := rootstore.CreateRootStore(a.storeOptions) if err != nil { diff --git a/runtime/v2/module.go b/runtime/v2/module.go index 3c18d07f6979..85f739954d8c 100644 --- a/runtime/v2/module.go +++ b/runtime/v2/module.go @@ -212,8 +212,8 @@ func ProvideEnvironment(logger log.Logger, config *runtimev2.Module, key depinje EventService: stf.NewEventService(), GasService: stf.NewGasMeterService(), HeaderService: stf.HeaderService{}, - QueryRouterService: stf.NewQueryRouterService(appBuilder.app.queryRouterBuilder), - MsgRouterService: stf.NewMsgRouterService(appBuilder.app.msgRouterBuilder), + QueryRouterService: stf.NewQueryRouterService(), + MsgRouterService: stf.NewMsgRouterService([]byte(key.Name())), TransactionService: services.NewContextAwareTransactionService(), KVStoreService: kvService, MemStoreService: memKvService, diff --git a/server/v2/stf/core_branch_service_test.go b/server/v2/stf/core_branch_service_test.go index 722f2ac7f314..ce63c984cc70 100644 --- a/server/v2/stf/core_branch_service_test.go +++ b/server/v2/stf/core_branch_service_test.go @@ -6,9 +6,9 @@ import ( "testing" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/wrapperspb" appmodulev2 "cosmossdk.io/core/appmodule/v2" - "cosmossdk.io/core/transaction" "cosmossdk.io/server/v2/stf/branch" "cosmossdk.io/server/v2/stf/gas" "cosmossdk.io/server/v2/stf/mock" @@ -16,12 +16,7 @@ import ( func TestBranchService(t *testing.T) { s := &STF[mock.Tx]{ - handleMsg: func(ctx context.Context, msg transaction.Msg) (msgResp transaction.Msg, err error) { - kvSet(t, ctx, "exec") - return nil, nil - }, - handleQuery: nil, - doPreBlock: func(ctx context.Context, txs []mock.Tx) error { return nil }, + doPreBlock: func(ctx context.Context, txs []mock.Tx) error { return nil }, doBeginBlock: func(ctx context.Context) error { kvSet(t, ctx, "begin-block") return nil @@ -43,6 +38,10 @@ func TestBranchService(t *testing.T) { makeGasMeter: gas.DefaultGasMeter, makeGasMeteredState: gas.DefaultWrapWithGasMeter, } + addMsgHandlerToSTF(t, s, func(ctx context.Context, msg *wrapperspb.BoolValue) (*wrapperspb.BoolValue, error) { + kvSet(t, ctx, "exec") + return nil, nil + }) makeContext := func() *executionContext { state := mock.DB() diff --git a/server/v2/stf/core_router_service.go b/server/v2/stf/core_router_service.go index 15da47e87cdc..dd41469f01c2 100644 --- a/server/v2/stf/core_router_service.go +++ b/server/v2/stf/core_router_service.go @@ -2,120 +2,72 @@ package stf import ( "context" - "errors" - "fmt" - "reflect" - "strings" "google.golang.org/protobuf/runtime/protoiface" - appmodulev2 "cosmossdk.io/core/appmodule/v2" "cosmossdk.io/core/router" + "cosmossdk.io/core/transaction" ) // NewMsgRouterService implements router.Service. -func NewMsgRouterService(msgRouterBuilder *MsgRouterBuilder) router.Service { - msgRouter, err := msgRouterBuilder.Build() - if err != nil { - panic(fmt.Errorf("cannot create msgRouter: %w", err)) - } - - return &msgRouterService{ - builder: msgRouterBuilder, - handler: msgRouter, - } +func NewMsgRouterService(identity transaction.Identity) router.Service { + return msgRouterService{identity: identity} } var _ router.Service = (*msgRouterService)(nil) type msgRouterService struct { - builder *MsgRouterBuilder - handler appmodulev2.Handler + // TODO(tip): the identity sits here for the purpose of disallowing modules to impersonate others (sudo). + // right now this is not used, but it serves the reminder of something that we should be eventually + // looking into. + identity []byte } // CanInvoke returns an error if the given message cannot be invoked. -func (m *msgRouterService) CanInvoke(ctx context.Context, typeURL string) error { - if typeURL == "" { - return errors.New("missing type url") - } - - typeURL = strings.TrimPrefix(typeURL, "/") - if exists := m.builder.HandlerExists(typeURL); exists { - return fmt.Errorf("unknown request: %s", typeURL) - } - - return nil +func (m msgRouterService) CanInvoke(ctx context.Context, typeURL string) error { + return ctx.(*executionContext).msgRouter.CanInvoke(ctx, typeURL) } // InvokeTyped execute a message and fill-in a response. // The response must be known and passed as a parameter. // Use InvokeUntyped if the response type is not known. -func (m *msgRouterService) InvokeTyped(ctx context.Context, msg, resp protoiface.MessageV1) error { - // see https://github.com/cosmos/cosmos-sdk/pull/20349 - panic("not implemented") +func (m msgRouterService) InvokeTyped(ctx context.Context, msg, resp protoiface.MessageV1) error { + return ctx.(*executionContext).msgRouter.InvokeTyped(ctx, msg, resp) } // InvokeUntyped execute a message and returns a response. -func (m *msgRouterService) InvokeUntyped(ctx context.Context, msg protoiface.MessageV1) (protoiface.MessageV1, error) { - return m.handler(ctx, msg) +func (m msgRouterService) InvokeUntyped(ctx context.Context, msg protoiface.MessageV1) (protoiface.MessageV1, error) { + return ctx.(*executionContext).msgRouter.InvokeUntyped(ctx, msg) } // NewQueryRouterService implements router.Service. -func NewQueryRouterService(queryRouterBuilder *MsgRouterBuilder) router.Service { - return &queryRouterService{ - builder: queryRouterBuilder, - } +func NewQueryRouterService() router.Service { + return queryRouterService{} } var _ router.Service = (*queryRouterService)(nil) -type queryRouterService struct { - builder *MsgRouterBuilder - handler appmodulev2.Handler -} +type queryRouterService struct{} // CanInvoke returns an error if the given request cannot be invoked. -func (m *queryRouterService) CanInvoke(ctx context.Context, typeURL string) error { - if typeURL == "" { - return errors.New("missing type url") - } - - typeURL = strings.TrimPrefix(typeURL, "/") - if exists := m.builder.HandlerExists(typeURL); exists { - return fmt.Errorf("unknown request: %s", typeURL) - } - - return nil +func (m queryRouterService) CanInvoke(ctx context.Context, typeURL string) error { + return ctx.(*executionContext).queryRouter.CanInvoke(ctx, typeURL) } // InvokeTyped execute a message and fill-in a response. // The response must be known and passed as a parameter. // Use InvokeUntyped if the response type is not known. -func (m *queryRouterService) InvokeTyped( +func (m queryRouterService) InvokeTyped( ctx context.Context, req, resp protoiface.MessageV1, ) error { - // TODO lazy initialization is ugly and not thread safe. we don't want to check a mutex on every InvokeTyped either. - if m.handler == nil { - var err error - m.handler, err = m.builder.Build() - if err != nil { - return fmt.Errorf("cannot create queryRouter: %w", err) - } - } - // reflection is required, see https://github.com/cosmos/cosmos-sdk/pull/20349 - res, err := m.handler(ctx, req) - if err != nil { - return err - } - reflect.Indirect(reflect.ValueOf(resp)).Set(reflect.Indirect(reflect.ValueOf(res))) - return nil + return ctx.(*executionContext).queryRouter.InvokeTyped(ctx, req, resp) } // InvokeUntyped execute a message and returns a response. -func (m *queryRouterService) InvokeUntyped( +func (m queryRouterService) InvokeUntyped( ctx context.Context, req protoiface.MessageV1, ) (protoiface.MessageV1, error) { - return m.handler(ctx, req) + return ctx.(*executionContext).queryRouter.InvokeUntyped(ctx, req) } diff --git a/server/v2/stf/stf.go b/server/v2/stf/stf.go index fd6d434c91c6..6707e1d8826d 100644 --- a/server/v2/stf/stf.go +++ b/server/v2/stf/stf.go @@ -12,6 +12,7 @@ import ( "cosmossdk.io/core/gas" "cosmossdk.io/core/header" "cosmossdk.io/core/log" + "cosmossdk.io/core/router" "cosmossdk.io/core/store" "cosmossdk.io/core/transaction" stfgas "cosmossdk.io/server/v2/stf/gas" @@ -23,9 +24,10 @@ var Identity = []byte("stf") // STF is a struct that manages the state transition component of the app. type STF[T transaction.Tx] struct { - logger log.Logger - handleMsg func(ctx context.Context, msg transaction.Msg) (transaction.Msg, error) - handleQuery func(ctx context.Context, req transaction.Msg) (transaction.Msg, error) + logger log.Logger + + msgRouter Router + queryRouter Router doPreBlock func(ctx context.Context, txs []T) error doBeginBlock func(ctx context.Context) error @@ -42,8 +44,9 @@ type STF[T transaction.Tx] struct { // NewSTF returns a new STF instance. func NewSTF[T transaction.Tx]( - handleMsg func(ctx context.Context, msg transaction.Msg) (transaction.Msg, error), - handleQuery func(ctx context.Context, req transaction.Msg) (transaction.Msg, error), + logger log.Logger, + msgRouterBuilder *MsgRouterBuilder, + queryRouterBuilder *MsgRouterBuilder, doPreBlock func(ctx context.Context, txs []T) error, doBeginBlock func(ctx context.Context) error, doEndBlock func(ctx context.Context) error, @@ -51,20 +54,30 @@ func NewSTF[T transaction.Tx]( doValidatorUpdate func(ctx context.Context) ([]appmodulev2.ValidatorUpdate, error), postTxExec func(ctx context.Context, tx T, success bool) error, branch func(store store.ReaderMap) store.WriterMap, -) *STF[T] { +) (*STF[T], error) { + msgRouter, err := msgRouterBuilder.Build() + if err != nil { + return nil, fmt.Errorf("build msg router: %w", err) + } + queryRouter, err := queryRouterBuilder.Build() + if err != nil { + return nil, fmt.Errorf("build query router: %w", err) + } + return &STF[T]{ - handleMsg: handleMsg, - handleQuery: handleQuery, + logger: logger, + msgRouter: msgRouter, + queryRouter: queryRouter, doPreBlock: doPreBlock, doBeginBlock: doBeginBlock, doEndBlock: doEndBlock, - doTxValidation: doTxValidation, doValidatorUpdate: doValidatorUpdate, + doTxValidation: doTxValidation, postTxExec: postTxExec, // TODO branchFn: branch, makeGasMeter: stfgas.DefaultGasMeter, makeGasMeteredState: stfgas.DefaultWrapWithGasMeter, - } + }, nil } // DeliverBlock is our state transition function. @@ -310,7 +323,7 @@ func (s STF[T]) runTxMsgs( execCtx.setGasLimit(gasLimit) for i, msg := range msgs { execCtx.sender = txSenders[i] - resp, err := s.handleMsg(execCtx, msg) + resp, err := s.msgRouter.InvokeUntyped(execCtx, msg) if err != nil { return nil, 0, nil, fmt.Errorf("message execution at index %d failed: %w", i, err) } @@ -346,7 +359,7 @@ func (s STF[T]) runConsensusMessages( ) ([]transaction.Msg, error) { responses := make([]transaction.Msg, len(messages)) for i := range messages { - resp, err := s.handleMsg(ctx, messages[i]) + resp, err := s.msgRouter.InvokeUntyped(ctx, messages[i]) if err != nil { return nil, err } @@ -498,11 +511,7 @@ func (s STF[T]) Query( queryCtx := s.makeContext(ctx, nil, queryState, internal.ExecModeSimulate) queryCtx.setHeaderInfo(hi) queryCtx.setGasLimit(gasLimit) - return s.handleQuery(queryCtx, req) -} - -func (s STF[T]) Message(ctx context.Context, msg transaction.Msg) (response transaction.Msg, err error) { - return s.handleMsg(ctx, msg) + return s.queryRouter.InvokeUntyped(queryCtx, req) } // RunWithCtx is made to support genesis, if genesis was just the execution of messages instead @@ -521,8 +530,9 @@ func (s STF[T]) RunWithCtx( // clone clones STF. func (s STF[T]) clone() STF[T] { return STF[T]{ - handleMsg: s.handleMsg, - handleQuery: s.handleQuery, + logger: s.logger, + msgRouter: s.msgRouter, + queryRouter: s.queryRouter, doPreBlock: s.doPreBlock, doBeginBlock: s.doBeginBlock, doEndBlock: s.doEndBlock, @@ -558,6 +568,9 @@ type executionContext struct { branchFn branchFn makeGasMeter makeGasMeterFn makeGasMeteredStore makeGasMeteredStateFn + + msgRouter router.Service + queryRouter router.Service } // setHeaderInfo sets the header info in the state to be used by queries in the future. @@ -599,6 +612,8 @@ func (s STF[T]) makeContext( sender, store, execMode, + s.msgRouter, + s.queryRouter, ) } @@ -610,6 +625,8 @@ func newExecutionContext( sender transaction.Identity, state store.WriterMap, execMode transaction.ExecMode, + msgRouter Router, + queryRouter Router, ) *executionContext { meter := makeGasMeterFn(gas.NoGasLimit) meteredState := makeGasMeteredStoreFn(meter, state) @@ -626,6 +643,8 @@ func newExecutionContext( branchFn: branchFn, makeGasMeter: makeGasMeterFn, makeGasMeteredStore: makeGasMeteredStoreFn, + msgRouter: msgRouter, + queryRouter: queryRouter, } } diff --git a/server/v2/stf/stf_router.go b/server/v2/stf/stf_router.go index 57e8fbfb9ede..8489e16a261e 100644 --- a/server/v2/stf/stf_router.go +++ b/server/v2/stf/stf_router.go @@ -4,11 +4,13 @@ import ( "context" "errors" "fmt" + "reflect" gogoproto "github.com/cosmos/gogoproto/proto" - "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/runtime/protoiface" appmodulev2 "cosmossdk.io/core/appmodule/v2" + "cosmossdk.io/core/router" ) var ErrNoHandler = errors.New("no handler") @@ -60,7 +62,7 @@ func (b *MsgRouterBuilder) HandlerExists(msgType string) bool { return ok } -func (b *MsgRouterBuilder) Build() (appmodulev2.Handler, error) { +func (b *MsgRouterBuilder) Build() (Router, error) { handlers := make(map[string]appmodulev2.Handler) globalPreHandler := func(ctx context.Context, msg appmodulev2.Message) error { @@ -92,14 +94,8 @@ func (b *MsgRouterBuilder) Build() (appmodulev2.Handler, error) { handlers[msgType] = buildHandler(handler, preHandlers, globalPreHandler, postHandlers, globalPostHandler) } - // return handler as function - return func(ctx context.Context, msg appmodulev2.Message) (appmodulev2.Message, error) { - typeName := msgTypeURL(msg) - handler, exists := handlers[typeName] - if !exists { - return nil, fmt.Errorf("%w: %s", ErrNoHandler, typeName) - } - return handler(ctx, msg) + return Router{ + handlers: handlers, }, nil } @@ -141,9 +137,42 @@ func buildHandler( // msgTypeURL returns the TypeURL of a proto message. func msgTypeURL(msg gogoproto.Message) string { - if m, ok := msg.(proto.Message); ok { - return string(m.ProtoReflect().Descriptor().FullName()) + return gogoproto.MessageName(msg) +} + +var _ router.Service = (*Router)(nil) + +// Router implements the STF router for msg and query handlers. +type Router struct { + handlers map[string]appmodulev2.Handler +} + +func (r Router) CanInvoke(_ context.Context, typeURL string) error { + _, exists := r.handlers[typeURL] + if !exists { + return fmt.Errorf("%w: %s", ErrNoHandler, typeURL) } + return nil +} - return gogoproto.MessageName(msg) +func (r Router) InvokeTyped(ctx context.Context, req, resp protoiface.MessageV1) error { + handlerResp, err := r.InvokeUntyped(ctx, req) + if err != nil { + return err + } + merge(resp, handlerResp) + return nil +} + +func merge(src, dst protoiface.MessageV1) { + reflect.Indirect(reflect.ValueOf(dst)).Set(reflect.Indirect(reflect.ValueOf(src))) +} + +func (r Router) InvokeUntyped(ctx context.Context, req protoiface.MessageV1) (res protoiface.MessageV1, err error) { + typeName := msgTypeURL(req) + handler, exists := r.handlers[typeName] + if !exists { + return nil, fmt.Errorf("%w: %s", ErrNoHandler, typeName) + } + return handler(ctx, req) } diff --git a/server/v2/stf/stf_test.go b/server/v2/stf/stf_test.go index 9e030dd52c0a..bdd057313bd1 100644 --- a/server/v2/stf/stf_test.go +++ b/server/v2/stf/stf_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/cosmos/gogoproto/proto" "github.com/stretchr/testify/require" "google.golang.org/protobuf/types/known/wrapperspb" @@ -14,29 +15,56 @@ import ( appmodulev2 "cosmossdk.io/core/appmodule/v2" coregas "cosmossdk.io/core/gas" "cosmossdk.io/core/store" - "cosmossdk.io/core/transaction" "cosmossdk.io/server/v2/stf/branch" "cosmossdk.io/server/v2/stf/gas" "cosmossdk.io/server/v2/stf/mock" ) +func addMsgHandlerToSTF[T any, PT interface { + *T + proto.Message +}, + U any, UT interface { + *U + proto.Message + }]( + t *testing.T, + stf *STF[mock.Tx], + handler func(ctx context.Context, msg PT) (UT, error), +) { + t.Helper() + msgRouterBuilder := NewMsgRouterBuilder() + err := msgRouterBuilder.RegisterHandler( + msgTypeURL(PT(new(T))), + func(ctx context.Context, msg appmodulev2.Message) (msgResp appmodulev2.Message, err error) { + typedReq := msg.(PT) + typedResp, err := handler(ctx, typedReq) + if err != nil { + return nil, err + } + + return typedResp, nil + }, + ) + require.NoError(t, err) + + msgRouter, err := msgRouterBuilder.Build() + require.NoError(t, err) + stf.msgRouter = msgRouter +} + func TestSTF(t *testing.T) { state := mock.DB() mockTx := mock.Tx{ Sender: []byte("sender"), - Msg: wrapperspb.Bool(true), // msg does not matter at all because our handler does nothing. + Msg: wrapperspb.Bool(true), GasLimit: 100_000, } sum := sha256.Sum256([]byte("test-hash")) s := &STF[mock.Tx]{ - handleMsg: func(ctx context.Context, msg transaction.Msg) (msgResp transaction.Msg, err error) { - kvSet(t, ctx, "exec") - return nil, nil - }, - handleQuery: nil, - doPreBlock: func(ctx context.Context, txs []mock.Tx) error { return nil }, + doPreBlock: func(ctx context.Context, txs []mock.Tx) error { return nil }, doBeginBlock: func(ctx context.Context) error { kvSet(t, ctx, "begin-block") return nil @@ -59,6 +87,11 @@ func TestSTF(t *testing.T) { makeGasMeteredState: gas.DefaultWrapWithGasMeter, } + addMsgHandlerToSTF(t, s, func(ctx context.Context, msg *wrapperspb.BoolValue) (*wrapperspb.BoolValue, error) { + kvSet(t, ctx, "exec") + return nil, nil + }) + t.Run("begin and end block", func(t *testing.T) { _, newState, err := s.DeliverBlock(context.Background(), &appmanager.BlockRequest[mock.Tx]{ Height: uint64(1), @@ -124,9 +157,9 @@ func TestSTF(t *testing.T) { t.Run("fail exec tx", func(t *testing.T) { // update the stf to fail on the handler s := s.clone() - s.handleMsg = func(ctx context.Context, msg transaction.Msg) (msgResp transaction.Msg, err error) { + addMsgHandlerToSTF(t, &s, func(ctx context.Context, msg *wrapperspb.BoolValue) (*wrapperspb.BoolValue, error) { return nil, fmt.Errorf("failure") - } + }) blockResult, newState, err := s.DeliverBlock(context.Background(), &appmanager.BlockRequest[mock.Tx]{ Height: uint64(1), @@ -167,9 +200,9 @@ func TestSTF(t *testing.T) { t.Run("tx failed and post tx failed", func(t *testing.T) { s := s.clone() - s.handleMsg = func(ctx context.Context, msg transaction.Msg) (msgResp transaction.Msg, err error) { + addMsgHandlerToSTF(t, &s, func(ctx context.Context, msg *wrapperspb.BoolValue) (*wrapperspb.BoolValue, error) { return nil, fmt.Errorf("exec failure") - } + }) s.postTxExec = func(ctx context.Context, tx mock.Tx, success bool) error { return fmt.Errorf("post tx failure") } blockResult, newState, err := s.DeliverBlock(context.Background(), &appmanager.BlockRequest[mock.Tx]{ Height: uint64(1), diff --git a/store/rootmulti/store_test.go b/store/rootmulti/store_test.go index 0e305eeefffe..baa24a0625d9 100644 --- a/store/rootmulti/store_test.go +++ b/store/rootmulti/store_test.go @@ -547,7 +547,6 @@ func TestMultiStore_Pruning(t *testing.T) { _, err := ms.CacheMultiStoreWithVersion(v) require.NoError(t, err, "expected no error when loading height: %d", v) } - }) } }