From 18d3f5618645fbebec703867364c3f2ac1102f1d Mon Sep 17 00:00:00 2001 From: crazybolillo Date: Fri, 8 Mar 2024 12:44:38 -0300 Subject: [PATCH] modfile: fix crash on AddGoStmt in empty File AddGoStmt uses File.Syntax without checking whether it is nil or not. This causes crashes when using it on empty files that have not had their Syntax member initialized to a valid pointer. This change fixes it by ensuring File.Syntax is a valid pointer before proceeding. Fixes golang/go#62457. Change-Id: Iab02039f79e73d939ca5d3e48b29faa5e0a9a5ec Reviewed-on: https://go-review.googlesource.com/c/mod/+/570115 Reviewed-by: Michael Knyszek Auto-Submit: Bryan Mills Reviewed-by: Bryan Mills LUCI-TryBot-Result: Go LUCI --- modfile/rule.go | 2 ++ modfile/rule_test.go | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/modfile/rule.go b/modfile/rule.go index 26acaa5..0e7b7e2 100644 --- a/modfile/rule.go +++ b/modfile/rule.go @@ -975,6 +975,8 @@ func (f *File) AddGoStmt(version string) error { var hint Expr if f.Module != nil && f.Module.Syntax != nil { hint = f.Module.Syntax + } else if f.Syntax == nil { + f.Syntax = new(FileSyntax) } f.Go = &Go{ Version: version, diff --git a/modfile/rule_test.go b/modfile/rule_test.go index 57c8be6..96e0bfe 100644 --- a/modfile/rule_test.go +++ b/modfile/rule_test.go @@ -1549,6 +1549,20 @@ var fixVersionTests = []struct { }, } +var modifyEmptyFilesTests = []struct { + desc string + operations func(f *File) + want string +}{ + { + desc: `addGoStmt`, + operations: func(f *File) { + f.AddGoStmt("1.20") + }, + want: `go 1.20`, + }, +} + func fixV(path, version string) (string, error) { if path != "example.com/m" { return "", fmt.Errorf("module path must be example.com/m") @@ -1846,3 +1860,29 @@ func TestFixVersion(t *testing.T) { }) } } + +func TestAddOnEmptyFile(t *testing.T) { + for _, tt := range modifyEmptyFilesTests { + t.Run(tt.desc, func(t *testing.T) { + f := &File{} + tt.operations(f) + + expect, err := Parse("out", []byte(tt.want), nil) + if err != nil { + t.Fatal(err) + } + golden, err := expect.Format() + if err != nil { + t.Fatal(err) + } + got, err := f.Format() + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(got, golden) { + t.Fatalf("got:\n%s\nwant:\n%s", got, golden) + } + }) + } +}