-
-
Notifications
You must be signed in to change notification settings - Fork 18
/
each.go
164 lines (133 loc) · 3.47 KB
/
each.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
package got
import (
"reflect"
"runtime/debug"
)
// Only run tests with it
type Only struct{}
// Skip the current test
type Skip struct{}
// Each runs each exported method Fn on type Ctx as a subtest of t.
// The iteratee can be a struct Ctx or:
//
// iteratee(t Testable) (ctx Ctx)
//
// Each Fn will be called like:
//
// ctx.Fn()
//
// If iteratee is Ctx, its G field will be set to New(t) for each test.
// Any Fn that has the same name with the embedded one will be ignored.
func Each(t Testable, iteratee interface{}) (count int) {
t.Helper()
itVal := normalizeIteratee(t, iteratee)
ctxType := itVal.Type().Out(0)
methods := filterMethods(ctxType)
runVal := reflect.ValueOf(t).MethodByName("Run")
cbType := runVal.Type().In(1)
for _, m := range methods {
// because the callback is in another goroutine, we create closures for each loop
method := m
runVal.Call([]reflect.Value{
reflect.ValueOf(method.Name),
reflect.MakeFunc(cbType, func(args []reflect.Value) []reflect.Value {
t := args[0].Interface().(Testable)
doSkip(t, method)
count++
res := itVal.Call(args)
return callMethod(t, method, res[0])
}),
})
}
return
}
func normalizeIteratee(t Testable, iteratee interface{}) reflect.Value {
t.Helper()
if iteratee == nil {
t.Logf("iteratee shouldn't be nil")
t.FailNow()
}
itVal := reflect.ValueOf(iteratee)
itType := itVal.Type()
fail := true
switch itType.Kind() {
case reflect.Func:
if itType.NumIn() != 1 || itType.NumOut() != 1 {
break
}
try(func() {
_ = reflect.New(itType.In(0).Elem()).Interface().(Testable)
fail = false
})
case reflect.Struct:
fnType := reflect.FuncOf([]reflect.Type{reflect.TypeOf(t)}, []reflect.Type{itType}, false)
structVal := itVal
itVal = reflect.MakeFunc(fnType, func(args []reflect.Value) []reflect.Value {
sub := args[0].Interface().(Testable)
as := reflect.ValueOf(New(sub))
c := reflect.New(itType).Elem()
c.Set(structVal)
try(func() { c.FieldByName("G").Set(as) })
return []reflect.Value{c}
})
fail = false
}
if fail {
t.Logf("iteratee <%v> should be a struct or <func(got.Testable) Ctx>", itType)
t.FailNow()
}
return itVal
}
func callMethod(t Testable, method reflect.Method, receiver reflect.Value) []reflect.Value {
args := make([]reflect.Value, method.Type.NumIn())
args[0] = receiver
for i := 1; i < len(args); i++ {
args[i] = reflect.New(method.Type.In(i)).Elem()
}
defer func() {
if err := recover(); err != nil {
t.Logf("[panic] %v\n%s", err, debug.Stack())
t.Fail()
}
}()
method.Func.Call(args)
return []reflect.Value{}
}
func filterMethods(typ reflect.Type) []reflect.Method {
embedded := map[string]struct{}{}
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
if field.Anonymous {
for j := 0; j < field.Type.NumMethod(); j++ {
embedded[field.Type.Method(j).Name] = struct{}{}
}
}
}
methods := []reflect.Method{}
onlyList := []reflect.Method{}
for i := 0; i < typ.NumMethod(); i++ {
method := typ.Method(i)
if _, has := embedded[method.Name]; has {
continue
}
if method.Type.NumIn() > 1 && method.Type.In(1) == reflect.TypeOf(Only{}) {
onlyList = append(onlyList, method)
}
methods = append(methods, method)
}
if len(onlyList) > 0 {
return onlyList
}
return methods
}
func doSkip(t Testable, method reflect.Method) {
if method.Type.NumIn() > 1 && method.Type.In(1) == reflect.TypeOf(Skip{}) {
t.SkipNow()
}
}
func try(fn func()) {
defer func() {
_ = recover()
}()
fn()
}