-
Notifications
You must be signed in to change notification settings - Fork 2.7k
How data gets processed in a Gradio Interface
Abubakar Abid edited this page Mar 16, 2022
·
2 revisions
1. Standard data flow (when a user provides a prediction fn
, inputs
, and outputs
to construct an Interface
)
Let's take this image classification Space (abidlabs/pytorch-image-classifier) as an example:
import requests
from torchvision import transforms
import torch
import gradio as gr
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval()
# Download human-readable labels for ImageNet.
response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")
def predict(inp):
inp = transforms.ToTensor()(inp).unsqueeze(0)
with torch.no_grad():
prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
confidences = {labels[i]: float(prediction[i]) for i in range(1000)}
return confidences
gr.Interface(fn=predict,
inputs=gr.inputs.Image(type="pil"),
outputs=gr.outputs.Label(num_top_classes=3)).launch()
The data flow is as follows:
- When a user uploads an image and clicks submit, the image is serialized into a
base64
format so that it can be sent to/api/predict
on the server where the gradio application is running - Based on the
type
of the input component,gradio
automatically applies certain preprocessing steps to convert the input image into the format that the user's predictionfn
expects. In this case, the base64 image is converted to thePIL
image format. - Then the image is run through the
fn
and in this case, a dictionary of labels is returned - The labels are postprocessed into the appropriate format based on the parameters of the
Label
component the user specifies. In this case, the postprocessing identifies the 3 label with the highest confidence. Depending on the component e.g.Image
output, the resulting output may also need to be serialized - Then this serialized output is sent to the front end and displayed
This special case is used when building a Gradio demo on top of a Space or Model on the Hub using the inference API endpoint. Let's take this image classification Space (abidlabs/vision-transformer) as an example:
import gradio as gr
gr.Interface.load("huggingface/google/vit-base-patch16-224").launch()
The data flow is as follows:
- When a user uploads an image and clicks submit, the image is serialized into a
base64
format so that it can be sent to/api/predict
on the server where the gradio application is running - Based on the type of the inference API endpoint,
gradio
automatically applies certain preprocessing steps to convert the input image into a particular format. In this case, the base64 image is converted to a RGB file and the filepath is returned. - Based on the type of input component,
gradio
automatically again serializes the data to convert it to a format that can be sent to the API endpoint on the Hugging Face Hub. - Then the image is sent to the inference API endpoint and some response is returned
-
gradio
deserializes the response data to get a dictionary of labels - Based on the type of the inference API endpoint,
gradio
automatically applies certain postprocessing steps to convert the dictionary of labels into a particular format. In this case, the label with the highest confidence is extracted and passed along with rest of the dictionary of labels. - Then this output sent to the front end and displayed