Skip to content

Commit

Permalink
Fix some bugs related to Python client (#3721)
Browse files Browse the repository at this point in the history
* client format

* docs

* formatting

* fix tests

* fixed bug

* api endpoint changes

* fix tests

* fix tests

* formatting

* Add support for sessions [python client] (#3731)

* client

* add state and tests

* remove session param
  • Loading branch information
abidlabs authored Apr 2, 2023
1 parent da7d1df commit f46f5f9
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 66 deletions.
181 changes: 128 additions & 53 deletions client/python/gradio_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from concurrent.futures import Future
from datetime import datetime
from threading import Lock
from typing import Any, Callable, Dict, List, Tuple
from typing import Any, Callable, Dict, List, Literal, Tuple

import huggingface_hub
import requests
Expand All @@ -32,7 +32,7 @@ def __init__(
"""
Parameters:
src: Either the name of the Hugging Face Space to load, (e.g. "abidlabs/pictionary") or the full URL (including "http" or "https") of the hosted Gradio app to load (e.g. "http://mydomain.com/app" or "https://bec81a83-5b5c-471e.gradio.live/").
hf_token: The Hugging Face token to use to access private Spaces. If not provided, only public Spaces can be loaded.
hf_token: The Hugging Face token to use to access private Spaces. Automatically fetched if you are logged in via the Hugging Face Hub CLI.
max_workers: The maximum number of thread workers that can be used to make requests to the remote Gradio app simultaneously.
"""
self.hf_token = hf_token
Expand All @@ -55,6 +55,7 @@ def __init__(
self.api_url = utils.API_URL.format(self.src)
self.ws_url = utils.WS_URL.format(self.src).replace("http", "ws", 1)
self.config = self._get_config()
self.session_hash = str(uuid.uuid4())

self.endpoints = [
Endpoint(self, fn_index, dependency)
Expand All @@ -71,25 +72,24 @@ def predict(
self,
*args,
api_name: str | None = None,
fn_index: int = 0,
fn_index: int | None = None,
result_callbacks: Callable | List[Callable] | None = None,
) -> Future:
"""
Parameters:
*args: The arguments to pass to the remote API. The order of the arguments must match the order of the inputs in the Gradio app.
api_name: The name of the API endpoint to call. If not provided, the first API will be called. Takes precedence over fn_index.
fn_index: The index of the API endpoint to call. If not provided, the first API will be called.
api_name: The name of the API endpoint to call starting with a leading slash, e.g. "/predict". Does not need to be provided if the Gradio app has only one named API endpoint.
fn_index: The index of the API endpoint to call, e.g. 0. Both api_name and fn_index can be provided, but if they conflict, api_name will take precedence.
result_callbacks: A callback function, or list of callback functions, to be called when the result is ready. If a list of functions is provided, they will be called in order. The return values from the remote API are provided as separate parameters into the callback. If None, no callback will be called.
Returns:
A Job object that can be used to retrieve the status and result of the remote API call.
"""
if api_name:
fn_index = self._infer_fn_index(api_name)
inferred_fn_index = self._infer_fn_index(api_name, fn_index)

helper = None
if self.endpoints[fn_index].use_ws:
if self.endpoints[inferred_fn_index].use_ws:
helper = Communicator(Lock(), JobStatus())
end_to_end_fn = self.endpoints[fn_index].make_end_to_end_fn(helper)
end_to_end_fn = self.endpoints[inferred_fn_index].make_end_to_end_fn(helper)
future = self.executor.submit(end_to_end_fn, *args)

job = Job(future, communicator=helper)
Expand All @@ -115,13 +115,15 @@ def fn(future):
def view_api(
self,
all_endpoints: bool | None = None,
return_info: bool = False,
) -> Dict | None:
print_info: bool = True,
return_format: Literal["dict", "str"] | None = None,
) -> Dict | str | None:
"""
Prints the usage info for the API. If the Gradio app has multiple API endpoints, the usage info for each endpoint will be printed separately.
Parameters:
all_endpoints: If True, prints information for both named and unnamed endpoints in the Gradio app. If False, will only print info about named endpoints. If None (default), will only print info about unnamed endpoints if there are no named endpoints.
return_info: If False (default), prints the usage info to the console. If True, returns the usage info as a dictionary that can be programmatically parsed (does not print), and *all endpoints are returned in the dictionary* regardless of the value of `all_endpoints`. The format of the dictionary is in the docstring of this method.
print_info: If True, prints the usage info to the console. If False, does not print the usage info.
return_format: If None, nothing is returned. If "str", returns the same string that would be printed to the console. If "dict", returns the usage info as a dictionary that can be programmatically parsed, and *all endpoints are returned in the dictionary* regardless of the value of `all_endpoints`. The format of the dictionary is in the docstring of this method.
Dictionary format:
{
"named_endpoints": {
Expand Down Expand Up @@ -153,9 +155,6 @@ def view_api(
else:
info["unnamed_endpoints"][endpoint.fn_index] = endpoint.get_info()

if return_info:
return info

num_named_endpoints = len(info["named_endpoints"])
num_unnamed_endpoints = len(info["unnamed_endpoints"])
if num_named_endpoints == 0 and all_endpoints is None:
Expand All @@ -175,7 +174,15 @@ def view_api(
if num_unnamed_endpoints > 0:
human_info += f"\nUnnamed API endpoints: {num_unnamed_endpoints}, to view, run Client.view_api(`all_endpoints=True`)\n"

print(human_info)
if print_info:
print(human_info)
if return_format == "str":
return human_info
elif return_format == "dict":
return info

def reset_session(self) -> None:
self.session_hash = str(uuid.uuid4())

def _render_endpoints_info(
self,
Expand All @@ -199,22 +206,26 @@ def _render_endpoints_info(
raise ValueError("name_or_index must be a string or integer")

human_info = f"\n - predict({rendered_parameters}{final_param}) -> {rendered_return_values}\n"
human_info += " Parameters:\n"
if endpoints_info["parameters"]:
human_info += " Parameters:\n"
for label, info in endpoints_info["parameters"].items():
human_info += f" - [{info[2]}] {label}: {info[0]} ({info[1]})\n"
for label, info in endpoints_info["parameters"].items():
human_info += f" - [{info[2]}] {label}: {info[0]} ({info[1]})\n"
else:
human_info += " - None\n"
human_info += " Returns:\n"
if endpoints_info["returns"]:
human_info += " Returns:\n"
for label, info in endpoints_info["returns"].items():
human_info += f" - [{info[2]}] {label}: {info[0]} ({info[1]})\n"
for label, info in endpoints_info["returns"].items():
human_info += f" - [{info[2]}] {label}: {info[0]} ({info[1]})\n"
else:
human_info += " - None\n"

return human_info

def __repr__(self):
return self.view_api()
return self.view_api(print_info=False, return_format="str")

def __str__(self):
return self.view_api()
return self.view_api(print_info=False, return_format="str")

def _telemetry_thread(self) -> None:
# Disable telemetry by setting the env variable HF_HUB_DISABLE_TELEMETRY=1
Expand All @@ -231,11 +242,34 @@ def _telemetry_thread(self) -> None:
except Exception:
pass

def _infer_fn_index(self, api_name: str) -> int:
for i, d in enumerate(self.config["dependencies"]):
if d.get("api_name") == api_name:
return i
raise ValueError(f"Cannot find a function with api_name: {api_name}")
def _infer_fn_index(self, api_name: str | None, fn_index: int | None) -> int:
inferred_fn_index = None
if api_name is not None:
for i, d in enumerate(self.config["dependencies"]):
config_api_name = d.get("api_name")
if config_api_name is None:
continue
if "/" + config_api_name == api_name:
inferred_fn_index = i
break
else:
error_message = f"Cannot find a function with `api_name`: {api_name}."
if not api_name.startswith("/"):
error_message += " Did you mean to use a leading slash?"
raise ValueError(error_message)
elif fn_index is not None:
inferred_fn_index = fn_index
else:
valid_endpoints = [
e for e in self.endpoints if e.is_valid and e.api_name is not None
]
if len(valid_endpoints) == 1:
inferred_fn_index = valid_endpoints[0].fn_index
else:
raise ValueError(
"This Gradio app might have multiple endpoints. Please specify an `api_name` or `fn_index`"
)
return inferred_fn_index

def __del__(self):
if hasattr(self, "executor"):
Expand Down Expand Up @@ -264,15 +298,15 @@ class Endpoint:
"""Helper class for storing all the information about a single API endpoint."""

def __init__(self, client: Client, fn_index: int, dependency: Dict):
self.api_url = client.api_url
self.ws_url = client.ws_url
self.client: Client = client
self.fn_index = fn_index
self.dependency = dependency
self.api_name: str | None = dependency.get("api_name")
self.headers = client.headers
self.config = client.config
if self.api_name:
self.api_name = "/" + self.api_name
self.use_ws = self._use_websocket(self.dependency)
self.hf_token = client.hf_token
self.input_component_types = []
self.output_component_types = []
try:
self.serializers, self.deserializers = self._setup_serializers()
self.is_valid = self.dependency[
Expand All @@ -298,7 +332,7 @@ def get_info(self) -> Dict[str, Dict[str, List[str]]]:
"""
parameters = {}
for i, input in enumerate(self.dependency["inputs"]):
for component in self.config["components"]:
for component in self.client.config["components"]:
if component["id"] == input:
label = (
component["props"]
Expand All @@ -311,11 +345,13 @@ def get_info(self) -> Dict[str, Dict[str, List[str]]]:
else:
info = self.serializers[i].input_api_info()
info = list(info)
info.append(component.get("type", "component").capitalize())
parameters[label] = info
component_type = component.get("type", "component").capitalize()
info.append(component_type)
if not component_type.lower() == utils.STATE_COMPONENT:
parameters[label] = info
returns = {}
for o, output in enumerate(self.dependency["outputs"]):
for component in self.config["components"]:
for component in self.client.config["components"]:
if component["id"] == output:
label = (
component["props"]
Expand All @@ -328,11 +364,19 @@ def get_info(self) -> Dict[str, Dict[str, List[str]]]:
else:
info = self.deserializers[o].output_api_info()
info = list(info)
info.append(component.get("type", "component").capitalize())
returns[label] = list(info)
component_type = component.get("type", "component").capitalize()
info.append(component_type)
if not component_type.lower() == utils.STATE_COMPONENT:
returns[label] = info

return {"parameters": parameters, "returns": returns}

def __repr__(self):
return json.dumps(self.get_info(), indent=4)

def __str__(self):
return json.dumps(self.get_info(), indent=4)

def make_end_to_end_fn(self, helper: Communicator | None = None):

_predict = self.make_predict(helper)
Expand All @@ -343,23 +387,44 @@ def _inner(*data):
inputs = self.serialize(*data)
predictions = _predict(*inputs)
outputs = self.deserialize(*predictions)
if len(self.dependency["outputs"]) == 1:
if (
len(
[
oct
for oct in self.output_component_types
if not oct == utils.STATE_COMPONENT
]
)
== 1
):
return outputs[0]
return outputs

return _inner

def make_predict(self, helper: Communicator | None = None):
def _predict(*data) -> Tuple:
data = json.dumps({"data": data, "fn_index": self.fn_index})
data = json.dumps(
{
"data": data,
"fn_index": self.fn_index,
"session_hash": self.client.session_hash,
}
)
hash_data = json.dumps(
{"fn_index": self.fn_index, "session_hash": str(uuid.uuid4())}
{
"fn_index": self.fn_index,
"session_hash": self.client.session_hash,
}
)

if self.use_ws:
result = utils.synchronize_async(self._ws_fn, data, hash_data, helper)
output = result["data"]
else:
response = requests.post(self.api_url, headers=self.headers, data=data)
response = requests.post(
self.client.api_url, headers=self.client.headers, data=data
)
result = json.loads(response.content.decode("utf-8"))
try:
output = result["data"]
Expand All @@ -383,6 +448,11 @@ def _predict_resolve(self, *data) -> Any:
return outputs

def serialize(self, *data) -> Tuple:
for i, input_component_type in enumerate(self.input_component_types):
if input_component_type == utils.STATE_COMPONENT:
data = list(data)
data.insert(i, None)
data = tuple(data)
assert len(data) == len(
self.serializers
), f"Expected {len(self.serializers)} arguments, got {len(data)}"
Expand All @@ -394,8 +464,11 @@ def deserialize(self, *data) -> Tuple:
), f"Expected {len(self.deserializers)} outputs, got {len(data)}"
return tuple(
[
s.deserialize(d, hf_token=self.hf_token)
for s, d in zip(self.deserializers, data)
s.deserialize(d, hf_token=self.client.hf_token)
for s, d, oct in zip(
self.deserializers, data, self.output_component_types
)
if not oct == utils.STATE_COMPONENT
]
)

Expand All @@ -404,16 +477,17 @@ def _setup_serializers(self) -> Tuple[List[Serializable], List[Serializable]]:
serializers = []

for i in inputs:
for component in self.config["components"]:
for component in self.client.config["components"]:
if component["id"] == i:
component_name = component["type"]
self.input_component_types.append(component_name)
if component.get("serializer"):
serializer_name = component["serializer"]
assert (
serializer_name in serializing.SERIALIZER_MAPPING
), f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version."
serializer = serializing.SERIALIZER_MAPPING[serializer_name]
else:
component_name = component["type"]
assert (
component_name in serializing.COMPONENT_MAPPING
), f"Unknown component: {component_name}, you may need to update your gradio_client version."
Expand All @@ -423,16 +497,17 @@ def _setup_serializers(self) -> Tuple[List[Serializable], List[Serializable]]:
outputs = self.dependency["outputs"]
deserializers = []
for i in outputs:
for component in self.config["components"]:
for component in self.client.config["components"]:
if component["id"] == i:
component_name = component["type"]
self.output_component_types.append(component_name)
if component.get("serializer"):
serializer_name = component["serializer"]
assert (
serializer_name in serializing.SERIALIZER_MAPPING
), f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version."
deserializer = serializing.SERIALIZER_MAPPING[serializer_name]
else:
component_name = component["type"]
assert (
component_name in serializing.COMPONENT_MAPPING
), f"Unknown component: {component_name}, you may need to update your gradio_client version."
Expand All @@ -442,16 +517,16 @@ def _setup_serializers(self) -> Tuple[List[Serializable], List[Serializable]]:
return serializers, deserializers

def _use_websocket(self, dependency: Dict) -> bool:
queue_enabled = self.config.get("enable_queue", False)
queue_enabled = self.client.config.get("enable_queue", False)
queue_uses_websocket = version.parse(
self.config.get("version", "2.0")
self.client.config.get("version", "2.0")
) >= version.Version("3.2")
dependency_uses_queue = dependency.get("queue", False) is not False
return queue_enabled and queue_uses_websocket and dependency_uses_queue

async def _ws_fn(self, data, hash_data, helper: Communicator):
async with websockets.connect( # type: ignore
self.ws_url, open_timeout=10, extra_headers=self.headers
self.client.ws_url, open_timeout=10, extra_headers=self.client.headers
) as websocket:
return await utils.get_pred_from_ws(websocket, data, hash_data, helper)

Expand Down
1 change: 1 addition & 0 deletions client/python/gradio_client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

API_URL = "{}/api/predict/"
WS_URL = "{}/queue/join"
STATE_COMPONENT = "state"

__version__ = (pkgutil.get_data(__name__, "version.txt") or b"").decode("ascii").strip()

Expand Down
Loading

0 comments on commit f46f5f9

Please sign in to comment.