Skip to content

Commit

Permalink
Merge pull request #6 from samsarahq/bf/persistent-invalidation-cache
Browse files Browse the repository at this point in the history
Add Persistent Task Invalidation
  • Loading branch information
berfarah authored Nov 6, 2018
2 parents 9efecd0 + 58dd6f1 commit 7fe220e
Show file tree
Hide file tree
Showing 304 changed files with 180,217 additions and 26 deletions.
104 changes: 104 additions & 0 deletions cache/cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package cache

import (
"context"
"log"
"os"
"path"

"github.com/samsarahq/taskrunner"
"github.com/samsarahq/taskrunner/shell"
)

const CachePath = ".cache/taskrunner"

var CacheDir = path.Join(os.Getenv("HOME"), CachePath)

type Cache struct {
ranOnce map[*taskrunner.Task]bool
snapshotter *snapshotter
cacheFile string
allDirty bool
dirtyFiles []string
}

func New() *Cache {
return &Cache{
ranOnce: make(map[*taskrunner.Task]bool),
cacheFile: path.Join(CacheDir, "example.json"),
}
}

func (c *Cache) Start(ctx context.Context, opt shell.RunOption) error {
c.snapshotter = newSnapshotter(
func(ctx context.Context, command string, opts ...shell.RunOption) error {
return shell.Run(ctx, command, append(opts, opt)...)
},
)

s, err := c.snapshotter.Read(c.cacheFile)
// If we can't get a cache file, assume that everything is dirty and needs to be re-run.
if err != nil {
c.allDirty = true
s = &snapshot{}
}

// Truncate the snapshot after we read it in order to prevent a stale cache, should taskrunner
// be terminated unexpectedly.
_ = os.Truncate(c.cacheFile, 0)

files, err := c.snapshotter.Diff(ctx, s)
if err != nil {
return err
}
c.dirtyFiles = files

return nil
}

// Finish creates and saves the cache state.
func (c *Cache) Finish(ctx context.Context) error {
if err := os.MkdirAll(CacheDir, os.ModePerm); err != nil {
return err
}
return c.snapshotter.Write(ctx, c.cacheFile)
}

func (c *Cache) isFirstRun(task *taskrunner.Task) bool {
ran := c.ranOnce[task]
c.ranOnce[task] = true
return !ran
}

func (c *Cache) isValid(task *taskrunner.Task) bool {
if c.allDirty {
return false
}
for _, f := range c.dirtyFiles {
if taskrunner.IsTaskSource(task, f) {
return false
}
}
return true
}

func (c *Cache) maybeRun(task *taskrunner.Task) func(context.Context, shell.ShellRun) error {
return func(ctx context.Context, shellRun shell.ShellRun) error {
if c.isFirstRun(task) && c.isValid(task) {
// report that the task wasn't run
return shellRun(ctx, `echo "no changes (cache)"`)
}
return task.Run(ctx, shellRun)
}
}

// WrapWithPersistentCache prevents the task from being invalidated between runs if the files it
// depends on don't change.
func (c *Cache) WrapWithPersistentCache(task *taskrunner.Task) *taskrunner.Task {
if len(task.Sources) == 0 {
log.Fatalf("Task %s cannot be wrapped with a persistent cache as it has no sources", task.Name)
}
newTask := *task
newTask.Run = c.maybeRun(task)
return &newTask
}
57 changes: 57 additions & 0 deletions cache/git.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package cache

import (
"bytes"
"context"
"fmt"
"strings"

"github.com/samsarahq/taskrunner/shell"
)

type gitClient struct {
shellRun shell.ShellRun
}

func stripStdout(buf bytes.Buffer) string {
return strings.Trim(buf.String(), "\n")
}

func splitStdout(buf bytes.Buffer) []string {
return strings.Split(stripStdout(buf), "\n")
}

func (g gitClient) currentCommit(ctx context.Context) (commitHash string, err error) {
var buffer bytes.Buffer
if err := g.shellRun(ctx, "git rev-parse HEAD", shell.Stdout(&buffer)); err != nil {
return "", err
}

return stripStdout(buffer), nil
}

func (g gitClient) diff(ctx context.Context, commitHash string) (modifiedFiles []string, error error) {
var buffer bytes.Buffer
if err := g.shellRun(ctx, fmt.Sprintf("git diff --name-only %s", commitHash), shell.Stdout(&buffer)); err != nil {
return nil, err
}

return splitStdout(buffer), nil
}

func (g gitClient) uncomittedFiles(ctx context.Context) (newFiles []string, modifiedFiles []string, err error) {
var buffer bytes.Buffer
if err := g.shellRun(ctx, "git status --porcelain", shell.Stdout(&buffer)); err != nil {
return nil, nil, err
}

for _, statusLine := range splitStdout(buffer) {
if strings.HasPrefix(statusLine, "??") {
newFiles = append(newFiles, statusLine[3:])
} else {
modifiedFiles = append(modifiedFiles, statusLine[3:])
}
}

return newFiles, modifiedFiles, nil
}
35 changes: 35 additions & 0 deletions cache/md5sum.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package cache

import (
"crypto/md5"
"fmt"
"io"
"os"
"path/filepath"
)

// hashSum takes in a file or directory.
// For files, it hashes based on filename + modified time.
// For directories, it does a hashsum of all files.
func hashSum(path string) (string, error) {
hash := md5.New()

err := filepath.Walk(path, func(path string, info os.FileInfo, err error) error {
if err != nil {
return nil
}
s, err := os.Stat(path)
if err != nil {
return nil
}
hashableContent := fmt.Sprintf("%s:%v", path, s.ModTime())
io.WriteString(hash, hashableContent)
return nil
})

if err != nil {
return "", nil
}

return fmt.Sprintf("%x", hash.Sum(nil)), nil
}
17 changes: 17 additions & 0 deletions cache/md5sum_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package cache

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestHashSum(t *testing.T) {
hash, err := hashSum("testdata/md5test/md5test.txt")
assert.NoError(t, err)
assert.Len(t, hash, 32)

hash, err = hashSum("testdata/md5test")
assert.NoError(t, err)
assert.Len(t, hash, 32)
}
182 changes: 182 additions & 0 deletions cache/snapshot.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
package cache

import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"os"

"github.com/samsarahq/taskrunner/shell"
)

// snapshotter tracks the state of a git repository via snapshots.
type snapshotter struct {
// CommitFunc gets the current commit SHA.
CommitFunc func(context.Context) (sha string, err error)
// DiffFunc gets modified files against a previous commit SHA.
DiffFunc func(context.Context, string) (diffFiles []string, err error)
// UncommittedFilesFunc gets a list of all uncommitted files (new and modified).
UncommittedFilesFunc func(context.Context) (newFiles []string, modifiedFiles []string, err error)
// HashFunc gets an MD5 hash of the file or directory.
HashFunc func(string) (hash string, err error)
}

// snapshot is the state of a git repository in time. It records the current commit as well as MD5
// sums of uncommited files or directories (hashed by name + timestamp).
// When comparing against a previous snapshot, we can therefore run a git diff against the old sha
// and compare uncommitted files manually.
type snapshot struct {
// Commit SHA at HEAD.
CommitSha string `json:"commitSha"`
// Uncommitted files at the time of snapshotting.
UncommittedFiles []uncommittedFile `json:"uncommittedFiles"`
// A map representation of UncomittedFiles for quick lookup: map[filename]md5hash.
uncommittedFilesMap map[string]string
}

func newSnapshotter(shellRun shell.ShellRun) *snapshotter {
client := gitClient{shellRun: shellRun}

return &snapshotter{
DiffFunc: client.diff,
CommitFunc: client.currentCommit,
UncommittedFilesFunc: client.uncomittedFiles,
HashFunc: hashSum,
}
}

// diff compares against another snapshot with the current state.
func (c *snapshotter) Diff(ctx context.Context, previous *snapshot) ([]string, error) {
current, err := c.snapshot(ctx, false)
if err != nil {
return nil, err
}

gitDiffFiles, err := c.DiffFunc(ctx, previous.CommitSha)
if err != nil {
return nil, err
}

// Track files that have changed since the last snapshot.
var modifiedFiles []string

// Only mark committed files as modified (uncommitted files are handled below).
for _, f := range gitDiffFiles {
if previous.hashFor(f) == "" {
modifiedFiles = append(modifiedFiles, f)
}
}

// Because we diff against the last commit, any files that were not committed the
// last time a snapshot was recorded needs to compare hashes instead.
for _, file := range previous.UncommittedFiles {
// If the file isn't currently uncommitted, rehash.
md5 := current.hashFor(file.Path)
if md5 == "" {
md5, err = c.HashFunc(file.Path)
if err != nil {
fmt.Fprintf(os.Stderr, "Warning: unable to hash file %s (error: %v)\n", file.Path, err)
}
}
if md5 == "" || md5 != file.MD5 {
modifiedFiles = append(gitDiffFiles, file.Path)
}
}

for _, file := range current.UncommittedFiles {
// Any uncommitted file that either has a different MD5 hash or wasn't recorded in the
// previous snapshot counts as different.
if sha := previous.hashFor(file.Path); sha == "" || sha != file.MD5 {
modifiedFiles = append(gitDiffFiles, file.Path)
}
}

return modifiedFiles, err
}

// snapshot takes a snapshot of the current state.
// withChanged dictates whether we should include modified files in the snapshot (vs just new files).
func (c *snapshotter) snapshot(ctx context.Context, withChanged bool) (*snapshot, error) {
var err error
s := snapshot{}

s.CommitSha, err = c.CommitFunc(ctx)
if err != nil {
return nil, err
}

newFiles, modifiedFiles, err := c.UncommittedFilesFunc(ctx)
if err != nil {
return nil, err
}

// We only want modified files for recording snapshots, not for diffing against an older
// snapshot (since we can rely on git diff to handle those).
if withChanged {
newFiles = append(newFiles, modifiedFiles...)
}

for _, file := range newFiles {
hash, err := c.HashFunc(file)
if err != nil {
return nil, err
}
s.UncommittedFiles = append(s.UncommittedFiles, uncommittedFile{
Path: file,
MD5: hash,
})
}

s.loadMap()

return &s, nil
}

// write records the current state to a file.
func (c *snapshotter) Write(ctx context.Context, cacheFilePath string) error {
// Create a snapshot including modified files.
snapshot, err := c.snapshot(ctx, true)
if err != nil {
return err
}

b, err := json.Marshal(snapshot)
if err != nil {
return err
}

return ioutil.WriteFile(cacheFilePath, b, 0644)
}

// read gets a previous state from a file.
func (c *snapshotter) Read(path string) (*snapshot, error) {
b, err := ioutil.ReadFile(path)
if err != nil {
return nil, err
}

var snapshot snapshot
if err := json.Unmarshal(b, &snapshot); err != nil {
return nil, err
}
snapshot.loadMap()

return &snapshot, nil
}

func (s *snapshot) hashFor(path string) string {
return s.uncommittedFilesMap[path]
}

func (s *snapshot) loadMap() {
s.uncommittedFilesMap = map[string]string{}
for _, f := range s.UncommittedFiles {
s.uncommittedFilesMap[f.Path] = f.MD5
}
}

type uncommittedFile struct {
Path string `json:"path"`
MD5 string `json:"md5"`
}
Loading

0 comments on commit 7fe220e

Please sign in to comment.