diff --git a/client/python/README.md b/client/python/README.md new file mode 100644 index 0000000000000..f27d88f198d8c --- /dev/null +++ b/client/python/README.md @@ -0,0 +1,105 @@ +# `gradio_client`: Use any Gradio app as an API -- in 3 lines of Python + +This directory contains the source code for `gradio_client`, a lightweight Python library that makes it very easy to use any Gradio app as an API. Warning: This library is **currently in alpha, and APIs may change**. + +As an example, consider the Stable Diffusion Gradio app, which is hosted on Hugging Face Spaces, and which generates images given a text prompt. Using the `gradio_client` library, we can easily use the Gradio as an API to generates images programmatically. + +Here's the entire code to do it: + +```python +import gradio_client as grc + +client = grc.Client(space="stability-ai/stable-diffusion") +job = client.predict("a hyperrealistic portrait of a cat wearing cyberpunk armor") +job.result() + +>> https://stabilityai-stable-diffusion.hf.space/kjbcxadsk3ada9k/image.png # URL to generated image + +``` + +## Installation + +If you already have a recent version of `gradio`, then the `gradio_client` is included as a dependency. + +Otherwise, the lightweight `gradio_client` package can be installed from pip (or pip3) and works with Python versions 3.7 or higher: + +```bash +$ pip install gradio_client +``` + +## Usage + +### Connecting to a Space or a Gradio app + +Start by connecting instantiating a `Client` object and connecting it to a Gradio app +that is running on Spaces (or anywhere else)! + +**Connecting to a Space** + +```python +import gradio_client as grc + +client = grc.Client(space="abidlabs/en2fr") +``` + +**Connecting a general Gradio app** + +If your app is running somewhere else, provide the full URL instead to the `src` argument. Here's an example of making predictions to a Gradio app that is running on a share URL: + +```python +import gradio_client as grc + +client = grc.Client(src="btd372-js72hd.gradio.app") +``` + +### Making a prediction + +The simplest way to make a prediction is simply to call the `.predict()` function with the appropriate arguments and then immediately calling `.result()`, like this: + + +```python +import gradio_client as grc + +client = grc.Client(space="abidlabs/en2fr") + +client.predict("Hello").result() + +>> Bonjour +``` + +**Running jobs asyncronously** + +Oe should note that `.result()` is a *blocking* operation as it waits for the operation to complete before returning the prediction. + +In many cases, you may be better off letting the job run asynchronously and waiting to call `.result()` when you need the results of the prediction. For example: + + +```python +import gradio_client as grc + +client = grc.Client(space="abidlabs/en2fr") + +job = client.predict("Hello") + +# Do something else + +job.result() + +>> Bonjour +``` + +**Adding callbacks** + +Alternatively, one can add callbacks to perform actions after the job has completed running, like this: + + +```python +import gradio_client as grc + +def print_result(x): + print(x"The translated result is: {x}") + +client = grc.Client(space="abidlabs/en2fr") + +job = client.predict("Hello", callbacks=[print_result]) +``` diff --git a/client/python/gradio_client/__init__.py b/client/python/gradio_client/__init__.py new file mode 100644 index 0000000000000..e0e1c94e23ff8 --- /dev/null +++ b/client/python/gradio_client/__init__.py @@ -0,0 +1,2 @@ +from gradio_client.client import Client +from gradio_client.utils import __version__ diff --git a/client/python/gradio_client/client.py b/client/python/gradio_client/client.py new file mode 100644 index 0000000000000..0447f24952f9f --- /dev/null +++ b/client/python/gradio_client/client.py @@ -0,0 +1,292 @@ +"""The main Client class for the Python client.""" +from __future__ import annotations + +import concurrent.futures +import json +import re +import threading +import uuid +from concurrent.futures import Future +from typing import Any, Callable, Dict, List, Tuple + +import huggingface_hub +import requests +import websockets +from gradio_client import serializing, utils +from gradio_client.serializing import Serializable +from huggingface_hub.utils import build_hf_headers, send_telemetry +from packaging import version + + +class Client: + def __init__( + self, + space: str | None = None, + src: str | None = None, + hf_token: str | None = None, + max_workers: int = 40, + ): + self.hf_token = hf_token + self.headers = build_hf_headers( + token=hf_token, + library_name="gradio_client", + library_version=utils.__version__, + ) + + if space is None and src is None: + raise ValueError("Either `space` or `src` must be provided") + elif space and src: + raise ValueError("Only one of `space` or `src` should be provided") + self.src = src or self._space_name_to_src(space) + if self.src is None: + raise ValueError( + f"Could not find Space: {space}. If it is a private Space, please provide an hf_token." + ) + else: + print(f"Loaded as API: {self.src} ✔") + + self.api_url = utils.API_URL.format(self.src) + self.ws_url = utils.WS_URL.format(self.src).replace("http", "ws", 1) + self.config = self._get_config() + + self.endpoints = [ + Endpoint(self, fn_index, dependency) + for fn_index, dependency in enumerate(self.config["dependencies"]) + ] + + # Create a pool of threads to handle the requests + self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) + + # Disable telemetry by setting the env variable HF_HUB_DISABLE_TELEMETRY=1 + threading.Thread(target=self._telemetry_thread).start() + + def predict( + self, + *args, + api_name: str | None = None, + fn_index: int = 0, + result_callbacks: Callable | List[Callable] | None = None, + ) -> Future: + if api_name: + fn_index = self._infer_fn_index(api_name) + + end_to_end_fn = self.endpoints[fn_index].end_to_end_fn + future = self.executor.submit(end_to_end_fn, *args) + job = Job(future) + + if result_callbacks: + if isinstance(result_callbacks, Callable): + result_callbacks = [result_callbacks] + + def create_fn(callback) -> Callable: + def fn(future): + if isinstance(future.result(), tuple): + callback(*future.result()) + else: + callback(future.result()) + + return fn + + for callback in result_callbacks: + job.add_done_callback(create_fn(callback)) + + return job + + def _telemetry_thread(self) -> None: + # Disable telemetry by setting the env variable HF_HUB_DISABLE_TELEMETRY=1 + data = { + "src": self.src, + } + try: + send_telemetry( + topic="py_client/initiated", + library_name="gradio_client", + library_version=utils.__version__, + user_agent=data, + ) + except Exception: + pass + + def _infer_fn_index(self, api_name: str) -> int: + for i, d in enumerate(self.config["dependencies"]): + if d.get("api_name") == api_name: + return i + raise ValueError(f"Cannot find a function with api_name: {api_name}") + + def __del__(self): + if hasattr(self, "executor"): + self.executor.shutdown(wait=True) + + def _space_name_to_src(self, space) -> str | None: + return huggingface_hub.space_info(space, token=self.hf_token).host # type: ignore + + def _get_config(self) -> Dict: + assert self.src is not None + r = requests.get(self.src, headers=self.headers) + # some basic regex to extract the config + result = re.search(r"window.gradio_config = (.*?);[\s]*", r.text) + try: + config = json.loads(result.group(1)) # type: ignore + except AttributeError: + raise ValueError(f"Could not get Gradio config from: {self.src}") + if "allow_flagging" in config: + raise ValueError( + "Gradio 2.x is not supported by this client. Please upgrade this app to Gradio 3.x." + ) + return config + + +class Endpoint: + """Helper class for storing all the information about a single API endpoint.""" + + def __init__(self, client: Client, fn_index: int, dependency: Dict): + self.api_url = client.api_url + self.ws_url = client.ws_url + self.fn_index = fn_index + self.dependency = dependency + self.headers = client.headers + self.config = client.config + self.use_ws = self._use_websocket(self.dependency) + self.hf_token = client.hf_token + try: + self.serializers, self.deserializers = self._setup_serializers() + self.is_valid = self.dependency[ + "backend_fn" + ] # Only a real API endpoint if backend_fn is True + except AssertionError: + self.is_valid = False + + def end_to_end_fn(self, *data): + if not self.is_valid: + raise utils.InvalidAPIEndpointError() + inputs = self.serialize(*data) + predictions = self.predict(*inputs) + outputs = self.deserialize(*predictions) + if len(self.dependency["outputs"]) == 1: + return outputs[0] + return outputs + + def predict(self, *data) -> Tuple: + data = json.dumps({"data": data, "fn_index": self.fn_index}) + hash_data = json.dumps( + {"fn_index": self.fn_index, "session_hash": str(uuid.uuid4())} + ) + if self.use_ws: + result = utils.synchronize_async(self._ws_fn, data, hash_data) + output = result["data"] + else: + response = requests.post(self.api_url, headers=self.headers, data=data) + result = json.loads(response.content.decode("utf-8")) + try: + output = result["data"] + except KeyError: + if "error" in result and "429" in result["error"]: + raise utils.TooManyRequestsError( + "Too many requests to the Hugging Face API" + ) + raise KeyError( + f"Could not find 'data' key in response. Response received: {result}" + ) + return tuple(output) + + def _predict_resolve(self, *data) -> Any: + """Needed for gradio.load(), which has a slightly different signature for serializing/deserializing""" + outputs = self.predict(*data) + if len(self.dependency["outputs"]) == 1: + return outputs[0] + return outputs + + def serialize(self, *data) -> Tuple: + assert len(data) == len( + self.serializers + ), f"Expected {len(self.serializers)} arguments, got {len(data)}" + return tuple([s.serialize(d) for s, d in zip(self.serializers, data)]) + + def deserialize(self, *data) -> Tuple: + assert len(data) == len( + self.deserializers + ), f"Expected {len(self.deserializers)} outputs, got {len(data)}" + return tuple( + [ + s.deserialize(d, hf_token=self.hf_token) + for s, d in zip(self.deserializers, data) + ] + ) + + def _setup_serializers(self) -> Tuple[List[Serializable], List[Serializable]]: + inputs = self.dependency["inputs"] + serializers = [] + + for i in inputs: + for component in self.config["components"]: + if component["id"] == i: + if component.get("serializer"): + serializer_name = component["serializer"] + assert ( + serializer_name in serializing.SERIALIZER_MAPPING + ), f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version." + serializer = serializing.SERIALIZER_MAPPING[serializer_name] + else: + component_name = component["type"] + assert ( + component_name in serializing.COMPONENT_MAPPING + ), f"Unknown component: {component_name}, you may need to update your gradio_client version." + serializer = serializing.COMPONENT_MAPPING[component_name] + serializers.append(serializer()) # type: ignore + + outputs = self.dependency["outputs"] + deserializers = [] + for i in outputs: + for component in self.config["components"]: + if component["id"] == i: + if component.get("serializer"): + serializer_name = component["serializer"] + assert ( + serializer_name in serializing.SERIALIZER_MAPPING + ), f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version." + deserializer = serializing.SERIALIZER_MAPPING[serializer_name] + else: + component_name = component["type"] + assert ( + component_name in serializing.COMPONENT_MAPPING + ), f"Unknown component: {component_name}, you may need to update your gradio_client version." + deserializer = serializing.COMPONENT_MAPPING[component_name] + deserializers.append(deserializer()) # type: ignore + + return serializers, deserializers + + def _use_websocket(self, dependency: Dict) -> bool: + queue_enabled = self.config.get("enable_queue", False) + queue_uses_websocket = version.parse( + self.config.get("version", "2.0") + ) >= version.Version("3.2") + dependency_uses_queue = dependency.get("queue", False) is not False + return queue_enabled and queue_uses_websocket and dependency_uses_queue + + async def _ws_fn(self, data, hash_data): + async with websockets.connect( # type: ignore + self.ws_url, open_timeout=10, extra_headers=self.headers + ) as websocket: + return await utils.get_pred_from_ws(websocket, data, hash_data) + + +class Job(Future): + """A Job is a thin wrapper over the Future class that can be cancelled.""" + + def __init__(self, future: Future): + self.future = future + + def __getattr__(self, name): + """Forwards any properties to the Future class.""" + return getattr(self.future, name) + + def cancel(self) -> bool: + """Cancels the job.""" + if self.future.cancelled() or self.future.done(): + pass + return False + elif self.future.running(): + pass # TODO: Handle this case + return True + else: + return self.future.cancel() diff --git a/gradio/serializing.py b/client/python/gradio_client/serializing.py similarity index 60% rename from gradio/serializing.py rename to client/python/gradio_client/serializing.py index 6780ef7d289a0..6c8814e58cbd3 100644 --- a/gradio/serializing.py +++ b/client/python/gradio_client/serializing.py @@ -1,11 +1,13 @@ from __future__ import annotations +import json +import os +import uuid from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Dict +from typing import Any, Dict, List -from gradio import processing_utils, utils -from gradio.context import Context +from gradio_client import utils class Serializable(ABC): @@ -22,6 +24,7 @@ def deserialize( x: Any, save_dir: str | Path | None = None, root_url: str | None = None, + hf_token: str | None = None, ): """ Convert data from serialized format for a browser to human-readable format. @@ -44,6 +47,7 @@ def deserialize( x: Any, save_dir: str | Path | None = None, root_url: str | None = None, + hf_token: str | None = None, ): """ Convert data from serialized format to human-readable format. For SimpleSerializable components, this is a no-op. @@ -51,6 +55,7 @@ def deserialize( x: Input data to deserialize save_dir: Ignored root_url: Ignored + hf_token: Ignored """ return x @@ -70,15 +75,16 @@ def serialize( """ if x is None or x == "": return None - is_url = utils.validate_url(x) + is_url = utils.is_valid_url(x) path = x if is_url else Path(load_dir) / x - return processing_utils.encode_url_or_file_to_base64(path) + return utils.encode_url_or_file_to_base64(path) def deserialize( self, x: str | None, save_dir: str | Path | None = None, root_url: str | None = None, + hf_token: str | None = None, ) -> str | None: """ Convert from serialized representation of a file (base64) to a human-friendly @@ -87,10 +93,11 @@ def deserialize( x: Base64 representation of image to deserialize into a string filepath save_dir: Path to directory to save the deserialized image to root_url: Ignored + hf_token: Ignored """ if x is None or x == "": return None - file = processing_utils.decode_base64_to_file(x, dir=save_dir) + file = utils.decode_base64_to_file(x, dir=save_dir) return file.name @@ -112,7 +119,7 @@ def serialize( filename = str(Path(load_dir) / x) return { "name": filename, - "data": processing_utils.encode_url_or_file_to_base64(filename), + "data": utils.encode_url_or_file_to_base64(filename), "orig_name": Path(filename).name, "is_file": False, } @@ -122,6 +129,7 @@ def deserialize( x: str | Dict | None, save_dir: Path | str | None = None, root_url: str | None = None, + hf_token: str | None = None, ) -> str | None: """ Convert from serialized representation of a file (base64) to a human-friendly @@ -130,29 +138,28 @@ def deserialize( x: Base64 representation of file to deserialize into a string filepath save_dir: Path to directory to save the deserialized file to root_url: If this component is loaded from an external Space, this is the URL of the Space + hf_token: If this component is loaded from an external private Space, this is the access token for the Space """ if x is None: return None if isinstance(save_dir, Path): save_dir = str(save_dir) if isinstance(x, str): - file_name = processing_utils.decode_base64_to_file(x, dir=save_dir).name + file_name = utils.decode_base64_to_file(x, dir=save_dir).name elif isinstance(x, dict): if x.get("is_file", False): if root_url is not None: - file_name = processing_utils.download_tmp_copy_of_file( + file_name = utils.download_tmp_copy_of_file( root_url + "file=" + x["name"], - access_token=Context.access_token, + hf_token=hf_token, dir=save_dir, ).name else: - file_name = processing_utils.create_tmp_copy_of_file( + file_name = utils.create_tmp_copy_of_file( x["name"], dir=save_dir ).name else: - file_name = processing_utils.decode_base64_to_file( - x["data"], dir=save_dir - ).name + file_name = utils.decode_base64_to_file(x["data"], dir=save_dir).name else: raise ValueError( f"A FileSerializable component cannot only deserialize a string or a dict, not a: {type(x)}" @@ -175,13 +182,14 @@ def serialize( """ if x is None or x == "": return None - return processing_utils.file_to_json(Path(load_dir) / x) + return utils.file_to_json(Path(load_dir) / x) def deserialize( self, x: str | Dict, save_dir: str | Path | None = None, root_url: str | None = None, + hf_token: str | None = None, ) -> str | None: """ Convert from serialized representation (json string) to a human-friendly @@ -190,7 +198,84 @@ def deserialize( x: Json string save_dir: Path to save the deserialized json file to root_url: Ignored + hf_token: Ignored """ if x is None: return None - return processing_utils.dict_or_str_to_json_file(x, dir=save_dir).name + return utils.dict_or_str_to_json_file(x, dir=save_dir).name + + +class GallerySerializable(Serializable): + def serialize( + self, x: str | None, load_dir: str | Path = "" + ) -> List[List[str]] | None: + if x is None or x == "": + return None + files = [] + captions_file = Path(x) / "captions.json" + with captions_file.open("r") as captions_json: + captions = json.load(captions_json) + for file_name, caption in captions.items(): + img = FileSerializable().serialize(file_name) + files.append([img, caption]) + return files + + def deserialize( + self, + x: Any, + save_dir: str = "", + root_url: str | None = None, + hf_token: str | None = None, + ) -> None | str: + if x is None: + return None + gallery_path = Path(save_dir) / str(uuid.uuid4()) + gallery_path.mkdir(exist_ok=True, parents=True) + captions = {} + for img_data in x: + if isinstance(img_data, list) or isinstance(img_data, tuple): + img_data, caption = img_data + else: + caption = None + name = FileSerializable().deserialize( + img_data, gallery_path, root_url=root_url, hf_token=hf_token + ) + captions[name] = caption + captions_file = gallery_path / "captions.json" + with captions_file.open("w") as captions_json: + json.dump(captions, captions_json) + return os.path.abspath(gallery_path) + + +SERIALIZER_MAPPING = {cls.__name__: cls for cls in Serializable.__subclasses__()} + +COMPONENT_MAPPING = { + "textbox": SimpleSerializable, + "number": SimpleSerializable, + "slider": SimpleSerializable, + "checkbox": SimpleSerializable, + "checkboxgroup": SimpleSerializable, + "radio": SimpleSerializable, + "dropdown": SimpleSerializable, + "image": ImgSerializable, + "video": FileSerializable, + "audio": FileSerializable, + "file": FileSerializable, + "dataframe": JSONSerializable, + "timeseries": JSONSerializable, + "state": SimpleSerializable, + "button": SimpleSerializable, + "uploadbutton": FileSerializable, + "colorpicker": SimpleSerializable, + "label": JSONSerializable, + "highlightedtext": JSONSerializable, + "json": JSONSerializable, + "html": SimpleSerializable, + "gallery": GallerySerializable, + "chatbot": JSONSerializable, + "model3d": FileSerializable, + "plot": JSONSerializable, + "markdown": SimpleSerializable, + "dataset": SimpleSerializable, + "code": SimpleSerializable, +} diff --git a/client/python/gradio_client/utils.py b/client/python/gradio_client/utils.py new file mode 100644 index 0000000000000..423009fcf6e2a --- /dev/null +++ b/client/python/gradio_client/utils.py @@ -0,0 +1,254 @@ +from __future__ import annotations + +import base64 +import json +import mimetypes +import os +import pkgutil +import shutil +import tempfile +from pathlib import Path +from typing import Any, Callable, Dict, Tuple + +import fsspec.asyn +import requests +from websockets.legacy.protocol import WebSocketCommonProtocol + +API_URL = "{}/api/predict/" +WS_URL = "{}/queue/join" + +__version__ = (pkgutil.get_data(__name__, "version.txt") or b"").decode("ascii").strip() + + +class TooManyRequestsError(Exception): + """Raised when the API returns a 429 status code.""" + + pass + + +class QueueError(Exception): + """Raised when the queue is full or there is an issue adding a job to the queue.""" + + pass + + +class InvalidAPIEndpointError(Exception): + """Raised when the API endpoint is invalid.""" + + pass + + +######################## +# Network utils +######################## + + +def is_valid_url(possible_url: str) -> bool: + headers = {"User-Agent": "gradio (https://gradio.app/; team@gradio.app)"} + try: + head_request = requests.head(possible_url, headers=headers) + if head_request.status_code == 405: + return requests.get(possible_url, headers=headers).ok + return head_request.ok + except Exception: + return False + + +async def get_pred_from_ws( + websocket: WebSocketCommonProtocol, data: str, hash_data: str +) -> Dict[str, Any]: + completed = False + resp = {} + while not completed: + msg = await websocket.recv() + resp = json.loads(msg) + if resp["msg"] == "queue_full": + raise QueueError("Queue is full! Please try again.") + if resp["msg"] == "send_hash": + await websocket.send(hash_data) + elif resp["msg"] == "send_data": + await websocket.send(data) + completed = resp["msg"] == "process_completed" + return resp["output"] + + +######################## +# Data processing utils +######################## + + +def download_tmp_copy_of_file( + url_path: str, hf_token: str | None = None, dir: str | None = None +) -> tempfile._TemporaryFileWrapper: + if dir is not None: + os.makedirs(dir, exist_ok=True) + headers = {"Authorization": "Bearer " + hf_token} if hf_token else {} + prefix = Path(url_path).stem + suffix = Path(url_path).suffix + file_obj = tempfile.NamedTemporaryFile( + delete=False, + prefix=prefix, + suffix=suffix, + dir=dir, + ) + with requests.get(url_path, headers=headers, stream=True) as r: + with open(file_obj.name, "wb") as f: + shutil.copyfileobj(r.raw, f) + return file_obj + + +def create_tmp_copy_of_file( + file_path: str, dir: str | None = None +) -> tempfile._TemporaryFileWrapper: + if dir is not None: + os.makedirs(dir, exist_ok=True) + prefix = Path(file_path).stem + suffix = Path(file_path).suffix + file_obj = tempfile.NamedTemporaryFile( + delete=False, + prefix=prefix, + suffix=suffix, + dir=dir, + ) + shutil.copy2(file_path, file_obj.name) + return file_obj + + +def get_mimetype(filename: str) -> str | None: + mimetype = mimetypes.guess_type(filename)[0] + if mimetype is not None: + mimetype = mimetype.replace("x-wav", "wav").replace("x-flac", "flac") + return mimetype + + +def get_extension(encoding: str) -> str | None: + encoding = encoding.replace("audio/wav", "audio/x-wav") + type = mimetypes.guess_type(encoding)[0] + if type == "audio/flac": # flac is not supported by mimetypes + return "flac" + elif type is None: + return None + extension = mimetypes.guess_extension(type) + if extension is not None and extension.startswith("."): + extension = extension[1:] + return extension + + +def encode_file_to_base64(f): + with open(f, "rb") as file: + encoded_string = base64.b64encode(file.read()) + base64_str = str(encoded_string, "utf-8") + mimetype = get_mimetype(f) + return ( + "data:" + + (mimetype if mimetype is not None else "") + + ";base64," + + base64_str + ) + + +def encode_url_to_base64(url): + encoded_string = base64.b64encode(requests.get(url).content) + base64_str = str(encoded_string, "utf-8") + mimetype = get_mimetype(url) + return ( + "data:" + (mimetype if mimetype is not None else "") + ";base64," + base64_str + ) + + +def encode_url_or_file_to_base64(path: str | Path): + path = str(path) + if is_valid_url(path): + return encode_url_to_base64(path) + else: + return encode_file_to_base64(path) + + +def decode_base64_to_binary(encoding) -> Tuple[bytes, str | None]: + extension = get_extension(encoding) + try: + data = encoding.split(",")[1] + except IndexError: + data = "" + return base64.b64decode(data), extension + + +def strip_invalid_filename_characters(filename: str, max_bytes: int = 200) -> str: + """Strips invalid characters from a filename and ensures that the file_length is less than `max_bytes` bytes.""" + filename = "".join([char for char in filename if char.isalnum() or char in "._- "]) + filename_len = len(filename.encode()) + if filename_len > max_bytes: + while filename_len > max_bytes: + if len(filename) == 0: + break + filename = filename[:-1] + filename_len = len(filename.encode()) + return filename + + +def decode_base64_to_file(encoding, file_path=None, dir=None, prefix=None): + if dir is not None: + os.makedirs(dir, exist_ok=True) + data, extension = decode_base64_to_binary(encoding) + if file_path is not None and prefix is None: + filename = Path(file_path).name + prefix = filename + if "." in filename: + prefix = filename[0 : filename.index(".")] + extension = filename[filename.index(".") + 1 :] + + if prefix is not None: + prefix = strip_invalid_filename_characters(prefix) + + if extension is None: + file_obj = tempfile.NamedTemporaryFile(delete=False, prefix=prefix, dir=dir) + else: + file_obj = tempfile.NamedTemporaryFile( + delete=False, + prefix=prefix, + suffix="." + extension, + dir=dir, + ) + file_obj.write(data) + file_obj.flush() + return file_obj + + +def dict_or_str_to_json_file(jsn, dir=None): + if dir is not None: + os.makedirs(dir, exist_ok=True) + + file_obj = tempfile.NamedTemporaryFile( + delete=False, suffix=".json", dir=dir, mode="w+" + ) + if isinstance(jsn, str): + jsn = json.loads(jsn) + json.dump(jsn, file_obj) + file_obj.flush() + return file_obj + + +def file_to_json(file_path: str | Path) -> Dict: + with open(file_path) as f: + return json.load(f) + + +######################## +# Misc utils +######################## + + +def synchronize_async(func: Callable, *args, **kwargs) -> Any: + """ + Runs async functions in sync scopes. Can be used in any scope. + + Example: + if inspect.iscoroutinefunction(block_fn.fn): + predictions = utils.synchronize_async(block_fn.fn, *processed_input) + + Args: + func: + *args: + **kwargs: + """ + return fsspec.asyn.sync(fsspec.asyn.get_loop(), func, *args, **kwargs) # type: ignore diff --git a/client/python/gradio_client/version.txt b/client/python/gradio_client/version.txt new file mode 100644 index 0000000000000..05b19b1f76ec5 --- /dev/null +++ b/client/python/gradio_client/version.txt @@ -0,0 +1 @@ +0.0.4 \ No newline at end of file diff --git a/client/python/pyproject.toml b/client/python/pyproject.toml new file mode 100644 index 0000000000000..9c64c12f987a6 --- /dev/null +++ b/client/python/pyproject.toml @@ -0,0 +1,43 @@ +[build-system] +requires = ["hatchling", "hatch-requirements-txt", "hatch-fancy-pypi-readme>=22.5.0"] +build-backend = "hatchling.build" + +[project] +name = "gradio_client" +dynamic = ["version", "dependencies", "readme"] +description = "Python library for easily interacting with trained machine learning models" +license = "Apache-2.0" +requires-python = ">=3.7" +authors = [ + { name = "Abubakar Abid", email = "team@gradio.app" }, + { name = "Ali Abid", email = "team@gradio.app" }, + { name = "Ali Abdalla", email = "team@gradio.app" }, + { name = "Dawood Khan", email = "team@gradio.app" }, + { name = "Ahsen Khaliq", email = "team@gradio.app" }, + { name = "Pete Allen", email = "team@gradio.app" }, + { name = "Freddy Boulton", email = "team@gradio.app" }, +] +keywords = ["machine learning", "client", "API"] + +[project.urls] +Homepage = "https://github.com/gradio-app/gradio" + +[tool.hatch.version] +path = "gradio_client/version.txt" +pattern = "(?P.+)" + +[tool.hatch.metadata.hooks.requirements_txt] +filename = "requirements.txt" + +[tool.hatch.metadata.hooks.fancy-pypi-readme] +content-type = "text/markdown" +fragments = [ + { path = "README.md" }, +] + +[tool.hatch.build.targets.sdist] +include = [ + "/gradio_client", + "/README.md", + "/requirements.txt", +] diff --git a/client/python/requirements.txt b/client/python/requirements.txt new file mode 100644 index 0000000000000..83587c7dea8af --- /dev/null +++ b/client/python/requirements.txt @@ -0,0 +1,5 @@ +requests +websockets +packaging +fsspec +huggingface_hub>=0.13.0 diff --git a/client/python/scripts/ci.sh b/client/python/scripts/ci.sh new file mode 100644 index 0000000000000..0d97a227c6e44 --- /dev/null +++ b/client/python/scripts/ci.sh @@ -0,0 +1,12 @@ +#!/bin/bash +set -e + +cd "$(dirname ${0})/.." + +echo "Linting..." +python -m black --check test gradio_client +python -m isort --profile=black --check-only test gradio_client +python -m flake8 --ignore=E731,E501,E722,W503,E126,E203,F403,F541 test gradio_client --exclude gradio_client/__init__.py + +echo "Testing..." +python -m pytest test diff --git a/client/python/scripts/format.sh b/client/python/scripts/format.sh new file mode 100644 index 0000000000000..e37f54f8a8233 --- /dev/null +++ b/client/python/scripts/format.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +cd "$(dirname ${0})/.." + +echo "Formatting the backend... Our style follows the Black code style." +python -m black test gradio_client +python -m isort --profile=black test gradio_client +python -m flake8 --ignore=E731,E501,E722,W503,E126,E203,F403 test gradio_client --exclude gradio_client/__init__.py diff --git a/client/python/scripts/upload_pypi.sh b/client/python/scripts/upload_pypi.sh new file mode 100644 index 0000000000000..02b7ea57ca147 --- /dev/null +++ b/client/python/scripts/upload_pypi.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +cd "$(dirname ${0})/.." + +python -m pip install build twine +rm -rf dist/* +python -m build +twine upload dist/* \ No newline at end of file diff --git a/client/python/test/test_client.py b/client/python/test/test_client.py new file mode 100644 index 0000000000000..cdc028cfcaee1 --- /dev/null +++ b/client/python/test/test_client.py @@ -0,0 +1,22 @@ +import json + +import pytest + +from gradio_client import Client + + +class TestPredictionsFromSpaces: + @pytest.mark.flaky + def test_numerical_to_label_space(self): + client = Client(space="abidlabs/titanic-survival") + output = client.predict("male", 77, 10).result() + assert json.load(open(output))["label"] == "Perishes" + + @pytest.mark.flaky + def test_private_space(self): + hf_token = "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" # Intentionally revealing this key for testing purposes + client = Client( + space="gradio-tests/not-actually-private-space", hf_token=hf_token + ) + output = client.predict("abc").result() + assert output == "abc" diff --git a/client/python/test/test_utils.py b/client/python/test/test_utils.py new file mode 100644 index 0000000000000..5dca48b0f3ed9 --- /dev/null +++ b/client/python/test/test_utils.py @@ -0,0 +1,100 @@ +import json +import tempfile +from copy import deepcopy +from pathlib import Path +from unittest.mock import MagicMock + +import pytest +from gradio import media_data + +from gradio_client import utils + + +def test_encode_url_or_file_to_base64(): + output_base64 = utils.encode_url_or_file_to_base64( + Path(__file__).parent / "../../../gradio/test_data/test_image.png" + ) + assert output_base64 == deepcopy(media_data.BASE64_IMAGE) + + +def test_encode_file_to_base64(): + output_base64 = utils.encode_file_to_base64( + Path(__file__).parent / "../../../gradio/test_data/test_image.png" + ) + assert output_base64 == deepcopy(media_data.BASE64_IMAGE) + + +@pytest.mark.flaky +def test_encode_url_to_base64(): + output_base64 = utils.encode_url_to_base64( + "https://raw.githubusercontent.com/gradio-app/gradio/main/gradio/test_data/test_image.png" + ) + assert output_base64 == deepcopy(media_data.BASE64_IMAGE) + + +def test_decode_base64_to_binary(): + binary = utils.decode_base64_to_binary(deepcopy(media_data.BASE64_IMAGE)) + assert deepcopy(media_data.BINARY_IMAGE) == binary + + +def test_decode_base64_to_file(): + temp_file = utils.decode_base64_to_file(deepcopy(media_data.BASE64_IMAGE)) + assert isinstance(temp_file, tempfile._TemporaryFileWrapper) + + +def test_download_private_file(): + url_path = "https://gradio-tests-not-actually-private-space.hf.space/file=lion.jpg" + hf_token = "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" # Intentionally revealing this key for testing purposes + file = utils.download_tmp_copy_of_file(url_path=url_path, hf_token=hf_token) + assert file.name.endswith(".jpg") + + +@pytest.mark.parametrize( + "orig_filename, new_filename", + [ + ("abc", "abc"), + ("$$AAabc&3", "AAabc3"), + ("$$AAabc&3", "AAabc3"), + ("$$AAa..b-c&3_", "AAa..b-c3_"), + ("$$AAa..b-c&3_", "AAa..b-c3_"), + ( + "ゆかりです。私、こんなかわいい服は初めて着ました…。なんだかうれしくって、楽しいです。歌いたくなる気分って、初めてです。これがアイドルってことなのかもしれませんね", + "ゆかりです私こんなかわいい服は初めて着ましたなんだかうれしくって楽しいです歌いたくなる気分って初めてですこれがアイドルってことなの", + ), + ], +) +def test_strip_invalid_filename_characters(orig_filename, new_filename): + assert utils.strip_invalid_filename_characters(orig_filename) == new_filename + + +class AsyncMock(MagicMock): + async def __call__(self, *args, **kwargs): + return super(AsyncMock, self).__call__(*args, **kwargs) + + +@pytest.mark.asyncio +async def test_get_pred_from_ws(): + mock_ws = AsyncMock(name="ws") + messages = [ + json.dumps({"msg": "estimation"}), + json.dumps({"msg": "send_data"}), + json.dumps({"msg": "process_generating"}), + json.dumps({"msg": "process_completed", "output": {"data": ["result!"]}}), + ] + mock_ws.recv.side_effect = messages + data = json.dumps({"data": ["foo"], "fn_index": "foo"}) + hash_data = json.dumps({"session_hash": "daslskdf", "fn_index": "foo"}) + output = await utils.get_pred_from_ws(mock_ws, data, hash_data) + assert output == {"data": ["result!"]} + mock_ws.send.assert_called_once_with(data) + + +@pytest.mark.asyncio +async def test_get_pred_from_ws_raises_if_queue_full(): + mock_ws = AsyncMock(name="ws") + messages = [json.dumps({"msg": "queue_full"})] + mock_ws.recv.side_effect = messages + data = json.dumps({"data": ["foo"], "fn_index": "foo"}) + hash_data = json.dumps({"session_hash": "daslskdf", "fn_index": "foo"}) + with pytest.raises(utils.QueueError, match="Queue is full!"): + await utils.get_pred_from_ws(mock_ws, data, hash_data) diff --git a/demo/autocomplete/run.ipynb b/demo/autocomplete/run.ipynb index d07b7fce83c96..51e547bf6d2a1 100644 --- a/demo/autocomplete/run.ipynb +++ b/demo/autocomplete/run.ipynb @@ -1 +1 @@ -{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: autocomplete\n", "### This text generation demo works like autocomplete. There's only one textbox and it's used for both the input and the output. The demo loads the model as an interface, and uses that interface as an API. It then uses blocks to create the UI. All of this is done in less than 10 lines of code.\n", " "]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "\n", "# save your HF API token from https:/hf.co/settings/tokens as an env variable to avoid rate limiting\n", "auth_token = os.getenv(\"auth_token\")\n", "\n", "# load a model from https://hf.co/models as an interface, then use it as an api \n", "# you can remove the api_key parameter if you don't care about rate limiting. \n", "api = gr.Interface.load(\"huggingface/EleutherAI/gpt-j-6B\", api_key=auth_token)\n", "\n", "def complete_with_gpt(text):\n", " return text[:-50] + api(text[-50:])\n", "\n", "with gr.Blocks() as demo:\n", " textbox = gr.Textbox(placeholder=\"Type here...\", lines=4)\n", " btn = gr.Button(\"Autocomplete\")\n", " \n", " # define what will run when the button is clicked, here the textbox is used as both an input and an output\n", " btn.click(fn=complete_with_gpt, inputs=textbox, outputs=textbox, queue=False)\n", "\n", "demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file +{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: autocomplete\n", "### This text generation demo works like autocomplete. There's only one textbox and it's used for both the input and the output. The demo loads the model as an interface, and uses that interface as an API. It then uses blocks to create the UI. All of this is done in less than 10 lines of code.\n", " "]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "\n", "# save your HF API token from https:/hf.co/settings/tokens as an env variable to avoid rate limiting\n", "auth_token = os.getenv(\"auth_token\")\n", "\n", "# load a model from https://hf.co/models as an interface, then use it as an api \n", "# you can remove the api_key parameter if you don't care about rate limiting. \n", "api = gr.load(\"huggingface/EleutherAI/gpt-j-6B\", api_key=auth_token)\n", "\n", "def complete_with_gpt(text):\n", " return text[:-50] + api(text[-50:])\n", "\n", "with gr.Blocks() as demo:\n", " textbox = gr.Textbox(placeholder=\"Type here...\", lines=4)\n", " btn = gr.Button(\"Autocomplete\")\n", " \n", " # define what will run when the button is clicked, here the textbox is used as both an input and an output\n", " btn.click(fn=complete_with_gpt, inputs=textbox, outputs=textbox, queue=False)\n", "\n", "demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/autocomplete/run.py b/demo/autocomplete/run.py index fab6b8cf46f04..172ac9cb99b34 100644 --- a/demo/autocomplete/run.py +++ b/demo/autocomplete/run.py @@ -6,7 +6,7 @@ # load a model from https://hf.co/models as an interface, then use it as an api # you can remove the api_key parameter if you don't care about rate limiting. -api = gr.Interface.load("huggingface/EleutherAI/gpt-j-6B", api_key=auth_token) +api = gr.load("huggingface/EleutherAI/gpt-j-6B", api_key=auth_token) def complete_with_gpt(text): return text[:-50] + api(text[-50:]) diff --git a/demo/automatic-speech-recognition/run.ipynb b/demo/automatic-speech-recognition/run.ipynb index 2758720871cef..b41b1f186c104 100644 --- a/demo/automatic-speech-recognition/run.ipynb +++ b/demo/automatic-speech-recognition/run.ipynb @@ -1 +1 @@ -{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: automatic-speech-recognition\n", "### Automatic speech recognition English. Record from your microphone and the app will transcribe the audio.\n", " "]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "\n", "# save your HF API token from https:/hf.co/settings/tokens as an env variable to avoid rate limiting\n", "auth_token = os.getenv(\"auth_token\")\n", "\n", "# automatically load the interface from a HF model \n", "# you can remove the api_key parameter if you don't care about rate limiting. \n", "demo = gr.Interface.load(\n", " \"huggingface/facebook/wav2vec2-base-960h\",\n", " title=\"Speech-to-text\",\n", " inputs=\"mic\",\n", " description=\"Let me try to guess what you're saying!\",\n", " api_key=auth_token\n", ")\n", "\n", "demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file +{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: automatic-speech-recognition\n", "### Automatic speech recognition English. Record from your microphone and the app will transcribe the audio.\n", " "]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "\n", "# save your HF API token from https:/hf.co/settings/tokens as an env variable to avoid rate limiting\n", "auth_token = os.getenv(\"auth_token\")\n", "\n", "# automatically load the interface from a HF model \n", "# you can remove the api_key parameter if you don't care about rate limiting. \n", "demo = gr.load(\n", " \"huggingface/facebook/wav2vec2-base-960h\",\n", " title=\"Speech-to-text\",\n", " inputs=\"mic\",\n", " description=\"Let me try to guess what you're saying!\",\n", " api_key=auth_token\n", ")\n", "\n", "demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/automatic-speech-recognition/run.py b/demo/automatic-speech-recognition/run.py index b18231f4c6fca..299e458781a5c 100644 --- a/demo/automatic-speech-recognition/run.py +++ b/demo/automatic-speech-recognition/run.py @@ -6,7 +6,7 @@ # automatically load the interface from a HF model # you can remove the api_key parameter if you don't care about rate limiting. -demo = gr.Interface.load( +demo = gr.load( "huggingface/facebook/wav2vec2-base-960h", title="Speech-to-text", inputs="mic", diff --git a/demo/blocks_gpt/run.ipynb b/demo/blocks_gpt/run.ipynb index 032ab40b437e4..c5e7694496e1b 100644 --- a/demo/blocks_gpt/run.ipynb +++ b/demo/blocks_gpt/run.ipynb @@ -1 +1 @@ -{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: blocks_gpt"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "api = gr.Interface.load(\"huggingface/EleutherAI/gpt-j-6B\")\n", "\n", "def complete_with_gpt(text):\n", " # Use the last 50 characters of the text as context\n", " return text[:-50] + api(text[-50:])\n", "\n", "with gr.Blocks() as demo:\n", " textbox = gr.Textbox(placeholder=\"Type here and press enter...\", lines=4)\n", " btn = gr.Button(\"Generate\")\n", " \n", " btn.click(complete_with_gpt, textbox, textbox)\n", " \n", "if __name__ == \"__main__\":\n", " demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file +{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: blocks_gpt"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "api = gr.load(\"huggingface/EleutherAI/gpt-j-6B\")\n", "\n", "def complete_with_gpt(text):\n", " # Use the last 50 characters of the text as context\n", " return text[:-50] + api(text[-50:])\n", "\n", "with gr.Blocks() as demo:\n", " textbox = gr.Textbox(placeholder=\"Type here and press enter...\", lines=4)\n", " btn = gr.Button(\"Generate\")\n", " \n", " btn.click(complete_with_gpt, textbox, textbox)\n", " \n", "if __name__ == \"__main__\":\n", " demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/blocks_gpt/run.py b/demo/blocks_gpt/run.py index 8799cde02d2e4..3f360ce3abe8b 100644 --- a/demo/blocks_gpt/run.py +++ b/demo/blocks_gpt/run.py @@ -1,6 +1,6 @@ import gradio as gr -api = gr.Interface.load("huggingface/EleutherAI/gpt-j-6B") +api = gr.load("huggingface/EleutherAI/gpt-j-6B") def complete_with_gpt(text): # Use the last 50 characters of the text as context diff --git a/demo/gpt_j/run.ipynb b/demo/gpt_j/run.ipynb index 06d255f9f390c..f96d85442027f 100644 --- a/demo/gpt_j/run.ipynb +++ b/demo/gpt_j/run.ipynb @@ -1 +1 @@ -{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: gpt_j"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "title = \"GPT-J-6B\"\n", "\n", "examples = [\n", " [\"The tower is 324 metres (1,063 ft) tall,\"],\n", " [\"The Moon's orbit around Earth has\"],\n", " [\"The smooth Borealis basin in the Northern Hemisphere covers 40%\"],\n", "]\n", "\n", "demo = gr.Interface.load(\n", " \"huggingface/EleutherAI/gpt-j-6B\",\n", " inputs=gr.Textbox(lines=5, max_lines=6, label=\"Input Text\"),\n", " title=title,\n", " examples=examples,\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file +{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: gpt_j"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "title = \"GPT-J-6B\"\n", "\n", "examples = [\n", " [\"The tower is 324 metres (1,063 ft) tall,\"],\n", " [\"The Moon's orbit around Earth has\"],\n", " [\"The smooth Borealis basin in the Northern Hemisphere covers 40%\"],\n", "]\n", "\n", "demo = gr.load(\n", " \"huggingface/EleutherAI/gpt-j-6B\",\n", " inputs=gr.Textbox(lines=5, max_lines=6, label=\"Input Text\"),\n", " title=title,\n", " examples=examples,\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/gpt_j/run.py b/demo/gpt_j/run.py index 18dbd4b107f74..29c6fa5f4206a 100644 --- a/demo/gpt_j/run.py +++ b/demo/gpt_j/run.py @@ -8,7 +8,7 @@ ["The smooth Borealis basin in the Northern Hemisphere covers 40%"], ] -demo = gr.Interface.load( +demo = gr.load( "huggingface/EleutherAI/gpt-j-6B", inputs=gr.Textbox(lines=5, max_lines=6, label="Input Text"), title=title, diff --git a/demo/gpt_j_unified/run.ipynb b/demo/gpt_j_unified/run.ipynb index d840f1d52ec7a..8732b5c942906 100644 --- a/demo/gpt_j_unified/run.ipynb +++ b/demo/gpt_j_unified/run.ipynb @@ -1 +1 @@ -{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: gpt_j_unified"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "component = gr.Textbox(lines=5, label=\"Text\")\n", "api = gr.Interface.load(\"huggingface/EleutherAI/gpt-j-6B\")\n", "\n", "demo = gr.Interface(\n", " fn=lambda x: x[:-50] + api(x[-50:]),\n", " inputs=component,\n", " outputs=component,\n", " title=\"GPT-J-6B\",\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file +{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: gpt_j_unified"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "component = gr.Textbox(lines=5, label=\"Text\")\n", "api = gr.load(\"huggingface/EleutherAI/gpt-j-6B\")\n", "\n", "demo = gr.Interface(\n", " fn=lambda x: x[:-50] + api(x[-50:]),\n", " inputs=component,\n", " outputs=component,\n", " title=\"GPT-J-6B\",\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/gpt_j_unified/run.py b/demo/gpt_j_unified/run.py index b561f89509172..30c13507c6cee 100644 --- a/demo/gpt_j_unified/run.py +++ b/demo/gpt_j_unified/run.py @@ -1,7 +1,7 @@ import gradio as gr component = gr.Textbox(lines=5, label="Text") -api = gr.Interface.load("huggingface/EleutherAI/gpt-j-6B") +api = gr.load("huggingface/EleutherAI/gpt-j-6B") demo = gr.Interface( fn=lambda x: x[:-50] + api(x[-50:]), diff --git a/demo/image_classifier_interface_load/run.ipynb b/demo/image_classifier_interface_load/run.ipynb index ffdb026b9d844..122dfa407fde1 100644 --- a/demo/image_classifier_interface_load/run.ipynb +++ b/demo/image_classifier_interface_load/run.ipynb @@ -1 +1 @@ -{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: image_classifier_interface_load"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/image_classifier_interface_load/cheetah1.jpeg\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/image_classifier_interface_load/cheetah1.jpg\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/image_classifier_interface_load/lion.jpg"]}, {"cell_type": "code", "execution_count": null, "id": 44380577570523278879349135829904343037, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import pathlib\n", "\n", "current_dir = pathlib.Path(__file__).parent\n", "\n", "images = [str(current_dir / \"cheetah1.jpeg\"), str(current_dir / \"cheetah1.jpg\"), str(current_dir / \"lion.jpg\")]\n", "\n", "\n", "img_classifier = gr.Interface.load(\n", " \"models/google/vit-base-patch16-224\", examples=images, cache_examples=False\n", ")\n", "\n", "\n", "def func(img, text):\n", " return img_classifier(img), text\n", "\n", "\n", "using_img_classifier_as_function = gr.Interface(\n", " func,\n", " [gr.Image(type=\"filepath\"), \"text\"],\n", " [\"label\", \"text\"],\n", " examples=[\n", " [str(current_dir / \"cheetah1.jpeg\"), None],\n", " [str(current_dir / \"cheetah1.jpg\"), \"cheetah\"],\n", " [str(current_dir / \"lion.jpg\"), \"lion\"],\n", " ],\n", " cache_examples=False,\n", ")\n", "demo = gr.TabbedInterface([using_img_classifier_as_function, img_classifier])\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file +{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: image_classifier_interface_load"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/image_classifier_interface_load/cheetah1.jpeg\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/image_classifier_interface_load/cheetah1.jpg\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/image_classifier_interface_load/lion.jpg"]}, {"cell_type": "code", "execution_count": null, "id": 44380577570523278879349135829904343037, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import pathlib\n", "\n", "current_dir = pathlib.Path(__file__).parent\n", "\n", "images = [str(current_dir / \"cheetah1.jpeg\"), str(current_dir / \"cheetah1.jpg\"), str(current_dir / \"lion.jpg\")]\n", "\n", "\n", "img_classifier = gr.load(\n", " \"models/google/vit-base-patch16-224\", examples=images, cache_examples=False\n", ")\n", "\n", "\n", "def func(img, text):\n", " return img_classifier(img), text\n", "\n", "\n", "using_img_classifier_as_function = gr.Interface(\n", " func,\n", " [gr.Image(type=\"filepath\"), \"text\"],\n", " [\"label\", \"text\"],\n", " examples=[\n", " [str(current_dir / \"cheetah1.jpeg\"), None],\n", " [str(current_dir / \"cheetah1.jpg\"), \"cheetah\"],\n", " [str(current_dir / \"lion.jpg\"), \"lion\"],\n", " ],\n", " cache_examples=False,\n", ")\n", "demo = gr.TabbedInterface([using_img_classifier_as_function, img_classifier])\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/image_classifier_interface_load/run.py b/demo/image_classifier_interface_load/run.py index b2e6b515be0da..86d8dc8bc8def 100644 --- a/demo/image_classifier_interface_load/run.py +++ b/demo/image_classifier_interface_load/run.py @@ -6,7 +6,7 @@ images = [str(current_dir / "cheetah1.jpeg"), str(current_dir / "cheetah1.jpg"), str(current_dir / "lion.jpg")] -img_classifier = gr.Interface.load( +img_classifier = gr.load( "models/google/vit-base-patch16-224", examples=images, cache_examples=False ) diff --git a/demo/interface_parallel_load/run.ipynb b/demo/interface_parallel_load/run.ipynb index e29e99ca2d23e..59f81075bbfde 100644 --- a/demo/interface_parallel_load/run.ipynb +++ b/demo/interface_parallel_load/run.ipynb @@ -1 +1 @@ -{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: interface_parallel_load"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "generator1 = gr.Interface.load(\"huggingface/gpt2\")\n", "generator2 = gr.Interface.load(\"huggingface/EleutherAI/gpt-neo-2.7B\")\n", "generator3 = gr.Interface.load(\"huggingface/EleutherAI/gpt-j-6B\")\n", "\n", "demo = gr.Parallel(generator1, generator2, generator3)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file +{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: interface_parallel_load"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "generator1 = gr.load(\"huggingface/gpt2\")\n", "generator2 = gr.load(\"huggingface/EleutherAI/gpt-neo-2.7B\")\n", "generator3 = gr.load(\"huggingface/EleutherAI/gpt-j-6B\")\n", "\n", "demo = gr.Parallel(generator1, generator2, generator3)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/interface_parallel_load/run.py b/demo/interface_parallel_load/run.py index cae6397329484..2267f4c514ef1 100644 --- a/demo/interface_parallel_load/run.py +++ b/demo/interface_parallel_load/run.py @@ -1,8 +1,8 @@ import gradio as gr -generator1 = gr.Interface.load("huggingface/gpt2") -generator2 = gr.Interface.load("huggingface/EleutherAI/gpt-neo-2.7B") -generator3 = gr.Interface.load("huggingface/EleutherAI/gpt-j-6B") +generator1 = gr.load("huggingface/gpt2") +generator2 = gr.load("huggingface/EleutherAI/gpt-neo-2.7B") +generator3 = gr.load("huggingface/EleutherAI/gpt-j-6B") demo = gr.Parallel(generator1, generator2, generator3) diff --git a/demo/interface_series_load/run.ipynb b/demo/interface_series_load/run.ipynb index a2fe9eb44f713..601225736450d 100644 --- a/demo/interface_series_load/run.ipynb +++ b/demo/interface_series_load/run.ipynb @@ -1 +1 @@ -{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: interface_series_load"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "generator = gr.Interface.load(\"huggingface/gpt2\")\n", "translator = gr.Interface.load(\"huggingface/t5-small\")\n", "\n", "demo = gr.Series(generator, translator, description=\"This demo combines two Spaces: a text generator (`huggingface/gpt2`) and a text translator (`huggingface/t5-small`). The first Space takes a prompt as input and generates a text. The second Space takes the generated text as input and translates it into another language.\")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file +{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: interface_series_load"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "generator = gr.load(\"huggingface/gpt2\")\n", "translator = gr.load(\"huggingface/t5-small\")\n", "\n", "demo = gr.Series(generator, translator, description=\"This demo combines two Spaces: a text generator (`huggingface/gpt2`) and a text translator (`huggingface/t5-small`). The first Space takes a prompt as input and generates a text. The second Space takes the generated text as input and translates it into another language.\")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/interface_series_load/run.py b/demo/interface_series_load/run.py index a3a0981d55f05..13703ccadbf81 100644 --- a/demo/interface_series_load/run.py +++ b/demo/interface_series_load/run.py @@ -1,7 +1,7 @@ import gradio as gr -generator = gr.Interface.load("huggingface/gpt2") -translator = gr.Interface.load("huggingface/t5-small") +generator = gr.load("huggingface/gpt2") +translator = gr.load("huggingface/t5-small") demo = gr.Series(generator, translator, description="This demo combines two Spaces: a text generator (`huggingface/gpt2`) and a text translator (`huggingface/t5-small`). The first Space takes a prompt as input and generates a text. The second Space takes the generated text as input and translates it into another language.") diff --git a/demo/question-answering/run.ipynb b/demo/question-answering/run.ipynb index 61c3732b70261..d634d3da8ae53 100644 --- a/demo/question-answering/run.ipynb +++ b/demo/question-answering/run.ipynb @@ -1 +1 @@ -{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: question-answering"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "context = \"The Amazon rainforest, also known in English as Amazonia or the Amazon Jungle, is a moist broadleaf forest that covers most of the Amazon basin of South America. This basin encompasses 7,000,000 square kilometres (2,700,000 sq mi), of which 5,500,000 square kilometres (2,100,000 sq mi) are covered by the rainforest. This region includes territory belonging to nine nations. The majority of the forest is contained within Brazil, with 60% of the rainforest, followed by Peru with 13%, Colombia with 10%, and with minor amounts in Venezuela, Ecuador, Bolivia, Guyana, Suriname and French Guiana. The Amazon represents over half of the planet's remaining rainforests, and comprises the largest and most biodiverse tract of tropical rainforest in the world, with an estimated 390 billion individual trees divided into 16,000 species.\"\n", "question = \"Which continent is the Amazon rainforest in?\"\n", "gr.Interface.load(\n", " \"huggingface/deepset/roberta-base-squad2\",\n", " inputs=[gr.inputs.Textbox(lines=7, default=context, label=\"Context Paragraph\"), gr.inputs.Textbox(lines=2, default=question, label=\"Question\")],\n", " outputs=[gr.outputs.Textbox(label=\"Answer\"), gr.outputs.Textbox(label=\"Score\")],\n", " title=None).launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file +{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: question-answering"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "context = \"The Amazon rainforest, also known in English as Amazonia or the Amazon Jungle, is a moist broadleaf forest that covers most of the Amazon basin of South America. This basin encompasses 7,000,000 square kilometres (2,700,000 sq mi), of which 5,500,000 square kilometres (2,100,000 sq mi) are covered by the rainforest. This region includes territory belonging to nine nations. The majority of the forest is contained within Brazil, with 60% of the rainforest, followed by Peru with 13%, Colombia with 10%, and with minor amounts in Venezuela, Ecuador, Bolivia, Guyana, Suriname and French Guiana. The Amazon represents over half of the planet's remaining rainforests, and comprises the largest and most biodiverse tract of tropical rainforest in the world, with an estimated 390 billion individual trees divided into 16,000 species.\"\n", "question = \"Which continent is the Amazon rainforest in?\"\n", "gr.load(\n", " \"huggingface/deepset/roberta-base-squad2\",\n", " inputs=[gr.inputs.Textbox(lines=7, default=context, label=\"Context Paragraph\"), gr.inputs.Textbox(lines=2, default=question, label=\"Question\")],\n", " outputs=[gr.outputs.Textbox(label=\"Answer\"), gr.outputs.Textbox(label=\"Score\")],\n", " title=None).launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/question-answering/run.py b/demo/question-answering/run.py index 1369aa37e79e2..5674c329c60d2 100644 --- a/demo/question-answering/run.py +++ b/demo/question-answering/run.py @@ -1,7 +1,7 @@ import gradio as gr context = "The Amazon rainforest, also known in English as Amazonia or the Amazon Jungle, is a moist broadleaf forest that covers most of the Amazon basin of South America. This basin encompasses 7,000,000 square kilometres (2,700,000 sq mi), of which 5,500,000 square kilometres (2,100,000 sq mi) are covered by the rainforest. This region includes territory belonging to nine nations. The majority of the forest is contained within Brazil, with 60% of the rainforest, followed by Peru with 13%, Colombia with 10%, and with minor amounts in Venezuela, Ecuador, Bolivia, Guyana, Suriname and French Guiana. The Amazon represents over half of the planet's remaining rainforests, and comprises the largest and most biodiverse tract of tropical rainforest in the world, with an estimated 390 billion individual trees divided into 16,000 species." question = "Which continent is the Amazon rainforest in?" -gr.Interface.load( +gr.load( "huggingface/deepset/roberta-base-squad2", inputs=[gr.inputs.Textbox(lines=7, default=context, label="Context Paragraph"), gr.inputs.Textbox(lines=2, default=question, label="Question")], outputs=[gr.outputs.Textbox(label="Answer"), gr.outputs.Textbox(label="Score")], diff --git a/demo/stt_or_tts/run.ipynb b/demo/stt_or_tts/run.ipynb index 028c02cbfabe6..0e7da7abeb49b 100644 --- a/demo/stt_or_tts/run.ipynb +++ b/demo/stt_or_tts/run.ipynb @@ -1 +1 @@ -{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: stt_or_tts"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "title = \"GPT-J-6B\"\n", "\n", "tts_examples = [\n", " \"I love learning machine learning\",\n", " \"How do you do?\",\n", "]\n", "\n", "tts_demo = gr.Interface.load(\n", " \"huggingface/facebook/fastspeech2-en-ljspeech\",\n", " title=None,\n", " examples=tts_examples,\n", " description=\"Give me something to say!\",\n", ")\n", "\n", "stt_demo = gr.Interface.load(\n", " \"huggingface/facebook/wav2vec2-base-960h\",\n", " title=None,\n", " inputs=\"mic\",\n", " description=\"Let me try to guess what you're saying!\",\n", ")\n", "\n", "demo = gr.TabbedInterface([tts_demo, stt_demo], [\"Text-to-speech\", \"Speech-to-text\"])\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file +{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: stt_or_tts"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "title = \"GPT-J-6B\"\n", "\n", "tts_examples = [\n", " \"I love learning machine learning\",\n", " \"How do you do?\",\n", "]\n", "\n", "tts_demo = gr.load(\n", " \"huggingface/facebook/fastspeech2-en-ljspeech\",\n", " title=None,\n", " examples=tts_examples,\n", " description=\"Give me something to say!\",\n", ")\n", "\n", "stt_demo = gr.load(\n", " \"huggingface/facebook/wav2vec2-base-960h\",\n", " title=None,\n", " inputs=\"mic\",\n", " description=\"Let me try to guess what you're saying!\",\n", ")\n", "\n", "demo = gr.TabbedInterface([tts_demo, stt_demo], [\"Text-to-speech\", \"Speech-to-text\"])\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/stt_or_tts/run.py b/demo/stt_or_tts/run.py index 55f86ce01829c..8235d24ed33e7 100644 --- a/demo/stt_or_tts/run.py +++ b/demo/stt_or_tts/run.py @@ -7,14 +7,14 @@ "How do you do?", ] -tts_demo = gr.Interface.load( +tts_demo = gr.load( "huggingface/facebook/fastspeech2-en-ljspeech", title=None, examples=tts_examples, description="Give me something to say!", ) -stt_demo = gr.Interface.load( +stt_demo = gr.load( "huggingface/facebook/wav2vec2-base-960h", title=None, inputs="mic", diff --git a/gradio/__init__.py b/gradio/__init__.py index da895876cb9e2..cf7881f5a14d7 100644 --- a/gradio/__init__.py +++ b/gradio/__init__.py @@ -54,6 +54,7 @@ ) from gradio.events import SelectData from gradio.exceptions import Error +from gradio.external import load from gradio.flagging import ( CSVLogger, FlaggingCallback, diff --git a/gradio/blocks.py b/gradio/blocks.py index a350305a94385..03018f703eb02 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -17,6 +17,7 @@ import anyio import requests from anyio import CapacityLimiter +from gradio_client import utils as client_utils from typing_extensions import Literal from gradio import components, external, networking, queueing, routes, strings, utils @@ -634,7 +635,7 @@ def iterate_over_children(children_list): # add the event triggers for dependency, fn in zip(config["dependencies"], fns): # We used to add a "fake_event" to the config to cache examples - # without removing it. This was causing bugs in calling gr.Interface.load + # without removing it. This was causing bugs in calling gr.load # We fixed the issue by removing "fake_event" from the config in examples.py # but we still need to skip these events when loading the config to support # older demos @@ -808,7 +809,7 @@ def __call__(self, *inputs, fn_index: int = 0, api_name: str | None = None): if batch: processed_inputs = [[inp] for inp in processed_inputs] - outputs = utils.synchronize_async( + outputs = client_utils.synchronize_async( self.process_api, fn_index=fn_index, inputs=processed_inputs, @@ -940,7 +941,9 @@ def deserialize_data(self, fn_index: int, outputs: List[Any]) -> List[Any]: assert isinstance( block, components.IOComponent ), f"{block.__class__} Component with id {output_id} not a valid output component." - deserialized = block.deserialize(outputs[o], root_url=block.root_url) + deserialized = block.deserialize( + outputs[o], root_url=block.root_url, hf_token=Context.hf_token + ) predictions.append(deserialized) return predictions @@ -1128,15 +1131,16 @@ def getLayout(block): config["layout"] = getLayout(self) for _id, block in self.blocks.items(): - config["components"].append( - { - "id": _id, - "type": (block.get_block_name()), - "props": utils.delete_none(block.get_config()) - if hasattr(block, "get_config") - else {}, - } - ) + props = block.get_config() if hasattr(block, "get_config") else {} + block_config = { + "id": _id, + "type": block.get_block_name(), + "props": utils.delete_none(props), + } + serializer = utils.get_serializer_name(block) + if serializer: + block_config["serializer"] = serializer + config["components"].append(block_config) config["dependencies"] = self.dependencies return config @@ -1188,7 +1192,7 @@ def load( method, the two of which, confusingly, do two completely different things. - Class method: loads a demo from a Hugging Face Spaces repo and creates it locally and returns a block instance. Equivalent to gradio.Interface.load() + Class method: loads a demo from a Hugging Face Spaces repo and creates it locally and returns a block instance. Warning: this method will be deprecated. Use the equivalent `gradio.load()` instead. Instance method: adds event that runs as soon as the demo loads in the browser. Example usage below. @@ -1221,11 +1225,14 @@ def get_time(): """ # _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components. if isinstance(self_or_cls, type): + warnings.warn("gr.Blocks.load() will be deprecated. Use gr.load() instead.") if name is None: raise ValueError( "Blocks.load() requires passing parameters as keyword arguments" ) - return external.load_blocks_from_repo(name, src, api_key, alias, **kwargs) + return external.load( + name=name, src=src, hf_token=api_key, alias=alias, **kwargs + ) else: return self_or_cls.set_event_trigger( event_name="load", diff --git a/gradio/components.py b/gradio/components.py index 88d26704659b5..fb5b87db9fd93 100644 --- a/gradio/components.py +++ b/gradio/components.py @@ -14,7 +14,6 @@ import shutil import tempfile import urllib.request -import uuid import warnings from copy import deepcopy from enum import Enum @@ -32,6 +31,15 @@ import requests from fastapi import UploadFile from ffmpy import FFmpeg +from gradio_client import utils as client_utils +from gradio_client.serializing import ( + FileSerializable, + GallerySerializable, + ImgSerializable, + JSONSerializable, + Serializable, + SimpleSerializable, +) from pandas.api.types import is_numeric_dtype from PIL import Image as _Image # using _ to minimize namespace pollution from typing_extensions import Literal @@ -56,13 +64,6 @@ ) from gradio.interpretation import NeighborInterpretable, TokenInterpretable from gradio.layouts import Column, Form, Row -from gradio.serializing import ( - FileSerializable, - ImgSerializable, - JSONSerializable, - Serializable, - SimpleSerializable, -) if TYPE_CHECKING: from typing import TypedDict @@ -81,7 +82,7 @@ class _Keywords(Enum): FINISHED_ITERATING = "FINISHED_ITERATING" # Used to skip processing of a component's value (needed for generators + state) -class Component(Block): +class Component(Block, Serializable): """ A base class for defining the methods that all gradio components should have. """ @@ -163,7 +164,7 @@ def style( return self -class IOComponent(Component, Serializable): +class IOComponent(Component): """ A base class for defining methods that all input/output components should have. """ @@ -242,7 +243,7 @@ def make_temp_copy_if_needed(self, file_path: str) -> str: temp_dir.mkdir(exist_ok=True, parents=True) f = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) - f.name = utils.strip_invalid_filename_characters(Path(file_path).name) + f.name = client_utils.strip_invalid_filename_characters(Path(file_path).name) full_temp_file_path = str(utils.abspath(temp_dir / f.name)) if not Path(full_temp_file_path).exists(): @@ -261,7 +262,9 @@ async def save_uploaded_file(self, file: UploadFile, upload_dir: str) -> str: if file.filename: file_name = Path(file.filename).name - output_file_obj.name = utils.strip_invalid_filename_characters(file_name) + output_file_obj.name = client_utils.strip_invalid_filename_characters( + file_name + ) full_temp_file_path = str(utils.abspath(temp_dir / output_file_obj.name)) @@ -282,7 +285,7 @@ def download_temp_copy_if_needed(self, url: str) -> str: temp_dir.mkdir(exist_ok=True, parents=True) f = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) - f.name = utils.strip_invalid_filename_characters(Path(url).name) + f.name = client_utils.strip_invalid_filename_characters(Path(url).name) full_temp_file_path = str(utils.abspath(temp_dir / f.name)) if not Path(full_temp_file_path).exists(): @@ -302,19 +305,19 @@ def base64_to_temp_file_if_needed( temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir temp_dir.mkdir(exist_ok=True, parents=True) - guess_extension = processing_utils.get_extension(base64_encoding) + guess_extension = client_utils.get_extension(base64_encoding) if file_name: - file_name = utils.strip_invalid_filename_characters(file_name) + file_name = client_utils.strip_invalid_filename_characters(file_name) elif guess_extension: file_name = "file." + guess_extension else: file_name = "file" f = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) - f.name = file_name - full_temp_file_path = str(utils.abspath(temp_dir / f.name)) + f.name = file_name # type: ignore + full_temp_file_path = str(utils.abspath(temp_dir / file_name)) # type: ignore if not Path(full_temp_file_path).exists(): - data, _ = processing_utils.decode_base64_to_binary(base64_encoding) + data, _ = client_utils.decode_base64_to_binary(base64_encoding) with open(full_temp_file_path, "wb") as fb: fb.write(data) @@ -1754,7 +1757,7 @@ def postprocess( elif isinstance(y, _Image.Image): return processing_utils.encode_pil_to_base64(y) elif isinstance(y, (str, Path)): - return processing_utils.encode_url_or_file_to_base64(y) + return client_utils.encode_url_or_file_to_base64(y) else: raise ValueError("Cannot process this value as an Image") @@ -2318,7 +2321,7 @@ def tokenize(self, x): leave_one_out_data[start:stop] = 0 file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") processing_utils.audio_to_file(sample_rate, leave_one_out_data, file.name) - out_data = processing_utils.encode_file_to_base64(file.name) + out_data = client_utils.encode_file_to_base64(file.name) leave_one_out_sets.append(out_data) file.close() Path(file.name).unlink() @@ -2329,7 +2332,7 @@ def tokenize(self, x): token[stop:] = 0 file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") processing_utils.audio_to_file(sample_rate, token, file.name) - token_data = processing_utils.encode_file_to_base64(file.name) + token_data = client_utils.encode_file_to_base64(file.name) file.close() Path(file.name).unlink() @@ -2360,7 +2363,7 @@ def get_masked_inputs(self, tokens, binary_mask_matrix): masked_input = masked_input + t * int(b) file = tempfile.NamedTemporaryFile(delete=False) processing_utils.audio_to_file(sample_rate, masked_input, file.name) - masked_data = processing_utils.encode_file_to_base64(file.name) + masked_data = client_utils.encode_file_to_base64(file.name) file.close() Path(file.name).unlink() masked_inputs.append(masked_data) @@ -2582,9 +2585,7 @@ def process_single_file(f) -> bytes | tempfile._TemporaryFileWrapper: file.name = temp_file_path file.orig_name = file_name # type: ignore else: - file = processing_utils.decode_base64_to_file( - data, file_path=file_name - ) + file = client_utils.decode_base64_to_file(data, file_path=file_name) file.orig_name = file_name # type: ignore self.temp_files.add(str(utils.abspath(file.name))) return file @@ -2594,7 +2595,7 @@ def process_single_file(f) -> bytes | tempfile._TemporaryFileWrapper: if is_file: with open(file_name, "rb") as file_data: return file_data.read() - return processing_utils.decode_base64_to_binary(data)[0] + return client_utils.decode_base64_to_binary(data)[0] else: raise ValueError( "Unknown type: " @@ -3322,9 +3323,7 @@ def process_single_file(f) -> bytes | tempfile._TemporaryFileWrapper: file.name = temp_file_path file.orig_name = file_name # type: ignore else: - file = processing_utils.decode_base64_to_file( - data, file_path=file_name - ) + file = client_utils.decode_base64_to_file(data, file_path=file_name) file.orig_name = file_name # type: ignore self.temp_files.add(str(utils.abspath(file.name))) return file @@ -3332,7 +3331,7 @@ def process_single_file(f) -> bytes | tempfile._TemporaryFileWrapper: if is_file: with open(file_name, "rb") as file_data: return file_data.read() - return processing_utils.decode_base64_to_binary(data)[0] + return client_utils.decode_base64_to_binary(data)[0] else: raise ValueError( "Unknown type: " @@ -3969,7 +3968,7 @@ def style(self): @document("style") -class Gallery(IOComponent, FileSerializable, Selectable): +class Gallery(IOComponent, GallerySerializable, Selectable): """ Used to display a list of images as a gallery that can be scrolled through. Preprocessing: this component does *not* accept input. @@ -4111,41 +4110,6 @@ def style( Component.style(self, container=container, **kwargs) return self - def deserialize( - self, - x: Any, - save_dir: str = "", - root_url: str | None = None, - ) -> None | str: - if x is None: - return None - gallery_path = Path(save_dir) / str(uuid.uuid4()) - gallery_path.mkdir(exist_ok=True, parents=True) - captions = {} - for img_data in x: - if isinstance(img_data, list) or isinstance(img_data, tuple): - img_data, caption = img_data - else: - caption = None - name = FileSerializable.deserialize( - self, img_data, gallery_path, root_url=root_url - ) - captions[name] = caption - captions_file = gallery_path / "captions.json" - with captions_file.open("w") as captions_json: - json.dump(captions, captions_json) - return str(utils.abspath(gallery_path)) - - def serialize(self, x: Any, load_dir: str = "", called_directly: bool = False): - files = [] - captions_file = Path(x) / "captions.json" - with captions_file.open("r") as captions_json: - captions = json.load(captions_json) - for file_name, caption in captions.items(): - img = FileSerializable.serialize(self, file_name) - files.append([img, caption]) - return files - class Carousel(IOComponent, Changeable, SimpleSerializable): """ @@ -4289,7 +4253,7 @@ def _postprocess_chat_messages( return None elif isinstance(chat_message, (tuple, list)): filepath = chat_message[0] - mime_type = processing_utils.get_mimetype(filepath) + mime_type = client_utils.get_mimetype(filepath) filepath = self.make_temp_copy_if_needed(filepath) return { "name": filepath, @@ -4552,7 +4516,7 @@ def postprocess(self, y) -> Dict[str, str] | None: """ if y is None: return None - if isinstance(y, (ModuleType, matplotlib.figure.Figure)): + if isinstance(y, (ModuleType, matplotlib.figure.Figure)): # type: ignore dtype = "matplotlib" out_y = processing_utils.encode_plot_to_base64(y) elif "bokeh" in y.__module__: @@ -5779,7 +5743,7 @@ def style(self): @document("style") -class Dataset(Clickable, Selectable, Component): +class Dataset(Clickable, Selectable, Component, SimpleSerializable): """ Used to create an output widget for showing datasets. Used to render the examples box. @@ -5887,7 +5851,7 @@ def style(self, **kwargs): @document() -class Interpretation(Component): +class Interpretation(Component, SimpleSerializable): """ Used to create an interpretation widget for a component. Preprocessing: this component does *not* accept input. @@ -5938,7 +5902,7 @@ def style(self): return self -class StatusTracker(Component): +class StatusTracker(Component, SimpleSerializable): def __init__( self, **kwargs, diff --git a/gradio/context.py b/gradio/context.py index 6048312eb0179..18eecbc84ec8f 100644 --- a/gradio/context.py +++ b/gradio/context.py @@ -13,6 +13,4 @@ class Context: block: BlockContext | None = None # The current block that children are added to. id: int = 0 # Running id to uniquely refer to any block that gets defined ip_address: str | None = None # The IP address of the user. - access_token: str | None = ( - None # The HF token that is provided when loading private models or Spaces - ) + hf_token: str | None = None # The token provided when loading private HF repos diff --git a/gradio/external.py b/gradio/external.py index 8e6686470910b..2f1a12536214b 100644 --- a/gradio/external.py +++ b/gradio/external.py @@ -1,30 +1,28 @@ """This module should not be used directly as its API is subject to change. Instead, -use the `gr.Blocks.load()` or `gr.Interface.load()` functions.""" +use the `gr.Blocks.load()` or `gr.load()` functions.""" from __future__ import annotations import json import re -import uuid import warnings -from copy import deepcopy from typing import TYPE_CHECKING, Callable, Dict import requests +from gradio_client import Client import gradio from gradio import components, utils from gradio.context import Context +from gradio.documentation import document, set_documentation_group from gradio.exceptions import Error, TooManyRequestsError from gradio.external_utils import ( cols_to_rows, encode_to_base64, get_tabular_examples, - get_ws_fn, postprocess_label, rows_to_cols, streamline_spaces_interface, - use_websocket, ) from gradio.processing_utils import to_binary @@ -33,6 +31,45 @@ from gradio.interface import Interface +set_documentation_group("helpers") + + +@document() +def load( + name: str, + src: str | None = None, + api_key: str | None = None, + hf_token: str | None = None, + alias: str | None = None, + **kwargs, +) -> Blocks: + """ + Method that constructs a Blocks from a Hugging Face repo. Can accept + model repos (if src is "models") or Space repos (if src is "spaces"). The input + and output components are automatically loaded from the repo. + Parameters: + name: the name of the model (e.g. "gpt2" or "facebook/bart-base") or space (e.g. "flax-community/spanish-gpt2"), can include the `src` as prefix (e.g. "models/facebook/bart-base") + src: the source of the model: `models` or `spaces` (or leave empty if source is provided as a prefix in `name`) + api_key: Deprecated. Please use the `hf_token` parameter instead. + hf_token: optional access token for loading private Hugging Face Hub models or spaces. Find your token here: https://huggingface.co/settings/tokens + alias: optional string used as the name of the loaded model instead of the default name (only applies if loading a Space running Gradio 2.x) + Returns: + a Gradio Blocks object for the given model + Example: + import gradio as gr + demo = gr.load("gradio/question-answering", src="spaces") + demo.launch() + """ + if hf_token is None and api_key: + warnings.warn( + "The `api_key` parameter will be deprecated. Please use the `hf_token` parameter going forward." + ) + hf_token = api_key + return load_blocks_from_repo( + name=name, src=src, api_key=hf_token, alias=alias, **kwargs + ) + + def load_blocks_from_repo( name: str, src: str | None = None, @@ -61,11 +98,11 @@ def load_blocks_from_repo( ) if api_key is not None: - if Context.access_token is not None and Context.access_token != api_key: + if Context.hf_token is not None and Context.hf_token != api_key: warnings.warn( """You are loading a model/Space with a different access token than the one you used to load a previous model/Space. This is not recommended, as it may cause unexpected behavior.""" ) - Context.access_token = api_key + Context.hf_token = api_key blocks: gradio.Blocks = factory_methods[src](name, api_key, alias, **kwargs) return blocks @@ -411,58 +448,13 @@ def from_spaces( "Blocks or Interface locally. You may find this Guide helpful: " "https://gradio.app/using_blocks_like_functions/" ) - return from_spaces_blocks(config, api_key, iframe_url) + return from_spaces_blocks(space=space_name, api_key=api_key) -def from_spaces_blocks(config: Dict, api_key: str | None, iframe_url: str) -> Blocks: - api_url = "{}/api/predict/".format(iframe_url) - - headers = {"Content-Type": "application/json"} - if api_key is not None: - headers["Authorization"] = f"Bearer {api_key}" - ws_url = "{}/queue/join".format(iframe_url).replace("https", "wss") - - ws_fn = get_ws_fn(ws_url, headers) - - fns = [] - for d, dependency in enumerate(config["dependencies"]): - if dependency["backend_fn"]: - - def get_fn(outputs, fn_index, use_ws): - def fn(*data): - data = json.dumps({"data": data, "fn_index": fn_index}) - hash_data = json.dumps( - {"fn_index": fn_index, "session_hash": str(uuid.uuid4())} - ) - if use_ws: - result = utils.synchronize_async(ws_fn, data, hash_data) - output = result["data"] - else: - response = requests.post(api_url, headers=headers, data=data) - result = json.loads(response.content.decode("utf-8")) - try: - output = result["data"] - except KeyError: - if "error" in result and "429" in result["error"]: - raise TooManyRequestsError( - "Too many requests to the Hugging Face API" - ) - raise KeyError( - f"Could not find 'data' key in response from external Space. Response received: {result}" - ) - if len(outputs) == 1: - output = output[0] - return output - - return fn - - fn = get_fn( - deepcopy(dependency["outputs"]), d, use_websocket(config, dependency) - ) - fns.append(fn) - else: - fns.append(None) - return gradio.Blocks.from_config(config, fns, iframe_url) +def from_spaces_blocks(space: str, api_key: str | None) -> Blocks: + client = Client(space=space, hf_token=api_key) + predict_fns = [endpoint._predict_resolve for endpoint in client.endpoints] + return gradio.Blocks.from_config(client.config, predict_fns, client.src) def from_spaces_interface( diff --git a/gradio/external_utils.py b/gradio/external_utils.py index 82294add0fe12..cb402e98d99c0 100644 --- a/gradio/external_utils.py +++ b/gradio/external_utils.py @@ -1,20 +1,16 @@ """Utility function for gradio/external.py""" import base64 -import json import math import operator import re import warnings -from typing import Any, Dict, List, Tuple +from typing import Dict, List, Tuple import requests -import websockets import yaml -from packaging import version -from websockets.legacy.protocol import WebSocketCommonProtocol -from gradio import components, exceptions +from gradio import components ################## # Helper functions for processing tabular data @@ -116,48 +112,6 @@ def encode_to_base64(r: requests.Response) -> str: return new_base64 -################## -# Helper functions for connecting to websockets -################## - - -async def get_pred_from_ws( - websocket: WebSocketCommonProtocol, data: str, hash_data: str -) -> Dict[str, Any]: - completed = False - resp = {} - while not completed: - msg = await websocket.recv() - resp = json.loads(msg) - if resp["msg"] == "queue_full": - raise exceptions.Error("Queue is full! Please try again.") - if resp["msg"] == "send_hash": - await websocket.send(hash_data) - elif resp["msg"] == "send_data": - await websocket.send(data) - completed = resp["msg"] == "process_completed" - return resp["output"] - - -def get_ws_fn(ws_url, headers): - async def ws_fn(data, hash_data): - async with websockets.connect( # type: ignore - ws_url, open_timeout=10, extra_headers=headers - ) as websocket: - return await get_pred_from_ws(websocket, data, hash_data) - - return ws_fn - - -def use_websocket(config, dependency): - queue_enabled = config.get("enable_queue", False) - queue_uses_websocket = version.parse( - config.get("version", "2.0") - ) >= version.Version("3.2") - dependency_uses_queue = dependency.get("queue", False) is not False - return queue_enabled and queue_uses_websocket and dependency_uses_queue - - ################## # Helper function for cleaning up an Interface loaded from HF Spaces ################## diff --git a/gradio/flagging.py b/gradio/flagging.py index cc19927d9812e..b22eb1f21663c 100644 --- a/gradio/flagging.py +++ b/gradio/flagging.py @@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, Any, List import pkg_resources +from gradio_client import utils as client_utils import gradio as gr from gradio import utils @@ -139,9 +140,9 @@ def flag( csv_data = [] for component, sample in zip(self.components, flag_data): - save_dir = Path(flagging_dir) / utils.strip_invalid_filename_characters( - component.label or "" - ) + save_dir = Path( + flagging_dir + ) / client_utils.strip_invalid_filename_characters(component.label or "") csv_data.append( component.deserialize( sample, @@ -205,7 +206,9 @@ def flag( csv_data = [] for idx, (component, sample) in enumerate(zip(self.components, flag_data)): - save_dir = Path(flagging_dir) / utils.strip_invalid_filename_characters( + save_dir = Path( + flagging_dir + ) / client_utils.strip_invalid_filename_characters( getattr(component, "label", None) or f"component {idx}" ) if utils.is_update(sample): @@ -339,7 +342,9 @@ def flag( for component, sample in zip(self.components, flag_data): save_dir = Path( self.dataset_dir - ) / utils.strip_invalid_filename_characters(component.label or "") + ) / client_utils.strip_invalid_filename_characters( + component.label or "" + ) filepath = component.deserialize(sample, save_dir, None) csv_data.append(filepath) if isinstance(component, tuple(file_preview_types)): @@ -474,7 +479,9 @@ def flag( headers.append(component.label) try: - save_dir = Path(folder_name) / utils.strip_invalid_filename_characters( + save_dir = Path( + folder_name + ) / client_utils.strip_invalid_filename_characters( component.label or "" ) filepath = component.deserialize(sample, save_dir, None) diff --git a/gradio/helpers.py b/gradio/helpers.py index 3619c40c69288..493a7fe8db910 100644 --- a/gradio/helpers.py +++ b/gradio/helpers.py @@ -19,6 +19,7 @@ import numpy as np import PIL import PIL.Image +from gradio_client import utils as client_utils from gradio import processing_utils, routes, utils from gradio.context import Context @@ -67,7 +68,7 @@ def create_examples( batch=batch, _initiated_directly=False, ) - utils.synchronize_async(examples_obj.create) + client_utils.synchronize_async(examples_obj.create) return examples_obj diff --git a/gradio/interface.py b/gradio/interface.py index 284f3a3cf9e11..8096d5fe22b89 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -13,7 +13,7 @@ import weakref from typing import TYPE_CHECKING, Any, Callable, List, Tuple -from gradio import Examples, interpretation, utils +from gradio import Examples, external, interpretation, utils from gradio.blocks import Blocks from gradio.components import ( Button, @@ -77,9 +77,10 @@ def load( api_key: str | None = None, alias: str | None = None, **kwargs, - ) -> Interface: + ) -> Blocks: """ - Class method that constructs an Interface from a Hugging Face repo. Can accept + Warning: this method will be deprecated. Use the equivalent `gradio.load()` instead. This is a class + method that constructs a Blocks from a Hugging Face repo. Can accept model repos (if src is "models") or Space repos (if src is "spaces"). The input and output components are automatically loaded from the repo. Parameters: @@ -89,14 +90,11 @@ def load( alias: optional string used as the name of the loaded model instead of the default name (only applies if loading a Space running Gradio 2.x) Returns: a Gradio Interface object for the given model - Example: - import gradio as gr - description = "Story generation with GPT" - examples = [["An adventurer is approached by a mysterious stranger in the tavern for a new quest."]] - demo = gr.Interface.load("models/EleutherAI/gpt-neo-1.3B", description=description, examples=examples) - demo.launch() """ - return super().load(name=name, src=src, api_key=api_key, alias=alias, **kwargs) + warnings.warn("gr.Intrerface.load() will be deprecated. Use gr.load() instead.") + return external.load( + name=name, src=src, hf_token=api_key, alias=alias, **kwargs + ) @classmethod def from_pipeline(cls, pipeline: Pipeline, **kwargs) -> Interface: @@ -242,10 +240,10 @@ def __init__( self.cache_examples = False self.input_components = [ - get_component_instance(i, render=False) for i in inputs + get_component_instance(i, render=False) for i in inputs # type: ignore ] self.output_components = [ - get_component_instance(o, render=False) for o in outputs + get_component_instance(o, render=False) for o in outputs # type: ignore ] for component in self.input_components + self.output_components: diff --git a/gradio/interpretation.py b/gradio/interpretation.py index f48feb379e71b..a4e64c5d2433e 100644 --- a/gradio/interpretation.py +++ b/gradio/interpretation.py @@ -8,8 +8,9 @@ from typing import TYPE_CHECKING, Any, Dict, List, Tuple import numpy as np +from gradio_client import utils as client_utils -from gradio import components, utils +from gradio import components if TYPE_CHECKING: # Only import for type checking (is False at runtime). from gradio import Interface @@ -139,7 +140,9 @@ async def run_interpret(interface: Interface, raw_input: List): ( neighbor_values, interpret_kwargs, - ) = input_component.get_interpretation_neighbors(x) + ) = input_component.get_interpretation_neighbors( + x + ) # type: ignore interface_scores = [] alternative_output = [] for neighbor_input in neighbor_values: @@ -208,7 +211,7 @@ def get_masked_prediction(binary_mask): for masked_x in masked_xs: processed_masked_input = copy.deepcopy(processed_input) processed_masked_input[i] = input_component.preprocess(masked_x) - new_output = utils.synchronize_async( + new_output = client_utils.synchronize_async( interface.call_function, 0, processed_masked_input ) new_output = new_output["prediction"] diff --git a/gradio/processing_utils.py b/gradio/processing_utils.py index 2ca9e04729ec4..ad7d027a9bb3f 100644 --- a/gradio/processing_utils.py +++ b/gradio/processing_utils.py @@ -2,23 +2,19 @@ import base64 import json -import mimetypes -import os import shutil import subprocess import tempfile import warnings from io import BytesIO from pathlib import Path -from typing import Dict, Tuple +from typing import Dict import numpy as np -import requests from ffmpy import FFmpeg, FFprobe, FFRuntimeError +from gradio_client import utils as client_utils from PIL import Image, ImageOps, PngImagePlugin -from gradio import utils - with warnings.catch_warnings(): warnings.simplefilter("ignore") # Ignore pydub warning if ffmpeg is not installed from pydub import AudioSegment @@ -35,7 +31,7 @@ def to_binary(x: str | Dict) -> bytes: if x.get("data"): base64str = x["data"] else: - base64str = encode_url_or_file_to_base64(x["name"]) + base64str = client_utils.encode_url_or_file_to_base64(x["name"]) else: base64str = x return base64.b64decode(base64str.split(",")[1]) @@ -57,56 +53,6 @@ def decode_base64_to_image(encoding: str) -> Image.Image: return img -def encode_url_or_file_to_base64(path: str | Path): - path = str(path) - if utils.validate_url(path): - return encode_url_to_base64(path) - else: - return encode_file_to_base64(path) - - -def get_mimetype(filename: str) -> str | None: - mimetype = mimetypes.guess_type(filename)[0] - if mimetype is not None: - mimetype = mimetype.replace("x-wav", "wav").replace("x-flac", "flac") - return mimetype - - -def get_extension(encoding: str) -> str | None: - encoding = encoding.replace("audio/wav", "audio/x-wav") - type = mimetypes.guess_type(encoding)[0] - if type == "audio/flac": # flac is not supported by mimetypes - return "flac" - elif type is None: - return None - extension = mimetypes.guess_extension(type) - if extension is not None and extension.startswith("."): - extension = extension[1:] - return extension - - -def encode_file_to_base64(f): - with open(f, "rb") as file: - encoded_string = base64.b64encode(file.read()) - base64_str = str(encoded_string, "utf-8") - mimetype = get_mimetype(f) - return ( - "data:" - + (mimetype if mimetype is not None else "") - + ";base64," - + base64_str - ) - - -def encode_url_to_base64(url): - encoded_string = base64.b64encode(requests.get(url).content) - base64_str = str(encoded_string, "utf-8") - mimetype = get_mimetype(url) - return ( - "data:" + (mimetype if mimetype is not None else "") + ";base64," + base64_str - ) - - def encode_plot_to_base64(plt): with BytesIO() as output_bytes: plt.savefig(output_bytes, format="png") @@ -257,99 +203,6 @@ def convert_to_16_bit_wav(data): ################## -def decode_base64_to_binary(encoding) -> Tuple[bytes, str | None]: - extension = get_extension(encoding) - try: - data = encoding.split(",")[1] - except IndexError: - data = "" - return base64.b64decode(data), extension - - -def decode_base64_to_file(encoding, file_path=None, dir=None, prefix=None): - if dir is not None: - os.makedirs(dir, exist_ok=True) - data, extension = decode_base64_to_binary(encoding) - if file_path is not None and prefix is None: - filename = Path(file_path).name - prefix = filename - if "." in filename: - prefix = filename[0 : filename.index(".")] - extension = filename[filename.index(".") + 1 :] - - if prefix is not None: - prefix = utils.strip_invalid_filename_characters(prefix) - - if extension is None: - file_obj = tempfile.NamedTemporaryFile(delete=False, prefix=prefix, dir=dir) - else: - file_obj = tempfile.NamedTemporaryFile( - delete=False, - prefix=prefix, - suffix="." + extension, - dir=dir, - ) - file_obj.write(data) - file_obj.flush() - return file_obj - - -def dict_or_str_to_json_file(jsn, dir=None): - if dir is not None: - os.makedirs(dir, exist_ok=True) - - file_obj = tempfile.NamedTemporaryFile( - delete=False, suffix=".json", dir=dir, mode="w+" - ) - if isinstance(jsn, str): - jsn = json.loads(jsn) - json.dump(jsn, file_obj) - file_obj.flush() - return file_obj - - -def file_to_json(file_path: str | Path) -> Dict: - with open(file_path) as f: - return json.load(f) - - -def download_tmp_copy_of_file( - url_path: str, access_token: str | None = None, dir: str | None = None -) -> tempfile._TemporaryFileWrapper: - if dir is not None: - os.makedirs(dir, exist_ok=True) - headers = {"Authorization": "Bearer " + access_token} if access_token else {} - prefix = Path(url_path).stem - suffix = Path(url_path).suffix - file_obj = tempfile.NamedTemporaryFile( - delete=False, - prefix=prefix, - suffix=suffix, - dir=dir, - ) - with requests.get(url_path, headers=headers, stream=True) as r: - with open(file_obj.name, "wb") as f: - shutil.copyfileobj(r.raw, f) - return file_obj - - -def create_tmp_copy_of_file( - file_path: str, dir: str | None = None -) -> tempfile._TemporaryFileWrapper: - if dir is not None: - os.makedirs(dir, exist_ok=True) - prefix = Path(file_path).stem - suffix = Path(file_path).suffix - file_obj = tempfile.NamedTemporaryFile( - delete=False, - prefix=prefix, - suffix=suffix, - dir=dir, - ) - shutil.copy2(file_path, file_obj.name) - return file_obj - - def _convert(image, dtype, force_copy=False, uniform=False): """ Adapted from: https://github.com/scikit-image/scikit-image/blob/main/skimage/util/dtype.py#L510-L531 diff --git a/gradio/routes.py b/gradio/routes.py index b27cb8f9017d9..3a017983c5694 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -282,8 +282,8 @@ async def reverse_proxy(url_path: str): # Adapted from: https://github.com/tiangolo/fastapi/issues/1788 url = httpx.URL(url_path) headers = {} - if Context.access_token is not None: - headers["Authorization"] = f"Bearer {Context.access_token}" + if Context.hf_token is not None: + headers["Authorization"] = f"Bearer {Context.hf_token}" rp_req = client.build_request("GET", url, headers=headers) rp_resp = await client.send(rp_req, stream=True) return StreamingResponse( diff --git a/gradio/test_data/blocks_configs.py b/gradio/test_data/blocks_configs.py index b1701588bd6bf..d9d2df2e83079 100644 --- a/gradio/test_data/blocks_configs.py +++ b/gradio/test_data/blocks_configs.py @@ -1,11 +1,11 @@ XRAY_CONFIG = { - "version": "3.4b3\n", + "version": "3.21.0\n", "mode": "blocks", "dev_mode": True, "analytics_enabled": False, "components": [ { - "id": 27, + "id": 1, "type": "markdown", "props": { "value": "

Detect Disease From Scan

\n

With this model you can lorem ipsum

\n\n", @@ -13,9 +13,10 @@ "visible": True, "style": {}, }, + "serializer": "SimpleSerializable", }, { - "id": 28, + "id": 2, "type": "checkboxgroup", "props": { "choices": ["Covid", "Malaria", "Lung Cancer"], @@ -26,15 +27,16 @@ "visible": True, "style": {}, }, + "serializer": "SimpleSerializable", }, - {"id": 29, "type": "tabs", "props": {"visible": True, "style": {}}}, + {"id": 3, "type": "tabs", "props": {"visible": True, "style": {}}}, { - "id": 30, + "id": 4, "type": "tabitem", "props": {"label": "X-ray", "visible": True, "style": {}}, }, { - "id": 31, + "id": 5, "type": "row", "props": { "type": "row", @@ -44,7 +46,7 @@ }, }, { - "id": 32, + "id": 6, "type": "image", "props": { "image_mode": "RGB", @@ -57,14 +59,16 @@ "visible": True, "style": {}, }, + "serializer": "ImgSerializable", }, { - "id": 33, + "id": 7, "type": "json", "props": {"show_label": True, "name": "json", "visible": True, "style": {}}, + "serializer": "JSONSerializable", }, { - "id": 34, + "id": 8, "type": "button", "props": { "value": "Run", @@ -74,14 +78,15 @@ "visible": True, "style": {}, }, + "serializer": "SimpleSerializable", }, { - "id": 35, + "id": 9, "type": "tabitem", "props": {"label": "CT Scan", "visible": True, "style": {}}, }, { - "id": 36, + "id": 10, "type": "row", "props": { "type": "row", @@ -91,7 +96,7 @@ }, }, { - "id": 37, + "id": 11, "type": "image", "props": { "image_mode": "RGB", @@ -104,26 +109,29 @@ "visible": True, "style": {}, }, + "serializer": "ImgSerializable", }, { - "id": 38, + "id": 12, "type": "json", "props": {"show_label": True, "name": "json", "visible": True, "style": {}}, + "serializer": "JSONSerializable", }, { - "id": 39, + "id": 13, "type": "button", "props": { "value": "Run", "variant": "secondary", - "name": "button", "interactive": True, + "name": "button", "visible": True, "style": {}, }, + "serializer": "SimpleSerializable", }, { - "id": 40, + "id": 14, "type": "textbox", "props": { "lines": 1, @@ -135,14 +143,15 @@ "visible": True, "style": {}, }, + "serializer": "SimpleSerializable", }, { - "id": 41, + "id": 15, "type": "form", "props": {"type": "form", "visible": True, "style": {}}, }, { - "id": 42, + "id": 16, "type": "form", "props": {"type": "form", "visible": True, "style": {}}, }, @@ -151,73 +160,74 @@ "title": "Gradio", "is_space": False, "enable_queue": None, - "show_error": False, + "show_error": True, "show_api": True, + "is_colab": False, "layout": { - "id": 26, + "id": 0, "children": [ - {"id": 27}, - {"id": 41, "children": [{"id": 28}]}, + {"id": 1}, + {"id": 15, "children": [{"id": 2}]}, { - "id": 29, + "id": 3, "children": [ { - "id": 30, + "id": 4, "children": [ - {"id": 31, "children": [{"id": 32}, {"id": 33}]}, - {"id": 34}, + {"id": 5, "children": [{"id": 6}, {"id": 7}]}, + {"id": 8}, ], }, { - "id": 35, + "id": 9, "children": [ - {"id": 36, "children": [{"id": 37}, {"id": 38}]}, - {"id": 39}, + {"id": 10, "children": [{"id": 11}, {"id": 12}]}, + {"id": 13}, ], }, ], }, - {"id": 42, "children": [{"id": 40}]}, + {"id": 16, "children": [{"id": 14}]}, ], }, "dependencies": [ { - "targets": [34], + "targets": [8], "trigger": "click", - "inputs": [28, 32], - "outputs": [33], + "inputs": [2, 6], + "outputs": [7], "backend_fn": True, "js": None, "queue": None, "api_name": None, "scroll_to_output": False, "show_progress": True, + "every": None, "batch": False, "max_batch_size": 4, "cancels": [], - "every": None, - "collects_event_data": False, "types": {"continuous": False, "generator": False}, + "collects_event_data": False, "trigger_after": None, "trigger_only_on_success": False, }, { - "targets": [39], + "targets": [13], "trigger": "click", - "inputs": [28, 37], - "outputs": [38], + "inputs": [2, 11], + "outputs": [12], "backend_fn": True, "js": None, "queue": None, "api_name": None, "scroll_to_output": False, "show_progress": True, + "every": None, "batch": False, "max_batch_size": 4, "cancels": [], - "every": None, - "collects_event_data": False, "types": {"continuous": False, "generator": False}, + "collects_event_data": False, "trigger_after": None, "trigger_only_on_success": False, }, @@ -225,19 +235,19 @@ "targets": [], "trigger": "load", "inputs": [], - "outputs": [40], + "outputs": [14], "backend_fn": True, "js": None, "queue": None, "api_name": None, "scroll_to_output": False, "show_progress": True, + "every": None, "batch": False, "max_batch_size": 4, "cancels": [], - "every": None, - "collects_event_data": False, "types": {"continuous": False, "generator": False}, + "collects_event_data": False, "trigger_after": None, "trigger_only_on_success": False, }, @@ -246,13 +256,13 @@ XRAY_CONFIG_DIFF_IDS = { - "version": "3.4b3\n", + "version": "3.21.0\n", "mode": "blocks", - "analytics_enabled": False, "dev_mode": True, + "analytics_enabled": False, "components": [ { - "id": 27, + "id": 1, "type": "markdown", "props": { "value": "

Detect Disease From Scan

\n

With this model you can lorem ipsum

\n\n", @@ -260,9 +270,10 @@ "visible": True, "style": {}, }, + "serializer": "SimpleSerializable", }, { - "id": 28, + "id": 2, "type": "checkboxgroup", "props": { "choices": ["Covid", "Malaria", "Lung Cancer"], @@ -273,15 +284,16 @@ "visible": True, "style": {}, }, + "serializer": "SimpleSerializable", }, - {"id": 29, "type": "tabs", "props": {"visible": True, "style": {}}}, + {"id": 3, "type": "tabs", "props": {"visible": True, "style": {}}}, { - "id": 30, + "id": 4, "type": "tabitem", "props": {"label": "X-ray", "visible": True, "style": {}}, }, { - "id": 31, + "id": 5, "type": "row", "props": { "type": "row", @@ -291,7 +303,7 @@ }, }, { - "id": 32, + "id": 6, "type": "image", "props": { "image_mode": "RGB", @@ -304,14 +316,16 @@ "visible": True, "style": {}, }, + "serializer": "ImgSerializable", }, { - "id": 33, + "id": 7, "type": "json", "props": {"show_label": True, "name": "json", "visible": True, "style": {}}, + "serializer": "JSONSerializable", }, { - "id": 34, + "id": 8, "type": "button", "props": { "value": "Run", @@ -321,14 +335,15 @@ "visible": True, "style": {}, }, + "serializer": "SimpleSerializable", }, { - "id": 35, + "id": 9, "type": "tabitem", "props": {"label": "CT Scan", "visible": True, "style": {}}, }, { - "id": 36, + "id": 10, "type": "row", "props": { "type": "row", @@ -338,7 +353,7 @@ }, }, { - "id": 37, + "id": 11, "type": "image", "props": { "image_mode": "RGB", @@ -351,14 +366,16 @@ "visible": True, "style": {}, }, + "serializer": "ImgSerializable", }, { - "id": 38, + "id": 1212, "type": "json", "props": {"show_label": True, "name": "json", "visible": True, "style": {}}, + "serializer": "JSONSerializable", }, { - "id": 933, + "id": 13, "type": "button", "props": { "value": "Run", @@ -368,9 +385,10 @@ "visible": True, "style": {}, }, + "serializer": "SimpleSerializable", }, { - "id": 40, + "id": 14, "type": "textbox", "props": { "lines": 1, @@ -382,14 +400,15 @@ "visible": True, "style": {}, }, + "serializer": "SimpleSerializable", }, { - "id": 41, + "id": 15, "type": "form", "props": {"type": "form", "visible": True, "style": {}}, }, { - "id": 42, + "id": 16, "type": "form", "props": {"type": "form", "visible": True, "style": {}}, }, @@ -398,73 +417,74 @@ "title": "Gradio", "is_space": False, "enable_queue": None, - "show_error": False, + "show_error": True, "show_api": True, + "is_colab": False, "layout": { - "id": 26, + "id": 0, "children": [ - {"id": 27}, - {"id": 41, "children": [{"id": 28}]}, + {"id": 1}, + {"id": 15, "children": [{"id": 2}]}, { - "id": 29, + "id": 3, "children": [ { - "id": 30, + "id": 4, "children": [ - {"id": 31, "children": [{"id": 32}, {"id": 33}]}, - {"id": 34}, + {"id": 5, "children": [{"id": 6}, {"id": 7}]}, + {"id": 8}, ], }, { - "id": 35, + "id": 9, "children": [ - {"id": 36, "children": [{"id": 37}, {"id": 38}]}, - {"id": 933}, + {"id": 10, "children": [{"id": 11}, {"id": 1212}]}, + {"id": 13}, ], }, ], }, - {"id": 42, "children": [{"id": 40}]}, + {"id": 16, "children": [{"id": 14}]}, ], }, "dependencies": [ { - "targets": [34], + "targets": [8], "trigger": "click", - "inputs": [28, 32], - "outputs": [33], + "inputs": [2, 6], + "outputs": [7], "backend_fn": True, "js": None, "queue": None, "api_name": None, "scroll_to_output": False, "show_progress": True, + "every": None, "batch": False, "max_batch_size": 4, "cancels": [], - "every": None, - "collects_event_data": False, "types": {"continuous": False, "generator": False}, + "collects_event_data": False, "trigger_after": None, "trigger_only_on_success": False, }, { - "targets": [933], + "targets": [13], "trigger": "click", - "inputs": [28, 37], - "outputs": [38], + "inputs": [2, 11], + "outputs": [1212], "backend_fn": True, "js": None, "queue": None, "api_name": None, "scroll_to_output": False, "show_progress": True, + "every": None, "batch": False, "max_batch_size": 4, "cancels": [], - "every": None, - "collects_event_data": False, "types": {"continuous": False, "generator": False}, + "collects_event_data": False, "trigger_after": None, "trigger_only_on_success": False, }, @@ -472,19 +492,19 @@ "targets": [], "trigger": "load", "inputs": [], - "outputs": [40], + "outputs": [14], "backend_fn": True, "js": None, "queue": None, "api_name": None, "scroll_to_output": False, "show_progress": True, + "every": None, "batch": False, "max_batch_size": 4, "cancels": [], - "every": None, - "collects_event_data": False, "types": {"continuous": False, "generator": False}, + "collects_event_data": False, "trigger_after": None, "trigger_only_on_success": False, }, diff --git a/gradio/utils.py b/gradio/utils.py index 56b6b7d863339..5147b019e4489 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -4,6 +4,7 @@ import asyncio import copy +import functools import inspect import json import json.decoder @@ -29,7 +30,6 @@ Dict, Generator, List, - NewType, Tuple, Type, TypeVar, @@ -37,7 +37,6 @@ ) import aiohttp -import fsspec.asyn import httpx import matplotlib.pyplot as plt import requests @@ -45,14 +44,14 @@ from markdown_it import MarkdownIt from mdit_py_plugins.dollarmath.index import dollarmath_plugin from mdit_py_plugins.footnote.index import footnote_plugin -from pydantic import BaseModel, Json, parse_obj_as +from pydantic import BaseModel, parse_obj_as import gradio from gradio.context import Context from gradio.strings import en if TYPE_CHECKING: # Only import for type checking (is False at runtime). - from gradio.blocks import BlockContext + from gradio.blocks import Block, BlockContext from gradio.components import Component analytics_url = "https://api.gradio.app/" @@ -498,24 +497,6 @@ def component_or_layout_class(cls_name: str) -> Type[Component] | Type[BlockCont raise ValueError(f"No such component or layout: {cls_name}") -def synchronize_async(func: Callable, *args, **kwargs) -> Any: - """ - Runs async functions in sync scopes. - - Can be used in any scope. See run_coro_in_background for more details. - - Example: - if inspect.iscoroutinefunction(block_fn.fn): - predictions = utils.synchronize_async(block_fn.fn, *processed_input) - - Args: - func: - *args: - **kwargs: - """ - return fsspec.asyn.sync(fsspec.asyn.get_loop(), func, *args, **kwargs) - - def run_coro_in_background(func: Callable, *args, **kwargs): """ Runs coroutines in background. @@ -571,7 +552,6 @@ class AsyncRequest: You can see example usages in test_utils.py. """ - ResponseJson = NewType("ResponseJson", Json) client = httpx.AsyncClient() class Method(str, Enum): @@ -674,9 +654,7 @@ def _create_request(method: Method, url: str, **kwargs) -> httpx.Request: request = httpx.Request(method, url, **kwargs) return request - def _validate_response_data( - self, response: ResponseJson - ) -> Union[BaseModel, ResponseJson | None]: + def _validate_response_data(self, response): """ Validate response using given validation methods. If there is a validation method and response is not valid, validation functions will raise an exception for them. @@ -705,7 +683,7 @@ def _validate_response_data( return validated_response - def _validate_response_by_model(self, response: ResponseJson) -> BaseModel: + def _validate_response_by_model(self, response) -> BaseModel: """ Validate response json using the validation model. Args: @@ -718,9 +696,7 @@ def _validate_response_by_model(self, response: ResponseJson) -> BaseModel: validated_data = parse_obj_as(self._validation_model, response) return validated_data - def _validate_response_by_validation_function( - self, response: ResponseJson - ) -> ResponseJson | None: + def _validate_response_by_validation_function(self, response): """ Validate response json using the validation function. Args: @@ -787,19 +763,6 @@ def set_directory(path: Path | str): os.chdir(origin) -def strip_invalid_filename_characters(filename: str, max_bytes: int = 200) -> str: - """Strips invalid characters from a filename and ensures that the file_length is less than `max_bytes` bytes.""" - filename = "".join([char for char in filename if char.isalnum() or char in "._- "]) - filename_len = len(filename.encode()) - if filename_len > max_bytes: - while filename_len > max_bytes: - if len(filename) == 0: - break - filename = filename[:-1] - filename_len = len(filename.encode()) - return filename - - def sanitize_value_for_csv(value: str | Number) -> str | Number: """ Sanitizes a value that is being written to a CSV file to prevent CSV injection attacks. @@ -1009,6 +972,38 @@ def abspath(path: str | Path) -> Path: return Path(path).resolve() +def get_serializer_name(block: Block) -> str | None: + if not hasattr(block, "serialize"): + return None + + def get_class_that_defined_method(meth: Callable): + # Adapted from: https://stackoverflow.com/a/25959545/5209347 + if isinstance(meth, functools.partial): + return get_class_that_defined_method(meth.func) + if inspect.ismethod(meth) or ( + inspect.isbuiltin(meth) + and getattr(meth, "__self__", None) is not None + and getattr(meth.__self__, "__class__", None) + ): + for cls in inspect.getmro(meth.__self__.__class__): + if meth.__name__ in cls.__dict__: + return cls + meth = getattr(meth, "__func__", meth) # fallback to __qualname__ parsing + if inspect.isfunction(meth): + cls = getattr( + inspect.getmodule(meth), + meth.__qualname__.split(".", 1)[0].rsplit(".", 1)[0], + None, + ) + if isinstance(cls, type): + return cls + return getattr(meth, "__objclass__", None) + + cls = get_class_that_defined_method(block.serialize) # type: ignore + if cls: + return cls.__name__ + + def get_markdown_parser() -> MarkdownIt: md = ( MarkdownIt( diff --git a/requirements.txt b/requirements.txt index 041f6ec3a7b63..9a71de09051ea 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,12 @@ +aiofiles aiohttp altair>=4.2.0 fastapi ffmpy +gradio_client>=0.0.4 +httpx huggingface_hub>=0.13.0 +Jinja2 markdown-it-py[linkify]>=2.0.0 mdit-py-plugins<=0.3.3 markupsafe @@ -11,17 +15,12 @@ numpy orjson pandas pillow +pydantic python-multipart pydub pyyaml requests +semantic_version +typing_extensions uvicorn -Jinja2 -fsspec -httpx -pydantic websockets>=10.0 -typing_extensions -aiofiles -huggingface_hub -semantic_version \ No newline at end of file diff --git a/scripts/format_backend.sh b/scripts/format_backend.sh index dbab16eca22a6..d3eea12b628b1 100755 --- a/scripts/format_backend.sh +++ b/scripts/format_backend.sh @@ -3,6 +3,6 @@ cd "$(dirname ${0})/.." echo "Formatting the backend... Our style follows the Black code style." -python -m black gradio test -python -m isort --profile=black gradio test -python -m flake8 --ignore=E731,E501,E722,W503,E126,E203,F403 gradio test --exclude gradio/__init__.py +python -m black gradio test client/python/gradio_client +python -m isort --profile=black gradio test client/python/gradio_client +python -m flake8 --ignore=E731,E501,E722,W503,E126,E203,F403 gradio test client/python/gradio_client --exclude gradio/__init__.py,client/python/gradio_client/__init__.py diff --git a/scripts/lint_backend.sh b/scripts/lint_backend.sh index a3c39b215f984..f4d92f939cf5e 100644 --- a/scripts/lint_backend.sh +++ b/scripts/lint_backend.sh @@ -2,6 +2,6 @@ cd "$(dirname ${0})/.." -python -m black --check gradio test -python -m isort --profile=black --check-only gradio test -python -m flake8 --ignore=E731,E501,E722,W503,E126,E203,F403,F541 gradio test --exclude gradio/__init__.py \ No newline at end of file +python -m black --check gradio test client/python/gradio_client +python -m isort --profile=black --check-only gradio test client/python/gradio_client +python -m flake8 --ignore=E731,E501,E722,W503,E126,E203,F403,F541 gradio test client/python/gradio_client --exclude gradio/__init__.py,client/python/gradio_client/__init__.py \ No newline at end of file diff --git a/scripts/type_check_backend.sh b/scripts/type_check_backend.sh index d0e40eeeea391..07e69872a2ba5 100644 --- a/scripts/type_check_backend.sh +++ b/scripts/type_check_backend.sh @@ -5,4 +5,4 @@ pip_required pip install --upgrade pip pip install pyright==1.1.298 -pyright gradio/*.py +pyright gradio/*.py client/python/gradio_client/*.py diff --git a/test/test_components.py b/test/test_components.py index 514107fc524cc..bfee1c769b7f5 100644 --- a/test/test_components.py +++ b/test/test_components.py @@ -22,6 +22,7 @@ import PIL import pytest import vega_datasets +from gradio_client import utils as client_utils from scipy.io import wavfile import gradio as gr @@ -679,7 +680,7 @@ def test_component_functions(self): @pytest.mark.flaky def test_serialize_url(self): img = "https://gradio.app/assets/img/header-image.jpg" - expected = processing_utils.encode_url_or_file_to_base64(img) + expected = client_utils.encode_url_or_file_to_base64(img) assert gr.Image().serialize(img) == expected def test_in_interface_as_input(self): @@ -820,7 +821,7 @@ def test_component_functions(self): gr.Audio(type="unknown") # Output functionalities - y_audio = gr.processing_utils.decode_base64_to_file( + y_audio = client_utils.decode_base64_to_file( deepcopy(media_data.BASE64_AUDIO)["data"] ) audio_output = gr.Audio(type="filepath") @@ -879,7 +880,7 @@ def reverse_audio(audio): iface = gr.Interface(reverse_audio, "audio", "audio") reversed_file = iface("test/test_files/audio_sample.wav") reversed_reversed_file = iface(reversed_file) - reversed_reversed_data = gr.processing_utils.encode_url_or_file_to_base64( + reversed_reversed_data = client_utils.encode_url_or_file_to_base64( reversed_reversed_file ) similarity = SequenceMatcher( @@ -1985,10 +1986,8 @@ def test_gallery(self, mock_uuid): gallery = gr.Gallery() test_file_dir = Path(Path(__file__).parent, "test_files") data = [ - gr.processing_utils.encode_file_to_base64(Path(test_file_dir, "bus.png")), - gr.processing_utils.encode_file_to_base64( - Path(test_file_dir, "cheetah1.jpg") - ), + client_utils.encode_file_to_base64(Path(test_file_dir, "bus.png")), + client_utils.encode_file_to_base64(Path(test_file_dir, "cheetah1.jpg")), ] with tempfile.TemporaryDirectory() as tmpdir: diff --git a/test/test_external.py b/test/test_external.py index a3d80813d2888..6eae11f0ff9c4 100644 --- a/test/test_external.py +++ b/test/test_external.py @@ -1,6 +1,5 @@ import json import os -import sys import textwrap import warnings from pathlib import Path @@ -9,25 +8,18 @@ import pytest from fastapi.testclient import TestClient -import gradio import gradio as gr from gradio import media_data from gradio.context import Context from gradio.exceptions import InvalidApiName -from gradio.external import ( - TooManyRequestsError, - cols_to_rows, - get_tabular_examples, - use_websocket, -) -from gradio.external_utils import get_pred_from_ws +from gradio.external import TooManyRequestsError, cols_to_rows, get_tabular_examples """ WARNING: These tests have an external dependency: namely that Hugging Face's Hub and Space APIs do not change, and they keep their most famous models up. So if, e.g. Spaces is down, then these test will not pass. -These tests actually test gr.Interface.load() and gr.Blocks.load() but are +These tests actually test gr.load() and gr.Blocks.load() but are included in a separate file because of the above-mentioned dependency. """ @@ -40,7 +32,7 @@ class TestLoadInterface: def test_audio_to_audio(self): model_type = "audio-to-audio" - interface = gr.Interface.load( + interface = gr.load( name="speechbrain/mtl-mimic-voicebank", src="models", alias=model_type, @@ -62,7 +54,7 @@ def test_question_answering(self): def test_text_generation(self): model_type = "text_generation" - interface = gr.Interface.load( + interface = gr.load( "models/gpt2", alias=model_type, description="This is a test description" ) assert interface.__name__ == model_type @@ -75,7 +67,7 @@ def test_text_generation(self): def test_summarization(self): model_type = "summarization" - interface = gr.Interface.load( + interface = gr.load( "models/facebook/bart-large-cnn", api_key=None, alias=model_type ) assert interface.__name__ == model_type @@ -84,7 +76,7 @@ def test_summarization(self): def test_translation(self): model_type = "translation" - interface = gr.Interface.load( + interface = gr.load( "models/facebook/bart-large-cnn", api_key=None, alias=model_type ) assert interface.__name__ == model_type @@ -93,7 +85,7 @@ def test_translation(self): def test_text2text_generation(self): model_type = "text2text-generation" - interface = gr.Interface.load( + interface = gr.load( "models/sshleifer/tiny-mbart", api_key=None, alias=model_type ) assert interface.__name__ == model_type @@ -102,7 +94,7 @@ def test_text2text_generation(self): def test_text_classification(self): model_type = "text-classification" - interface = gr.Interface.load( + interface = gr.load( "models/distilbert-base-uncased-finetuned-sst-2-english", api_key=None, alias=model_type, @@ -113,16 +105,14 @@ def test_text_classification(self): def test_fill_mask(self): model_type = "fill-mask" - interface = gr.Interface.load( - "models/bert-base-uncased", api_key=None, alias=model_type - ) + interface = gr.load("models/bert-base-uncased", api_key=None, alias=model_type) assert interface.__name__ == model_type assert isinstance(interface.input_components[0], gr.Textbox) assert isinstance(interface.output_components[0], gr.Label) def test_zero_shot_classification(self): model_type = "zero-shot-classification" - interface = gr.Interface.load( + interface = gr.load( "models/facebook/bart-large-mnli", api_key=None, alias=model_type ) assert interface.__name__ == model_type @@ -133,7 +123,7 @@ def test_zero_shot_classification(self): def test_automatic_speech_recognition(self): model_type = "automatic-speech-recognition" - interface = gr.Interface.load( + interface = gr.load( "models/facebook/wav2vec2-base-960h", api_key=None, alias=model_type ) assert interface.__name__ == model_type @@ -142,7 +132,7 @@ def test_automatic_speech_recognition(self): def test_image_classification(self): model_type = "image-classification" - interface = gr.Interface.load( + interface = gr.load( "models/google/vit-base-patch16-224", api_key=None, alias=model_type ) assert interface.__name__ == model_type @@ -151,7 +141,7 @@ def test_image_classification(self): def test_feature_extraction(self): model_type = "feature-extraction" - interface = gr.Interface.load( + interface = gr.load( "models/sentence-transformers/distilbert-base-nli-mean-tokens", api_key=None, alias=model_type, @@ -162,7 +152,7 @@ def test_feature_extraction(self): def test_sentence_similarity(self): model_type = "text-to-speech" - interface = gr.Interface.load( + interface = gr.load( "models/julien-c/ljspeech_tts_train_tacotron2_raw_phn_tacotron_g2p_en_no_space_train", api_key=None, alias=model_type, @@ -173,7 +163,7 @@ def test_sentence_similarity(self): def test_text_to_speech(self): model_type = "text-to-speech" - interface = gr.Interface.load( + interface = gr.load( "models/julien-c/ljspeech_tts_train_tacotron2_raw_phn_tacotron_g2p_en_no_space_train", api_key=None, alias=model_type, @@ -184,7 +174,7 @@ def test_text_to_speech(self): def test_text_to_image(self): model_type = "text-to-image" - interface = gr.Interface.load( + interface = gr.load( "models/osanseviero/BigGAN-deep-128", api_key=None, alias=model_type ) assert interface.__name__ == model_type @@ -193,12 +183,12 @@ def test_text_to_image(self): def test_english_to_spanish(self): with pytest.warns(UserWarning): - io = gr.Interface.load("spaces/abidlabs/english_to_spanish", title="hi") + io = gr.load("spaces/abidlabs/english_to_spanish", title="hi") assert isinstance(io.input_components[0], gr.Textbox) assert isinstance(io.output_components[0], gr.Textbox) def test_sentiment_model(self): - io = gr.Interface.load("models/distilbert-base-uncased-finetuned-sst-2-english") + io = gr.load("models/distilbert-base-uncased-finetuned-sst-2-english") try: output = io("I am happy, I love you") assert json.load(open(output))["label"] == "POSITIVE" @@ -222,7 +212,7 @@ def test_translation_model(self): pass def test_numerical_to_label_space(self): - io = gr.Interface.load("spaces/abidlabs/titanic-survival") + io = gr.load("spaces/abidlabs/titanic-survival") try: output = io("male", 77, 10) assert json.load(open(output))["label"] == "Perishes" @@ -230,7 +220,7 @@ def test_numerical_to_label_space(self): pass def test_image_to_text(self): - io = gr.Interface.load("models/nlpconnect/vit-gpt2-image-captioning") + io = gr.load("models/nlpconnect/vit-gpt2-image-captioning") try: output = io("gradio/test_data/lion.jpg") assert isinstance(output, str) @@ -238,7 +228,7 @@ def test_image_to_text(self): pass def test_conversational(self): - io = gr.Interface.load("models/microsoft/DialoGPT-medium") + io = gr.load("models/microsoft/DialoGPT-medium") app, _, _ = io.launch(prevent_thread_lock=True) client = TestClient(app) assert app.state_holder == {} @@ -252,7 +242,7 @@ def test_conversational(self): assert isinstance(app.state_holder["foo"], dict) def test_speech_recognition_model(self): - io = gr.Interface.load("models/facebook/wav2vec2-base-960h") + io = gr.load("models/facebook/wav2vec2-base-960h") try: output = io("gradio/test_data/test_audio.wav") assert output is not None @@ -281,7 +271,7 @@ def test_speech_recognition_model(self): io.close() def test_text_to_image_model(self): - io = gr.Interface.load("models/osanseviero/BigGAN-deep-128") + io = gr.load("models/osanseviero/BigGAN-deep-128") try: filename = io("chest") assert filename.endswith(".jpg") or filename.endswith(".jpeg") @@ -290,9 +280,7 @@ def test_text_to_image_model(self): def test_private_space(self): api_key = "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" # Intentionally revealing this key for testing purposes - io = gr.Interface.load( - "spaces/gradio-tests/not-actually-private-space", api_key=api_key - ) + io = gr.load("spaces/gradio-tests/not-actually-private-space", api_key=api_key) try: output = io("abc") assert output == "abc" @@ -301,7 +289,7 @@ def test_private_space(self): def test_private_space_audio(self): api_key = "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" # Intentionally revealing this key for testing purposes - io = gr.Interface.load( + io = gr.load( "spaces/gradio-tests/not-actually-private-space-audio", api_key=api_key ) try: @@ -313,17 +301,15 @@ def test_private_space_audio(self): def test_multiple_spaces_one_private(self): api_key = "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" # Intentionally revealing this key for testing purposes with gr.Blocks(): - gr.Interface.load( - "spaces/gradio-tests/not-actually-private-space", api_key=api_key - ) - gr.Interface.load( + gr.load("spaces/gradio-tests/not-actually-private-space", api_key=api_key) + gr.load( "spaces/gradio/test-loading-examples", ) - assert Context.access_token == api_key + assert Context.hf_token == api_key def test_loading_files_via_proxy_works(self): api_key = "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" # Intentionally revealing this key for testing purposes - io = gr.Interface.load( + io = gr.load( "spaces/gradio-tests/test-loading-examples-private", api_key=api_key ) app, _, _ = io.launch(prevent_thread_lock=True) @@ -338,7 +324,7 @@ class TestLoadInterfaceWithExamples: def test_interface_load_examples(self, tmp_path): test_file_dir = Path(Path(__file__).parent, "test_files") with patch("gradio.helpers.CACHED_FOLDER", tmp_path): - gr.Interface.load( + gr.load( name="models/google/vit-base-patch16-224", examples=[Path(test_file_dir, "cheetah1.jpg")], cache_examples=False, @@ -347,14 +333,14 @@ def test_interface_load_examples(self, tmp_path): def test_interface_load_cache_examples(self, tmp_path): test_file_dir = Path(Path(__file__).parent, "test_files") with patch("gradio.helpers.CACHED_FOLDER", tmp_path): - gr.Interface.load( + gr.load( name="models/google/vit-base-patch16-224", examples=[Path(test_file_dir, "cheetah1.jpg")], cache_examples=True, ) def test_root_url(self): - demo = gr.Interface.load("spaces/gradio/test-loading-examples") + demo = gr.load("spaces/gradio/test-loading-examples") assert all( [ c["props"]["root_url"] @@ -364,17 +350,17 @@ def test_root_url(self): ) def test_root_url_deserialization(self): - demo = gr.Interface.load("spaces/gradio/simple_gallery") + demo = gr.load("spaces/gradio/simple_gallery") path_to_files = demo("test") assert (Path(path_to_files) / "captions.json").exists() def test_interface_with_examples(self): # This demo has the "fake_event" correctly removed - demo = gr.Interface.load("spaces/freddyaboulton/calculator") + demo = gr.load("spaces/freddyaboulton/calculator") assert demo(2, "add", 3) == 5 # This demo still has the "fake_event". both should work - demo = gr.Interface.load("spaces/abidlabs/test-calculator-2") + demo = gr.load("spaces/abidlabs/test-calculator-2") assert demo(2, "add", 4) == 6 @@ -445,13 +431,13 @@ def check_dataset(config, readme_examples): def test_load_blocks_with_default_values(): - io = gr.Interface.load("spaces/abidlabs/min-dalle") + io = gr.load("spaces/abidlabs/min-dalle") assert isinstance(io.get_config_file()["components"][0]["props"]["value"], list) - io = gr.Interface.load("spaces/abidlabs/min-dalle-later") + io = gr.load("spaces/abidlabs/min-dalle-later") assert isinstance(io.get_config_file()["components"][0]["props"]["value"], list) - io = gr.Interface.load("spaces/freddyaboulton/dataframe_load") + io = gr.load("spaces/freddyaboulton/dataframe_load") assert io.get_config_file()["components"][0]["props"]["value"] == { "headers": ["a", "b"], "data": [[1, 4], [2, 5], [3, 6]], @@ -472,76 +458,11 @@ def test_can_load_tabular_model_with_different_widget_data(hypothetical_readme): with patch( "gradio.external.get_tabular_examples", return_value=hypothetical_readme ): - io = gr.Interface.load("models/scikit-learn/tabular-playground") + io = gr.load("models/scikit-learn/tabular-playground") check_dataframe(io.config) check_dataset(io.config, hypothetical_readme) -@pytest.mark.parametrize( - "config, dependency, answer", - [ - ({"version": "3.3", "enable_queue": True}, {"queue": True}, True), - ({"version": "3.3", "enable_queue": False}, {"queue": None}, False), - ({"version": "3.3", "enable_queue": True}, {"queue": None}, True), - ({"version": "3.3", "enable_queue": True}, {"queue": False}, False), - ({"enable_queue": True}, {"queue": False}, False), - ({"version": "3.2", "enable_queue": False}, {"queue": None}, False), - ({"version": "3.2", "enable_queue": True}, {"queue": None}, True), - ({"version": "3.2", "enable_queue": True}, {"queue": False}, False), - ({"version": "3.1.3", "enable_queue": True}, {"queue": None}, False), - ({"version": "3.1.3", "enable_queue": False}, {"queue": True}, False), - ], -) -def test_use_websocket_after_315(config, dependency, answer): - assert use_websocket(config, dependency) == answer - - -class AsyncMock(MagicMock): - async def __call__(self, *args, **kwargs): - return super(AsyncMock, self).__call__(*args, **kwargs) - - -@pytest.mark.asyncio -async def test_get_pred_from_ws(): - mock_ws = AsyncMock(name="ws") - messages = [ - json.dumps({"msg": "estimation"}), - json.dumps({"msg": "send_data"}), - json.dumps({"msg": "process_generating"}), - json.dumps({"msg": "process_completed", "output": {"data": ["result!"]}}), - ] - mock_ws.recv.side_effect = messages - data = json.dumps({"data": ["foo"], "fn_index": "foo"}) - hash_data = json.dumps({"session_hash": "daslskdf", "fn_index": "foo"}) - output = await get_pred_from_ws(mock_ws, data, hash_data) - assert output == {"data": ["result!"]} - mock_ws.send.assert_called_once_with(data) - - -@pytest.mark.asyncio -async def test_get_pred_from_ws_raises_if_queue_full(): - mock_ws = AsyncMock(name="ws") - messages = [json.dumps({"msg": "queue_full"})] - mock_ws.recv.side_effect = messages - data = json.dumps({"data": ["foo"], "fn_index": "foo"}) - hash_data = json.dumps({"session_hash": "daslskdf", "fn_index": "foo"}) - with pytest.raises(gradio.Error, match="Queue is full!"): - await get_pred_from_ws(mock_ws, data, hash_data) - - -@pytest.mark.skipif( - sys.version_info < (3, 8), - reason="Mocks of async context manager don't work for 3.7", -) -def test_respect_queue_when_load_from_config(): - with patch("websockets.connect"): - with patch( - "gradio.external_utils.get_pred_from_ws", return_value={"data": ["foo"]} - ): - interface = gr.Interface.load("spaces/freddyaboulton/saymyname") - assert interface("bob") == "foo" - - def test_raise_value_error_when_api_name_invalid(): with pytest.raises(InvalidApiName): demo = gr.Blocks.load(name="spaces/gradio/hello_world") diff --git a/test/test_mix.py b/test/test_mix.py index d58de1ece96c3..5729f46a4d8bb 100644 --- a/test/test_mix.py +++ b/test/test_mix.py @@ -24,8 +24,8 @@ def test_in_interface(self): @pytest.mark.flaky def test_with_external(self): - io1 = gr.Interface.load("spaces/abidlabs/image-identity") - io2 = gr.Interface.load("spaces/abidlabs/image-classifier") + io1 = gr.load("spaces/abidlabs/image-identity") + io2 = gr.load("spaces/abidlabs/image-classifier") series = mix.Series(io1, io2) try: output = series("gradio/test_data/lion.jpg") @@ -55,8 +55,8 @@ def test_multiple_return_in_interface(self): @pytest.mark.flaky def test_with_external(self): - io1 = gr.Interface.load("spaces/abidlabs/english_to_spanish") - io2 = gr.Interface.load("spaces/abidlabs/english2german") + io1 = gr.load("spaces/abidlabs/english_to_spanish") + io2 = gr.load("spaces/abidlabs/english2german") parallel = mix.Parallel(io1, io2) try: hello_es, hello_de = parallel("Hello") diff --git a/test/test_processing_utils.py b/test/test_processing_utils.py index 6e1f6a8367001..8c87e0e17a6ed 100644 --- a/test/test_processing_utils.py +++ b/test/test_processing_utils.py @@ -23,25 +23,6 @@ def test_decode_base64_to_image(self): ) assert isinstance(output_image, Image.Image) - def test_encode_url_or_file_to_base64(self): - output_base64 = processing_utils.encode_url_or_file_to_base64( - "gradio/test_data/test_image.png" - ) - assert output_base64 == deepcopy(media_data.BASE64_IMAGE) - - def test_encode_file_to_base64(self): - output_base64 = processing_utils.encode_file_to_base64( - "gradio/test_data/test_image.png" - ) - assert output_base64 == deepcopy(media_data.BASE64_IMAGE) - - @pytest.mark.flaky - def test_encode_url_to_base64(self): - output_base64 = processing_utils.encode_url_to_base64( - "https://raw.githubusercontent.com/gradio-app/gradio/main/gradio/test_data/test_image.png" - ) - assert output_base64 == deepcopy(media_data.BASE64_IMAGE) - def test_encode_plot_to_base64(self): plt.plot([1, 2, 3, 4]) output_base64 = processing_utils.encode_plot_to_base64(plt) @@ -135,18 +116,6 @@ def test_convert_to_16_bit_wav(self): class TestOutputPreprocessing: - def test_decode_base64_to_binary(self): - binary = processing_utils.decode_base64_to_binary( - deepcopy(media_data.BASE64_IMAGE) - ) - assert deepcopy(media_data.BINARY_IMAGE) == binary - - def test_decode_base64_to_file(self): - temp_file = processing_utils.decode_base64_to_file( - deepcopy(media_data.BASE64_IMAGE) - ) - assert isinstance(temp_file, tempfile._TemporaryFileWrapper) - float_dtype_list = [ float, float, @@ -245,12 +214,3 @@ def test_video_conversion_returns_original_video_if_fails( ) # If the conversion succeeded it'd be .mp4 assert Path(playable_vid).suffix == ".avi" - - -def test_download_private_file(): - url_path = "https://gradio-tests-not-actually-private-space.hf.space/file=lion.jpg" - access_token = "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" # Intentionally revealing this key for testing purposes - file = processing_utils.download_tmp_copy_of_file( - url_path=url_path, access_token=access_token - ) - assert file.name.endswith(".jpg") diff --git a/test/test_utils.py b/test/test_utils.py index da67d26fadaf3..7aef377cd5682 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -37,7 +37,6 @@ sagemaker_check, sanitize_list_for_csv, sanitize_value_for_csv, - strip_invalid_filename_characters, validate_url, version_check, ) @@ -577,24 +576,6 @@ def test_later_suffix(self): assert append_unique_suffix(name, list_of_names) == "test_4" -@pytest.mark.parametrize( - "orig_filename, new_filename", - [ - ("abc", "abc"), - ("$$AAabc&3", "AAabc3"), - ("$$AAabc&3", "AAabc3"), - ("$$AAa..b-c&3_", "AAa..b-c3_"), - ("$$AAa..b-c&3_", "AAa..b-c3_"), - ( - "ゆかりです。私、こんなかわいい服は初めて着ました…。なんだかうれしくって、楽しいです。歌いたくなる気分って、初めてです。これがアイドルってことなのかもしれませんね", - "ゆかりです私こんなかわいい服は初めて着ましたなんだかうれしくって楽しいです歌いたくなる気分って初めてですこれがアイドルってことなの", - ), - ], -) -def test_strip_invalid_filename_characters(orig_filename, new_filename): - assert strip_invalid_filename_characters(orig_filename) == new_filename - - class TestAbspath: def test_abspath_no_symlink(self): resolved_path = str(abspath("../gradio/gradio/test_data/lion.jpg"))