Skip to content

Commit

Permalink
Allow decoding b64 string without header in processing utils (#4031)
Browse files Browse the repository at this point in the history
* allow decoding b64 string without headers

* install gradio-client in edittable mode

* update GH actions

* add test for decoding headerless b64

* add test for decoding headerless b64 image

* some linting

* fix test

---------

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
  • Loading branch information
1lint and abidlabs authored May 1, 2023
1 parent f1ea4f7 commit f97b18e
Show file tree
Hide file tree
Showing 12 changed files with 33 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/backend.yml
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ jobs:
- name: Install Gradio
shell: bash
run: |
bash scripts/install_gradio.sh
pip install -e .
python -m pip install --upgrade pip
- name: Install 3.9 Test Dependencies
shell: bash
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ui.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ jobs:
uses: actions/setup-python@v3
with:
python-version: '3.x'
- run: bash scripts/install_gradio.sh
- run: pip install -e .
- run: pip install -r demo/outbreak_forecast/requirements.txt
- run: pnpm install --frozen-lockfile
- run: pnpm exec playwright install chromium
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3049,6 +3049,7 @@ We've introduced a lot of new components in `3.0`, including `Model3D`, `Dataset
- Mobile responsive guides by [@aliabd](https://github.com/aliabd) in [PR 1293](https://github.com/gradio-app/gradio/pull/1293)
- Update readme by [@abidlabs](https://github.com/abidlabs) in [PR 1292](https://github.com/gradio-app/gradio/pull/1292)
- gif by [@abidlabs](https://github.com/abidlabs) in [PR 1296](https://github.com/gradio-app/gradio/pull/1296)
- Allow decoding headerless b64 string [@1lint](https://github.com/1lint) in [PR 4031](https://github.com/gradio-app/gradio/pull/4031)

## Contributors Shoutout:

Expand Down
4 changes: 2 additions & 2 deletions client/python/gradio_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,12 +451,12 @@ def _render_endpoints_info(
name_or_index: str | int,
endpoints_info: Dict[str, List[Dict[str, str]]],
) -> str:
parameter_names = list(p["label"] for p in endpoints_info["parameters"])
parameter_names = [p["label"] for p in endpoints_info["parameters"]]
parameter_names = [utils.sanitize_parameter_names(p) for p in parameter_names]
rendered_parameters = ", ".join(parameter_names)
if rendered_parameters:
rendered_parameters = rendered_parameters + ", "
return_values = list(p["label"] for p in endpoints_info["returns"])
return_values = [p["label"] for p in endpoints_info["returns"]]
return_values = [utils.sanitize_parameter_names(r) for r in return_values]
rendered_return_values = ", ".join(return_values)
if len(return_values) > 1:
Expand Down
2 changes: 1 addition & 1 deletion client/python/gradio_client/documentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def extract_instance_attr_doc(cls, attr):
code = inspect.getsource(cls.__init__)
lines = [line.strip() for line in code.split("\n")]
i = None
for i, line in enumerate(lines):
for i, line in enumerate(lines): # noqa: B007
if line.startswith("self." + attr + ":") or line.startswith(
"self." + attr + " ="
):
Expand Down
3 changes: 1 addition & 2 deletions client/python/gradio_client/serializing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
import json
import os
import uuid
from abc import ABC
from pathlib import Path
from typing import Any, Dict, List, Tuple

from gradio_client import media_data, utils
from gradio_client.data_classes import FileData


class Serializable(ABC):
class Serializable:
def api_info(self) -> Dict[str, List[str]]:
"""
The typing information for this component as a dictionary whose values are a list of 2 strings: [Python type, language-agnostic description].
Expand Down
5 changes: 1 addition & 4 deletions client/python/gradio_client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,10 +362,7 @@ def encode_url_or_file_to_base64(path: str | Path):

def decode_base64_to_binary(encoding: str) -> Tuple[bytes, str | None]:
extension = get_extension(encoding)
try:
data = encoding.split(",")[1]
except IndexError:
data = ""
data = encoding.rsplit(",", 1)[-1]
return base64.b64decode(data), extension


Expand Down
7 changes: 3 additions & 4 deletions client/python/test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def test_intermediate_outputs(self, count_generator_demo):
def test_break_in_loop_if_error(self, calculator_demo):
with connect(calculator_demo) as client:
job = client.submit("foo", "add", 4, fn_index=0)
output = [o for o in job]
output = list(job)
assert output == []

@pytest.mark.flaky
Expand Down Expand Up @@ -268,7 +268,7 @@ def test_upload_file_private_space(self):
f.write("Hello from private space!")

output = client.submit(f.name, api_name="/upload_btn").result()
open(output).read() == "Hello from private space!"
assert open(output).read() == "Hello from private space!"
upload.assert_called_once()

with patch.object(
Expand Down Expand Up @@ -300,8 +300,7 @@ def test_upload_file_upload_route_does_not_exist(self):
with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
f.write("Hello from private space!")

output = client.submit(1, "foo", f.name, fn_index=0).result()
open(output).read() == "Hello from private space!"
client.submit(1, "foo", f.name, fn_index=0).result()
serialize.assert_called_once_with(1, "foo", f.name)


Expand Down
8 changes: 8 additions & 0 deletions client/python/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ def test_decode_base64_to_binary():
binary = utils.decode_base64_to_binary(deepcopy(media_data.BASE64_IMAGE))
assert deepcopy(media_data.BINARY_IMAGE) == binary

b64_img_without_header = deepcopy(media_data.BASE64_IMAGE).split(",")[1]
binary_without_header, extension = utils.decode_base64_to_binary(
b64_img_without_header
)

assert binary[0] == binary_without_header
assert extension is None


def test_decode_base64_to_file():
temp_file = utils.decode_base64_to_file(deepcopy(media_data.BASE64_IMAGE))
Expand Down
7 changes: 3 additions & 4 deletions gradio/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ def to_binary(x: str | Dict) -> bytes:
base64str = client_utils.encode_url_or_file_to_base64(x["name"])
else:
base64str = x
return base64.b64decode(base64str.split(",")[1])
return base64.b64decode(extract_base64_data(base64str))


def extract_base64_data(x: str) -> str:
"""Just extracts the base64 data from a general base64 string."""
return x.split("base64,")[1]
return x.rsplit(",", 1)[-1]


#########################
Expand All @@ -48,8 +48,7 @@ def extract_base64_data(x: str) -> str:


def decode_base64_to_image(encoding: str) -> Image.Image:
content = encoding.split(";")[1]
image_encoded = content.split(",")[1]
image_encoded = extract_base64_data(encoding)
img = Image.open(BytesIO(base64.b64decode(image_encoded)))
exif = img.getexif()
# 274 is the code for image rotation and 1 means "correct orientation"
Expand Down
5 changes: 4 additions & 1 deletion scripts/install_gradio.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,7 @@ source scripts/helpers.sh
pip_required

echo "Installing Gradio..."
pip install -e .
pip install -e .

echo "Installing Gradio Client..."
pip install -e client/python
7 changes: 7 additions & 0 deletions test/test_processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ def test_decode_base64_to_image(self):
)
assert isinstance(output_image, Image.Image)

b64_img_without_header = deepcopy(media_data.BASE64_IMAGE).split(",")[1]
output_image_without_header = processing_utils.decode_base64_to_image(
b64_img_without_header
)

assert output_image == output_image_without_header

def test_encode_plot_to_base64(self):
plt.plot([1, 2, 3, 4])
output_base64 = processing_utils.encode_plot_to_base64(plt)
Expand Down

0 comments on commit f97b18e

Please sign in to comment.