From b4fa4384c28b739c42a9f8a91bf68cb1d568170d Mon Sep 17 00:00:00 2001 From: visualfc Date: Sun, 24 Mar 2024 20:41:23 +0800 Subject: [PATCH] inferFunc: support arguments with typeparam signature --- typeparams.go | 36 +++++++++++++++++++++++++++++------- typeparams_test.go | 45 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 7 deletions(-) diff --git a/typeparams.go b/typeparams.go index ed5e8a98..83ce4255 100644 --- a/typeparams.go +++ b/typeparams.go @@ -318,6 +318,13 @@ func inferFunc(pkg *Package, fn *internal.Elem, sig *types.Signature, targs []ty return nil, err } xlist := make([]*operand, len(args)) + tp := sig.TypeParams() + n := tp.Len() + tparams := make([]*types.TypeParam, n) + for i := 0; i < n; i++ { + tparams[i] = tp.At(i) + } + var targList []*internal.Elem for i, arg := range args { xlist[i] = &operand{ mode: value, @@ -325,18 +332,33 @@ func inferFunc(pkg *Package, fn *internal.Elem, sig *types.Signature, targs []ty typ: arg.Type, val: arg.CVal, } - } - tp := sig.TypeParams() - n := tp.Len() - tparams := make([]*types.TypeParam, n) - for i := 0; i < n; i++ { - tparams[i] = tp.At(i) + if sig, ok := arg.Type.(*types.Signature); ok { + if tp := sig.TypeParams(); tp != nil { + for i := 0; i < n; i++ { + tparams = append(tparams, tp.At(i)) + } + targList = append(targList, arg) + } + } } targs, err = infer(pkg, fn.Val, tparams, targs, sig.Params(), xlist) if err != nil { return nil, err } - return types.Instantiate(pkg.cb.ctxt, sig, targs, true) + typ, err := types.Instantiate(pkg.cb.ctxt, sig, targs[:n], true) + if err == nil { + for _, targ := range targList { + tsig := targ.Type.(*types.Signature) + tp := tsig.TypeParams() + tn := tp.Len() + tt, err := types.Instantiate(pkg.cb.ctxt, tsig, targs[n:n+tn], true) + if err == nil { + targ.Type = tt + } + n += tn + } + } + return typ, err } func inferFuncTargs(pkg *Package, fn *internal.Elem, sig *types.Signature, targs []types.Type) (types.Type, error) { diff --git a/typeparams_test.go b/typeparams_test.go index 045a1c38..414f1648 100644 --- a/typeparams_test.go +++ b/typeparams_test.go @@ -1107,3 +1107,48 @@ func main() { } `) } + +func TestTypeParamsArgumentsSignature(t *testing.T) { + const src = `package foo + +import "fmt" + +func ListMap[T any](ar []T, fn func(v T) T, dump func(i int, v T)) { + for i, v := range ar { + ar[i] = fn(v) + dump(i,ar[i]) + } +} + +func Add[N ~int](x N) N { + return x+x +} + +func Dump[T any](i int, v T) { + fmt.Println(i,v) +} + +var Numbers = []int{1,2,3,4} +` + gt := newGoxTest() + _, err := gt.LoadGoPackage("foo", "foo.go", src) + if err != nil { + t.Fatal(err) + } + pkg := gt.NewPackage("", "main") + fooRef := pkg.Import("foo") + pkg.NewFunc(nil, "main", nil, nil, false).BodyStart(pkg). + Val(fooRef.Ref("ListMap")). + Val(fooRef.Ref("Numbers")).Val(fooRef.Ref("Add")).Val(fooRef.Ref("Dump")). + Call(3).EndStmt(). + End() + + domTest(t, pkg, `package main + +import "foo" + +func main() { + foo.ListMap(foo.Numbers, foo.Add, foo.Dump) +} +`) +}