Skip to content

Commit

Permalink
BREAKING: When using get_file, with extract=True or untar=True, the r…
Browse files Browse the repository at this point in the history
…eturn value will be the path of the extracted directory, rather than the path of the archive.
  • Loading branch information
fchollet committed Aug 25, 2024
1 parent 829c9aa commit dcefb13
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 51 deletions.
86 changes: 52 additions & 34 deletions keras/src/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,18 @@ def get_file(
```
Args:
fname: Name of the file. If an absolute path, e.g. `"/path/to/file.txt"`
is specified, the file will be saved at that location.
fname: If the target is a single file, this is your desired
local name for the file.
If `None`, the name of the file at `origin` will be used.
If downloading and extracting a directory archive,
the provided `fname` will be used as extraction directory
name (only if it doesn't have an extension).
origin: Original URL of the file.
untar: Deprecated in favor of `extract` argument.
boolean, whether the file should be decompressed
Boolean, whether the file is a tar archive that should
be extracted.
md5_hash: Deprecated in favor of `file_hash` argument.
md5 hash of the file for verification
md5 hash of the file for file integrity verification.
file_hash: The expected hash string of the file after download.
The sha256 and md5 hash algorithms are both supported.
cache_subdir: Subdirectory under the Keras cache dir where the file is
Expand All @@ -179,7 +183,8 @@ def get_file(
hash_algorithm: Select the hash algorithm to verify the file.
options are `"md5'`, `"sha256'`, and `"auto'`.
The default 'auto' detects the hash algorithm in use.
extract: True tries extracting the file as an Archive, like tar or zip.
extract: If `True`, extracts the archive. Only applicable to compressed
archive files like tar or zip.
archive_format: Archive format to try for extracting the file.
Options are `"auto'`, `"tar'`, `"zip'`, and `None`.
`"tar"` includes tar, tar.gz, and tar.bz files.
Expand Down Expand Up @@ -219,36 +224,50 @@ def get_file(
datadir = os.path.join(datadir_base, cache_subdir)
os.makedirs(datadir, exist_ok=True)

provided_fname = fname
fname = path_to_string(fname)

if not fname:
fname = os.path.basename(urllib.parse.urlsplit(origin).path)
if not fname:
raise ValueError(
"Can't parse the file name from the origin provided: "
f"'{origin}'."
"Please specify the `fname` as the input param."
"Please specify the `fname` argument."
)
else:
if os.sep in fname:
raise ValueError(
"Paths are no longer accepted as the `fname` argument. "
"To specify the file's parent directory, use "
f"the `cache_dir` argument. Received: fname={fname}"
)

if untar:
if fname.endswith(".tar.gz"):
fname = pathlib.Path(fname)
# The 2 `.with_suffix()` are because of `.tar.gz` as pathlib
# considers it as 2 suffixes.
fname = fname.with_suffix("").with_suffix("")
fname = str(fname)
untar_fpath = os.path.join(datadir, fname)
fpath = untar_fpath + ".tar.gz"
if extract or untar:
if provided_fname:
if "." in fname:
download_target = os.path.join(datadir, fname)
fname = fname[: fname.find(".")]
extraction_dir = os.path.join(datadir, fname + "_extracted")
else:
extraction_dir = os.path.join(datadir, fname)
download_target = os.path.join(datadir, fname + "_archive")
else:
extraction_dir = os.path.join(datadir, fname)
download_target = os.path.join(datadir, fname + "_archive")
else:
fpath = os.path.join(datadir, fname)
download_target = os.path.join(datadir, fname)

if force_download:
download = True
elif os.path.exists(fpath):
elif os.path.exists(download_target):
# File found in cache.
download = False
# Verify integrity if a hash was provided.
if file_hash is not None:
if not validate_file(fpath, file_hash, algorithm=hash_algorithm):
if not validate_file(
download_target, file_hash, algorithm=hash_algorithm
):
io_utils.print_msg(
"A local file was found, but it seems to be "
f"incomplete or outdated because the {hash_algorithm} "
Expand Down Expand Up @@ -288,43 +307,42 @@ def __call__(self, block_num, block_size, total_size):
error_msg = "URL fetch failure on {}: {} -- {}"
try:
try:
urlretrieve(origin, fpath, DLProgbar())
urlretrieve(origin, download_target, DLProgbar())
except urllib.error.HTTPError as e:
raise Exception(error_msg.format(origin, e.code, e.msg))
except urllib.error.URLError as e:
raise Exception(error_msg.format(origin, e.errno, e.reason))
except (Exception, KeyboardInterrupt):
if os.path.exists(fpath):
os.remove(fpath)
if os.path.exists(download_target):
os.remove(download_target)
raise

# Validate download if succeeded and user provided an expected hash
# Security conscious users would get the hash of the file from a
# separate channel and pass it to this API to prevent MITM / corruption:
if os.path.exists(fpath) and file_hash is not None:
if not validate_file(fpath, file_hash, algorithm=hash_algorithm):
if os.path.exists(download_target) and file_hash is not None:
if not validate_file(
download_target, file_hash, algorithm=hash_algorithm
):
raise ValueError(
"Incomplete or corrupted file detected. "
f"The {hash_algorithm} "
"file hash does not match the provided value "
f"of {file_hash}."
)

if untar:
if not os.path.exists(untar_fpath):
status = extract_archive(fpath, datadir, archive_format="tar")
if not status:
warnings.warn("Could not extract archive.", stacklevel=2)
return untar_fpath
if extract or untar:
if untar:
archive_format = "tar"

if extract:
status = extract_archive(fpath, datadir, archive_format)
status = extract_archive(
download_target, extraction_dir, archive_format
)
if not status:
warnings.warn("Could not extract archive.", stacklevel=2)
return extraction_dir

# TODO: return extracted fpath if we extracted an archive,
# rather than the archive path.
return fpath
return download_target


def resolve_hasher(algorithm, file_hash=None):
Expand Down
31 changes: 14 additions & 17 deletions keras/src/utils/file_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def test_valid_tar_extraction(self):
"""Test valid tar.gz extraction and hash validation."""
dest_dir = self.get_temp_dir()
orig_dir = self.get_temp_dir()
text_file_path, tar_file_path = self._create_tar_file(orig_dir)
_, tar_file_path = self._create_tar_file(orig_dir)
self._test_file_extraction_and_validation(
dest_dir, tar_file_path, "tar.gz"
)
Expand All @@ -328,7 +328,7 @@ def test_valid_zip_extraction(self):
"""Test valid zip extraction and hash validation."""
dest_dir = self.get_temp_dir()
orig_dir = self.get_temp_dir()
text_file_path, zip_file_path = self._create_zip_file(orig_dir)
_, zip_file_path = self._create_zip_file(orig_dir)
self._test_file_extraction_and_validation(
dest_dir, zip_file_path, "zip"
)
Expand All @@ -348,7 +348,7 @@ def test_get_file_with_tgz_extension(self):
"""Test extraction of file with .tar.gz extension."""
dest_dir = self.get_temp_dir()
orig_dir = dest_dir
text_file_path, tar_file_path = self._create_tar_file(orig_dir)
_, tar_file_path = self._create_tar_file(orig_dir)

origin = urllib.parse.urljoin(
"file://",
Expand All @@ -358,8 +358,8 @@ def test_get_file_with_tgz_extension(self):
path = file_utils.get_file(
"test.txt.tar.gz", origin, untar=True, cache_subdir=dest_dir
)
self.assertTrue(path.endswith(".txt"))
self.assertTrue(os.path.exists(path))
self.assertTrue(os.path.exists(os.path.join(path, "test.txt")))

def test_get_file_with_integrity_check(self):
"""Test file download with integrity check."""
Expand Down Expand Up @@ -459,7 +459,7 @@ def _create_tar_file(self, directory):
text_file.write("Float like a butterfly, sting like a bee.")

with tarfile.open(tar_file_path, "w:gz") as tar_file:
tar_file.add(text_file_path)
tar_file.add(text_file_path, arcname="test.txt")

return text_file_path, tar_file_path

Expand All @@ -471,7 +471,7 @@ def _create_zip_file(self, directory):
text_file.write("Float like a butterfly, sting like a bee.")

with zipfile.ZipFile(zip_file_path, "w") as zip_file:
zip_file.write(text_file_path)
zip_file.write(text_file_path, arcname="test.txt")

return text_file_path, zip_file_path

Expand All @@ -484,7 +484,6 @@ def _test_file_extraction_and_validation(
urllib.request.pathname2url(os.path.abspath(file_path)),
)

hashval_sha256 = file_utils.hash_file(file_path)
hashval_md5 = file_utils.hash_file(file_path, algorithm="md5")

if archive_type:
Expand All @@ -499,17 +498,15 @@ def _test_file_extraction_and_validation(
extract=extract,
cache_subdir=dest_dir,
)
path = file_utils.get_file(
"test",
origin,
file_hash=hashval_sha256,
extract=extract,
cache_subdir=dest_dir,
)
if extract:
fpath = path + "_archive"
else:
fpath = path

self.assertTrue(os.path.exists(path))
self.assertTrue(file_utils.validate_file(path, hashval_sha256))
self.assertTrue(file_utils.validate_file(path, hashval_md5))
os.remove(path)
self.assertTrue(file_utils.validate_file(fpath, hashval_md5))
if extract:
self.assertTrue(os.path.exists(os.path.join(path, "test.txt")))

def test_exists(self):
temp_dir = self.get_temp_dir()
Expand Down

0 comments on commit dcefb13

Please sign in to comment.