diff --git a/pkg/errctx/BUILD.bazel b/pkg/errctx/BUILD.bazel index ef0f7368ccd79..6fe6f78f5ee1a 100644 --- a/pkg/errctx/BUILD.bazel +++ b/pkg/errctx/BUILD.bazel @@ -7,6 +7,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/errno", + "//pkg/util/context", "//pkg/util/intest", "@com_github_pingcap_errors//:errors", ], @@ -20,6 +21,7 @@ go_test( deps = [ ":errctx", "//pkg/types", + "//pkg/util/context", "@com_github_pingcap_errors//:errors", "@com_github_stretchr_testify//require", "@org_uber_go_multierr//:multierr", diff --git a/pkg/errctx/context.go b/pkg/errctx/context.go index 18dd3e5b22c1a..d0b3ab0ab3260 100644 --- a/pkg/errctx/context.go +++ b/pkg/errctx/context.go @@ -17,6 +17,7 @@ package errctx import ( "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/errno" + contextutil "github.com/pingcap/tidb/pkg/util/context" "github.com/pingcap/tidb/pkg/util/intest" ) @@ -34,14 +35,14 @@ const ( // Context defines how to handle an error type Context struct { - levelMap [errGroupCount]Level - appendWarningFn func(err error) + levelMap [errGroupCount]Level + warnHandler contextutil.WarnHandler } // WithStrictErrGroupLevel makes the context to return the error directly for any kinds of errors. func (ctx *Context) WithStrictErrGroupLevel() Context { newCtx := Context{ - appendWarningFn: ctx.appendWarningFn, + warnHandler: ctx.warnHandler, } return newCtx @@ -50,24 +51,24 @@ func (ctx *Context) WithStrictErrGroupLevel() Context { // WithErrGroupLevel sets a `Level` for an `ErrGroup` func (ctx *Context) WithErrGroupLevel(eg ErrGroup, l Level) Context { newCtx := Context{ - levelMap: ctx.levelMap, - appendWarningFn: ctx.appendWarningFn, + levelMap: ctx.levelMap, + warnHandler: ctx.warnHandler, } newCtx.levelMap[eg] = l return newCtx } -// appendWarning appends the error to warning. If the inner `appendWarningFn` is nil, do nothing. +// appendWarning appends the error to warning. If the inner `warnHandler` is nil, do nothing. func (ctx *Context) appendWarning(err error) { - intest.Assert(ctx.appendWarningFn != nil) - if fn := ctx.appendWarningFn; fn != nil { - // appendWarningFn should always not be nil, check fn != nil here to just make code safe. - fn(err) + intest.Assert(ctx.warnHandler != nil) + if w := ctx.warnHandler; w != nil { + // warnHandler should always not be nil, check fn != nil here to just make code safe. + w.AppendWarning(err) } } -// HandleError handles the error according to the context. See the comment of `HandleErrorWithAlias` for detailed logic. +// HandleError handles the error according to the contextutil. See the comment of `HandleErrorWithAlias` for detailed logic. // // It also allows using `errors.ErrorGroup`, in this case, it'll handle each error in order, and return the first error // it founds. @@ -92,7 +93,7 @@ func (ctx *Context) HandleError(err error) error { return ctx.HandleErrorWithAlias(err, err, err) } -// HandleErrorWithAlias handles the error according to the context. +// HandleErrorWithAlias handles the error according to the contextutil. // 1. If the `internalErr` is not `"pingcap/errors".Error`, or the error code is not defined in the `errGroupMap`, or the error // level is set to `LevelError`(0), the `err` will be returned directly. // 2. If the error level is set to `LevelWarn`, the `warnErr` will be appended as a warning. @@ -134,17 +135,15 @@ func (ctx *Context) HandleErrorWithAlias(internalErr error, err error, warnErr e } // NewContext creates an error context to handle the errors and warnings -func NewContext(appendWarningFn func(err error)) Context { - intest.Assert(appendWarningFn != nil) +func NewContext(handler contextutil.WarnHandler) Context { + intest.Assert(handler != nil) return Context{ - appendWarningFn: appendWarningFn, + warnHandler: handler, } } // StrictNoWarningContext returns all errors directly, and ignore all errors -var StrictNoWarningContext = NewContext(func(_ error) { - // the error is ignored -}) +var StrictNoWarningContext = NewContext(contextutil.IgnoreWarn) var errGroupMap = make(map[errors.ErrCode]ErrGroup) diff --git a/pkg/errctx/context_test.go b/pkg/errctx/context_test.go index e78cbb09c7640..0162f8e5e2d73 100644 --- a/pkg/errctx/context_test.go +++ b/pkg/errctx/context_test.go @@ -20,15 +20,16 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/errctx" "github.com/pingcap/tidb/pkg/types" + contextutil "github.com/pingcap/tidb/pkg/util/context" "github.com/stretchr/testify/require" "go.uber.org/multierr" ) func TestContext(t *testing.T) { var warn error - ctx := errctx.NewContext(func(err error) { + ctx := errctx.NewContext(contextutil.NewFuncWarnHandlerForTest(func(err error) { warn = err - }) + })) testInternalErr := types.ErrOverflow testErr := errors.New("error") diff --git a/pkg/server/BUILD.bazel b/pkg/server/BUILD.bazel index c6eb16b528ae5..51431ecfbe979 100644 --- a/pkg/server/BUILD.bazel +++ b/pkg/server/BUILD.bazel @@ -78,6 +78,7 @@ go_library( "//pkg/util", "//pkg/util/arena", "//pkg/util/chunk", + "//pkg/util/context", "//pkg/util/cpuprofile", "//pkg/util/dbterror", "//pkg/util/dbterror/exeerrors", @@ -181,6 +182,7 @@ go_test( "//pkg/util", "//pkg/util/arena", "//pkg/util/chunk", + "//pkg/util/context", "//pkg/util/dbterror/exeerrors", "//pkg/util/replayer", "//pkg/util/sqlkiller", diff --git a/pkg/server/conn_stmt_params_test.go b/pkg/server/conn_stmt_params_test.go index 61dd33ffcc935..df6190c6193b7 100644 --- a/pkg/server/conn_stmt_params_test.go +++ b/pkg/server/conn_stmt_params_test.go @@ -31,6 +31,7 @@ import ( "github.com/pingcap/tidb/pkg/testkit" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" + contextutil "github.com/pingcap/tidb/pkg/util/context" "github.com/stretchr/testify/require" ) @@ -268,9 +269,9 @@ func TestParseExecArgs(t *testing.T) { } for _, tt := range tests { var warn error - typectx := types.NewContext(types.DefaultStmtFlags.WithTruncateAsWarning(true), time.UTC, func(err error) { + typectx := types.NewContext(types.DefaultStmtFlags.WithTruncateAsWarning(true), time.UTC, contextutil.NewFuncWarnHandlerForTest(func(err error) { warn = err - }) + })) err := decodeAndParse(typectx, tt.args.args, tt.args.boundParams, tt.args.nullBitmap, tt.args.paramTypes, tt.args.paramValues, nil) require.Truef(t, terror.ErrorEqual(err, tt.err), "err %v", err) require.Truef(t, terror.ErrorEqual(warn, tt.warn), "warn %v", warn) diff --git a/pkg/server/extension.go b/pkg/server/extension.go index 407c4fa979e8b..4339e08c19d45 100644 --- a/pkg/server/extension.go +++ b/pkg/server/extension.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/types" + contextutil "github.com/pingcap/tidb/pkg/util/context" ) func (cc *clientConn) onExtensionConnEvent(tp extension.ConnEventTp, err error) { @@ -89,9 +90,7 @@ func (cc *clientConn) onExtensionStmtEnd(node interface{}, stmtCtxValid bool, er // TODO: the `BinaryParam` is parsed two times: one in the `Execute` method and one here. It would be better to // eliminate one of them by storing the parsed result. typectx := ctx.GetSessionVars().StmtCtx.TypeCtx() - typectx = types.NewContext(typectx.Flags(), typectx.Location(), func(_ error) { - // ignore all warnings - }) + typectx = types.NewContext(typectx.Flags(), typectx.Location(), contextutil.IgnoreWarn) params, _ := param.ExecArgs(typectx, args) info.executeStmt = &ast.ExecuteStmt{ PrepStmt: prepared, diff --git a/pkg/server/internal/column/BUILD.bazel b/pkg/server/internal/column/BUILD.bazel index fc758df83711f..88fb2d7c97b35 100644 --- a/pkg/server/internal/column/BUILD.bazel +++ b/pkg/server/internal/column/BUILD.bazel @@ -40,6 +40,7 @@ go_test( "//pkg/server/internal/util", "//pkg/types", "//pkg/util/chunk", + "//pkg/util/context", "@com_github_stretchr_testify//require", ], ) diff --git a/pkg/server/internal/column/column_test.go b/pkg/server/internal/column/column_test.go index f4fdceab40962..7c224fbf50679 100644 --- a/pkg/server/internal/column/column_test.go +++ b/pkg/server/internal/column/column_test.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/tidb/pkg/server/internal/util" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" + contextutil "github.com/pingcap/tidb/pkg/util/context" "github.com/stretchr/testify/require" ) @@ -182,7 +183,7 @@ func TestDumpTextValue(t *testing.T) { losAngelesTz, err := time.LoadLocation("America/Los_Angeles") require.NoError(t, err) - typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), losAngelesTz, func(err error) {}) + typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), losAngelesTz, contextutil.IgnoreWarn) time, err := types.ParseTime(typeCtx, "2017-01-05 23:59:59.575601", mysql.TypeDatetime, 0) require.NoError(t, err) diff --git a/pkg/sessionctx/stmtctx/stmtctx.go b/pkg/sessionctx/stmtctx/stmtctx.go index 5e3a929c051b5..4cbacb77afd55 100644 --- a/pkg/sessionctx/stmtctx/stmtctx.go +++ b/pkg/sessionctx/stmtctx/stmtctx.go @@ -427,7 +427,7 @@ type StatementContext struct { // NewStmtCtx creates a new statement context func NewStmtCtx() *StatementContext { sc := &StatementContext{} - sc.typeCtx = types.NewContext(types.DefaultStmtFlags, time.UTC, sc.AppendWarning) + sc.typeCtx = types.NewContext(types.DefaultStmtFlags, time.UTC, sc) return sc } @@ -435,14 +435,14 @@ func NewStmtCtx() *StatementContext { func NewStmtCtxWithTimeZone(tz *time.Location) *StatementContext { intest.AssertNotNil(tz) sc := &StatementContext{} - sc.typeCtx = types.NewContext(types.DefaultStmtFlags, tz, sc.AppendWarning) + sc.typeCtx = types.NewContext(types.DefaultStmtFlags, tz, sc) return sc } // Reset resets a statement context func (sc *StatementContext) Reset() { *sc = StatementContext{ - typeCtx: types.NewContext(types.DefaultStmtFlags, time.UTC, sc.AppendWarning), + typeCtx: types.NewContext(types.DefaultStmtFlags, time.UTC, sc), } } @@ -470,7 +470,7 @@ func (sc *StatementContext) TypeCtx() types.Context { // ErrCtx returns the error context // TODO: add a cache to the `ErrCtx` if needed, though it's not a big burden to generate `ErrCtx` everytime. func (sc *StatementContext) ErrCtx() errctx.Context { - ctx := errctx.NewContext(sc.AppendWarning) + ctx := errctx.NewContext(sc) if sc.TypeFlags().IgnoreTruncateErr() { ctx = ctx.WithErrGroupLevel(errctx.ErrGroupTruncate, errctx.LevelIgnore) diff --git a/pkg/sessionctx/stmtctx/stmtctx_test.go b/pkg/sessionctx/stmtctx/stmtctx_test.go index 60b5bdad3ff84..26ec689e304ea 100644 --- a/pkg/sessionctx/stmtctx/stmtctx_test.go +++ b/pkg/sessionctx/stmtctx/stmtctx_test.go @@ -386,3 +386,11 @@ func TestResetStmtCtx(t *testing.T) { require.Equal(t, stmtctx.WarnLevelWarning, warnings[0].Level) require.Equal(t, "err2", warnings[0].Err.Error()) } + +func BenchmarkErrCtx(b *testing.B) { + sc := stmtctx.NewStmtCtx() + + for i := 0; i < b.N; i++ { + sc.ErrCtx() + } +} diff --git a/pkg/types/BUILD.bazel b/pkg/types/BUILD.bazel index 416a8f81070f8..bd29720a792d6 100644 --- a/pkg/types/BUILD.bazel +++ b/pkg/types/BUILD.bazel @@ -52,6 +52,7 @@ go_library( "//pkg/parser/terror", "//pkg/parser/types", "//pkg/util/collate", + "//pkg/util/context", "//pkg/util/dbterror", "//pkg/util/hack", "//pkg/util/intest", @@ -106,6 +107,7 @@ go_test( "//pkg/parser/terror", "//pkg/testkit/testsetup", "//pkg/util/collate", + "//pkg/util/context", "//pkg/util/hack", "@com_github_pingcap_errors//:errors", "@com_github_stretchr_testify//assert", diff --git a/pkg/types/context.go b/pkg/types/context.go index 8f894d35942c2..b515c78fddc72 100644 --- a/pkg/types/context.go +++ b/pkg/types/context.go @@ -17,6 +17,7 @@ package types import ( "time" + contextutil "github.com/pingcap/tidb/pkg/util/context" "github.com/pingcap/tidb/pkg/util/intest" ) @@ -197,18 +198,18 @@ func (f Flags) WithCastTimeToYearThroughConcat(flag bool) Flags { // Context provides the information when converting between different types. type Context struct { - flags Flags - loc *time.Location - appendWarningFn func(err error) + flags Flags + loc *time.Location + warnHandler contextutil.WarnHandler } // NewContext creates a new `Context` -func NewContext(flags Flags, loc *time.Location, appendWarningFn func(err error)) Context { - intest.Assert(loc != nil && appendWarningFn != nil) +func NewContext(flags Flags, loc *time.Location, handler contextutil.WarnHandler) Context { + intest.Assert(loc != nil && handler != nil) return Context{ - flags: flags, - loc: loc, - appendWarningFn: appendWarningFn, + flags: flags, + loc: loc, + warnHandler: handler, } } @@ -242,32 +243,25 @@ func (c *Context) Location() *time.Location { return c.loc } -// AppendWarning appends the error to warning. If the inner `appendWarningFn` is nil, do nothing. +// AppendWarning appends the error to warning. If the inner `warnHandler` is nil, do nothing. func (c *Context) AppendWarning(err error) { - intest.Assert(c.appendWarningFn != nil) - if fn := c.appendWarningFn; fn != nil { - // appendWarningFn should always not be nil, check fn != nil here to just make code safe. - fn(err) + intest.Assert(c.warnHandler != nil) + if w := c.warnHandler; w != nil { + // warnHandler should always not be nil, check fn != nil here to just make code safe. + w.AppendWarning(err) } } -// AppendWarningFunc returns the inner `appendWarningFn` -func (c *Context) AppendWarningFunc() func(err error) { - return c.appendWarningFn -} - // DefaultStmtFlags is the default flags for statement context with the flag `FlagAllowNegativeToUnsigned` set. // TODO: make DefaultStmtFlags to be equal with StrictFlags, and setting flag `FlagAllowNegativeToUnsigned` // is only for make the code to be equivalent with the old implement during refactoring. const DefaultStmtFlags = StrictFlags | FlagAllowNegativeToUnsigned | FlagIgnoreZeroDateErr // DefaultStmtNoWarningContext is the context with default statement flags without any other special configuration -var DefaultStmtNoWarningContext = NewContext(DefaultStmtFlags, time.UTC, func(_ error) { - // the error is ignored -}) +var DefaultStmtNoWarningContext = NewContext(DefaultStmtFlags, time.UTC, contextutil.IgnoreWarn) // StrictContext is the most strict context which returns every error it meets -var StrictContext = NewContext(StrictFlags, time.UTC, func(_ error) { - // this context should never append warnings - // However, the implementation of `types` may still append some warnings. TODO: remove them in the future. -}) +// +// this context should never append warnings +// However, the implementation of `types` may still append some warnings. TODO: remove them in the future. +var StrictContext = NewContext(StrictFlags, time.UTC, contextutil.IgnoreWarn) diff --git a/pkg/types/context_test.go b/pkg/types/context_test.go index 0bd8d5cbcc811..38d1874ab5024 100644 --- a/pkg/types/context_test.go +++ b/pkg/types/context_test.go @@ -20,11 +20,12 @@ import ( "testing" "time" + contextutil "github.com/pingcap/tidb/pkg/util/context" "github.com/stretchr/testify/require" ) func TestWithNewFlags(t *testing.T) { - ctx := NewContext(FlagSkipASCIICheck, time.UTC, func(_ error) {}) + ctx := NewContext(FlagSkipASCIICheck, time.UTC, contextutil.IgnoreWarn) ctx2 := ctx.WithFlags(FlagSkipUTF8Check) require.Equal(t, FlagSkipASCIICheck, ctx.Flags()) require.Equal(t, FlagSkipUTF8Check, ctx2.Flags()) diff --git a/pkg/types/convert_test.go b/pkg/types/convert_test.go index 86d9893edff21..41dd4753d315c 100644 --- a/pkg/types/convert_test.go +++ b/pkg/types/convert_test.go @@ -864,7 +864,7 @@ func TestGetValidInt(t *testing.T) { {"123de", "123", true, true}, } warnings := &warnStore{} - ctx := NewContext(DefaultStmtFlags.WithTruncateAsWarning(true), time.UTC, warnings.AppendWarning) + ctx := NewContext(DefaultStmtFlags.WithTruncateAsWarning(true), time.UTC, warnings) warningCount := 0 for i, tt := range tests { prefix, err := getValidIntPrefix(ctx, tt.origin, false) diff --git a/pkg/types/datum_test.go b/pkg/types/datum_test.go index 53dc48ef05c69..5abc09e2b9c56 100644 --- a/pkg/types/datum_test.go +++ b/pkg/types/datum_test.go @@ -638,7 +638,7 @@ func TestProduceDecWithSpecifiedTp(t *testing.T) { {"-99.9999", 6, 3, "-100.000", false, true}, } warnings := &warnStore{} - ctx := NewContext(DefaultStmtFlags, time.UTC, warnings.AppendWarning) + ctx := NewContext(DefaultStmtFlags, time.UTC, warnings) for _, tt := range tests { tp := NewFieldTypeBuilder().SetType(mysql.TypeNewDecimal).SetFlen(tt.flen).SetDecimal(tt.frac).BuildP() dec := NewDecFromStringForTest(tt.dec) diff --git a/pkg/types/format_test.go b/pkg/types/format_test.go index a6caefdbe4911..8db9d57d858f4 100644 --- a/pkg/types/format_test.go +++ b/pkg/types/format_test.go @@ -20,11 +20,12 @@ import ( "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/types" + contextutil "github.com/pingcap/tidb/pkg/util/context" "github.com/stretchr/testify/require" ) func TestTimeFormatMethod(t *testing.T) { - typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), time.UTC, func(err error) {}) + typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), time.UTC, contextutil.IgnoreWarn) tblDate := []struct { Input string Format string @@ -78,7 +79,7 @@ func TestTimeFormatMethod(t *testing.T) { } func TestStrToDate(t *testing.T) { - typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), time.UTC, func(err error) {}) + typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), time.UTC, contextutil.IgnoreWarn) tests := []struct { input string format string diff --git a/pkg/types/time_test.go b/pkg/types/time_test.go index bff3fd3f8c8e5..4910ae075b655 100644 --- a/pkg/types/time_test.go +++ b/pkg/types/time_test.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/parser/terror" "github.com/pingcap/tidb/pkg/types" + contextutil "github.com/pingcap/tidb/pkg/util/context" "github.com/stretchr/testify/require" ) @@ -60,9 +61,9 @@ func TestTimeEncoding(t *testing.T) { func TestDateTime(t *testing.T) { var warnings []error - typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), time.UTC, func(err error) { + typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), time.UTC, contextutil.NewFuncWarnHandlerForTest(func(err error) { warnings = append(warnings, err) - }) + })) table := []struct { Input string Expect string @@ -209,7 +210,7 @@ func TestTimestamp(t *testing.T) { } func TestDate(t *testing.T) { - typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), time.UTC, func(err error) {}) + typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), time.UTC, contextutil.IgnoreWarn) table := []struct { Input string Expect string @@ -303,7 +304,7 @@ func TestDate(t *testing.T) { } func TestTime(t *testing.T) { - typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), time.UTC, func(err error) {}) + typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), time.UTC, contextutil.IgnoreWarn) table := []struct { Input string Expect string @@ -449,7 +450,7 @@ func TestDurationAdd(t *testing.T) { } func TestDurationSub(t *testing.T) { - typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), time.UTC, func(err error) {}) + typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), time.UTC, contextutil.IgnoreWarn) table := []struct { Input string Fsp int @@ -472,7 +473,7 @@ func TestDurationSub(t *testing.T) { } func TestTimeFsp(t *testing.T) { - typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), time.UTC, func(err error) {}) + typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), time.UTC, contextutil.IgnoreWarn) table := []struct { Input string Fsp int @@ -701,7 +702,7 @@ func TestParseTimeFromNum(t *testing.T) { func TestToNumber(t *testing.T) { losAngelesTz, err := time.LoadLocation("America/Los_Angeles") require.NoError(t, err) - typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), losAngelesTz, func(err error) {}) + typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), losAngelesTz, contextutil.IgnoreWarn) tblDateTime := []struct { Input string Fsp int @@ -773,7 +774,7 @@ func TestToNumber(t *testing.T) { } func TestParseTimeFromFloatString(t *testing.T) { - typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), time.UTC, func(err error) {}) + typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), time.UTC, contextutil.IgnoreWarn) table := []struct { Input string Fsp int @@ -840,7 +841,7 @@ func TestParseFrac(t *testing.T) { } func TestRoundFrac(t *testing.T) { - typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), time.UTC, func(err error) {}) + typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), time.UTC, contextutil.IgnoreWarn) tbl := []struct { Input string Fsp int @@ -931,7 +932,7 @@ func TestRoundFrac(t *testing.T) { func TestConvert(t *testing.T) { losAngelesTz, _ := time.LoadLocation("America/Los_Angeles") - typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), losAngelesTz, func(err error) {}) + typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), losAngelesTz, contextutil.IgnoreWarn) tbl := []struct { Input string Fsp int @@ -1774,7 +1775,7 @@ func TestIsDateFormat(t *testing.T) { } func TestParseTimeFromInt64(t *testing.T) { - typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), time.UTC, func(err error) {}) + typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), time.UTC, contextutil.IgnoreWarn) input := int64(20190412140000) output, err := types.ParseTimeFromInt64(typeCtx, input) @@ -1791,7 +1792,7 @@ func TestParseTimeFromInt64(t *testing.T) { } func TestParseTimeFromFloat64(t *testing.T) { - typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), time.UTC, func(err error) {}) + typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), time.UTC, contextutil.IgnoreWarn) cases := []struct { f float64 @@ -1834,7 +1835,7 @@ func TestParseTimeFromFloat64(t *testing.T) { } func TestParseTimeFromDecimal(t *testing.T) { - typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), time.UTC, func(err error) {}) + typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), time.UTC, contextutil.IgnoreWarn) cases := []struct { d *types.MyDecimal @@ -1909,7 +1910,7 @@ func TestGetFracIndex(t *testing.T) { } func TestTimeOverflow(t *testing.T) { - typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), time.UTC, func(err error) {}) + typeCtx := types.NewContext(types.StrictFlags.WithIgnoreZeroInDate(true), time.UTC, contextutil.IgnoreWarn) table := []struct { Input string Output bool @@ -2001,7 +2002,7 @@ func TestCheckMonthDay(t *testing.T) { {types.FromDate(3200, 2, 29, 0, 0, 0, 0), true}, } - typeCtx := types.NewContext(types.StrictFlags.WithIgnoreInvalidDateErr(false), time.UTC, func(err error) {}) + typeCtx := types.NewContext(types.StrictFlags.WithIgnoreInvalidDateErr(false), time.UTC, contextutil.IgnoreWarn) for _, tt := range dates { v := types.NewTime(tt.date, mysql.TypeDate, types.DefaultFsp) @@ -2166,7 +2167,7 @@ func TestParseWithTimezone(t *testing.T) { }, } for ith, ca := range cases { - v, err := types.ParseTime(types.NewContext(types.StrictFlags, ca.sysTZ, func(err error) {}), ca.lit, mysql.TypeTimestamp, ca.fsp) + v, err := types.ParseTime(types.NewContext(types.StrictFlags, ca.sysTZ, contextutil.IgnoreWarn), ca.lit, mysql.TypeTimestamp, ca.fsp) require.NoErrorf(t, err, "tidb time parse misbehaved on %d", ith) if err != nil { continue @@ -2209,9 +2210,9 @@ func TestDurationConvertToYearFromNow(t *testing.T) { } for _, c := range cases { - ctx := types.NewContext(types.StrictFlags.WithCastTimeToYearThroughConcat(c.throughStr), c.sysTZ, func(_ error) { + ctx := types.NewContext(types.StrictFlags.WithCastTimeToYearThroughConcat(c.throughStr), c.sysTZ, contextutil.NewFuncWarnHandlerForTest(func(_ error) { require.Fail(t, "shouldn't append warninng") - }) + })) now, err := time.Parse(time.RFC3339, c.nowLit) require.NoError(t, err) diff --git a/pkg/util/context/BUILD.bazel b/pkg/util/context/BUILD.bazel new file mode 100644 index 0000000000000..a879bb3b0ac9e --- /dev/null +++ b/pkg/util/context/BUILD.bazel @@ -0,0 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "context", + srcs = ["warn.go"], + importpath = "github.com/pingcap/tidb/pkg/util/context", + visibility = ["//visibility:public"], +) diff --git a/pkg/util/context/warn.go b/pkg/util/context/warn.go new file mode 100644 index 0000000000000..9b4b53d600d7b --- /dev/null +++ b/pkg/util/context/warn.go @@ -0,0 +1,44 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +// WarnHandler provides a function to add a warning. +// Using interface rather than a simple function/closure can avoid memory allocation in some cases. +// See https://github.com/pingcap/tidb/issues/49277 +type WarnHandler interface { + // AppendWarning appends a warning + AppendWarning(err error) +} + +type ignoreWarn struct{} + +func (*ignoreWarn) AppendWarning(_ error) {} + +// IgnoreWarn is WarnHandler which does nothing +var IgnoreWarn WarnHandler = &ignoreWarn{} + +type funcWarnHandler struct { + fn func(err error) +} + +func (r *funcWarnHandler) AppendWarning(err error) { + r.fn(err) +} + +// NewFuncWarnHandlerForTest creates a `WarnHandler` which will use the function to handle warn +// To have a better performance, it's not suggested to use this function in production. +func NewFuncWarnHandlerForTest(fn func(err error)) WarnHandler { + return &funcWarnHandler{fn} +}