Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

errctx, types, sessionctx: avoid memory allocation in HandleError and reduce allocation in creation of statement context #49280

Merged
merged 1 commit into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pkg/errctx/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/errno",
"//pkg/util/context",
"//pkg/util/intest",
"@com_github_pingcap_errors//:errors",
],
Expand All @@ -20,6 +21,7 @@ go_test(
deps = [
":errctx",
"//pkg/types",
"//pkg/util/context",
"@com_github_pingcap_errors//:errors",
"@com_github_stretchr_testify//require",
"@org_uber_go_multierr//:multierr",
Expand Down
35 changes: 17 additions & 18 deletions pkg/errctx/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package errctx
import (
"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/errno"
contextutil "github.com/pingcap/tidb/pkg/util/context"
"github.com/pingcap/tidb/pkg/util/intest"
)

Expand All @@ -34,14 +35,14 @@ const (

// Context defines how to handle an error
type Context struct {
levelMap [errGroupCount]Level
appendWarningFn func(err error)
levelMap [errGroupCount]Level
warnHandler contextutil.WarnHandler
}

// WithStrictErrGroupLevel makes the context to return the error directly for any kinds of errors.
func (ctx *Context) WithStrictErrGroupLevel() Context {
newCtx := Context{
appendWarningFn: ctx.appendWarningFn,
warnHandler: ctx.warnHandler,
}

return newCtx
Expand All @@ -50,24 +51,24 @@ func (ctx *Context) WithStrictErrGroupLevel() Context {
// WithErrGroupLevel sets a `Level` for an `ErrGroup`
func (ctx *Context) WithErrGroupLevel(eg ErrGroup, l Level) Context {
newCtx := Context{
levelMap: ctx.levelMap,
appendWarningFn: ctx.appendWarningFn,
levelMap: ctx.levelMap,
warnHandler: ctx.warnHandler,
}
newCtx.levelMap[eg] = l

return newCtx
}

// appendWarning appends the error to warning. If the inner `appendWarningFn` is nil, do nothing.
// appendWarning appends the error to warning. If the inner `warnHandler` is nil, do nothing.
func (ctx *Context) appendWarning(err error) {
intest.Assert(ctx.appendWarningFn != nil)
if fn := ctx.appendWarningFn; fn != nil {
// appendWarningFn should always not be nil, check fn != nil here to just make code safe.
fn(err)
intest.Assert(ctx.warnHandler != nil)
if w := ctx.warnHandler; w != nil {
// warnHandler should always not be nil, check fn != nil here to just make code safe.
w.AppendWarning(err)
}
}

// HandleError handles the error according to the context. See the comment of `HandleErrorWithAlias` for detailed logic.
// HandleError handles the error according to the contextutil. See the comment of `HandleErrorWithAlias` for detailed logic.
//
// It also allows using `errors.ErrorGroup`, in this case, it'll handle each error in order, and return the first error
// it founds.
Expand All @@ -92,7 +93,7 @@ func (ctx *Context) HandleError(err error) error {
return ctx.HandleErrorWithAlias(err, err, err)
}

// HandleErrorWithAlias handles the error according to the context.
// HandleErrorWithAlias handles the error according to the contextutil.
// 1. If the `internalErr` is not `"pingcap/errors".Error`, or the error code is not defined in the `errGroupMap`, or the error
// level is set to `LevelError`(0), the `err` will be returned directly.
// 2. If the error level is set to `LevelWarn`, the `warnErr` will be appended as a warning.
Expand Down Expand Up @@ -134,17 +135,15 @@ func (ctx *Context) HandleErrorWithAlias(internalErr error, err error, warnErr e
}

// NewContext creates an error context to handle the errors and warnings
func NewContext(appendWarningFn func(err error)) Context {
intest.Assert(appendWarningFn != nil)
func NewContext(handler contextutil.WarnHandler) Context {
intest.Assert(handler != nil)
return Context{
appendWarningFn: appendWarningFn,
warnHandler: handler,
}
}

// StrictNoWarningContext returns all errors directly, and ignore all errors
var StrictNoWarningContext = NewContext(func(_ error) {
// the error is ignored
})
var StrictNoWarningContext = NewContext(contextutil.IgnoreWarn)

var errGroupMap = make(map[errors.ErrCode]ErrGroup)

Expand Down
5 changes: 3 additions & 2 deletions pkg/errctx/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@ import (
"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/errctx"
"github.com/pingcap/tidb/pkg/types"
contextutil "github.com/pingcap/tidb/pkg/util/context"
"github.com/stretchr/testify/require"
"go.uber.org/multierr"
)

func TestContext(t *testing.T) {
var warn error
ctx := errctx.NewContext(func(err error) {
ctx := errctx.NewContext(contextutil.NewFuncWarnHandlerForTest(func(err error) {
warn = err
})
}))

testInternalErr := types.ErrOverflow
testErr := errors.New("error")
Expand Down
2 changes: 2 additions & 0 deletions pkg/server/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ go_library(
"//pkg/util",
"//pkg/util/arena",
"//pkg/util/chunk",
"//pkg/util/context",
"//pkg/util/cpuprofile",
"//pkg/util/dbterror",
"//pkg/util/dbterror/exeerrors",
Expand Down Expand Up @@ -181,6 +182,7 @@ go_test(
"//pkg/util",
"//pkg/util/arena",
"//pkg/util/chunk",
"//pkg/util/context",
"//pkg/util/dbterror/exeerrors",
"//pkg/util/replayer",
"//pkg/util/sqlkiller",
Expand Down
5 changes: 3 additions & 2 deletions pkg/server/conn_stmt_params_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/pingcap/tidb/pkg/testkit"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
contextutil "github.com/pingcap/tidb/pkg/util/context"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -268,9 +269,9 @@ func TestParseExecArgs(t *testing.T) {
}
for _, tt := range tests {
var warn error
typectx := types.NewContext(types.DefaultStmtFlags.WithTruncateAsWarning(true), time.UTC, func(err error) {
typectx := types.NewContext(types.DefaultStmtFlags.WithTruncateAsWarning(true), time.UTC, contextutil.NewFuncWarnHandlerForTest(func(err error) {
warn = err
})
}))
err := decodeAndParse(typectx, tt.args.args, tt.args.boundParams, tt.args.nullBitmap, tt.args.paramTypes, tt.args.paramValues, nil)
require.Truef(t, terror.ErrorEqual(err, tt.err), "err %v", err)
require.Truef(t, terror.ErrorEqual(warn, tt.warn), "warn %v", warn)
Expand Down
5 changes: 2 additions & 3 deletions pkg/server/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/pingcap/tidb/pkg/sessionctx/stmtctx"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
"github.com/pingcap/tidb/pkg/types"
contextutil "github.com/pingcap/tidb/pkg/util/context"
)

func (cc *clientConn) onExtensionConnEvent(tp extension.ConnEventTp, err error) {
Expand Down Expand Up @@ -89,9 +90,7 @@ func (cc *clientConn) onExtensionStmtEnd(node interface{}, stmtCtxValid bool, er
// TODO: the `BinaryParam` is parsed two times: one in the `Execute` method and one here. It would be better to
// eliminate one of them by storing the parsed result.
typectx := ctx.GetSessionVars().StmtCtx.TypeCtx()
typectx = types.NewContext(typectx.Flags(), typectx.Location(), func(_ error) {
// ignore all warnings
})
typectx = types.NewContext(typectx.Flags(), typectx.Location(), contextutil.IgnoreWarn)
params, _ := param.ExecArgs(typectx, args)
info.executeStmt = &ast.ExecuteStmt{
PrepStmt: prepared,
Expand Down
1 change: 1 addition & 0 deletions pkg/server/internal/column/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ go_test(
"//pkg/server/internal/util",
"//pkg/types",
"//pkg/util/chunk",
"//pkg/util/context",
"@com_github_stretchr_testify//require",
],
)
3 changes: 2 additions & 1 deletion pkg/server/internal/column/column_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/pingcap/tidb/pkg/server/internal/util"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
contextutil "github.com/pingcap/tidb/pkg/util/context"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -182,7 +183,7 @@ func TestDumpTextValue(t *testing.T) {

losAngelesTz, err := time.LoadLocation("America/Los_Angeles")
require.NoError(t, err)
typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), losAngelesTz, func(err error) {})
typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), losAngelesTz, contextutil.IgnoreWarn)

time, err := types.ParseTime(typeCtx, "2017-01-05 23:59:59.575601", mysql.TypeDatetime, 0)
require.NoError(t, err)
Expand Down
8 changes: 4 additions & 4 deletions pkg/sessionctx/stmtctx/stmtctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -427,22 +427,22 @@ type StatementContext struct {
// NewStmtCtx creates a new statement context
func NewStmtCtx() *StatementContext {
sc := &StatementContext{}
sc.typeCtx = types.NewContext(types.DefaultStmtFlags, time.UTC, sc.AppendWarning)
sc.typeCtx = types.NewContext(types.DefaultStmtFlags, time.UTC, sc)
return sc
}

// NewStmtCtxWithTimeZone creates a new StatementContext with the given timezone
func NewStmtCtxWithTimeZone(tz *time.Location) *StatementContext {
intest.AssertNotNil(tz)
sc := &StatementContext{}
sc.typeCtx = types.NewContext(types.DefaultStmtFlags, tz, sc.AppendWarning)
sc.typeCtx = types.NewContext(types.DefaultStmtFlags, tz, sc)
return sc
}

// Reset resets a statement context
func (sc *StatementContext) Reset() {
*sc = StatementContext{
typeCtx: types.NewContext(types.DefaultStmtFlags, time.UTC, sc.AppendWarning),
typeCtx: types.NewContext(types.DefaultStmtFlags, time.UTC, sc),
}
}

Expand Down Expand Up @@ -470,7 +470,7 @@ func (sc *StatementContext) TypeCtx() types.Context {
// ErrCtx returns the error context
// TODO: add a cache to the `ErrCtx` if needed, though it's not a big burden to generate `ErrCtx` everytime.
func (sc *StatementContext) ErrCtx() errctx.Context {
ctx := errctx.NewContext(sc.AppendWarning)
ctx := errctx.NewContext(sc)

if sc.TypeFlags().IgnoreTruncateErr() {
ctx = ctx.WithErrGroupLevel(errctx.ErrGroupTruncate, errctx.LevelIgnore)
Expand Down
8 changes: 8 additions & 0 deletions pkg/sessionctx/stmtctx/stmtctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -386,3 +386,11 @@ func TestResetStmtCtx(t *testing.T) {
require.Equal(t, stmtctx.WarnLevelWarning, warnings[0].Level)
require.Equal(t, "err2", warnings[0].Err.Error())
}

func BenchmarkErrCtx(b *testing.B) {
sc := stmtctx.NewStmtCtx()

for i := 0; i < b.N; i++ {
sc.ErrCtx()
}
}
2 changes: 2 additions & 0 deletions pkg/types/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ go_library(
"//pkg/parser/terror",
"//pkg/parser/types",
"//pkg/util/collate",
"//pkg/util/context",
"//pkg/util/dbterror",
"//pkg/util/hack",
"//pkg/util/intest",
Expand Down Expand Up @@ -106,6 +107,7 @@ go_test(
"//pkg/parser/terror",
"//pkg/testkit/testsetup",
"//pkg/util/collate",
"//pkg/util/context",
"//pkg/util/hack",
"@com_github_pingcap_errors//:errors",
"@com_github_stretchr_testify//assert",
Expand Down
44 changes: 19 additions & 25 deletions pkg/types/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package types
import (
"time"

contextutil "github.com/pingcap/tidb/pkg/util/context"
"github.com/pingcap/tidb/pkg/util/intest"
)

Expand Down Expand Up @@ -197,18 +198,18 @@ func (f Flags) WithCastTimeToYearThroughConcat(flag bool) Flags {

// Context provides the information when converting between different types.
type Context struct {
flags Flags
loc *time.Location
appendWarningFn func(err error)
flags Flags
loc *time.Location
warnHandler contextutil.WarnHandler
}

// NewContext creates a new `Context`
func NewContext(flags Flags, loc *time.Location, appendWarningFn func(err error)) Context {
intest.Assert(loc != nil && appendWarningFn != nil)
func NewContext(flags Flags, loc *time.Location, handler contextutil.WarnHandler) Context {
intest.Assert(loc != nil && handler != nil)
return Context{
flags: flags,
loc: loc,
appendWarningFn: appendWarningFn,
flags: flags,
loc: loc,
warnHandler: handler,
}
}

Expand Down Expand Up @@ -242,32 +243,25 @@ func (c *Context) Location() *time.Location {
return c.loc
}

// AppendWarning appends the error to warning. If the inner `appendWarningFn` is nil, do nothing.
// AppendWarning appends the error to warning. If the inner `warnHandler` is nil, do nothing.
func (c *Context) AppendWarning(err error) {
intest.Assert(c.appendWarningFn != nil)
if fn := c.appendWarningFn; fn != nil {
// appendWarningFn should always not be nil, check fn != nil here to just make code safe.
fn(err)
intest.Assert(c.warnHandler != nil)
if w := c.warnHandler; w != nil {
// warnHandler should always not be nil, check fn != nil here to just make code safe.
w.AppendWarning(err)
}
}

// AppendWarningFunc returns the inner `appendWarningFn`
func (c *Context) AppendWarningFunc() func(err error) {
return c.appendWarningFn
}

// DefaultStmtFlags is the default flags for statement context with the flag `FlagAllowNegativeToUnsigned` set.
// TODO: make DefaultStmtFlags to be equal with StrictFlags, and setting flag `FlagAllowNegativeToUnsigned`
// is only for make the code to be equivalent with the old implement during refactoring.
const DefaultStmtFlags = StrictFlags | FlagAllowNegativeToUnsigned | FlagIgnoreZeroDateErr

// DefaultStmtNoWarningContext is the context with default statement flags without any other special configuration
var DefaultStmtNoWarningContext = NewContext(DefaultStmtFlags, time.UTC, func(_ error) {
// the error is ignored
})
var DefaultStmtNoWarningContext = NewContext(DefaultStmtFlags, time.UTC, contextutil.IgnoreWarn)

// StrictContext is the most strict context which returns every error it meets
var StrictContext = NewContext(StrictFlags, time.UTC, func(_ error) {
// this context should never append warnings
// However, the implementation of `types` may still append some warnings. TODO: remove them in the future.
})
//
// this context should never append warnings
// However, the implementation of `types` may still append some warnings. TODO: remove them in the future.
var StrictContext = NewContext(StrictFlags, time.UTC, contextutil.IgnoreWarn)
3 changes: 2 additions & 1 deletion pkg/types/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ import (
"testing"
"time"

contextutil "github.com/pingcap/tidb/pkg/util/context"
"github.com/stretchr/testify/require"
)

func TestWithNewFlags(t *testing.T) {
ctx := NewContext(FlagSkipASCIICheck, time.UTC, func(_ error) {})
ctx := NewContext(FlagSkipASCIICheck, time.UTC, contextutil.IgnoreWarn)
ctx2 := ctx.WithFlags(FlagSkipUTF8Check)
require.Equal(t, FlagSkipASCIICheck, ctx.Flags())
require.Equal(t, FlagSkipUTF8Check, ctx2.Flags())
Expand Down
2 changes: 1 addition & 1 deletion pkg/types/convert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,7 @@ func TestGetValidInt(t *testing.T) {
{"123de", "123", true, true},
}
warnings := &warnStore{}
ctx := NewContext(DefaultStmtFlags.WithTruncateAsWarning(true), time.UTC, warnings.AppendWarning)
ctx := NewContext(DefaultStmtFlags.WithTruncateAsWarning(true), time.UTC, warnings)
warningCount := 0
for i, tt := range tests {
prefix, err := getValidIntPrefix(ctx, tt.origin, false)
Expand Down
2 changes: 1 addition & 1 deletion pkg/types/datum_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ func TestProduceDecWithSpecifiedTp(t *testing.T) {
{"-99.9999", 6, 3, "-100.000", false, true},
}
warnings := &warnStore{}
ctx := NewContext(DefaultStmtFlags, time.UTC, warnings.AppendWarning)
ctx := NewContext(DefaultStmtFlags, time.UTC, warnings)
for _, tt := range tests {
tp := NewFieldTypeBuilder().SetType(mysql.TypeNewDecimal).SetFlen(tt.flen).SetDecimal(tt.frac).BuildP()
dec := NewDecFromStringForTest(tt.dec)
Expand Down
Loading