Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow users to add a custom API route #10332

Merged
merged 13 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .changeset/sour-apples-boil.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
"@gradio/client": patch
"@gradio/core": patch
"gradio": patch
"gradio_client": patch
---

fix:Allow users to add a custom API route
1 change: 1 addition & 0 deletions client/js/src/helpers/api_info.ts
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ export function get_type(
serializer: string,
signature_type: "return" | "parameter"
): string | undefined {
if (component === "Api") return type.type;
switch (type?.type) {
case "string":
return "string";
Expand Down
100 changes: 99 additions & 1 deletion client/python/gradio_client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import base64
import concurrent.futures
import copy
import inspect
import json
import mimetypes
import os
Expand All @@ -19,7 +20,17 @@
from enum import Enum
from pathlib import Path
from threading import Lock
from typing import TYPE_CHECKING, Any, Literal, Optional, TypedDict
from typing import (
TYPE_CHECKING,
Any,
Literal,
Optional,
TypedDict,
Union,
get_args,
get_origin,
get_type_hints,
)

import fsspec.asyn
import httpx
Expand Down Expand Up @@ -994,6 +1005,93 @@ def get_desc(v):
raise APIInfoParseError(f"Cannot parse schema {schema}")


def python_type_to_json_schema(type_hint: Any) -> dict:
try:
return _python_type_to_json_schema(type_hint)
except Exception:
return {}


def _python_type_to_json_schema(type_hint: Any) -> dict:
"""Convert a Python type hint to a JSON schema."""
if type_hint is type(None):
return {"type": "null"}
if type_hint is str:
return {"type": "string"}
if type_hint is int:
return {"type": "integer"}
if type_hint is float:
return {"type": "number"}
if type_hint is bool:
return {"type": "boolean"}

origin = get_origin(type_hint)

if origin is Literal:
literal_values = get_args(type_hint)
if len(literal_values) == 1:
return {"const": literal_values[0]}
return {"enum": list(literal_values)}

if origin is Union or str(origin) == "|":
types = get_args(type_hint)
if len(types) == 2 and type(None) in types:
other_type = next(t for t in types if t is not type(None))
schema = _python_type_to_json_schema(other_type)
if "type" in schema:
schema["type"] = [schema["type"], "null"]
else:
schema["oneOf"] = [{"type": "null"}, schema]
return schema
return {"anyOf": [_python_type_to_json_schema(t) for t in types]}

if origin is list:
item_type = get_args(type_hint)[0]
return {"type": "array", "items": _python_type_to_json_schema(item_type)}
if origin is tuple:
types = get_args(type_hint)
return {
"type": "array",
"prefixItems": [_python_type_to_json_schema(t) for t in types],
"minItems": len(types),
"maxItems": len(types),
}

if origin is dict:
key_type, value_type = get_args(type_hint)
if key_type is not str:
raise ValueError("JSON Schema only supports string keys in objects")
schema = {
"type": "object",
"additionalProperties": _python_type_to_json_schema(value_type),
}
return schema

if inspect.isclass(type_hint) and hasattr(type_hint, "__annotations__"):
properties = {}
required = []

hints = get_type_hints(type_hint)
for field_name, field_type in hints.items():
properties[field_name] = _python_type_to_json_schema(field_type)
if hasattr(type_hint, "__total__"):
if type_hint.__total__:
required.append(field_name)
elif (
not hasattr(type_hint, "__dataclass_fields__")
or not type_hint.__dataclass_fields__[field_name].default
):
required.append(field_name)

schema = {"type": "object", "properties": properties}
if required:
schema["required"] = required
return schema

if type_hint is Any:
return {}


def traverse(json_obj: Any, func: Callable, is_root: Callable[..., bool]) -> Any:
"""
Traverse a JSON object and apply a function to each element that satisfies the is_root condition.
Expand Down
1 change: 1 addition & 0 deletions gradio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
RetryData,
SelectData,
UndoData,
api,
on,
)
from gradio.exceptions import Error
Expand Down
11 changes: 6 additions & 5 deletions gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,11 +779,12 @@ def set_event_trigger(
if fn is not None and not cancels:
check_function_inputs_match(fn, inputs, inputs_as_dict)

if _targets[0][1] in ["change", "key_up"] and trigger_mode is None:
trigger_mode = "always_last"
elif _targets[0][1] in ["stream"] and trigger_mode is None:
trigger_mode = "multiple"
elif trigger_mode is None:
if len(_targets) and trigger_mode is None:
if _targets[0][1] in ["change", "key_up"]:
trigger_mode = "always_last"
elif _targets[0][1] in ["stream"]:
trigger_mode = "multiple"
if trigger_mode is None:
trigger_mode = "once"
elif trigger_mode not in ["once", "multiple", "always_last"]:
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion gradio/component_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def {{ event.event_name }}(self,
fn: the function to call when this event is triggered. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
inputs: list of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
outputs: list of gradio.components to use as outputs. If the function returns no outputs, this should be an empty list.
api_name: defines how the endpoint appears in the API docs. Can be a string, None, or False. If False, the endpoint will not be exposed in the api docs. If set to None, the endpoint will be exposed in the api docs as an unnamed endpoint, although this behavior will be changed in Gradio 4.0. If set to a string, the endpoint will be exposed in the api docs with the given name.
api_name: defines how the endpoint appears in the API docs. Can be a string, None, or False. If False, the endpoint will not be exposed in the api docs. If set to None, will use the functions name as the endpoint route. If set to a string, the endpoint will be exposed in the api docs with the given name.
scroll_to_output: if True, will scroll to output component on completion
show_progress: how to show the progress animation while event is running: "full" shows a spinner which covers the output component area as well as a runtime display in the upper right corner, "minimal" only shows the runtime display, "hidden" shows no progress animation at all
queue: if True, will place the request on the queue, if the queue has been enabled. If False, will not put this event on the queue, even if the queue has been enabled. If None, will use the queue setting of the gradio app.
Expand Down
46 changes: 46 additions & 0 deletions gradio/components/api_component.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""gr.Api() component."""

from __future__ import annotations

from typing import Any

from gradio.components.base import Component


class Api(Component):
"""
A generic component that holds any value. Used for generating APIs with no actual frontend component.
"""

EVENTS = []

def __init__(
self,
value: Any,
_api_info: dict[str, str],
label: str = "API",
):
"""
Parameters:
value: default value.
"""
self._api_info = _api_info
super().__init__(value=value, label=label)

def preprocess(self, payload: Any) -> Any:
return payload

def postprocess(self, value: Any) -> Any:
return value

def api_info(self) -> dict[str, str]:
return self._api_info

def example_payload(self) -> Any:
return self.value if self.value is not None else "..."

def example_value(self) -> Any:
return self.value if self.value is not None else "..."

# def get_block_name(self) -> str:
# return "state" # so that it does not render in the frontend, just like state
126 changes: 124 additions & 2 deletions gradio/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
from gradio.blocks import Block, BlockContext, Component
from gradio.components import Timer

from gradio_client.utils import python_type_to_json_schema

from gradio.context import get_blocks_context
from gradio.utils import get_cancelled_fn_indices
from gradio.utils import get_cancelled_fn_indices, get_function_params, get_return_types


def set_cancel_events(
Expand Down Expand Up @@ -760,7 +762,7 @@ def on(
fn: the function to call when this event is triggered. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
outputs: List of gradio.components to use as outputs. If the function returns no outputs, this should be an empty list.
api_name: Defines how the endpoint appears in the API docs. Can be a string, None, or False. If False, the endpoint will not be exposed in the api docs. If set to None, the endpoint will be exposed in the api docs as an unnamed endpoint, although this behavior will be changed in Gradio 4.0. If set to a string, the endpoint will be exposed in the api docs with the given name.
api_name: Defines how the endpoint appears in the API docs. Can be a string, None, or False. If False, the endpoint will not be exposed in the api docs. If set to None, will use the functions name as the endpoint route. If set to a string, the endpoint will be exposed in the api docs with the given name.
scroll_to_output: If True, will scroll to output component on completion
show_progress: how to show the progress animation while event is running: "full" shows a spinner which covers the output component area as well as a runtime display in the upper right corner, "minimal" only shows the runtime display, "hidden" shows no progress animation at all
queue: If True, will place the request on the queue, if the queue has been enabled. If False, will not put this event on the queue, even if the queue has been enabled. If None, will use the queue setting of the gradio app.
Expand Down Expand Up @@ -886,6 +888,126 @@ def inner(*args, **kwargs):
return Dependency(None, dep.get_config(), dep_index, fn)


@document()
def api(
fn: Callable | Literal["decorator"] = "decorator",
*,
api_name: str | None | Literal[False] = None,
queue: bool = True,
batch: bool = False,
max_batch_size: int = 4,
concurrency_limit: int | None | Literal["default"] = "default",
concurrency_id: str | None = None,
show_api: bool = True,
time_limit: int | None = None,
stream_every: float = 0.5,
) -> Dependency:
"""
Sets up an API endpoint for a generic function that can be called via the gradio client. Derives its type from type-hints in the function signature.

Parameters:
fn: the function to call when this event is triggered. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
api_name: Defines how the endpoint appears in the API docs. Can be a string, None, or False. If False, the endpoint will not be exposed in the api docs. If set to None, will use the functions name as the endpoint route. If set to a string, the endpoint will be exposed in the api docs with the given name.
queue: If True, will place the request on the queue, if the queue has been enabled. If False, will not put this event on the queue, even if the queue has been enabled. If None, will use the queue setting of the gradio app.
batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
concurrency_limit: If set, this is the maximum number of this event that can be running simultaneously. Can be set to None to mean no concurrency_limit (any number of this event can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `Blocks.queue()`, which itself is 1 by default).
concurrency_id: If set, this is the id of the concurrency group. Events with the same concurrency_id will be limited by the lowest set concurrency_limit.
show_api: whether to show this event in the "view API" page of the Gradio app, or in the ".view_api()" method of the Gradio clients. Unlike setting api_name to False, setting show_api to False will still allow downstream apps as well as the Clients to use this event. If fn is None, show_api will automatically be set to False.
time_limit: The time limit for the function to run. Parameter only used for the `.stream()` event.
stream_every: The latency (in seconds) at which stream chunks are sent to the backend. Defaults to 0.5 seconds. Parameter only used for the `.stream()` event.
Example:
import gradio as gr
with gr.Blocks() as demo:
with gr.Row():
input = gr.Textbox()
button = gr.Button("Submit")
output = gr.Textbox()
gr.on(
triggers=[button.click, input.submit],
fn=lambda x: x,
inputs=[input],
outputs=[output]
)
demo.launch()
"""
if fn == "decorator":

def wrapper(func):
api(
fn=func,
api_name=api_name,
queue=queue,
batch=batch,
max_batch_size=max_batch_size,
concurrency_limit=concurrency_limit,
concurrency_id=concurrency_id,
show_api=show_api,
time_limit=time_limit,
stream_every=stream_every,
)

@wraps(func)
def inner(*args, **kwargs):
return func(*args, **kwargs)

return inner

return Dependency(None, {}, None, wrapper)

root_block = get_blocks_context()
if root_block is None:
raise Exception("Cannot call api() outside of a gradio.Blocks context.")

from gradio.components.api_component import Api

fn_params = get_function_params(fn)
return_types = get_return_types(fn)

def ordinal(n):
return f"{n}{'th' if 10 <= n % 100 <= 20 else {1: 'st', 2: 'nd', 3: 'rd'}.get(n % 10, 'th')}"

if any(param[3] is None for param in fn_params):
raise ValueError(
"API endpoints must have type hints. Please specify a type hint for all parameters."
)
inputs = [
Api(
default_value if has_default else None,
python_type_to_json_schema(_type),
ordinal(i + 1),
)
for i, (_, has_default, default_value, _type) in enumerate(fn_params)
]
outputs = [
Api(None, python_type_to_json_schema(type), ordinal(i + 1))
for i, type in enumerate(return_types)
]

dep, dep_index = root_block.set_event_trigger(
[],
fn,
inputs,
outputs,
preprocess=False,
postprocess=False,
scroll_to_output=False,
show_progress="hidden",
api_name=api_name,
js=None,
concurrency_limit=concurrency_limit,
concurrency_id=concurrency_id,
queue=queue,
batch=batch,
max_batch_size=max_batch_size,
show_api=show_api,
trigger_mode=None,
time_limit=time_limit,
stream_every=stream_every,
)
return Dependency(None, dep.get_config(), dep_index, fn)


class Events:
change = EventListener(
"change",
Expand Down
Loading
Loading