Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache temp files created from base64 data #3197

Merged
merged 23 commits into from
Feb 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ No changes to highlight.

## Full Changelog:
* Fix demos page css and add close demos button by [@aliabd](https://github.com/aliabd) in [PR 3151](https://github.com/gradio-app/gradio/pull/3151)
* Caches temp files from base64 input data by giving them a deterministic path based on the contents of data by [@abidlabs](https://github.com/abidlabs) in [PR 3197](https://github.com/gradio-app/gradio/pull/3197)
* Better warnings (when there is a mismatch between the number of output components and values returned by a function, or when the `File` component or `UploadButton` component includes a `file_types` parameter along with `file_count=="dir"`) by [@abidlabs](https://github.com/abidlabs) in [PR 3194](https://github.com/gradio-app/gradio/pull/3194)
* Raises a `gr.Error` instead of a regular Python error when you use `gr.Interface.load()` to load a model and there's an error querying the HF API by [@abidlabs](https://github.com/abidlabs) in [PR 3194](https://github.com/gradio-app/gradio/pull/3194)

Expand Down
43 changes: 21 additions & 22 deletions gradio/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -1808,13 +1808,9 @@ def preprocess(self, x: Dict[str, str] | None) -> str | None:
x.get("is_file", False),
)
if is_file:
file = self.make_temp_copy_if_needed(file_name)
file_name = Path(file)
file_name = Path(self.make_temp_copy_if_needed(file_name))
else:
file = processing_utils.decode_base64_to_file(
file_data, file_path=file_name
)
file_name = Path(file.name)
file_name = Path(self.base64_to_temp_file_if_needed(file_data, file_name))

uploaded_format = file_name.suffix.replace(".", "")
modify_format = self.format is not None and uploaded_format != self.format
Expand Down Expand Up @@ -2041,20 +2037,26 @@ def preprocess(
else:
temp_file_path = self.make_temp_copy_if_needed(file_name)
else:
temp_file_obj = processing_utils.decode_base64_to_file(
file_data, file_path=file_name
)
temp_file_path = temp_file_obj.name
temp_file_path = self.base64_to_temp_file_if_needed(file_data, file_name)

sample_rate, data = processing_utils.audio_from_file(
temp_file_path, crop_min=crop_min, crop_max=crop_max
)

# Need a unique name for the file to avoid re-using the same audio file if
# a user submits the same audio file twice, but with different crop min/max.
temp_file_path = Path(temp_file_path)
output_file_name = str(
temp_file_path.with_name(
abidlabs marked this conversation as resolved.
Show resolved Hide resolved
f"{temp_file_path.stem}-{crop_min}-{crop_max}{temp_file_path.suffix}"
)
)

if self.type == "numpy":
return sample_rate, data
elif self.type == "filepath":
processing_utils.audio_to_file(sample_rate, data, temp_file_path)
return temp_file_path
processing_utils.audio_to_file(sample_rate, data, output_file_name)
return output_file_name
else:
raise ValueError(
"Unknown type: "
Expand All @@ -2075,8 +2077,8 @@ def tokenize(self, x):
if x.get("is_file"):
sample_rate, data = processing_utils.audio_from_file(x["name"])
else:
file_obj = processing_utils.decode_base64_to_file(x["data"])
sample_rate, data = processing_utils.audio_from_file(file_obj.name)
file_name = self.base64_to_temp_file_if_needed(x["data"])
sample_rate, data = processing_utils.audio_from_file(file_name)
leave_one_out_sets = []
tokens = []
masks = []
Expand Down Expand Up @@ -2117,14 +2119,14 @@ def tokenize(self, x):
def get_masked_inputs(self, tokens, binary_mask_matrix):
# create a "zero input" vector and get sample rate
x = tokens[0]["data"]
file_obj = processing_utils.decode_base64_to_file(x)
sample_rate, data = processing_utils.audio_from_file(file_obj.name)
file_name = self.base64_to_temp_file_if_needed(x)
sample_rate, data = processing_utils.audio_from_file(file_name)
zero_input = np.zeros_like(data, dtype="int16")
# decode all of the tokens
token_data = []
for token in tokens:
file_obj = processing_utils.decode_base64_to_file(token["data"])
_, data = processing_utils.audio_from_file(file_obj.name)
file_name = self.base64_to_temp_file_if_needed(token["data"])
_, data = processing_utils.audio_from_file(file_name)
token_data.append(data)
# construct the masked version
masked_inputs = []
Expand Down Expand Up @@ -4046,10 +4048,7 @@ def preprocess(self, x: Dict[str, str] | None) -> str | None:
if is_file:
temp_file_path = self.make_temp_copy_if_needed(file_name)
else:
temp_file = processing_utils.decode_base64_to_file(
file_data, file_path=file_name
)
temp_file_path = temp_file.name
temp_file_path = self.base64_to_temp_file_if_needed(file_data, file_name)

return temp_file_path

Expand Down
34 changes: 34 additions & 0 deletions gradio/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,13 @@ def hash_url(self, url: str, chunk_num_blocks: int = 128) -> str:
sha1.update(data)
return sha1.hexdigest()

def hash_base64(self, base64_encoding: str, chunk_num_blocks: int = 128) -> str:
sha1 = hashlib.sha1()
for i in range(0, len(base64_encoding), chunk_num_blocks * sha1.block_size):
data = base64_encoding[i : i + chunk_num_blocks * sha1.block_size]
sha1.update(data.encode("utf-8"))
return sha1.hexdigest()

def get_prefix_and_extension(self, file_path_or_url: str) -> Tuple[str, str]:
file_name = Path(file_path_or_url).name
prefix, extension = file_name, None
Expand All @@ -374,6 +381,12 @@ def get_temp_url_path(self, url: str) -> str:
file_hash = self.hash_url(url)
return prefix + file_hash + extension

def get_temp_base64_path(self, base64_encoding: str, prefix: str) -> str:
extension = get_extension(base64_encoding)
extension = "." + extension if extension else ""
base64_hash = self.hash_base64(base64_encoding)
return prefix + base64_hash + extension

def make_temp_copy_if_needed(self, file_path: str) -> str:
"""Returns a temporary file path for a copy of the given file path if it does
not already exist. Otherwise returns the path to the existing temp file."""
Expand Down Expand Up @@ -408,6 +421,27 @@ def download_temp_copy_if_needed(self, url: str) -> str:
self.temp_files.add(full_temp_file_path)
return full_temp_file_path

def base64_to_temp_file_if_needed(
self, base64_encoding: str, file_name: str | None = None
) -> str:
"""Converts a base64 encoding to a file and returns the path to the file if
the file doesn't already exist. Otherwise returns the path to the existing file."""
f = tempfile.NamedTemporaryFile(delete=False)
temp_dir = Path(f.name).parent
prefix = self.get_prefix_and_extension(file_name)[0] if file_name else ""

temp_file_path = self.get_temp_base64_path(base64_encoding, prefix=prefix)
f.name = str(temp_dir / temp_file_path)
full_temp_file_path = str(utils.abspath(f.name))

if not Path(full_temp_file_path).exists():
data, _ = decode_base64_to_binary(base64_encoding)
with open(full_temp_file_path, "wb") as fb:
fb.write(data)

self.temp_files.add(full_temp_file_path)
return full_temp_file_path


def download_tmp_copy_of_file(
url_path: str, access_token: str | None = None, dir: str | None = None
Expand Down
10 changes: 6 additions & 4 deletions test/test_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,9 +770,9 @@ def test_component_functions(self):
"""
x_wav = deepcopy(media_data.BASE64_AUDIO)
audio_input = gr.Audio()
output = audio_input.preprocess(x_wav)
assert output[0] == 8000
assert output[1].shape == (8046,)
output1 = audio_input.preprocess(x_wav)
assert output1[0] == 8000
assert output1[1].shape == (8046,)
assert filecmp.cmp(
"test/test_files/audio_sample.wav",
audio_input.serialize("test/test_files/audio_sample.wav")["name"],
Expand All @@ -796,7 +796,9 @@ def test_component_functions(self):
assert audio_input.preprocess(None) is None
x_wav["is_example"] = True
x_wav["crop_min"], x_wav["crop_max"] = 1, 4
assert audio_input.preprocess(x_wav) is not None
output2 = audio_input.preprocess(x_wav)
assert output2 is not None
assert output1 != output2

audio_input = gr.Audio(type="filepath")
assert isinstance(audio_input.preprocess(x_wav), str)
Expand Down
24 changes: 24 additions & 0 deletions test/test_processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,30 @@ def test_make_temp_copy_if_needed(self, mock_copy):
)
assert len(temp_file_manager.temp_files) == 2

def test_base64_to_temp_file_if_needed(self):
abidlabs marked this conversation as resolved.
Show resolved Hide resolved
temp_file_manager = processing_utils.TempFileManager()

base64_file_1 = media_data.BASE64_IMAGE
base64_file_2 = media_data.BASE64_AUDIO["data"]

f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_1)
try: # Delete if already exists from before this test
os.remove(f)
except OSError:
pass

f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_1)
assert len(temp_file_manager.temp_files) == 1

f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_1)
assert len(temp_file_manager.temp_files) == 1

f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_2)
assert len(temp_file_manager.temp_files) == 2

for file in temp_file_manager.temp_files:
os.remove(file)

@pytest.mark.flaky
@patch("shutil.copyfileobj")
def test_download_temp_copy_if_needed(self, mock_copy):
Expand Down