Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for URLs for Video, Audio, and Image #2256

Merged
merged 7 commits into from
Sep 14, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 25 additions & 23 deletions gradio/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class DataframeData(TypedDict):
from ffmpy import FFmpeg
from markdown_it import MarkdownIt

from gradio import media_data, processing_utils
from gradio import media_data, processing_utils, utils
from gradio.blocks import Block
from gradio.documentation import document, set_documentation_group
from gradio.events import (
Expand All @@ -55,7 +55,6 @@ class DataframeData(TypedDict):
Serializable,
SimpleSerializable,
)
from gradio.utils import component_or_layout_class

set_documentation_group("component")

Expand Down Expand Up @@ -1349,27 +1348,20 @@ def preprocess(self, x: str | Dict) -> np.array | PIL.Image | str | None:
def postprocess(self, y: np.ndarray | PIL.Image | str | Path) -> str:
"""
Parameters:
y: image as a numpy array, PIL Image, string filepath, or Path filepath
y: image as a numpy array, PIL Image, string/Path filepath, or string URL
Returns:
base64 url data
"""
if y is None:
return None
if isinstance(y, np.ndarray):
dtype = "numpy"
return processing_utils.encode_array_to_base64(y)
elif isinstance(y, PIL.Image.Image):
dtype = "pil"
return processing_utils.encode_pil_to_base64(y)
elif isinstance(y, (str, Path)):
dtype = "file"
return processing_utils.encode_url_or_file_to_base64(y)
else:
raise ValueError("Cannot process this value as an Image")
if dtype == "pil":
out_y = processing_utils.encode_pil_to_base64(y)
elif dtype == "numpy":
out_y = processing_utils.encode_array_to_base64(y)
elif dtype == "file":
out_y = processing_utils.encode_url_or_file_to_base64(y)
return out_y

def set_interpret_parameters(self, segments: int = 16):
"""
Expand Down Expand Up @@ -1530,7 +1522,7 @@ class Video(Changeable, Clearable, Playable, IOComponent, FileSerializable):
combinations are .mp4 with h264 codec, .ogg with theora codec, and .webm with vp9 codec. If the component detects
that the output video would not be playable in the browser it will attempt to convert it to a playable mp4 video.
If the conversion fails, the original video is returned.
Preprocessing: passes the uploaded video as a {str} filepath whose extension can be set by `format`.
Preprocessing: passes the uploaded video as a {str} filepath or URL whose extension can be modified by `format`.
Postprocessing: expects a {str} filepath to a video which is displayed.
Examples-format: a {str} filepath to a local file that contains the video.
Demos: video_identity
Expand Down Expand Up @@ -1656,14 +1648,20 @@ def postprocess(self, y: str | None) -> Dict[str, str] | None:
Processes a video to ensure that it is in the correct format before
returning it to the front end.
Parameters:
y: a path to video file
y: a path or URL to the video file
Returns:
a dictionary with the following keys: 'name' (containing the file path
to a temporary copy of the video), 'data' (None), and 'is_file` (True).
"""
if y is None:
return None

is_temp_file = False

if utils.validate_url(y):
y = processing_utils.download_to_file(y, dir=self.temp_dir).name
is_temp_file = True

returned_format = y.split(".")[-1].lower()
if (
processing_utils.ffmpeg_installed()
Expand All @@ -1679,7 +1677,8 @@ def postprocess(self, y: str | None) -> Dict[str, str] | None:
ff.run()
y = output_file_name

y = processing_utils.create_tmp_copy_of_file(y, dir=self.temp_dir)
if not is_temp_file:
y = processing_utils.create_tmp_copy_of_file(y, dir=self.temp_dir)

return {"name": y.name, "data": None, "is_file": True}

Expand Down Expand Up @@ -1709,7 +1708,7 @@ class Audio(Changeable, Clearable, Playable, Streamable, IOComponent, FileSerial
"""
Creates an audio component that can be used to upload/record audio (as an input) or display audio (as an output).
Preprocessing: passes the uploaded audio as a {Tuple(int, numpy.array)} corresponding to (sample rate, data) or as a {str} filepath, depending on `type`
Postprocessing: expects a {Tuple(int, numpy.array)} corresponding to (sample rate, data) or as a {str} filepath to an audio file, which gets displayed
Postprocessing: expects a {Tuple(int, numpy.array)} corresponding to (sample rate, data) or as a {str} filepath or URL to an audio file, which gets displayed
Examples-format: a {str} filepath to a local file that contains audio.
Demos: main_note, generate_tone, reverse_audio
Guides: real_time_speech_recognition
Expand Down Expand Up @@ -1927,21 +1926,24 @@ def generate_sample(self):
def postprocess(self, y: Tuple[int, np.array] | str | None) -> str | None:
"""
Parameters:
y: audio data in either of the following formats: a tuple of (sample_rate, data), or a string of the path to an audio file, or None.
y: audio data in either of the following formats: a tuple of (sample_rate, data), or a string filepath or URL to an audio file, or None.
Returns:
base64 url data
"""
if y is None:
return None
if isinstance(y, tuple):

if utils.validate_url(y):
y = processing_utils.download_to_file(y, dir=self.temp_dir).name
elif isinstance(y, tuple):
sample_rate, data = y
file = tempfile.NamedTemporaryFile(
prefix="sample", suffix=".wav", delete=False
)
processing_utils.audio_to_file(sample_rate, data, file.name)
y = file.name

y = processing_utils.create_tmp_copy_of_file(y, dir=self.temp_dir)
else:
y = processing_utils.create_tmp_copy_of_file(y, dir=self.temp_dir)

return {"name": y.name, "data": None, "is_file": True}

Expand Down Expand Up @@ -3974,7 +3976,7 @@ def update(


def component(cls_name: str) -> Component:
obj = component_or_layout_class(cls_name)()
obj = utils.component_or_layout_class(cls_name)()
return obj


Expand All @@ -3986,7 +3988,7 @@ def get_component_instance(comp: str | dict | Component, render=True) -> Compone
return component_obj
elif isinstance(comp, dict):
name = comp.pop("name")
component_cls = component_or_layout_class(name)
component_cls = utils.component_or_layout_class(name)
component_obj = component_cls(**comp)
if not (render):
component_obj.unrender()
Expand Down
16 changes: 12 additions & 4 deletions gradio/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ffmpy import FFmpeg, FFprobe, FFRuntimeError
from PIL import Image, ImageOps, PngImagePlugin

from gradio import encryptor
from gradio import encryptor, utils

with warnings.catch_warnings():
warnings.simplefilter("ignore") # Ignore pydub warning if ffmpeg is not installed
Expand All @@ -31,10 +31,9 @@ def decode_base64_to_image(encoding):


def encode_url_or_file_to_base64(path, encryption_key=None):
try:
requests.get(path)
if utils.validate_url(path):
return encode_url_to_base64(path, encryption_key=encryption_key)
except (requests.exceptions.MissingSchema, requests.exceptions.InvalidSchema):
else:
return encode_file_to_base64(path, encryption_key=encryption_key)


Expand Down Expand Up @@ -90,6 +89,15 @@ def encode_plot_to_base64(plt):
return "data:image/png;base64," + base64_str


def download_to_file(url, dir=None):
file_suffix = os.path.splitext(url)[1]
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=file_suffix, dir=dir)
with requests.get(url, stream=True) as r:
with open(file_obj.name, "wb") as f:
shutil.copyfileobj(r.raw, f)
return file_obj


def save_array_to_file(image_array, dir=None):
pil_image = Image.fromarray(_convert(image_array, np.uint8, force_copy=False))
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
Expand Down
7 changes: 5 additions & 2 deletions gradio/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import orjson
import pkg_resources
import requests
from fastapi import Depends, FastAPI, HTTPException, Request, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
Expand All @@ -28,9 +29,9 @@
from starlette.websockets import WebSocket, WebSocketState

import gradio
from gradio import encryptor
from gradio import encryptor, utils
from gradio.exceptions import Error
from gradio.queue import Estimation, Event, Queue
from gradio.queue import Estimation, Event

mimetypes.init()

Expand Down Expand Up @@ -221,6 +222,8 @@ async def favicon():

@app.get("/file={path:path}", dependencies=[Depends(login_check)])
def file(path: str):
if utils.validate_url(path):
return RedirectResponse(url=path, status_code=status.HTTP_302_FOUND)
if (
app.blocks.encrypt
and isinstance(app.blocks.examples, str)
Expand Down
9 changes: 9 additions & 0 deletions gradio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,3 +670,12 @@ def append_unique_suffix(name: str, list_of_names: List[str]):
suffix_counter += 1
new_name = name + f"_{suffix_counter}"
return new_name


def validate_url(possible_url: str) -> bool:
try:
if requests.get(possible_url).status_code == 200:
return True
except Exception:
pass
return False