Skip to content

Commit

Permalink
Implemented safety backups
Browse files Browse the repository at this point in the history
  • Loading branch information
lflare committed Jan 8, 2022
1 parent 982b843 commit ccc7e0e
Showing 1 changed file with 44 additions and 5 deletions.
49 changes: 44 additions & 5 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"crypto/sha256"
"fmt"
"io"
"os"
Expand Down Expand Up @@ -55,11 +56,32 @@ func Rewrite(path string, info os.FileInfo, err error) error {
signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)

// Open file
f, err := os.OpenFile(path, os.O_RDWR, info.Mode().Perm())
file, err := os.OpenFile(path, os.O_RDWR, info.Mode().Perm())
if err != nil {
return err
}
defer f.Close()
defer file.Close()

// Open backup file
backupPath := fmt.Sprintf("%s.bak", path)
backupFile, err := os.Create(backupPath)
if err != nil {
return err
}
defer backupFile.Close()

// Copy to backup file
written, err := io.Copy(backupFile, file)
if err != nil {
return err
}
log.Infof("Backed up file '%s' of size %d\n", path, written)

// Calculate original hash
oldHash := sha256.New()
if _, err := io.Copy(oldHash, file); err != nil {
return err
}

// Prepare buffer
buf := make([]byte, 2)
Expand All @@ -75,7 +97,7 @@ func Rewrite(path string, info os.FileInfo, err error) error {
// Loop through whole file in steps of block size
for i := int64(0); i < info.Size()-2; i += BLOCKSIZE {
// Read two bytes at offset
_, err = f.ReadAt(buf, i)
_, err = file.ReadAt(buf, i)
if err != nil {
return fmt.Errorf("failed to read to buf: %v", err)
}
Expand All @@ -84,7 +106,7 @@ func Rewrite(path string, info os.FileInfo, err error) error {
buf[0], buf[1] = buf[1], buf[0]

// Write swapped bytes
err = writeSync(f, buf, i)
err = writeSync(file, buf, i)
if err != nil {
return err
}
Expand All @@ -100,7 +122,7 @@ func Rewrite(path string, info os.FileInfo, err error) error {
}

// Force filesystem sync
err = f.Sync()
err = file.Sync()
if err != nil {
return fmt.Errorf("failed to sync: %v", err)
}
Expand All @@ -112,6 +134,23 @@ func Rewrite(path string, info os.FileInfo, err error) error {
}
}

// Calculate new hash
newHash := sha256.New()
if _, err := io.Copy(newHash, file); err != nil {
return err
}

// If for some reason, hashes are not the same, restore backup
oldHashString := fmt.Sprintf("%x", oldHash.Sum(nil))
newHashString := fmt.Sprintf("%x", newHash.Sum(nil))
if oldHashString != newHashString {
io.Copy(file, backupFile)
os.Remove(backupPath)
return fmt.Errorf("unexpected hash of file '%s', '%s' != '%s', restoring backup", path, oldHashString, newHashString)
} else {
os.Remove(backupPath)
}

// Check if signal was raised
select {
case <-done:
Expand Down

0 comments on commit ccc7e0e

Please sign in to comment.