diff --git a/CHANGELOG.md b/CHANGELOG.md index c1cb61049dd82..e472419714c22 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,7 @@ ## New Features: -No changes to highlight. +* Make `Blocks.load` behave like other event listeners (allows chaining `then` off of it) [@anentropic](https://github.com/anentropic/) in [PR 4304](https://github.com/gradio-app/gradio/pull/4304) ## Bug Fixes: diff --git a/gradio/analytics.py b/gradio/analytics.py index c1c1a908bbb1e..c4b710aa26687 100644 --- a/gradio/analytics.py +++ b/gradio/analytics.py @@ -114,7 +114,10 @@ def launched_analytics(blocks: gradio.Blocks, data: dict[str, Any]) -> None: for x in blocks.dependencies: targets_telemetry = targets_telemetry + [ - str(blocks.blocks[y]) for y in x["targets"] + # Sometimes the target can be the Blocks object itself, so we need to check if its in blocks.blocks + str(blocks.blocks[y]) + for y in x["targets"] + if y in blocks.blocks ] inputs_telemetry = inputs_telemetry + [ str(blocks.blocks[y]) for y in x["inputs"] diff --git a/gradio/blocks.py b/gradio/blocks.py index 7b35cd846840f..9731fb739ddb6 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -1482,7 +1482,9 @@ def get_time(): name=name, src=src, hf_token=api_key, alias=alias, **kwargs ) else: - return self_or_cls.set_event_trigger( + from gradio.events import Dependency + + dep, dep_index = self_or_cls.set_event_trigger( event_name="load", fn=fn, inputs=inputs, @@ -1498,7 +1500,8 @@ def get_time(): max_batch_size=max_batch_size, every=every, no_target=True, - )[0] + ) + return Dependency(self_or_cls, dep, dep_index) def clear(self): """Resets the layout of the Blocks object.""" diff --git a/gradio/routes.py b/gradio/routes.py index aad4270d16b8d..753c55c7c634b 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -411,7 +411,7 @@ async def run_predict( dependency = app.get_blocks().dependencies[fn_index_inferred] target = dependency["targets"][0] if len(dependency["targets"]) else None event_data = EventData( - app.get_blocks().blocks[target] if target else None, + app.get_blocks().blocks.get(target) if target else None, body.event_data, ) batch = dependency["batch"] diff --git a/test/test_events.py b/test/test_events.py index 79f8b63b87828..2c72af136cf67 100644 --- a/test/test_events.py +++ b/test/test_events.py @@ -1,8 +1,12 @@ +import os + import pytest from fastapi.testclient import TestClient import gradio as gr +os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" + class TestEvent: def test_clear_event(self): @@ -69,6 +73,47 @@ def clear(): assert not parent.config["dependencies"][2]["trigger_only_on_success"] assert parent.config["dependencies"][3]["trigger_only_on_success"] + def test_load_chaining(self): + calls = 0 + + def increment(): + nonlocal calls + calls += 1 + return str(calls) + + with gr.Blocks() as demo: + out = gr.Textbox(label="Call counter") + demo.load(increment, inputs=None, outputs=out).then( + increment, inputs=None, outputs=out + ) + + assert demo.config["dependencies"][0]["trigger"] == "load" + assert demo.config["dependencies"][0]["trigger_after"] is None + assert demo.config["dependencies"][1]["trigger"] == "then" + assert demo.config["dependencies"][1]["trigger_after"] == 0 + + def test_load_chaining_reuse(self): + calls = 0 + + def increment(): + nonlocal calls + calls += 1 + return str(calls) + + with gr.Blocks() as demo: + out = gr.Textbox(label="Call counter") + demo.load(increment, inputs=None, outputs=out).then( + increment, inputs=None, outputs=out + ) + + with gr.Blocks() as demo2: + demo.render() + + assert demo2.config["dependencies"][0]["trigger"] == "load" + assert demo2.config["dependencies"][0]["trigger_after"] is None + assert demo2.config["dependencies"][1]["trigger"] == "then" + assert demo2.config["dependencies"][1]["trigger_after"] == 0 + class TestEventErrors: def test_event_defined_invalid_scope(self):