diff --git a/modfile/rule.go b/modfile/rule.go index 26acaa5..0e7b7e2 100644 --- a/modfile/rule.go +++ b/modfile/rule.go @@ -975,6 +975,8 @@ func (f *File) AddGoStmt(version string) error { var hint Expr if f.Module != nil && f.Module.Syntax != nil { hint = f.Module.Syntax + } else if f.Syntax == nil { + f.Syntax = new(FileSyntax) } f.Go = &Go{ Version: version, diff --git a/modfile/rule_test.go b/modfile/rule_test.go index 57c8be6..96e0bfe 100644 --- a/modfile/rule_test.go +++ b/modfile/rule_test.go @@ -1549,6 +1549,20 @@ var fixVersionTests = []struct { }, } +var modifyEmptyFilesTests = []struct { + desc string + operations func(f *File) + want string +}{ + { + desc: `addGoStmt`, + operations: func(f *File) { + f.AddGoStmt("1.20") + }, + want: `go 1.20`, + }, +} + func fixV(path, version string) (string, error) { if path != "example.com/m" { return "", fmt.Errorf("module path must be example.com/m") @@ -1846,3 +1860,29 @@ func TestFixVersion(t *testing.T) { }) } } + +func TestAddOnEmptyFile(t *testing.T) { + for _, tt := range modifyEmptyFilesTests { + t.Run(tt.desc, func(t *testing.T) { + f := &File{} + tt.operations(f) + + expect, err := Parse("out", []byte(tt.want), nil) + if err != nil { + t.Fatal(err) + } + golden, err := expect.Format() + if err != nil { + t.Fatal(err) + } + got, err := f.Format() + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(got, golden) { + t.Fatalf("got:\n%s\nwant:\n%s", got, golden) + } + }) + } +}