Skip to content

Commit

Permalink
Cache temp files created from base64 data (#3197)
Browse files Browse the repository at this point in the history
* changes

* added workflow

* fix action

* fix action

* fix action

* changelg

* formatting

* fix

* Delete benchmark-queue.yml

* Delete benchmark_queue.py

* changelog

* lint

* fix tests

* fix tests

* fix for python 3.7

* formatting
  • Loading branch information
abidlabs authored Feb 15, 2023
1 parent 74f75f0 commit 752ec0e
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 26 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ No changes to highlight.

## Full Changelog:
* Fix demos page css and add close demos button by [@aliabd](https://github.com/aliabd) in [PR 3151](https://github.com/gradio-app/gradio/pull/3151)
* Caches temp files from base64 input data by giving them a deterministic path based on the contents of data by [@abidlabs](https://github.com/abidlabs) in [PR 3197](https://github.com/gradio-app/gradio/pull/3197)
* Better warnings (when there is a mismatch between the number of output components and values returned by a function, or when the `File` component or `UploadButton` component includes a `file_types` parameter along with `file_count=="dir"`) by [@abidlabs](https://github.com/abidlabs) in [PR 3194](https://github.com/gradio-app/gradio/pull/3194)
* Raises a `gr.Error` instead of a regular Python error when you use `gr.Interface.load()` to load a model and there's an error querying the HF API by [@abidlabs](https://github.com/abidlabs) in [PR 3194](https://github.com/gradio-app/gradio/pull/3194)

Expand Down
43 changes: 21 additions & 22 deletions gradio/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -1808,13 +1808,9 @@ def preprocess(self, x: Dict[str, str] | None) -> str | None:
x.get("is_file", False),
)
if is_file:
file = self.make_temp_copy_if_needed(file_name)
file_name = Path(file)
file_name = Path(self.make_temp_copy_if_needed(file_name))
else:
file = processing_utils.decode_base64_to_file(
file_data, file_path=file_name
)
file_name = Path(file.name)
file_name = Path(self.base64_to_temp_file_if_needed(file_data, file_name))

uploaded_format = file_name.suffix.replace(".", "")
modify_format = self.format is not None and uploaded_format != self.format
Expand Down Expand Up @@ -2041,20 +2037,26 @@ def preprocess(
else:
temp_file_path = self.make_temp_copy_if_needed(file_name)
else:
temp_file_obj = processing_utils.decode_base64_to_file(
file_data, file_path=file_name
)
temp_file_path = temp_file_obj.name
temp_file_path = self.base64_to_temp_file_if_needed(file_data, file_name)

sample_rate, data = processing_utils.audio_from_file(
temp_file_path, crop_min=crop_min, crop_max=crop_max
)

# Need a unique name for the file to avoid re-using the same audio file if
# a user submits the same audio file twice, but with different crop min/max.
temp_file_path = Path(temp_file_path)
output_file_name = str(
temp_file_path.with_name(
f"{temp_file_path.stem}-{crop_min}-{crop_max}{temp_file_path.suffix}"
)
)

if self.type == "numpy":
return sample_rate, data
elif self.type == "filepath":
processing_utils.audio_to_file(sample_rate, data, temp_file_path)
return temp_file_path
processing_utils.audio_to_file(sample_rate, data, output_file_name)
return output_file_name
else:
raise ValueError(
"Unknown type: "
Expand All @@ -2075,8 +2077,8 @@ def tokenize(self, x):
if x.get("is_file"):
sample_rate, data = processing_utils.audio_from_file(x["name"])
else:
file_obj = processing_utils.decode_base64_to_file(x["data"])
sample_rate, data = processing_utils.audio_from_file(file_obj.name)
file_name = self.base64_to_temp_file_if_needed(x["data"])
sample_rate, data = processing_utils.audio_from_file(file_name)
leave_one_out_sets = []
tokens = []
masks = []
Expand Down Expand Up @@ -2117,14 +2119,14 @@ def tokenize(self, x):
def get_masked_inputs(self, tokens, binary_mask_matrix):
# create a "zero input" vector and get sample rate
x = tokens[0]["data"]
file_obj = processing_utils.decode_base64_to_file(x)
sample_rate, data = processing_utils.audio_from_file(file_obj.name)
file_name = self.base64_to_temp_file_if_needed(x)
sample_rate, data = processing_utils.audio_from_file(file_name)
zero_input = np.zeros_like(data, dtype="int16")
# decode all of the tokens
token_data = []
for token in tokens:
file_obj = processing_utils.decode_base64_to_file(token["data"])
_, data = processing_utils.audio_from_file(file_obj.name)
file_name = self.base64_to_temp_file_if_needed(token["data"])
_, data = processing_utils.audio_from_file(file_name)
token_data.append(data)
# construct the masked version
masked_inputs = []
Expand Down Expand Up @@ -4046,10 +4048,7 @@ def preprocess(self, x: Dict[str, str] | None) -> str | None:
if is_file:
temp_file_path = self.make_temp_copy_if_needed(file_name)
else:
temp_file = processing_utils.decode_base64_to_file(
file_data, file_path=file_name
)
temp_file_path = temp_file.name
temp_file_path = self.base64_to_temp_file_if_needed(file_data, file_name)

return temp_file_path

Expand Down
34 changes: 34 additions & 0 deletions gradio/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,13 @@ def hash_url(self, url: str, chunk_num_blocks: int = 128) -> str:
sha1.update(data)
return sha1.hexdigest()

def hash_base64(self, base64_encoding: str, chunk_num_blocks: int = 128) -> str:
sha1 = hashlib.sha1()
for i in range(0, len(base64_encoding), chunk_num_blocks * sha1.block_size):
data = base64_encoding[i : i + chunk_num_blocks * sha1.block_size]
sha1.update(data.encode("utf-8"))
return sha1.hexdigest()

def get_prefix_and_extension(self, file_path_or_url: str) -> Tuple[str, str]:
file_name = Path(file_path_or_url).name
prefix, extension = file_name, None
Expand All @@ -374,6 +381,12 @@ def get_temp_url_path(self, url: str) -> str:
file_hash = self.hash_url(url)
return prefix + file_hash + extension

def get_temp_base64_path(self, base64_encoding: str, prefix: str) -> str:
extension = get_extension(base64_encoding)
extension = "." + extension if extension else ""
base64_hash = self.hash_base64(base64_encoding)
return prefix + base64_hash + extension

def make_temp_copy_if_needed(self, file_path: str) -> str:
"""Returns a temporary file path for a copy of the given file path if it does
not already exist. Otherwise returns the path to the existing temp file."""
Expand Down Expand Up @@ -408,6 +421,27 @@ def download_temp_copy_if_needed(self, url: str) -> str:
self.temp_files.add(full_temp_file_path)
return full_temp_file_path

def base64_to_temp_file_if_needed(
self, base64_encoding: str, file_name: str | None = None
) -> str:
"""Converts a base64 encoding to a file and returns the path to the file if
the file doesn't already exist. Otherwise returns the path to the existing file."""
f = tempfile.NamedTemporaryFile(delete=False)
temp_dir = Path(f.name).parent
prefix = self.get_prefix_and_extension(file_name)[0] if file_name else ""

temp_file_path = self.get_temp_base64_path(base64_encoding, prefix=prefix)
f.name = str(temp_dir / temp_file_path)
full_temp_file_path = str(utils.abspath(f.name))

if not Path(full_temp_file_path).exists():
data, _ = decode_base64_to_binary(base64_encoding)
with open(full_temp_file_path, "wb") as fb:
fb.write(data)

self.temp_files.add(full_temp_file_path)
return full_temp_file_path


def download_tmp_copy_of_file(
url_path: str, access_token: str | None = None, dir: str | None = None
Expand Down
10 changes: 6 additions & 4 deletions test/test_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,9 +770,9 @@ def test_component_functions(self):
"""
x_wav = deepcopy(media_data.BASE64_AUDIO)
audio_input = gr.Audio()
output = audio_input.preprocess(x_wav)
assert output[0] == 8000
assert output[1].shape == (8046,)
output1 = audio_input.preprocess(x_wav)
assert output1[0] == 8000
assert output1[1].shape == (8046,)
assert filecmp.cmp(
"test/test_files/audio_sample.wav",
audio_input.serialize("test/test_files/audio_sample.wav")["name"],
Expand All @@ -796,7 +796,9 @@ def test_component_functions(self):
assert audio_input.preprocess(None) is None
x_wav["is_example"] = True
x_wav["crop_min"], x_wav["crop_max"] = 1, 4
assert audio_input.preprocess(x_wav) is not None
output2 = audio_input.preprocess(x_wav)
assert output2 is not None
assert output1 != output2

audio_input = gr.Audio(type="filepath")
assert isinstance(audio_input.preprocess(x_wav), str)
Expand Down
24 changes: 24 additions & 0 deletions test/test_processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,30 @@ def test_make_temp_copy_if_needed(self, mock_copy):
)
assert len(temp_file_manager.temp_files) == 2

def test_base64_to_temp_file_if_needed(self):
temp_file_manager = processing_utils.TempFileManager()

base64_file_1 = media_data.BASE64_IMAGE
base64_file_2 = media_data.BASE64_AUDIO["data"]

f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_1)
try: # Delete if already exists from before this test
os.remove(f)
except OSError:
pass

f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_1)
assert len(temp_file_manager.temp_files) == 1

f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_1)
assert len(temp_file_manager.temp_files) == 1

f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_2)
assert len(temp_file_manager.temp_files) == 2

for file in temp_file_manager.temp_files:
os.remove(file)

@pytest.mark.flaky
@patch("shutil.copyfileobj")
def test_download_temp_copy_if_needed(self, mock_copy):
Expand Down

0 comments on commit 752ec0e

Please sign in to comment.