diff --git a/monai/apps/utils.py b/monai/apps/utils.py index 95c1450f2a..856bc64c9e 100644 --- a/monai/apps/utils.py +++ b/monai/apps/utils.py @@ -122,6 +122,38 @@ def update_to(self, b: int = 1, bsize: int = 1, tsize: int | None = None) -> Non raise e +def safe_extract_member(member, extract_to): + """Securely verify compressed package member paths to prevent path traversal attacks""" + # Get member path (handle different compression formats) + if hasattr(member, "filename"): + member_path = member.filename # zipfile + elif hasattr(member, "name"): + member_path = member.name # tarfile + else: + member_path = str(member) + + if hasattr(member, "issym") and member.issym(): + raise ValueError(f"Symbolic link detected in archive: {member_path}") + if hasattr(member, "islnk") and member.islnk(): + raise ValueError(f"Hard link detected in archive: {member_path}") + + member_path = os.path.normpath(member_path) + + if os.path.isabs(member_path) or ".." in member_path.split(os.sep): + raise ValueError(f"Unsafe path detected in archive: {member_path}") + + full_path = os.path.join(extract_to, member_path) + full_path = os.path.normpath(full_path) + + extract_root = os.path.realpath(extract_to) + target_real = os.path.realpath(full_path) + # Ensure the resolved path stays within the extraction root + if os.path.commonpath([extract_root, target_real]) != extract_root: + raise ValueError(f"Unsafe path: path traversal {member_path}") + + return full_path + + def check_hash(filepath: PathLike, val: str | None = None, hash_type: str = "md5") -> bool: """ Verify hash signature of specified file. @@ -242,6 +274,32 @@ def download_url( ) +def _extract_zip(filepath, output_dir): + with zipfile.ZipFile(filepath, "r") as zip_file: + for member in zip_file.infolist(): + safe_path = safe_extract_member(member, output_dir) + if member.is_dir(): + continue + os.makedirs(os.path.dirname(safe_path), exist_ok=True) + with zip_file.open(member) as source: + with open(safe_path, "wb") as target: + shutil.copyfileobj(source, target) + + +def _extract_tar(filepath, output_dir): + with tarfile.open(filepath, "r") as tar_file: + for member in tar_file.getmembers(): + safe_path = safe_extract_member(member, output_dir) + if not member.isfile(): + continue + os.makedirs(os.path.dirname(safe_path), exist_ok=True) + source = tar_file.extractfile(member) + if source is not None: + with source: + with open(safe_path, "wb") as target: + shutil.copyfileobj(source, target) + + def extractall( filepath: PathLike, output_dir: PathLike = ".", @@ -287,14 +345,10 @@ def extractall( logger.info(f"Writing into directory: {output_dir}.") _file_type = file_type.lower().strip() if filepath.name.endswith("zip") or _file_type == "zip": - zip_file = zipfile.ZipFile(filepath) - zip_file.extractall(output_dir) - zip_file.close() + _extract_zip(filepath, output_dir) return if filepath.name.endswith("tar") or filepath.name.endswith("tar.gz") or "tar" in _file_type: - tar_file = tarfile.open(filepath) - tar_file.extractall(output_dir) - tar_file.close() + _extract_tar(filepath, output_dir) return raise NotImplementedError( f'Unsupported file type, available options are: ["zip", "tar.gz", "tar"]. name={filepath} type={file_type}.' diff --git a/tests/apps/test_download_and_extract.py b/tests/apps/test_download_and_extract.py index 190e32fc79..6d16a72735 100644 --- a/tests/apps/test_download_and_extract.py +++ b/tests/apps/test_download_and_extract.py @@ -11,8 +11,10 @@ from __future__ import annotations +import tarfile import tempfile import unittest +import zipfile from pathlib import Path from urllib.error import ContentTooShortError, HTTPError @@ -66,5 +68,186 @@ def test_default(self, key, file_type): ) +class TestPathTraversalProtection(unittest.TestCase): + """Test cases for path traversal attack protection in extractall function.""" + + def test_valid_zip_extraction(self): + """Test that valid zip files extract successfully without raising exceptions.""" + with tempfile.TemporaryDirectory() as tmp_dir: + # Create a valid zip file + zip_path = Path(tmp_dir) / "valid_test.zip" + extract_dir = Path(tmp_dir) / "extract" + extract_dir.mkdir() + + # Create zip with normal file structure + with zipfile.ZipFile(zip_path, "w") as zf: + zf.writestr("normal_file.txt", "This is a normal file") + zf.writestr("subdir/nested_file.txt", "This is a nested file") + zf.writestr("another_file.json", '{"key": "value"}') + + # This should not raise any exception + try: + extractall(str(zip_path), str(extract_dir)) + + # Verify files were extracted correctly + self.assertTrue((extract_dir / "normal_file.txt").exists()) + self.assertTrue((extract_dir / "subdir" / "nested_file.txt").exists()) + self.assertTrue((extract_dir / "another_file.json").exists()) + + # Verify content + with open(extract_dir / "normal_file.txt") as f: + self.assertEqual(f.read(), "This is a normal file") + + except Exception as e: + self.fail(f"Valid zip extraction should not raise exception: {e}") + + def test_malicious_zip_path_traversal(self): + """Test that malicious zip files with path traversal attempts raise ValueError.""" + with tempfile.TemporaryDirectory() as tmp_dir: + # Create malicious zip file with path traversal + zip_path = Path(tmp_dir) / "malicious_test.zip" + extract_dir = Path(tmp_dir) / "extract" + extract_dir.mkdir() + + # Create zip with malicious paths + with zipfile.ZipFile(zip_path, "w") as zf: + # Try to write outside extraction directory + zf.writestr("../../../etc/malicious.txt", "malicious content") + zf.writestr("normal_file.txt", "normal content") + + # This should raise ValueError due to path traversal detection + with self.assertRaises(ValueError) as context: + extractall(str(zip_path), str(extract_dir)) + + self.assertIn("unsafe path", str(context.exception).lower()) + + def test_valid_tar_extraction(self): + """Test that valid tar files extract successfully without raising exceptions.""" + with tempfile.TemporaryDirectory() as tmp_dir: + # Create a valid tar file + tar_path = Path(tmp_dir) / "valid_test.tar.gz" + extract_dir = Path(tmp_dir) / "extract" + extract_dir.mkdir() + + # Create tar with normal file structure + with tarfile.open(tar_path, "w:gz") as tf: + # Create temporary files to add to tar + temp_file1 = Path(tmp_dir) / "temp1.txt" + temp_file1.write_text("This is a normal file") + tf.add(temp_file1, arcname="normal_file.txt") + + temp_file2 = Path(tmp_dir) / "temp2.txt" + temp_file2.write_text("This is a nested file") + tf.add(temp_file2, arcname="subdir/nested_file.txt") + + # This should not raise any exception + try: + extractall(str(tar_path), str(extract_dir)) + + # Verify files were extracted correctly + self.assertTrue((extract_dir / "normal_file.txt").exists()) + self.assertTrue((extract_dir / "subdir" / "nested_file.txt").exists()) + + # Verify content + with open(extract_dir / "normal_file.txt") as f: + self.assertEqual(f.read(), "This is a normal file") + + except Exception as e: + self.fail(f"Valid tar extraction should not raise exception: {e}") + + def test_malicious_tar_path_traversal(self): + """Test that malicious tar files with path traversal attempts raise ValueError.""" + with tempfile.TemporaryDirectory() as tmp_dir: + # Create malicious tar file with path traversal + tar_path = Path(tmp_dir) / "malicious_test.tar.gz" + extract_dir = Path(tmp_dir) / "extract" + extract_dir.mkdir() + + # Create tar with malicious paths + with tarfile.open(tar_path, "w:gz") as tf: + # Create a temporary file + temp_file = Path(tmp_dir) / "temp.txt" + temp_file.write_text("malicious content") + + # Add with malicious path (trying to write outside extraction directory) + tf.add(temp_file, arcname="../../../etc/malicious.txt") + + # This should raise ValueError due to path traversal detection + with self.assertRaises(ValueError) as context: + extractall(str(tar_path), str(extract_dir)) + + self.assertIn("unsafe path", str(context.exception).lower()) + + def test_absolute_path_protection(self): + """Test protection against absolute paths in archives.""" + with tempfile.TemporaryDirectory() as tmp_dir: + # Create zip with absolute path + zip_path = Path(tmp_dir) / "absolute_path_test.zip" + extract_dir = Path(tmp_dir) / "extract" + extract_dir.mkdir() + + with zipfile.ZipFile(zip_path, "w") as zf: + # Try to use absolute path + zf.writestr("/etc/passwd_bad", "malicious content") + + # This should raise ValueError due to absolute path detection + with self.assertRaises(ValueError) as context: + extractall(str(zip_path), str(extract_dir)) + + self.assertIn("unsafe path", str(context.exception).lower()) + + def test_malicious_symlink_protection(self): + """Test protection against malicious symlinks in tar archives.""" + with tempfile.TemporaryDirectory() as tmp_dir: + # Create malicious tar file with symlink + tar_path = Path(tmp_dir) / "malicious_symlink_test.tar.gz" + extract_dir = Path(tmp_dir) / "extract" + extract_dir.mkdir() + + # Create tar with malicious symlink + with tarfile.open(tar_path, "w:gz") as tf: + temp_file = Path(tmp_dir) / "normal.txt" + temp_file.write_text("normal content") + tf.add(temp_file, arcname="normal.txt") + + symlink_info = tarfile.TarInfo(name="malicious_symlink.txt") + symlink_info.type = tarfile.SYMTYPE + symlink_info.linkname = "../../../etc/passwd_bad" + symlink_info.size = 0 + tf.addfile(symlink_info) + + with self.assertRaises(ValueError) as context: + extractall(str(tar_path), str(extract_dir)) + + error_msg = str(context.exception).lower() + self.assertTrue("unsafe path" in error_msg or "symlink" in error_msg) + + def test_malicious_hardlink_protection(self): + """Test protection against malicious hard links in tar archives.""" + with tempfile.TemporaryDirectory() as tmp_dir: + # Create malicious tar file with hard link + tar_path = Path(tmp_dir) / "malicious_hardlink_test.tar.gz" + extract_dir = Path(tmp_dir) / "extract" + extract_dir.mkdir() + + # Create tar with malicious hard link + with tarfile.open(tar_path, "w:gz") as tf: + temp_file = Path(tmp_dir) / "normal.txt" + temp_file.write_text("normal content") + tf.add(temp_file, arcname="normal.txt") + + hardlink_info = tarfile.TarInfo(name="malicious_hardlink.txt") + hardlink_info.type = tarfile.LNKTYPE + hardlink_info.linkname = "/etc/passwd_bad" + hardlink_info.size = 0 + tf.addfile(hardlink_info) + + with self.assertRaises(ValueError) as context: + extractall(str(tar_path), str(extract_dir)) + + error_msg = str(context.exception).lower() + self.assertTrue("unsafe path" in error_msg or "hardlink" in error_msg) + + if __name__ == "__main__": unittest.main()