diff --git a/cache/disk/casblob/BUILD.bazel b/cache/disk/casblob/BUILD.bazel index eb6f88f3e..64df55194 100644 --- a/cache/disk/casblob/BUILD.bazel +++ b/cache/disk/casblob/BUILD.bazel @@ -11,4 +11,9 @@ go_library( go_test( name = "go_default_test", srcs = ["casblob_test.go"], + deps = [ + ":go_default_library", + "//cache/disk/zstdimpl:go_default_library", + "//utils:go_default_library", + ], ) diff --git a/cache/disk/casblob/casblob.go b/cache/disk/casblob/casblob.go index 12a533080..4d0d465a4 100644 --- a/cache/disk/casblob/casblob.go +++ b/cache/disk/casblob/casblob.go @@ -402,7 +402,7 @@ func GetLegacyZstdReadCloser(zstd zstdimpl.ZstdImpl, f *os.File) (io.ReadCloser, pr, pw := io.Pipe() - enc, err := zstd.GetEncoder(f) + enc, err := zstd.GetEncoder(pw) if err != nil { _ = f.Close() return nil, err @@ -425,7 +425,14 @@ func GetLegacyZstdReadCloser(zstd zstdimpl.ZstdImpl, f *os.File) (io.ReadCloser, log.Println("Error while closing encoder:", err) _ = pw.CloseWithError(err) } - _ = f.Close() + + err = f.Close() + if err != nil { + log.Println("Error while closing file:", err) + _ = pw.CloseWithError(err) + } + + _ = pw.Close() }() return pr, nil diff --git a/cache/disk/casblob/casblob_test.go b/cache/disk/casblob/casblob_test.go index c6560059e..599d7282d 100644 --- a/cache/disk/casblob/casblob_test.go +++ b/cache/disk/casblob/casblob_test.go @@ -1,8 +1,18 @@ package casblob_test import ( + "bytes" + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "os" "testing" "unsafe" + + "github.com/buchgr/bazel-remote/v2/cache/disk/casblob" + "github.com/buchgr/bazel-remote/v2/cache/disk/zstdimpl" + testutils "github.com/buchgr/bazel-remote/v2/utils" ) func TestLenSize(t *testing.T) { @@ -17,3 +27,55 @@ func TestLenSize(t *testing.T) { t.Errorf("This should silence linters that think slice is never used") } } + +func TestZstdFromLegacy(t *testing.T) { + size := 1024 + zstd, err := zstdimpl.Get("go") + if err != nil { + t.Fatal(err) + } + + data, hash := testutils.RandomDataAndHash(int64(size)) + dir := testutils.TempDir(t) + filename := fmt.Sprintf("%s/%s", dir, hash) + file, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0664) + if err != nil { + t.Fatal(err) + } + n, err := file.Write(data) + if err != nil { + t.Fatal(err) + } + if n != size { + t.Fatalf("Unexpected short write %d, expected %d", n, size) + } + file.Close() + + file, err = os.Open(filename) + if err != nil { + t.Fatal(err) + } + zrc, err := casblob.GetLegacyZstdReadCloser(zstd, file) + if err != nil { + t.Fatal(err) + } + rc, err := zstd.GetDecoder(zrc) + if err != nil { + t.Fatal(err) + } + buf := bytes.NewBuffer(nil) + _, err = io.Copy(buf, rc) + if err != nil { + t.Fatal(err) + } + + if buf.Len() != size { + t.Fatalf("Unexpected buf size %d, expected %d", buf.Len(), size) + } + + h := sha256.Sum256(data) + hs := hex.EncodeToString(h[:]) + if hs != hash { + t.Fatalf("Unexpected content sha %s, expected %s", hs, hash) + } +}