Skip to content

Commit

Permalink
avoid memory allocation in context
Browse files Browse the repository at this point in the history
Signed-off-by: Yang Keao <yangkeao@chunibyo.icu>
  • Loading branch information
YangKeao committed Dec 8, 2023
1 parent 90e272a commit 3780ce7
Show file tree
Hide file tree
Showing 19 changed files with 143 additions and 78 deletions.
2 changes: 2 additions & 0 deletions pkg/errctx/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/errno",
"//pkg/util/context",
"//pkg/util/intest",
"@com_github_pingcap_errors//:errors",
],
Expand All @@ -20,6 +21,7 @@ go_test(
deps = [
":errctx",
"//pkg/types",
"//pkg/util/context",
"@com_github_pingcap_errors//:errors",
"@com_github_stretchr_testify//require",
"@org_uber_go_multierr//:multierr",
Expand Down
35 changes: 17 additions & 18 deletions pkg/errctx/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package errctx
import (
"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/errno"
contextutil "github.com/pingcap/tidb/pkg/util/context"
"github.com/pingcap/tidb/pkg/util/intest"
)

Expand All @@ -34,14 +35,14 @@ const (

// Context defines how to handle an error
type Context struct {
levelMap [errGroupCount]Level
appendWarningFn func(err error)
levelMap [errGroupCount]Level
warnHandler contextutil.WarnHandler
}

// WithStrictErrGroupLevel makes the context to return the error directly for any kinds of errors.
func (ctx *Context) WithStrictErrGroupLevel() Context {
newCtx := Context{
appendWarningFn: ctx.appendWarningFn,
warnHandler: ctx.warnHandler,
}

return newCtx
Expand All @@ -50,24 +51,24 @@ func (ctx *Context) WithStrictErrGroupLevel() Context {
// WithErrGroupLevel sets a `Level` for an `ErrGroup`
func (ctx *Context) WithErrGroupLevel(eg ErrGroup, l Level) Context {
newCtx := Context{
levelMap: ctx.levelMap,
appendWarningFn: ctx.appendWarningFn,
levelMap: ctx.levelMap,
warnHandler: ctx.warnHandler,
}
newCtx.levelMap[eg] = l

return newCtx
}

// appendWarning appends the error to warning. If the inner `appendWarningFn` is nil, do nothing.
// appendWarning appends the error to warning. If the inner `warnHandler` is nil, do nothing.
func (ctx *Context) appendWarning(err error) {
intest.Assert(ctx.appendWarningFn != nil)
if fn := ctx.appendWarningFn; fn != nil {
// appendWarningFn should always not be nil, check fn != nil here to just make code safe.
fn(err)
intest.Assert(ctx.warnHandler != nil)
if w := ctx.warnHandler; w != nil {
// warnHandler should always not be nil, check fn != nil here to just make code safe.
w.AppendWarning(err)
}
}

// HandleError handles the error according to the context. See the comment of `HandleErrorWithAlias` for detailed logic.
// HandleError handles the error according to the contextutil. See the comment of `HandleErrorWithAlias` for detailed logic.
//
// It also allows using `errors.ErrorGroup`, in this case, it'll handle each error in order, and return the first error
// it founds.
Expand All @@ -92,7 +93,7 @@ func (ctx *Context) HandleError(err error) error {
return ctx.HandleErrorWithAlias(err, err, err)
}

// HandleErrorWithAlias handles the error according to the context.
// HandleErrorWithAlias handles the error according to the contextutil.
// 1. If the `internalErr` is not `"pingcap/errors".Error`, or the error code is not defined in the `errGroupMap`, or the error
// level is set to `LevelError`(0), the `err` will be returned directly.
// 2. If the error level is set to `LevelWarn`, the `warnErr` will be appended as a warning.
Expand Down Expand Up @@ -134,17 +135,15 @@ func (ctx *Context) HandleErrorWithAlias(internalErr error, err error, warnErr e
}

// NewContext creates an error context to handle the errors and warnings
func NewContext(appendWarningFn func(err error)) Context {
intest.Assert(appendWarningFn != nil)
func NewContext(handler contextutil.WarnHandler) Context {
intest.Assert(handler != nil)
return Context{
appendWarningFn: appendWarningFn,
warnHandler: handler,
}
}

// StrictNoWarningContext returns all errors directly, and ignore all errors
var StrictNoWarningContext = NewContext(func(_ error) {
// the error is ignored
})
var StrictNoWarningContext = NewContext(contextutil.IgnoreWarn)

var errGroupMap = make(map[errors.ErrCode]ErrGroup)

Expand Down
5 changes: 3 additions & 2 deletions pkg/errctx/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@ import (
"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/errctx"
"github.com/pingcap/tidb/pkg/types"
contextutil "github.com/pingcap/tidb/pkg/util/context"
"github.com/stretchr/testify/require"
"go.uber.org/multierr"
)

func TestContext(t *testing.T) {
var warn error
ctx := errctx.NewContext(func(err error) {
ctx := errctx.NewContext(contextutil.NewFuncWarnHandlerForTest(func(err error) {
warn = err
})
}))

testInternalErr := types.ErrOverflow
testErr := errors.New("error")
Expand Down
2 changes: 2 additions & 0 deletions pkg/server/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ go_library(
"//pkg/util",
"//pkg/util/arena",
"//pkg/util/chunk",
"//pkg/util/context",
"//pkg/util/cpuprofile",
"//pkg/util/dbterror",
"//pkg/util/dbterror/exeerrors",
Expand Down Expand Up @@ -181,6 +182,7 @@ go_test(
"//pkg/util",
"//pkg/util/arena",
"//pkg/util/chunk",
"//pkg/util/context",
"//pkg/util/dbterror/exeerrors",
"//pkg/util/replayer",
"//pkg/util/sqlkiller",
Expand Down
5 changes: 3 additions & 2 deletions pkg/server/conn_stmt_params_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/pingcap/tidb/pkg/testkit"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
contextutil "github.com/pingcap/tidb/pkg/util/context"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -268,9 +269,9 @@ func TestParseExecArgs(t *testing.T) {
}
for _, tt := range tests {
var warn error
typectx := types.NewContext(types.DefaultStmtFlags.WithTruncateAsWarning(true), time.UTC, func(err error) {
typectx := types.NewContext(types.DefaultStmtFlags.WithTruncateAsWarning(true), time.UTC, contextutil.NewFuncWarnHandlerForTest(func(err error) {
warn = err
})
}))
err := decodeAndParse(typectx, tt.args.args, tt.args.boundParams, tt.args.nullBitmap, tt.args.paramTypes, tt.args.paramValues, nil)
require.Truef(t, terror.ErrorEqual(err, tt.err), "err %v", err)
require.Truef(t, terror.ErrorEqual(warn, tt.warn), "warn %v", warn)
Expand Down
5 changes: 2 additions & 3 deletions pkg/server/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/pingcap/tidb/pkg/sessionctx/stmtctx"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
"github.com/pingcap/tidb/pkg/types"
contextutil "github.com/pingcap/tidb/pkg/util/context"
)

func (cc *clientConn) onExtensionConnEvent(tp extension.ConnEventTp, err error) {
Expand Down Expand Up @@ -89,9 +90,7 @@ func (cc *clientConn) onExtensionStmtEnd(node interface{}, stmtCtxValid bool, er
// TODO: the `BinaryParam` is parsed two times: one in the `Execute` method and one here. It would be better to
// eliminate one of them by storing the parsed result.
typectx := ctx.GetSessionVars().StmtCtx.TypeCtx()
typectx = types.NewContext(typectx.Flags(), typectx.Location(), func(_ error) {
// ignore all warnings
})
typectx = types.NewContext(typectx.Flags(), typectx.Location(), contextutil.IgnoreWarn)
params, _ := param.ExecArgs(typectx, args)
info.executeStmt = &ast.ExecuteStmt{
PrepStmt: prepared,
Expand Down
1 change: 1 addition & 0 deletions pkg/server/internal/column/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ go_test(
"//pkg/server/internal/util",
"//pkg/types",
"//pkg/util/chunk",
"//pkg/util/context",
"@com_github_stretchr_testify//require",
],
)
3 changes: 2 additions & 1 deletion pkg/server/internal/column/column_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/pingcap/tidb/pkg/server/internal/util"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
contextutil "github.com/pingcap/tidb/pkg/util/context"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -182,7 +183,7 @@ func TestDumpTextValue(t *testing.T) {

losAngelesTz, err := time.LoadLocation("America/Los_Angeles")
require.NoError(t, err)
typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), losAngelesTz, func(err error) {})
typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), losAngelesTz, contextutil.IgnoreWarn)

time, err := types.ParseTime(typeCtx, "2017-01-05 23:59:59.575601", mysql.TypeDatetime, 0)
require.NoError(t, err)
Expand Down
8 changes: 4 additions & 4 deletions pkg/sessionctx/stmtctx/stmtctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -427,22 +427,22 @@ type StatementContext struct {
// NewStmtCtx creates a new statement context
func NewStmtCtx() *StatementContext {
sc := &StatementContext{}
sc.typeCtx = types.NewContext(types.DefaultStmtFlags, time.UTC, sc.AppendWarning)
sc.typeCtx = types.NewContext(types.DefaultStmtFlags, time.UTC, sc)
return sc
}

// NewStmtCtxWithTimeZone creates a new StatementContext with the given timezone
func NewStmtCtxWithTimeZone(tz *time.Location) *StatementContext {
intest.AssertNotNil(tz)
sc := &StatementContext{}
sc.typeCtx = types.NewContext(types.DefaultStmtFlags, tz, sc.AppendWarning)
sc.typeCtx = types.NewContext(types.DefaultStmtFlags, tz, sc)
return sc
}

// Reset resets a statement context
func (sc *StatementContext) Reset() {
*sc = StatementContext{
typeCtx: types.NewContext(types.DefaultStmtFlags, time.UTC, sc.AppendWarning),
typeCtx: types.NewContext(types.DefaultStmtFlags, time.UTC, sc),
}
}

Expand Down Expand Up @@ -470,7 +470,7 @@ func (sc *StatementContext) TypeCtx() types.Context {
// ErrCtx returns the error context
// TODO: add a cache to the `ErrCtx` if needed, though it's not a big burden to generate `ErrCtx` everytime.
func (sc *StatementContext) ErrCtx() errctx.Context {
ctx := errctx.NewContext(sc.AppendWarning)
ctx := errctx.NewContext(sc)

if sc.TypeFlags().IgnoreTruncateErr() {
ctx = ctx.WithErrGroupLevel(errctx.ErrGroupTruncate, errctx.LevelIgnore)
Expand Down
8 changes: 8 additions & 0 deletions pkg/sessionctx/stmtctx/stmtctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -386,3 +386,11 @@ func TestResetStmtCtx(t *testing.T) {
require.Equal(t, stmtctx.WarnLevelWarning, warnings[0].Level)
require.Equal(t, "err2", warnings[0].Err.Error())
}

func BenchmarkErrCtx(b *testing.B) {
sc := stmtctx.NewStmtCtx()

for i := 0; i < b.N; i++ {
sc.ErrCtx()
}
}
2 changes: 2 additions & 0 deletions pkg/types/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ go_library(
"//pkg/parser/terror",
"//pkg/parser/types",
"//pkg/util/collate",
"//pkg/util/context",
"//pkg/util/dbterror",
"//pkg/util/hack",
"//pkg/util/intest",
Expand Down Expand Up @@ -106,6 +107,7 @@ go_test(
"//pkg/parser/terror",
"//pkg/testkit/testsetup",
"//pkg/util/collate",
"//pkg/util/context",
"//pkg/util/hack",
"@com_github_pingcap_errors//:errors",
"@com_github_stretchr_testify//assert",
Expand Down
44 changes: 19 additions & 25 deletions pkg/types/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package types
import (
"time"

contextutil "github.com/pingcap/tidb/pkg/util/context"
"github.com/pingcap/tidb/pkg/util/intest"
)

Expand Down Expand Up @@ -197,18 +198,18 @@ func (f Flags) WithCastTimeToYearThroughConcat(flag bool) Flags {

// Context provides the information when converting between different types.
type Context struct {
flags Flags
loc *time.Location
appendWarningFn func(err error)
flags Flags
loc *time.Location
warnHandler contextutil.WarnHandler
}

// NewContext creates a new `Context`
func NewContext(flags Flags, loc *time.Location, appendWarningFn func(err error)) Context {
intest.Assert(loc != nil && appendWarningFn != nil)
func NewContext(flags Flags, loc *time.Location, handler contextutil.WarnHandler) Context {
intest.Assert(loc != nil && handler != nil)
return Context{
flags: flags,
loc: loc,
appendWarningFn: appendWarningFn,
flags: flags,
loc: loc,
warnHandler: handler,
}
}

Expand Down Expand Up @@ -242,32 +243,25 @@ func (c *Context) Location() *time.Location {
return c.loc
}

// AppendWarning appends the error to warning. If the inner `appendWarningFn` is nil, do nothing.
// AppendWarning appends the error to warning. If the inner `warnHandler` is nil, do nothing.
func (c *Context) AppendWarning(err error) {
intest.Assert(c.appendWarningFn != nil)
if fn := c.appendWarningFn; fn != nil {
// appendWarningFn should always not be nil, check fn != nil here to just make code safe.
fn(err)
intest.Assert(c.warnHandler != nil)
if w := c.warnHandler; w != nil {
// warnHandler should always not be nil, check fn != nil here to just make code safe.
w.AppendWarning(err)
}
}

// AppendWarningFunc returns the inner `appendWarningFn`
func (c *Context) AppendWarningFunc() func(err error) {
return c.appendWarningFn
}

// DefaultStmtFlags is the default flags for statement context with the flag `FlagAllowNegativeToUnsigned` set.
// TODO: make DefaultStmtFlags to be equal with StrictFlags, and setting flag `FlagAllowNegativeToUnsigned`
// is only for make the code to be equivalent with the old implement during refactoring.
const DefaultStmtFlags = StrictFlags | FlagAllowNegativeToUnsigned | FlagIgnoreZeroDateErr

// DefaultStmtNoWarningContext is the context with default statement flags without any other special configuration
var DefaultStmtNoWarningContext = NewContext(DefaultStmtFlags, time.UTC, func(_ error) {
// the error is ignored
})
var DefaultStmtNoWarningContext = NewContext(DefaultStmtFlags, time.UTC, contextutil.IgnoreWarn)

// StrictContext is the most strict context which returns every error it meets
var StrictContext = NewContext(StrictFlags, time.UTC, func(_ error) {
// this context should never append warnings
// However, the implementation of `types` may still append some warnings. TODO: remove them in the future.
})
//
// this context should never append warnings
// However, the implementation of `types` may still append some warnings. TODO: remove them in the future.
var StrictContext = NewContext(StrictFlags, time.UTC, contextutil.IgnoreWarn)
3 changes: 2 additions & 1 deletion pkg/types/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ import (
"testing"
"time"

contextutil "github.com/pingcap/tidb/pkg/util/context"
"github.com/stretchr/testify/require"
)

func TestWithNewFlags(t *testing.T) {
ctx := NewContext(FlagSkipASCIICheck, time.UTC, func(_ error) {})
ctx := NewContext(FlagSkipASCIICheck, time.UTC, contextutil.IgnoreWarn)
ctx2 := ctx.WithFlags(FlagSkipUTF8Check)
require.Equal(t, FlagSkipASCIICheck, ctx.Flags())
require.Equal(t, FlagSkipUTF8Check, ctx2.Flags())
Expand Down
2 changes: 1 addition & 1 deletion pkg/types/convert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,7 @@ func TestGetValidInt(t *testing.T) {
{"123de", "123", true, true},
}
warnings := &warnStore{}
ctx := NewContext(DefaultStmtFlags.WithTruncateAsWarning(true), time.UTC, warnings.AppendWarning)
ctx := NewContext(DefaultStmtFlags.WithTruncateAsWarning(true), time.UTC, warnings)
warningCount := 0
for i, tt := range tests {
prefix, err := getValidIntPrefix(ctx, tt.origin, false)
Expand Down
2 changes: 1 addition & 1 deletion pkg/types/datum_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ func TestProduceDecWithSpecifiedTp(t *testing.T) {
{"-99.9999", 6, 3, "-100.000", false, true},
}
warnings := &warnStore{}
ctx := NewContext(DefaultStmtFlags, time.UTC, warnings.AppendWarning)
ctx := NewContext(DefaultStmtFlags, time.UTC, warnings)
for _, tt := range tests {
tp := NewFieldTypeBuilder().SetType(mysql.TypeNewDecimal).SetFlen(tt.flen).SetDecimal(tt.frac).BuildP()
dec := NewDecFromStringForTest(tt.dec)
Expand Down
Loading

0 comments on commit 3780ce7

Please sign in to comment.