Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(zstdx): uncompress with custom size limit #70

Merged
merged 1 commit into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions zstdx/zstdx.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,20 @@ import (
)

// For protection from decompression bomb
const maxFileSize int64 = 16 * 1024 * 1024 * 1024
const defaultMaxFileSize int64 = 16 * 1024 * 1024 * 1024

// The code at https://go.dev/play/p/A2GXsDFWx9m is used as a reference
// Uncompress with a default max file size limit
func Uncompress(tarball, targetDir string) ([]string, error) {
return uncompress(tarball, targetDir, defaultMaxFileSize)
}

// Uncompress with a specified max file size limit
func UncompressWithCustomSizeLimit(tarball, targetDir string, maxFileSize int64) ([]string, error) {
return uncompress(tarball, targetDir, maxFileSize)
}

// The code at https://go.dev/play/p/A2GXsDFWx9m is used as a reference
func uncompress(tarball, targetDir string, maxFileSize int64) ([]string, error) {
file, err := os.Open(filepath.Clean(tarball))
if err != nil {
return nil, err
Expand All @@ -44,7 +54,7 @@ func Uncompress(tarball, targetDir string) ([]string, error) {
if err != nil {
return nil, err
}
return untar(reader, targetDir)
return untar(reader, targetDir, maxFileSize)
}

func min(a int64, b int64) int64 {
Expand All @@ -54,7 +64,7 @@ func min(a int64, b int64) int64 {
return b
}

func untar(reader io.Reader, targetDir string) ([]string, error) {
func untar(reader io.Reader, targetDir string, maxFileSize int64) ([]string, error) {
var extractedFiles []string
tarReader := tar.NewReader(reader)

Expand Down Expand Up @@ -92,15 +102,15 @@ func untar(reader io.Reader, targetDir string) ([]string, error) {
// if it's a file create it
case tar.TypeReg, tar.TypeGNUSparse:
// tar.Next() will externally only iterate files, so we might have to create intermediate directories here
if err := untarFile(tarReader, header, path); err != nil {
if err := untarFile(tarReader, header, path, maxFileSize); err != nil {
return nil, err
}
extractedFiles = append(extractedFiles, path)
}
}
}

func untarFile(tarReader *tar.Reader, header *tar.Header, path string) error {
func untarFile(tarReader *tar.Reader, header *tar.Header, path string, maxFileSize int64) error {
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
return err
}
Expand Down
19 changes: 19 additions & 0 deletions zstdx/zstdx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,25 @@ func TestUncompress(t *testing.T) {
}
}

func TestUncompressWithCustomSizeLimit(t *testing.T) {
// Create a temporary directory to store the output
tmpDir, err := os.MkdirTemp("", "uncompress-custom-")
defer os.RemoveAll(tmpDir)
require.NoError(t, err)

// Uncompress the tarball
files, err := zstdx.UncompressWithCustomSizeLimit("testdata/sample.tar.zst", tmpDir, 1024)
require.NoError(t, err)

// Check if the files are correctly uncompressed
assert.Equal(t, 2, len(files), "Expected two files to be uncompressed")
for _, f := range files {
info, err := os.Stat(f)
require.NoError(t, err)
assert.False(t, info.IsDir(), "Expected a file, not a directory")
}
}

func TestCompress(t *testing.T) {
// Create a temporary file to store the output
tmpFile, err := os.CreateTemp("", "compress.tar.zst")
Expand Down
Loading