Skip to content

Commit

Permalink
*: Add Assert to do assertions in test (#47552)
Browse files Browse the repository at this point in the history
close #47551
  • Loading branch information
lcwangchao authored Oct 12, 2023
1 parent 8426ec5 commit 0fd232f
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 2 deletions.
1 change: 1 addition & 0 deletions types/context/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go_library(
srcs = ["context.go"],
importpath = "github.com/pingcap/tidb/types/context",
visibility = ["//visibility:public"],
deps = ["//util/intest"],
)

go_test(
Expand Down
12 changes: 11 additions & 1 deletion types/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@

package context

import "time"
import (
"time"

"github.com/pingcap/tidb/util/intest"
)

// StrictFlags is a flags with a fields unset and has the most strict behavior.
const StrictFlags Flags = 0
Expand Down Expand Up @@ -109,6 +113,7 @@ type Context struct {

// NewContext creates a new `Context`
func NewContext(flags Flags, loc *time.Location, appendWarningFn func(err error)) Context {
intest.Assert(loc)
return Context{
flags: flags,
loc: loc,
Expand All @@ -130,6 +135,11 @@ func (c *Context) WithFlags(f Flags) Context {

// Location returns the location of the context
func (c *Context) Location() *time.Location {
intest.Assert(c.loc)
if c.loc == nil {
// this should never happen, just make the code safe here.
return time.UTC
}
return c.loc
}

Expand Down
14 changes: 13 additions & 1 deletion util/intest/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")

go_library(
name = "intest",
srcs = [
"assert.go", #keep
"common.go",
"intest.go", #keep
],
importpath = "github.com/pingcap/tidb/util/intest",
visibility = ["//visibility:public"],
)

go_test(
name = "intest_test",
timeout = "short",
srcs = ["assert_test.go"],
flaky = True,
deps = [
":intest",
"@com_github_stretchr_testify//require",
],
)
83 changes: 83 additions & 0 deletions util/intest/assert.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright 2023 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//go:build intest

package intest

import (
"fmt"
"reflect"
)

// Assert asserts a condition. It only works in test (intest.InTest == true).
// You can assert a condition like this to assert a variable `foo` is not nil: `assert.Assert(foo != nil)`.
// Or you can pass foo as a parameter directly for simple: `assert.Assert(foo)`
// You can also assert a function that returns a bool: `intest.Assert(func() bool { return foo != nil })`
// If you pass a function without a signature `func() bool`, the function will always panic.
func Assert(cond any, msgAndArgs ...any) {
if InTest {
assert(cond, msgAndArgs...)
}
}

// AssertFunc asserts a function condition
func AssertFunc(fn func() bool, msgAndArgs ...any) {
if InTest {
assert(fn(), msgAndArgs...)
}
}

func assert(cond any, msgAndArgs ...any) {
if !checkAssertObject(cond) {
doPanic(msgAndArgs...)
}
}

func doPanic(msgAndArgs ...any) {
panic(assertionFailedMsg(msgAndArgs...))
}

func assertionFailedMsg(msgAndArgs ...any) string {
if len(msgAndArgs) == 0 {
return "assert failed"
}

msg, ok := msgAndArgs[0].(string)
if !ok {
msg = fmt.Sprintf("%+v", msgAndArgs[0])
}

msg = fmt.Sprintf("assert failed: %s", msg)
return fmt.Sprintf(msg, msgAndArgs[1:]...)
}

func checkAssertObject(obj any) bool {
if obj == nil {
return false
}

value := reflect.ValueOf(obj)
switch value.Kind() {
case reflect.Bool:
return obj.(bool)
case reflect.Func:
panic("you should use `intest.Assert(fn != nil)` to assert a function is not nil, " +
"or use `intest.AssertFunc(fn)` to assert a function's return value is true")
case reflect.Chan, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice, reflect.UnsafePointer:
return !value.IsNil()
default:
return true
}
}
95 changes: 95 additions & 0 deletions util/intest/assert_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// Copyright 2023 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package intest_test

import (
"testing"

"github.com/pingcap/tidb/util/intest"
"github.com/stretchr/testify/require"
)

type foo struct{}

func TestAssert(t *testing.T) {
require.True(t, intest.InTest)
checkAssert(t, true, true)
checkAssert(t, "", true)
checkAssert(t, "abc", true)
checkAssert(t, 0, true)
checkAssert(t, 123, true)
checkAssert(t, foo{}, true)
checkAssert(t, &foo{}, true)
checkAssert(t, false, false)
checkAssert(t, false, false, "assert failed: msg1", "msg1")
checkAssert(t, false, false, "assert failed: msg2 a b 1", "msg2 %s %s %d", "a", "b", 1)
checkAssert(t, false, false, "assert failed: 123", 123)
checkAssert(t, nil, false)
var f *foo
checkAssert(t, f, false)
checkAssert(t, func() bool { return true }, false, "you should use `intest.Assert(fn != nil)` to assert a function is not nil, or use `intest.AssertFunc(fn)` to assert a function's return value is true")
checkAssert(t, func(_ string) bool { return true }, false, "you should use `intest.Assert(fn != nil)` to assert a function is not nil, or use `intest.AssertFunc(fn)` to assert a function's return value is true")
checkFuncAssert(t, func() bool { panic("inner panic1") }, false, "inner panic1")
checkFuncAssert(t, func() bool { return true }, true)
checkFuncAssert(t, func() bool { return false }, false)
checkFuncAssert(t, func() bool { return false }, false, "assert failed: msg3", "msg3")
checkFuncAssert(t, func() bool { return false }, false, "assert failed: msg4 c d 2", "msg4 %s %s %d", "c", "d", 2)
checkFuncAssert(t, func() bool { panic("inner panic2") }, false, "inner panic2")
}

func checkFuncAssert(t *testing.T, fn func() bool, pass bool, msgAndArgs ...any) {
doCheckAssert(t, intest.AssertFunc, fn, pass, msgAndArgs...)
}

func checkAssert(t *testing.T, cond any, pass bool, msgAndArgs ...any) {
doCheckAssert(t, intest.Assert, cond, pass, msgAndArgs...)
}

func doCheckAssert(t *testing.T, fn any, cond any, pass bool, msgAndArgs ...any) {
expectMsg := "assert failed"
if len(msgAndArgs) > 0 {
expectMsg = msgAndArgs[0].(string)
msgAndArgs = msgAndArgs[1:]
}

if !pass {
defer func() {
r := recover()
require.NotNil(t, r)
require.Equal(t, expectMsg, r)
}()
}

testFn, ok := fn.(func(any, ...any))
if !ok {
if fnAssert, ok := fn.(func(func() bool, ...any)); ok {
testFn = func(any, ...any) {
fnAssert(cond.(func() bool), msgAndArgs...)
}
} else {
require.FailNow(t, "invalid assert function")
}
}

if len(msgAndArgs) == 0 {
testFn(cond)
}

if len(msgAndArgs) == 1 {
testFn(cond, msgAndArgs[0])
}

testFn(cond, msgAndArgs...)
}
8 changes: 8 additions & 0 deletions util/intest/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,11 @@ package intest

// InTest checks if the code is running in test.
const InTest = false

// Assert is a stub function in release build.
// See the same function in `util/intest/assert.go` for the real implement in test.
func Assert(_ any, _ ...any) {}

// AssertFunc is a stub function in release build.
// See the same function `util/intest/assert.go` for the real implement in test.
func AssertFunc(_ func() bool, _ ...any) {}

0 comments on commit 0fd232f

Please sign in to comment.