Skip to content

Commit

Permalink
Refactor Wasm Plugin Framework (#586)
Browse files Browse the repository at this point in the history
* feat: update wasm plugin

* feat: update wasm plugin

* feat: update wasm plugin

* feat: update wasm plugin

* feat: update wasm plugin

* feat: update wasm plugin

* feat: update wasm plugin

* feat: update wasm plugin

* feat: update wasm plugin

* feat: update wasm plugin

* feat: update wasm plugin

* feat: remove skip_filter

* feat: fix testcase

* feat: use gauge instead of counter for wasm memory size
  • Loading branch information
earayu authored Nov 26, 2024
1 parent c1c5091 commit 455763a
Show file tree
Hide file tree
Showing 17 changed files with 124 additions and 717 deletions.
4 changes: 0 additions & 4 deletions go/vt/vttablet/tabletserver/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@ type ActionInterface interface {
SetParams(args ActionArgs) error

GetRule() *rules.Rule

GetSkipFlag() bool

SetSkipFlag(skip bool)
}

type ActionArgs interface {
Expand Down
146 changes: 3 additions & 143 deletions go/vt/vttablet/tabletserver/action_builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package tabletserver
import (
"context"
"fmt"
"regexp"
"time"

"github.com/BurntSushi/toml"
Expand All @@ -20,8 +19,6 @@ type ContinueAction struct {

// Action is the action to take if the rule matches
Action rules.Action

skipFlag bool
}

func (p *ContinueAction) BeforeExecution(_ *QueryExecutor) *ActionExecutionResponse {
Expand Down Expand Up @@ -55,16 +52,6 @@ type FailAction struct {

// Action is the action to take if the rule matches
Action rules.Action

skipFlag bool
}

func (p *ContinueAction) GetSkipFlag() bool {
return p.skipFlag
}

func (p *ContinueAction) SetSkipFlag(skip bool) {
p.skipFlag = skip
}

func (p *FailAction) BeforeExecution(_ *QueryExecutor) *ActionExecutionResponse {
Expand Down Expand Up @@ -93,21 +80,11 @@ func (p *FailAction) GetRule() *rules.Rule {
return p.Rule
}

func (p *FailAction) GetSkipFlag() bool {
return p.skipFlag
}

func (p *FailAction) SetSkipFlag(skip bool) {
p.skipFlag = skip
}

type FailRetryAction struct {
Rule *rules.Rule

// Action is the action to take if the rule matches
Action rules.Action

skipFlag bool
}

func (p *FailRetryAction) BeforeExecution(_ *QueryExecutor) *ActionExecutionResponse {
Expand Down Expand Up @@ -136,21 +113,11 @@ func (p *FailRetryAction) GetRule() *rules.Rule {
return p.Rule
}

func (p *FailRetryAction) GetSkipFlag() bool {
return p.skipFlag
}

func (p *FailRetryAction) SetSkipFlag(skip bool) {
p.skipFlag = skip
}

type BufferAction struct {
Rule *rules.Rule

// Action is the action to take if the rule matches
Action rules.Action

skipFlag bool
}

func (p *BufferAction) BeforeExecution(qre *QueryExecutor) *ActionExecutionResponse {
Expand Down Expand Up @@ -198,23 +165,13 @@ func (p *BufferAction) GetRule() *rules.Rule {
return p.Rule
}

func (p *BufferAction) GetSkipFlag() bool {
return p.skipFlag
}

func (p *BufferAction) SetSkipFlag(skip bool) {
p.skipFlag = skip
}

type ConcurrencyControlAction struct {
Rule *rules.Rule

// Action is the action to take if the rule matches
Action rules.Action

Args *ConcurrencyControlActionArgs

skipFlag bool
}

type ConcurrencyControlActionArgs struct {
Expand Down Expand Up @@ -292,23 +249,13 @@ func (p *ConcurrencyControlAction) GetRule() *rules.Rule {
return p.Rule
}

func (p *ConcurrencyControlAction) GetSkipFlag() bool {
return p.skipFlag
}

func (p *ConcurrencyControlAction) SetSkipFlag(skip bool) {
p.skipFlag = skip
}

type WasmPluginAction struct {
Rule *rules.Rule

// Action is the action to take if the rule matches
Action rules.Action

Args *WasmPluginActionArgs

skipFlag bool
}

type WasmPluginActionArgs struct {
Expand Down Expand Up @@ -337,20 +284,21 @@ func (args *WasmPluginActionArgs) Parse(stringParams string) (ActionArgs, error)
func (p *WasmPluginAction) BeforeExecution(qre *QueryExecutor) *ActionExecutionResponse {
controller := qre.tsv.qe.wasmPluginController

ok, module := controller.VM.GetWasmModule(p.Args.WasmBinaryName)
ok, module := controller.VM.GetWasmModule(p.GetRule().Name)
if !ok {
wasmBytes, err := controller.GetWasmBytesByBinaryName(qre.ctx, p.Args.WasmBinaryName)
if err != nil {
return &ActionExecutionResponse{Err: err}
}
module, err = controller.VM.InitWasmModule(p.Args.WasmBinaryName, wasmBytes)
module, err = controller.VM.InitWasmModule(p.GetRule().Name, wasmBytes)
if err != nil {
return &ActionExecutionResponse{Err: err}
}
}

instance, err := module.NewInstance(qre)
if err != nil {
//todo wasm: if instance is nil, we will not be able to get the it in AfterExecution. We need to handle this case
return &ActionExecutionResponse{Err: err}
}

Expand Down Expand Up @@ -399,91 +347,3 @@ func (p *WasmPluginAction) SetParams(args ActionArgs) error {
func (p *WasmPluginAction) GetRule() *rules.Rule {
return p.Rule
}

func (p *WasmPluginAction) GetSkipFlag() bool {
return p.skipFlag
}

func (p *WasmPluginAction) SetSkipFlag(skip bool) {
p.skipFlag = skip
}

type SkipFilterAction struct {
Rule *rules.Rule

// Action is the action to take if the rule matches
Action rules.Action

Args *SkipFilterActionArgs

skipFlag bool
}

type SkipFilterActionArgs struct {
AllowRegexString string `toml:"skip_filter_regex"`
AllowRegex *regexp.Regexp
}

func (args *SkipFilterActionArgs) Parse(stringParams string) (ActionArgs, error) {
s := &SkipFilterActionArgs{}
if stringParams == "" {
return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "stringParams is empty when parsing skip filter action args")
}

userInputTOML := ConvertUserInputToTOML(stringParams)
err := toml.Unmarshal([]byte(userInputTOML), s)
if err != nil {
return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "error when parsing skip filter action args: %v", err)
}
s.AllowRegex, err = regexp.Compile(fmt.Sprintf("^%s$", s.AllowRegexString))
if err != nil {
return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "error when compiling skip filter action args: %v", err)
}

return s, nil
}

func (p *SkipFilterAction) BeforeExecution(qre *QueryExecutor) *ActionExecutionResponse {
findSelf := false
for _, a := range qre.matchedActionList {
if a.GetRule().Name == p.GetRule().Name {
findSelf = true
continue
}
if findSelf {
if p.Args.AllowRegex.MatchString(a.GetRule().Name) {
a.SetSkipFlag(true)
}
}
}
return &ActionExecutionResponse{Err: nil}
}

func (p *SkipFilterAction) AfterExecution(qre *QueryExecutor, reply *sqltypes.Result, err error) *ActionExecutionResponse {
return &ActionExecutionResponse{Reply: reply, Err: err}
}

func (p *SkipFilterAction) ParseParams(argsStr string) (ActionArgs, error) {
return p.Args.Parse(argsStr)
}

func (p *SkipFilterAction) SetParams(args ActionArgs) error {
skipFilterArgs, ok := args.(*SkipFilterActionArgs)
if !ok {
return fmt.Errorf("args :%v is not a valid SkipFilterActionArgs)", args)
}
p.Args = skipFilterArgs
return nil
}

func (p *SkipFilterAction) GetRule() *rules.Rule {
return p.Rule
}

func (p *SkipFilterAction) GetSkipFlag() bool {
return p.skipFlag
}

func (p *SkipFilterAction) SetSkipFlag(skip bool) {
p.skipFlag = skip
}
2 changes: 0 additions & 2 deletions go/vt/vttablet/tabletserver/action_factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ func CreateActionInstance(action rules.Action, rule *rules.Rule) (ActionInterfac
actInst, err = &ConcurrencyControlAction{Rule: rule, Action: action}, nil
case rules.QRWasmPlugin:
actInst, err = &WasmPluginAction{Rule: rule, Action: action}, nil
case rules.QRSkipFilter:
actInst, err = &SkipFilterAction{Rule: rule, Action: action}, nil
default:
log.Errorf("unknown action: %v", action)
actInst, err = nil, fmt.Errorf("unknown action: %v", action)
Expand Down
26 changes: 26 additions & 0 deletions go/vt/vttablet/tabletserver/action_stats.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package tabletserver

import (
"time"
"vitess.io/vitess/go/stats"
"vitess.io/vitess/go/vt/servenv"
)

type ActionStats struct {
FilterBeforeExecutionTiming *servenv.TimingsWrapper
FilterAfterExecutionTiming *servenv.TimingsWrapper
FilterErrorCounts *stats.CountersWithSingleLabel
FilterQPSRates *stats.Rates
FilterWasmMemorySize *stats.GaugesWithMultiLabels
}

func NewActionStats(exporter *servenv.Exporter) *ActionStats {
stats := &ActionStats{
FilterBeforeExecutionTiming: exporter.NewTimings("FilterBeforeExecution", "Filter before execution timings", "Name"),
FilterAfterExecutionTiming: exporter.NewTimings("FilterAfterExecution", "Filter before execution timings", "Name"),
FilterErrorCounts: exporter.NewCountersWithSingleLabel("FilterErrorCounts", "filter error counts", "Name"),
FilterWasmMemorySize: exporter.NewGaugesWithMultiLabels("FilterWasmMemorySize", "Wasm memory size", []string{"Name", "BeforeOrAfter"}),
}
stats.FilterQPSRates = exporter.NewRates("FilterQps", stats.FilterBeforeExecutionTiming, 15*60/5, 5*time.Second)
return stats
}
4 changes: 4 additions & 0 deletions go/vt/vttablet/tabletserver/query_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ type QueryEngine struct {

// stats
queryCounts, queryTimes, queryErrorCounts, queryRowsAffected, queryRowsReturned *stats.CountersWithMultiLabels
// actionStats for filters
actionStats *ActionStats

// Loggers
accessCheckerLogger *logutil.ThrottledLogger
Expand Down Expand Up @@ -281,6 +283,8 @@ func NewQueryEngine(env tabletenv.Env, se *schema.Engine) *QueryEngine {
qe.queryRowsReturned = env.Exporter().NewCountersWithMultiLabels("QueryRowsReturned", "query rows returned", []string{"Table", "Plan"})
qe.queryErrorCounts = env.Exporter().NewCountersWithMultiLabels("QueryErrorCounts", "query error counts", []string{"Table", "Plan"})

qe.actionStats = NewActionStats(env.Exporter())

env.Exporter().HandleFunc("/debug/ccl", qe.concurrencyController.ServeHTTP)
env.Exporter().HandleFunc("/debug/hotrows", qe.txSerializer.ServeHTTP)
env.Exporter().HandleFunc("/debug/tablet_plans", qe.handleHTTPQueryPlans)
Expand Down
22 changes: 16 additions & 6 deletions go/vt/vttablet/tabletserver/query_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -611,12 +611,16 @@ func (qre *QueryExecutor) runActionListBeforeExecution() (*sqltypes.Result, erro
return nil, nil
}
for _, a := range qre.matchedActionList {
if !a.GetSkipFlag() {
resp := a.BeforeExecution(qre)
qre.calledActionList = append(qre.calledActionList, a)
if resp.Reply != nil || resp.Err != nil {
return resp.Reply, resp.Err
}
startTime := time.Now()
// execute the filter action
resp := a.BeforeExecution(qre)
qre.tsv.qe.actionStats.FilterBeforeExecutionTiming.Add(a.GetRule().Name, time.Since(startTime))
qre.calledActionList = append(qre.calledActionList, a)
if resp.Err != nil {
qre.tsv.qe.actionStats.FilterErrorCounts.Add(a.GetRule().Name, 1)
}
if resp.Reply != nil || resp.Err != nil {
return resp.Reply, resp.Err
}
}
return nil, nil
Expand All @@ -631,7 +635,13 @@ func (qre *QueryExecutor) runActionListAfterExecution(reply *sqltypes.Result, er

for i := len(qre.calledActionList) - 1; i >= 0; i-- {
a := qre.calledActionList[i]
startTime := time.Now()
// execute the filter action
resp := a.AfterExecution(qre, newReply, newErr)
qre.tsv.qe.actionStats.FilterAfterExecutionTiming.Add(a.GetRule().Name, time.Since(startTime))
if resp.Err != nil {
qre.tsv.qe.actionStats.FilterErrorCounts.Add(a.GetRule().Name, 1)
}
newReply, newErr = resp.Reply, resp.Err
}
return newReply, newErr
Expand Down
15 changes: 12 additions & 3 deletions go/vt/vttablet/tabletserver/query_executor_filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ func TestQueryExecutor_runActionListBeforeExecution(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
qre := &QueryExecutor{ctx: ctx}
db := setUpQueryExecutorTest(t)
defer db.Close()
tsv := newTestTabletServer(ctx, noFlags, db)
qre := newTestQueryExecutor(ctx, tsv, "select 1", 0)
qre.matchedActionList = tt.actionList
_, err := qre.runActionListBeforeExecution()
tt.wantErr(t, err, "runActionListBeforeExecution()")
Expand Down Expand Up @@ -129,7 +132,10 @@ func TestQueryExecutor_runActionListAfterExecution(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
qre := &QueryExecutor{ctx: ctx}
db := setUpQueryExecutorTest(t)
defer db.Close()
tsv := newTestTabletServer(ctx, noFlags, db)
qre := newTestQueryExecutor(ctx, tsv, "select 1", 0)
qre.matchedActionList = tt.actionList
qr := &sqltypes.Result{}
var err error
Expand Down Expand Up @@ -162,7 +168,10 @@ func TestQueryExecutor_actions_can_be_skipped(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
qre := &QueryExecutor{ctx: ctx}
db := setUpQueryExecutorTest(t)
defer db.Close()
tsv := newTestTabletServer(ctx, noFlags, db)
qre := newTestQueryExecutor(ctx, tsv, "select 1", 0)
qre.matchedActionList = tt.actionList
qr, err := qre.runActionListBeforeExecution()
tt.wantErr(t, err, "runActionListBeforeExecution()")
Expand Down
Loading

0 comments on commit 455763a

Please sign in to comment.