diff --git a/pkg/template/mockery.templ b/pkg/mockery.templ similarity index 100% rename from pkg/template/mockery.templ rename to pkg/mockery.templ diff --git a/pkg/template/moq.templ b/pkg/moq.templ similarity index 100% rename from pkg/template/moq.templ rename to pkg/moq.templ diff --git a/pkg/template/template.go b/pkg/template/template.go index 7e3e5a12..db4ada63 100644 --- a/pkg/template/template.go +++ b/pkg/template/template.go @@ -12,7 +12,6 @@ import ( "github.com/huandu/xstrings" "github.com/vektra/mockery/v3/pkg/registry" - "github.com/vektra/mockery/v3/pkg/stackerr" ) // Template is the Moq template. It is capable of generating the Moq @@ -21,26 +20,9 @@ type Template struct { tmpl *template.Template } -var ( - //go:embed moq.templ - templateMoq string - //go:embed mockery.templ - templateMockery string -) - -var styleTemplates = map[string]string{ - "moq": templateMoq, - "mockery": templateMockery, -} - // New returns a new instance of Template. -func New(style string) (Template, error) { - templateString, styleExists := styleTemplates[style] - if !styleExists { - return Template{}, stackerr.NewStackErrf(nil, "style %s does not exist", style) - } - - tmpl, err := template.New(style).Funcs(templateFuncs).Parse(templateString) +func New(templateString string, name string) (Template, error) { + tmpl, err := template.New(name).Funcs(templateFuncs).Parse(templateString) if err != nil { return Template{}, err } diff --git a/pkg/template_generator.go b/pkg/template_generator.go index b3ab1ba8..84565554 100644 --- a/pkg/template_generator.go +++ b/pkg/template_generator.go @@ -29,6 +29,18 @@ const ( FORMAT_NOOP Formatter = "noop" ) +var ( + //go:embed moq.templ + templateMoq string + //go:embed mockery.templ + templateMockery string +) + +var styleTemplates = map[string]string{ + "moq": templateMoq, + "mockery": templateMockery, +} + // findPkgPath returns the fully-qualified go import path of a given dir. The // dir must be relative to a go.mod file. In the case it isn't, an error is returned. func findPkgPath(dirPath *pathlib.Path) (string, error) { @@ -297,7 +309,24 @@ func (g *TemplateGenerator) Generate( } data.Imports = g.registry.Imports() - templ, err := template.New(g.templateName) + var templateString string + if strings.HasPrefix(g.templateName, "file://") { + templatePath := pathlib.NewPath(strings.SplitAfterN(g.templateName, "file://", 2)[1]) + templateBytes, err := templatePath.ReadFile() + if err != nil { + log.Err(err).Str("template-path", g.templateName).Msg("Failed to read template") + return nil, err + } + templateString = string(templateBytes) + } else { + var styleExists bool + templateString, styleExists = styleTemplates[g.templateName] + if !styleExists { + return nil, stackerr.NewStackErrf(nil, "style %s does not exist", g.templateName) + } + } + + templ, err := template.New(templateString, g.templateName) if err != nil { return []byte{}, fmt.Errorf("creating new template: %w", err) }