diff --git a/.gitignore b/.gitignore index fa871d2..1f20212 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,5 @@ _testmain.go dist/ .DS_Store + +**/vendor \ No newline at end of file diff --git a/go.mod b/go.mod index c3c8c86..8af045a 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/matryer/moq -go 1.18 +go 1.19 require ( github.com/pmezard/go-difflib v1.0.0 diff --git a/internal/registry/generic_type_params.go b/internal/registry/generic_type_params.go new file mode 100644 index 0000000..5dac976 --- /dev/null +++ b/internal/registry/generic_type_params.go @@ -0,0 +1,80 @@ +package registry + +import ( + "fmt" + "go/types" + "regexp" + "strings" +) + +// GenericConstraint is used as a wrapper to instantiate a new type +// for use with the registry type params +type GenericConstraint struct { + Pkg string + Path string + Name string +} + +// NewGenericConstraint returns a pointer to a new GenericContstraint instance +func NewGenericConstraint(constraint string) *GenericConstraint { + return &GenericConstraint{ + Pkg: getPkgName(constraint), + Path: getPackagePath(constraint), + Name: getName(constraint), + } +} + +// Underlying satisfies types.Type Underlying method +func (g GenericConstraint) Underlying() types.Type { + return g +} + +// String statisfies types.Type String method +func (g GenericConstraint) String() string { + return g.Name +} + +var appearsImportedRegex = regexp.MustCompile(`.+\/.+\.[^\.\/]`) + +// ConstraintAppearsImported checks a constraints against a regular expression +// to loosely tell if it follows an imported type pattern +func ConstraintAppearsImported(constraint string) bool { + return appearsImportedRegex.Match([]byte(constraint)) +} + +func getPkgName(constraint string) string { + if i := strings.LastIndexByte(constraint, '/'); i != -1 { + constraint = strings.TrimLeft(constraint[i:], "/") + } + + if i := strings.LastIndexByte(constraint, '.'); i != -1 { + constraint = constraint[:i] + } + + return constraint +} + +func getPackagePath(constraint string) string { + if i := strings.LastIndexByte(constraint, '.'); i != -1 { + constraint = constraint[:i] + } + + return strings.TrimLeft(constraint, "*") +} + +func getName(constraint string) string { + var ptr bool + if constraint[0] == '*' { + ptr = true + } + + if i := strings.LastIndexByte(constraint, '/'); i != -1 { + constraint = strings.TrimPrefix(constraint[i:], "/") + } + + if ptr { + constraint = fmt.Sprintf("*%s", constraint) + } + + return constraint +} diff --git a/internal/registry/generic_type_params_test.go b/internal/registry/generic_type_params_test.go new file mode 100644 index 0000000..ab903e9 --- /dev/null +++ b/internal/registry/generic_type_params_test.go @@ -0,0 +1,78 @@ +package registry + +import "testing" + +func TestPackageUtilities(t *testing.T) { + testCases := []struct { + Name string + Input string + ExpectedPkgName string + ExpectedPkgPath string + ExpectedTypeName string + }{ + { + Name: "external package", + Input: "github.com/matryer/moq/fakepkg.MyType", + ExpectedPkgName: "fakepkg", + ExpectedPkgPath: "github.com/matryer/moq/fakepkg", + ExpectedTypeName: "fakepkg.MyType", + }, + { + Name: "internal package with ptr", + Input: "*os/fs.FileInfo", + ExpectedPkgName: "fs", + ExpectedPkgPath: "os/fs", + ExpectedTypeName: "*fs.FileInfo", + }, + { + Name: "internal package", + Input: "os/fs.FileInfo", + ExpectedPkgName: "fs", + ExpectedPkgPath: "os/fs", + ExpectedTypeName: "fs.FileInfo", + }, + { + Name: "external package with ptr", + Input: "*github.com/matryer/moq/fakepkg.MyType", + ExpectedPkgName: "fakepkg", + ExpectedPkgPath: "github.com/matryer/moq/fakepkg", + ExpectedTypeName: "*fakepkg.MyType", + }, + } + + for _, test := range testCases { + t.Run(test.Name, func(t *testing.T) { + if res := getPkgName(test.Input); res != test.ExpectedPkgName { + t.Fatalf("Got unexpected package name, Expected: '%s' Got: '%s'\n", test.ExpectedPkgName, res) + } + + if res := getPackagePath(test.Input); res != test.ExpectedPkgPath { + t.Fatalf("Got unexpected package path, Expected: '%s' Got: '%s'\n", test.ExpectedPkgPath, res) + } + + if res := getName(test.Input); res != test.ExpectedTypeName { + t.Fatalf("Got unexpected type name, Expected: '%s' Got: '%s'\n", test.ExpectedTypeName, res) + } + }) + } +} + +func TestConstraintContainsPkg(t *testing.T) { + testCases := []struct { + Input string + Expected bool + }{ + {Input: "github.com/matryer/moq/pkg.SomeType", Expected: true}, + {Input: "*os/fs.T", Expected: true}, + {Input: "os.T", Expected: false}, + {Input: "os/fs", Expected: false}, + {Input: "os/fs.", Expected: false}, + {Input: "os/fs./os", Expected: false}, + } + + for _, test := range testCases { + if res := ConstraintAppearsImported(test.Input); res != test.Expected { + t.Fatalf("Got unexpected result, Expected: '%v' Got: '%v' for string: '%s'\n", test.Expected, res, test.Input) + } + } +} diff --git a/pkg/moq/moq.go b/pkg/moq/moq.go index e8a2975..e13c6f9 100644 --- a/pkg/moq/moq.go +++ b/pkg/moq/moq.go @@ -88,6 +88,7 @@ func (m *Mocker) Mock(w io.Writer, namePairs ...string) error { if data.MocksSomeMethod() { m.registry.AddImport(types.NewPackage("sync", "sync")) } + if m.registry.SrcPkgName() != m.mockPkgName() { data.SrcPkgQualifier = m.registry.SrcPkgName() + "." if !m.cfg.SkipEnsure { @@ -126,9 +127,22 @@ func (m *Mocker) typeParams(tparams *types.TypeParamList) []template.TypeParamDa for i := 0; i < len(tpd); i++ { tp := tparams.At(i) typeParam := types.NewParam(token.Pos(i), tp.Obj().Pkg(), tp.Obj().Name(), tp.Constraint()) + + constraint := explicitConstraintType(typeParam) + if constraint != nil && registry.ConstraintAppearsImported(constraint.String()) { + // generate a new type + t := registry.NewGenericConstraint(constraint.String()) + + // since our constraint is from a package, we need to add it to the registry + m.registry.AddImport( + types.NewPackage(t.Path, t.Pkg), + ) + constraint = t + } + tpd[i] = template.TypeParamData{ ParamData: template.ParamData{Var: scope.AddVar(typeParam, "")}, - Constraint: explicitConstraintType(typeParam), + Constraint: constraint, } } @@ -137,6 +151,7 @@ func (m *Mocker) typeParams(tparams *types.TypeParamList) []template.TypeParamDa func explicitConstraintType(typeParam *types.Var) (t types.Type) { underlying := typeParam.Type().Underlying().(*types.Interface) + // check if any of the embedded types is either a basic type or a union, // because the generic type has to be an alias for one of those types then for j := 0; j < underlying.NumEmbeddeds(); j++ { diff --git a/pkg/moq/moq_test.go b/pkg/moq/moq_test.go index 7dceb8f..44545d4 100644 --- a/pkg/moq/moq_test.go +++ b/pkg/moq/moq_test.go @@ -401,6 +401,12 @@ func TestMockGolden(t *testing.T) { interfaces: []string{"ResetStore"}, goldenFile: filepath.Join("testpackages/withresets", "withresets_moq.golden.go"), }, + { + name: "GenericsImportedConstraint", + cfg: Config{SrcDir: "testpackages/generics_imported_constraint"}, + interfaces: []string{"GenericStore1"}, + goldenFile: filepath.Join("testpackages/generics_imported_constraint", "generics_imported_constraint_moq.golden.go"), + }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { diff --git a/pkg/moq/testpackages/generics_imported_constraint/extern/extern_generic.go b/pkg/moq/testpackages/generics_imported_constraint/extern/extern_generic.go new file mode 100644 index 0000000..9c1b585 --- /dev/null +++ b/pkg/moq/testpackages/generics_imported_constraint/extern/extern_generic.go @@ -0,0 +1,35 @@ +package extern + +import ( + "io/fs" + "os" + + "github.com/matryer/moq/pkg/moq/testpackages/generics_imported_constraint/extern2" +) + +// validate gh package works +type Foo1 interface { + *extern2.SomeType | *fs.PathError | os.FileMode +} + +// validate with ptr works +type Foo2 interface { + *fs.PathError | os.FileMode +} + +// validate without ptr works +type Foo3 interface { + fs.PathError | os.FileMode +} + +type Local struct{} + +// validate works with local extern, how deep can we go? +type Foo4 interface { + Local | os.File +} + +// validate works with extern extern tilde pointer +type Foo5 interface { + ~*extern2.SomeType | os.File +} diff --git a/pkg/moq/testpackages/generics_imported_constraint/extern2/extern2_generic.go b/pkg/moq/testpackages/generics_imported_constraint/extern2/extern2_generic.go new file mode 100644 index 0000000..c9764e7 --- /dev/null +++ b/pkg/moq/testpackages/generics_imported_constraint/extern2/extern2_generic.go @@ -0,0 +1,3 @@ +package extern2 + +type SomeType string diff --git a/pkg/moq/testpackages/generics_imported_constraint/generics.go b/pkg/moq/testpackages/generics_imported_constraint/generics.go new file mode 100644 index 0000000..501768b --- /dev/null +++ b/pkg/moq/testpackages/generics_imported_constraint/generics.go @@ -0,0 +1,17 @@ +package generics_imported_constraint + +import ( + "context" + + "github.com/matryer/moq/pkg/moq/testpackages/generics_imported_constraint/extern" +) + +//go:generate moq -out generics_moq_test.go -pkg generics_moq_test . GenericStore1 + +type GenericStore1[T extern.Foo1, J extern.Foo2, L extern.Foo3, F extern.Foo4, E extern.Foo5] interface { + Tet(ctx context.Context, handler T) error + Jet(ctx context.Context, handler J) error + Let(ctx context.Context, handler L) error + Fet(ctx context.Context, handler F) error + Eet(ctx context.Context, handler E) error +} diff --git a/pkg/moq/testpackages/generics_imported_constraint/generics_imported_constraint_moq.golden.go b/pkg/moq/testpackages/generics_imported_constraint/generics_imported_constraint_moq.golden.go new file mode 100644 index 0000000..0bb2046 --- /dev/null +++ b/pkg/moq/testpackages/generics_imported_constraint/generics_imported_constraint_moq.golden.go @@ -0,0 +1,284 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package generics_imported_constraint + +import ( + "context" + "github.com/matryer/moq/pkg/moq/testpackages/generics_imported_constraint/extern" + "github.com/matryer/moq/pkg/moq/testpackages/generics_imported_constraint/extern2" + "io/fs" + "sync" +) + +// Ensure, that GenericStore1Mock does implement GenericStore1. +// If this is not the case, regenerate this file with moq. +var _ GenericStore1[*extern2.SomeType, *fs.PathError, fs.PathError, extern.Local, *extern2.SomeType] = &GenericStore1Mock[*extern2.SomeType, *fs.PathError, fs.PathError, extern.Local, *extern2.SomeType]{} + +// GenericStore1Mock is a mock implementation of GenericStore1. +// +// func TestSomethingThatUsesGenericStore1(t *testing.T) { +// +// // make and configure a mocked GenericStore1 +// mockedGenericStore1 := &GenericStore1Mock{ +// EetFunc: func(ctx context.Context, handler E) error { +// panic("mock out the Eet method") +// }, +// FetFunc: func(ctx context.Context, handler F) error { +// panic("mock out the Fet method") +// }, +// JetFunc: func(ctx context.Context, handler J) error { +// panic("mock out the Jet method") +// }, +// LetFunc: func(ctx context.Context, handler L) error { +// panic("mock out the Let method") +// }, +// TetFunc: func(ctx context.Context, handler T) error { +// panic("mock out the Tet method") +// }, +// } +// +// // use mockedGenericStore1 in code that requires GenericStore1 +// // and then make assertions. +// +// } +type GenericStore1Mock[T extern.Foo1, J extern.Foo2, L extern.Foo3, F extern.Foo4, E extern.Foo5] struct { + // EetFunc mocks the Eet method. + EetFunc func(ctx context.Context, handler E) error + + // FetFunc mocks the Fet method. + FetFunc func(ctx context.Context, handler F) error + + // JetFunc mocks the Jet method. + JetFunc func(ctx context.Context, handler J) error + + // LetFunc mocks the Let method. + LetFunc func(ctx context.Context, handler L) error + + // TetFunc mocks the Tet method. + TetFunc func(ctx context.Context, handler T) error + + // calls tracks calls to the methods. + calls struct { + // Eet holds details about calls to the Eet method. + Eet []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Handler is the handler argument value. + Handler E + } + // Fet holds details about calls to the Fet method. + Fet []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Handler is the handler argument value. + Handler F + } + // Jet holds details about calls to the Jet method. + Jet []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Handler is the handler argument value. + Handler J + } + // Let holds details about calls to the Let method. + Let []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Handler is the handler argument value. + Handler L + } + // Tet holds details about calls to the Tet method. + Tet []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Handler is the handler argument value. + Handler T + } + } + lockEet sync.RWMutex + lockFet sync.RWMutex + lockJet sync.RWMutex + lockLet sync.RWMutex + lockTet sync.RWMutex +} + +// Eet calls EetFunc. +func (mock *GenericStore1Mock[T, J, L, F, E]) Eet(ctx context.Context, handler E) error { + if mock.EetFunc == nil { + panic("GenericStore1Mock.EetFunc: method is nil but GenericStore1.Eet was just called") + } + callInfo := struct { + Ctx context.Context + Handler E + }{ + Ctx: ctx, + Handler: handler, + } + mock.lockEet.Lock() + mock.calls.Eet = append(mock.calls.Eet, callInfo) + mock.lockEet.Unlock() + return mock.EetFunc(ctx, handler) +} + +// EetCalls gets all the calls that were made to Eet. +// Check the length with: +// +// len(mockedGenericStore1.EetCalls()) +func (mock *GenericStore1Mock[T, J, L, F, E]) EetCalls() []struct { + Ctx context.Context + Handler E +} { + var calls []struct { + Ctx context.Context + Handler E + } + mock.lockEet.RLock() + calls = mock.calls.Eet + mock.lockEet.RUnlock() + return calls +} + +// Fet calls FetFunc. +func (mock *GenericStore1Mock[T, J, L, F, E]) Fet(ctx context.Context, handler F) error { + if mock.FetFunc == nil { + panic("GenericStore1Mock.FetFunc: method is nil but GenericStore1.Fet was just called") + } + callInfo := struct { + Ctx context.Context + Handler F + }{ + Ctx: ctx, + Handler: handler, + } + mock.lockFet.Lock() + mock.calls.Fet = append(mock.calls.Fet, callInfo) + mock.lockFet.Unlock() + return mock.FetFunc(ctx, handler) +} + +// FetCalls gets all the calls that were made to Fet. +// Check the length with: +// +// len(mockedGenericStore1.FetCalls()) +func (mock *GenericStore1Mock[T, J, L, F, E]) FetCalls() []struct { + Ctx context.Context + Handler F +} { + var calls []struct { + Ctx context.Context + Handler F + } + mock.lockFet.RLock() + calls = mock.calls.Fet + mock.lockFet.RUnlock() + return calls +} + +// Jet calls JetFunc. +func (mock *GenericStore1Mock[T, J, L, F, E]) Jet(ctx context.Context, handler J) error { + if mock.JetFunc == nil { + panic("GenericStore1Mock.JetFunc: method is nil but GenericStore1.Jet was just called") + } + callInfo := struct { + Ctx context.Context + Handler J + }{ + Ctx: ctx, + Handler: handler, + } + mock.lockJet.Lock() + mock.calls.Jet = append(mock.calls.Jet, callInfo) + mock.lockJet.Unlock() + return mock.JetFunc(ctx, handler) +} + +// JetCalls gets all the calls that were made to Jet. +// Check the length with: +// +// len(mockedGenericStore1.JetCalls()) +func (mock *GenericStore1Mock[T, J, L, F, E]) JetCalls() []struct { + Ctx context.Context + Handler J +} { + var calls []struct { + Ctx context.Context + Handler J + } + mock.lockJet.RLock() + calls = mock.calls.Jet + mock.lockJet.RUnlock() + return calls +} + +// Let calls LetFunc. +func (mock *GenericStore1Mock[T, J, L, F, E]) Let(ctx context.Context, handler L) error { + if mock.LetFunc == nil { + panic("GenericStore1Mock.LetFunc: method is nil but GenericStore1.Let was just called") + } + callInfo := struct { + Ctx context.Context + Handler L + }{ + Ctx: ctx, + Handler: handler, + } + mock.lockLet.Lock() + mock.calls.Let = append(mock.calls.Let, callInfo) + mock.lockLet.Unlock() + return mock.LetFunc(ctx, handler) +} + +// LetCalls gets all the calls that were made to Let. +// Check the length with: +// +// len(mockedGenericStore1.LetCalls()) +func (mock *GenericStore1Mock[T, J, L, F, E]) LetCalls() []struct { + Ctx context.Context + Handler L +} { + var calls []struct { + Ctx context.Context + Handler L + } + mock.lockLet.RLock() + calls = mock.calls.Let + mock.lockLet.RUnlock() + return calls +} + +// Tet calls TetFunc. +func (mock *GenericStore1Mock[T, J, L, F, E]) Tet(ctx context.Context, handler T) error { + if mock.TetFunc == nil { + panic("GenericStore1Mock.TetFunc: method is nil but GenericStore1.Tet was just called") + } + callInfo := struct { + Ctx context.Context + Handler T + }{ + Ctx: ctx, + Handler: handler, + } + mock.lockTet.Lock() + mock.calls.Tet = append(mock.calls.Tet, callInfo) + mock.lockTet.Unlock() + return mock.TetFunc(ctx, handler) +} + +// TetCalls gets all the calls that were made to Tet. +// Check the length with: +// +// len(mockedGenericStore1.TetCalls()) +func (mock *GenericStore1Mock[T, J, L, F, E]) TetCalls() []struct { + Ctx context.Context + Handler T +} { + var calls []struct { + Ctx context.Context + Handler T + } + mock.lockTet.RLock() + calls = mock.calls.Tet + mock.lockTet.RUnlock() + return calls +} diff --git a/pkg/moq/testpackages/go.mod b/pkg/moq/testpackages/go.mod index b65ffdf..8ea56f8 100644 --- a/pkg/moq/testpackages/go.mod +++ b/pkg/moq/testpackages/go.mod @@ -1,5 +1,5 @@ module github.com/matryer/moq/pkg/moq/testpackages -go 1.18 +go 1.19 require github.com/sudo-suhas/moq-test-pkgs/somerepo v0.0.0-20200816045313-d2f573eea6c7