Skip to content

Commit

Permalink
feat: ui build in one single http request
Browse files Browse the repository at this point in the history
  • Loading branch information
nicoloboschi committed Aug 1, 2024
1 parent 4fb96d6 commit 75328f1
Show file tree
Hide file tree
Showing 9 changed files with 574 additions and 31 deletions.
7 changes: 6 additions & 1 deletion src/backend/base/langflow/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def format_elapsed_time(elapsed_time: float) -> str:
return f"{minutes} {minutes_unit}, {seconds} {seconds_unit}"


async def build_graph_from_db(flow_id: str, session: Session, chat_service: "ChatService"):
async def build_graph_from_db_no_cache(flow_id: str, session: Session):
"""Build and cache the graph."""
flow: Optional[Flow] = session.get(Flow, flow_id)
if not flow or not flow.data:
Expand All @@ -139,6 +139,11 @@ async def build_graph_from_db(flow_id: str, session: Session, chat_service: "Cha
graph.set_run_id(run_id)
graph.set_run_name()
await graph.initialize_run()

return graph

async def build_graph_from_db(flow_id: str, session: Session, chat_service: "ChatService"):
graph = await build_graph_from_db_no_cache(flow_id, session)
await chat_service.set_cache(flow_id, graph)
return graph

Expand Down
301 changes: 298 additions & 3 deletions src/backend/base/langflow/api/v1/chat.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,34 @@
import asyncio
import json
import time
import traceback
import typing
import uuid
from asyncio import QueueEmpty

Check failure on line 7 in src/backend/base/langflow/api/v1/chat.py

View workflow job for this annotation

GitHub Actions / Ruff Style Check (3.12)

Ruff (F401)

src/backend/base/langflow/api/v1/chat.py:7:21: F401 `asyncio.QueueEmpty` imported but unused
from typing import TYPE_CHECKING, Annotated, Optional

from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException
from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException, Request

Check failure on line 10 in src/backend/base/langflow/api/v1/chat.py

View workflow job for this annotation

GitHub Actions / Ruff Style Check (3.12)

Ruff (F401)

src/backend/base/langflow/api/v1/chat.py:10:79: F401 `fastapi.Request` imported but unused
from fastapi.responses import StreamingResponse
from loguru import logger
from starlette.background import BackgroundTask
from starlette.responses import ContentStream
from starlette.types import Receive

from langflow.api.utils import (
build_and_cache_graph_from_data,
build_graph_from_db,
format_elapsed_time,
format_exception_message,
get_top_level_vertices,
parse_exception,
parse_exception, build_graph_from_db_no_cache,
)
from langflow.api.v1.schemas import (
FlowDataRequest,
InputValueRequest,
ResultDataResponse,
StreamData,
VertexBuildResponse,
VerticesOrderResponse,
VerticesOrderResponse
)
from langflow.exceptions.component import ComponentBuildException
from langflow.graph.graph.base import Graph
Expand Down Expand Up @@ -139,6 +146,294 @@ async def retrieve_vertices_order(
logger.exception(exc)
raise HTTPException(status_code=500, detail=str(exc)) from exc

@router.post("/build/{flow_id}/flow")
async def build_flow(
background_tasks: BackgroundTasks,
flow_id: uuid.UUID,
inputs: Annotated[Optional[InputValueRequest], Body(embed=True)] = None,
data: Annotated[Optional[FlowDataRequest], Body(embed=True)] = None,
files: Optional[list[str]] = None,
stop_component_id: Optional[str] = None,
start_component_id: Optional[str] = None,
chat_service: "ChatService" = Depends(get_chat_service),
current_user=Depends(get_current_active_user),
telemetry_service: "TelemetryService" = Depends(get_telemetry_service),
session=Depends(get_session),

):

async def build_graph_and_get_order() -> tuple[list[str], list[str], "Graph"]:
start_time = time.perf_counter()
components_count = None
try:
flow_id_str = str(flow_id)
if not data:
graph = await build_graph_from_db_no_cache(flow_id=flow_id_str, session=session)
else:
graph = Graph.from_payload(data.model_dump(), flow_id_str)
graph.validate_stream()
if stop_component_id or start_component_id:
try:
first_layer = graph.sort_vertices(stop_component_id, start_component_id)
except Exception as exc:
logger.error(exc)
first_layer = graph.sort_vertices()
else:
first_layer = graph.sort_vertices()

for vertex_id in first_layer:
graph.run_manager.add_to_vertices_being_run(vertex_id)

# Now vertices is a list of lists
# We need to get the id of each vertex
# and return the same structure but only with the ids
components_count = len(graph.vertices)
vertices_to_run = list(graph.vertices_to_run.union(get_top_level_vertices(graph, graph.vertices_to_run)))
background_tasks.add_task(
telemetry_service.log_package_playground,
PlaygroundPayload(
playgroundSeconds=int(time.perf_counter() - start_time),
playgroundComponentCount=components_count,
playgroundSuccess=True,
),
)
return first_layer, vertices_to_run, graph
except Exception as exc:
background_tasks.add_task(
telemetry_service.log_package_playground,
PlaygroundPayload(
playgroundSeconds=int(time.perf_counter() - start_time),
playgroundComponentCount=components_count,
playgroundSuccess=False,
playgroundErrorMessage=str(exc),
),
)
if "stream or streaming set to True" in str(exc):
raise HTTPException(status_code=400, detail=str(exc))
logger.error(f"Error checking build status: {exc}")
logger.exception(exc)
raise HTTPException(status_code=500, detail=str(exc)) from exc


async def _build_vertex(vertex_id: str, graph: "Graph") -> VertexBuildResponse:
flow_id_str = str(flow_id)

next_runnable_vertices = []
top_level_vertices = []
start_time = time.perf_counter()
error_message = None
try:
vertex = graph.get_vertex(vertex_id)
try:
lock = chat_service._async_cache_locks[flow_id_str]
(
result_dict,
params,
valid,
artifacts,
vertex,
) = await graph.build_vertex(
chat_service=None,
vertex_id=vertex_id,
user_id=current_user.id,
inputs_dict=inputs.model_dump() if inputs else {},
files=files,
)
next_runnable_vertices = await graph.get_next_runnable_vertices(lock, vertex=vertex, cache=False)
top_level_vertices = graph.get_top_level_vertices(next_runnable_vertices)

result_data_response = ResultDataResponse.model_validate(result_dict, from_attributes=True)
except Exception as exc:
if isinstance(exc, ComponentBuildException):
params = exc.message
tb = exc.formatted_traceback
else:
tb = traceback.format_exc()
logger.exception(f"Error building Component: {exc}")
params = format_exception_message(exc)
message = {"errorMessage": params, "stackTrace": tb}
valid = False
error_message = params
output_label = vertex.outputs[0]["name"] if vertex.outputs else "output"
outputs = {output_label: OutputValue(message=message, type="error")}
result_data_response = ResultDataResponse(results={}, outputs=outputs)
artifacts = {}
background_tasks.add_task(graph.end_all_traces, error=exc)

result_data_response.message = artifacts

# Log the vertex build
if not vertex.will_stream:
background_tasks.add_task(
log_vertex_build,
flow_id=flow_id_str,
vertex_id=vertex_id.split("-")[0],
valid=valid,
params=params,
data=result_data_response,
artifacts=artifacts,
)

timedelta = time.perf_counter() - start_time
duration = format_elapsed_time(timedelta)
result_data_response.duration = duration
result_data_response.timedelta = timedelta
vertex.add_build_time(timedelta)
inactivated_vertices = list(graph.inactivated_vertices)
graph.reset_inactivated_vertices()
graph.reset_activated_vertices()
# graph.stop_vertex tells us if the user asked
# to stop the build of the graph at a certain vertex
# if it is in next_vertices_ids, we need to remove other
# vertices from next_vertices_ids
if graph.stop_vertex and graph.stop_vertex in next_runnable_vertices:
next_runnable_vertices = [graph.stop_vertex]

if not graph.run_manager.vertices_being_run and not next_runnable_vertices:
background_tasks.add_task(graph.end_all_traces)

build_response = VertexBuildResponse(
inactivated_vertices=list(set(inactivated_vertices)),
next_vertices_ids=list(set(next_runnable_vertices)),
top_level_vertices=list(set(top_level_vertices)),
valid=valid,
params=params,
id=vertex.id,
data=result_data_response,
)
background_tasks.add_task(
telemetry_service.log_package_component,
ComponentPayload(
componentName=vertex_id.split("-")[0],
componentSeconds=int(time.perf_counter() - start_time),
componentSuccess=valid,
componentErrorMessage=error_message,
),
)
return build_response
except Exception as exc:
background_tasks.add_task(
telemetry_service.log_package_component,
ComponentPayload(
componentName=vertex_id.split("-")[0],
componentSeconds=int(time.perf_counter() - start_time),
componentSuccess=False,
componentErrorMessage=str(exc),
),
)
logger.error(f"Error building Component: \n\n{exc}")
logger.exception(exc)
message = parse_exception(exc)
raise HTTPException(status_code=500, detail=message) from exc


def send_event(event_type: str, value: dict, queue: asyncio.Queue) -> None:
json_data = {
"event": event_type,
"data": value
}
event_id = uuid.uuid4()
logger.debug(f"sending event {event_id}: {event_type}")
str_data = json.dumps(json_data) + "\n\n"
queue.put_nowait((event_id, str_data.encode('utf-8'), time.time()))

async def build_vertices(vertex_id: str, graph: "Graph", queue: asyncio.Queue, client_consumed_queue: asyncio.Queue) -> None:
build_task = asyncio.create_task(await asyncio.to_thread(_build_vertex, vertex_id, graph))
try:
await build_task
except asyncio.CancelledError:
build_task.cancel()
return

vertex_build_response: VertexBuildResponse = build_task.result()
# send built event or error event
send_event("end_vertex", {
"build_data": json.loads(vertex_build_response.model_dump_json())},
queue)
await client_consumed_queue.get()
if vertex_build_response.valid:
if vertex_build_response.next_vertices_ids:
tasks = []
for next_vertex_id in vertex_build_response.next_vertices_ids:
task = asyncio.create_task(build_vertices(next_vertex_id, graph, queue, client_consumed_queue))
tasks.append(task)
try:
await asyncio.gather(*tasks)
except asyncio.CancelledError:
for task in tasks:
task.cancel()
return


async def event_generator(queue: asyncio.Queue, client_consumed_queue: asyncio.Queue) -> None:
if not data:
# using another thread since the DB query is I/O bound
vertices_task = asyncio.create_task(await asyncio.to_thread(build_graph_and_get_order))
try:
await vertices_task
except asyncio.CancelledError:
vertices_task.cancel()
return

ids, vertices_to_run, graph = vertices_task.result()
else:
ids, vertices_to_run, graph = await build_graph_and_get_order()
send_event("vertices_sorted",
{
"ids": ids,
"to_run": vertices_to_run
}, queue)
await client_consumed_queue.get()

tasks = []
for vertex_id in ids:
task = asyncio.create_task(build_vertices(vertex_id, graph, queue, client_consumed_queue))
tasks.append(task)
try:
await asyncio.gather(*tasks)
except asyncio.CancelledError:
for task in tasks:
task.cancel()
return
send_event("end", {}, queue)
await queue.put((None, None, time.time))

async def consume_and_yield(queue: asyncio.Queue, client_consumed_queue: asyncio.Queue) -> None:
while True:
event_id, value, put_time = await queue.get()
if value is None:
break
get_time = time.time()
yield value
get_time_yield = time.time()
client_consumed_queue.put_nowait(event_id)
logger.debug(f"consumed event {str(event_id)} (time in queue, {get_time - put_time:.4f}, client {get_time_yield - get_time:.4f})")

asyncio_queue = asyncio.Queue()
asyncio_queue_client_consumed = asyncio.Queue()
main_task = asyncio.create_task(event_generator(asyncio_queue, asyncio_queue_client_consumed))
def on_disconnect():
logger.debug("Client disconnected, closing tasks")
main_task.cancel()
return DisconnectHandlerStreamingResponse(consume_and_yield(asyncio_queue, asyncio_queue_client_consumed), media_type="application/x-ndjson", on_disconnect=on_disconnect)

class DisconnectHandlerStreamingResponse(StreamingResponse):

def __init__(self,
content: ContentStream, status_code: int = 200, headers: typing.Mapping[str, str] | None = None,
media_type: str | None = None, background: BackgroundTask | None = None, on_disconnect: Optional[typing.Callable] = None):
super().__init__(content, status_code, headers, media_type, background)
self.on_disconnect = on_disconnect

async def listen_for_disconnect(self, receive: Receive) -> None:
while True:
message = await receive()
if message["type"] == "http.disconnect":
if self.on_disconnect:
await self.on_disconnect()
break



@router.post("/build/{flow_id}/vertices/{vertex_id}")
async def build_vertex(
Expand Down
1 change: 1 addition & 0 deletions src/backend/base/langflow/api/v1/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ class InputValueRequest(BaseModel):
)



class SimplifiedAPIRequest(BaseModel):
input_value: Optional[str] = Field(default=None, description="The input value")
input_type: Optional[InputType] = Field(default="chat", description="The input type")
Expand Down
10 changes: 6 additions & 4 deletions src/backend/base/langflow/graph/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,7 @@ def get_root_of_group_node(self, vertex_id: str) -> Vertex:

async def build_vertex(
self,
chat_service: ChatService,
chat_service: Optional[ChatService],
vertex_id: str,
inputs_dict: Optional[Dict[str, str]] = None,
files: Optional[list[str]] = None,
Expand Down Expand Up @@ -881,12 +881,13 @@ async def build_vertex(
params = ""
if vertex.frozen:
# Check the cache for the vertex
cached_result = await chat_service.get_cache(key=vertex.id)
cached_result = await chat_service.get_cache(key=vertex.id) if chat_service else CacheMiss
if isinstance(cached_result, CacheMiss):
await vertex.build(
user_id=user_id, inputs=inputs_dict, fallback_to_env_vars=fallback_to_env_vars, files=files
)
await chat_service.set_cache(key=vertex.id, data=vertex)
if chat_service:
await chat_service.set_cache(key=vertex.id, data=vertex)
else:
cached_vertex = cached_result["result"]
# Now set update the vertex with the cached vertex
Expand All @@ -903,7 +904,8 @@ async def build_vertex(
await vertex.build(
user_id=user_id, inputs=inputs_dict, fallback_to_env_vars=fallback_to_env_vars, files=files
)
await chat_service.set_cache(key=vertex.id, data=vertex)
if chat_service:
await chat_service.set_cache(key=vertex.id, data=vertex)

if vertex.result is not None:
params = f"{vertex._built_object_repr()}{params}"
Expand Down
Loading

0 comments on commit 75328f1

Please sign in to comment.