Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ui build in one single http request #3020

Merged
merged 9 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions src/backend/base/langflow/api/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import uuid
import warnings
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Dict

from fastapi import HTTPException
from sqlmodel import Session
Expand Down Expand Up @@ -122,12 +122,9 @@ def format_elapsed_time(elapsed_time: float) -> str:
return f"{minutes} {minutes_unit}, {seconds} {seconds_unit}"


async def build_graph_from_db(flow_id: str, session: Session, chat_service: "ChatService"):
async def build_graph_from_data(flow_id: str, payload: Dict, **kwargs):
"""Build and cache the graph."""
flow: Optional[Flow] = session.get(Flow, flow_id)
if not flow or not flow.data:
raise ValueError("Invalid flow ID")
graph = Graph.from_payload(flow.data, flow_id, flow_name=flow.name, user_id=str(flow.user_id))
graph = Graph.from_payload(payload, flow_id, **kwargs)
for vertex_id in graph._has_session_id_vertices:
vertex = graph.get_vertex(vertex_id)
if vertex is None:
Expand All @@ -139,6 +136,19 @@ async def build_graph_from_db(flow_id: str, session: Session, chat_service: "Cha
graph.set_run_id(run_id)
graph.set_run_name()
await graph.initialize_run()
return graph


async def build_graph_from_db_no_cache(flow_id: str, session: Session):
"""Build and cache the graph."""
flow: Optional[Flow] = session.get(Flow, flow_id)
if not flow or not flow.data:
raise ValueError("Invalid flow ID")
return await build_graph_from_data(flow_id, flow.data, flow_name=flow.name, user_id=str(flow.user_id))


async def build_graph_from_db(flow_id: str, session: Session, chat_service: "ChatService"):
graph = await build_graph_from_db_no_cache(flow_id, session)
await chat_service.set_cache(flow_id, graph)
return graph

Expand Down
298 changes: 298 additions & 0 deletions src/backend/base/langflow/api/v1/chat.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import asyncio
import json
import time
import traceback
import typing
import uuid
from typing import TYPE_CHECKING, Annotated, Optional

from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException
from fastapi.responses import StreamingResponse
from loguru import logger
from starlette.background import BackgroundTask
from starlette.responses import ContentStream
from starlette.types import Receive

from langflow.api.utils import (
build_and_cache_graph_from_data,
Expand All @@ -14,6 +20,8 @@
format_exception_message,
get_top_level_vertices,
parse_exception,
build_graph_from_db_no_cache,
build_graph_from_data,
)
from langflow.api.v1.schemas import (
FlowDataRequest,
Expand Down Expand Up @@ -140,6 +148,296 @@ async def retrieve_vertices_order(
raise HTTPException(status_code=500, detail=str(exc)) from exc


@router.post("/build/{flow_id}/flow")
async def build_flow(
background_tasks: BackgroundTasks,
flow_id: uuid.UUID,
inputs: Annotated[Optional[InputValueRequest], Body(embed=True)] = None,
data: Annotated[Optional[FlowDataRequest], Body(embed=True)] = None,
files: Optional[list[str]] = None,
stop_component_id: Optional[str] = None,
start_component_id: Optional[str] = None,
chat_service: "ChatService" = Depends(get_chat_service),
current_user=Depends(get_current_active_user),
telemetry_service: "TelemetryService" = Depends(get_telemetry_service),
session=Depends(get_session),
):
async def build_graph_and_get_order() -> tuple[list[str], list[str], "Graph"]:
start_time = time.perf_counter()
components_count = None
try:
flow_id_str = str(flow_id)
if not data:
graph = await build_graph_from_db_no_cache(flow_id=flow_id_str, session=session)
else:
graph = await build_graph_from_data(flow_id_str, data.model_dump())
graph.validate_stream()
if stop_component_id or start_component_id:
try:
first_layer = graph.sort_vertices(stop_component_id, start_component_id)
except Exception as exc:
logger.error(exc)
first_layer = graph.sort_vertices()
else:
first_layer = graph.sort_vertices()

for vertex_id in first_layer:
graph.run_manager.add_to_vertices_being_run(vertex_id)

# Now vertices is a list of lists
# We need to get the id of each vertex
# and return the same structure but only with the ids
components_count = len(graph.vertices)
vertices_to_run = list(graph.vertices_to_run.union(get_top_level_vertices(graph, graph.vertices_to_run)))
background_tasks.add_task(
telemetry_service.log_package_playground,
PlaygroundPayload(
playgroundSeconds=int(time.perf_counter() - start_time),
playgroundComponentCount=components_count,
playgroundSuccess=True,
),
)
return first_layer, vertices_to_run, graph
except Exception as exc:
background_tasks.add_task(
telemetry_service.log_package_playground,
PlaygroundPayload(
playgroundSeconds=int(time.perf_counter() - start_time),
playgroundComponentCount=components_count,
playgroundSuccess=False,
playgroundErrorMessage=str(exc),
),
)
if "stream or streaming set to True" in str(exc):
raise HTTPException(status_code=400, detail=str(exc))
logger.error(f"Error checking build status: {exc}")
logger.exception(exc)
raise HTTPException(status_code=500, detail=str(exc)) from exc

async def _build_vertex(vertex_id: str, graph: "Graph") -> VertexBuildResponse:
flow_id_str = str(flow_id)

next_runnable_vertices = []
top_level_vertices = []
start_time = time.perf_counter()
error_message = None
try:
vertex = graph.get_vertex(vertex_id)
try:
lock = chat_service._async_cache_locks[flow_id_str]
(
result_dict,
params,
valid,
artifacts,
vertex,
) = await graph.build_vertex(
chat_service=None,
vertex_id=vertex_id,
user_id=current_user.id,
inputs_dict=inputs.model_dump() if inputs else {},
files=files,
)
next_runnable_vertices = await graph.get_next_runnable_vertices(lock, vertex=vertex, cache=False)
top_level_vertices = graph.get_top_level_vertices(next_runnable_vertices)

result_data_response = ResultDataResponse.model_validate(result_dict, from_attributes=True)
except Exception as exc:
if isinstance(exc, ComponentBuildException):
params = exc.message
tb = exc.formatted_traceback
else:
tb = traceback.format_exc()
logger.exception(f"Error building Component: {exc}")
params = format_exception_message(exc)
message = {"errorMessage": params, "stackTrace": tb}
valid = False
error_message = params
output_label = vertex.outputs[0]["name"] if vertex.outputs else "output"
outputs = {output_label: OutputValue(message=message, type="error")}
result_data_response = ResultDataResponse(results={}, outputs=outputs)
artifacts = {}
background_tasks.add_task(graph.end_all_traces, error=exc)

result_data_response.message = artifacts

# Log the vertex build
if not vertex.will_stream:
background_tasks.add_task(
log_vertex_build,
flow_id=flow_id_str,
vertex_id=vertex_id.split("-")[0],
valid=valid,
params=params,
data=result_data_response,
artifacts=artifacts,
)

timedelta = time.perf_counter() - start_time
duration = format_elapsed_time(timedelta)
result_data_response.duration = duration
result_data_response.timedelta = timedelta
vertex.add_build_time(timedelta)
inactivated_vertices = list(graph.inactivated_vertices)
graph.reset_inactivated_vertices()
graph.reset_activated_vertices()
# graph.stop_vertex tells us if the user asked
# to stop the build of the graph at a certain vertex
# if it is in next_vertices_ids, we need to remove other
# vertices from next_vertices_ids
if graph.stop_vertex and graph.stop_vertex in next_runnable_vertices:
next_runnable_vertices = [graph.stop_vertex]

if not graph.run_manager.vertices_being_run and not next_runnable_vertices:
background_tasks.add_task(graph.end_all_traces)

build_response = VertexBuildResponse(
inactivated_vertices=list(set(inactivated_vertices)),
next_vertices_ids=list(set(next_runnable_vertices)),
top_level_vertices=list(set(top_level_vertices)),
valid=valid,
params=params,
id=vertex.id,
data=result_data_response,
)
background_tasks.add_task(
telemetry_service.log_package_component,
ComponentPayload(
componentName=vertex_id.split("-")[0],
componentSeconds=int(time.perf_counter() - start_time),
componentSuccess=valid,
componentErrorMessage=error_message,
),
)
return build_response
except Exception as exc:
background_tasks.add_task(
telemetry_service.log_package_component,
ComponentPayload(
componentName=vertex_id.split("-")[0],
componentSeconds=int(time.perf_counter() - start_time),
componentSuccess=False,
componentErrorMessage=str(exc),
),
)
logger.error(f"Error building Component: \n\n{exc}")
logger.exception(exc)
message = parse_exception(exc)
raise HTTPException(status_code=500, detail=message) from exc

def send_event(event_type: str, value: dict, queue: asyncio.Queue) -> None:
json_data = {"event": event_type, "data": value}
event_id = uuid.uuid4()
logger.debug(f"sending event {event_id}: {event_type}")
str_data = json.dumps(json_data) + "\n\n"
queue.put_nowait((event_id, str_data.encode("utf-8"), time.time()))

async def build_vertices(
vertex_id: str, graph: "Graph", queue: asyncio.Queue, client_consumed_queue: asyncio.Queue
) -> None:
build_task = asyncio.create_task(await asyncio.to_thread(_build_vertex, vertex_id, graph))
try:
await build_task
except asyncio.CancelledError:
build_task.cancel()
return

vertex_build_response: VertexBuildResponse = build_task.result()
# send built event or error event
send_event("end_vertex", {"build_data": json.loads(vertex_build_response.model_dump_json())}, queue)
await client_consumed_queue.get()
if vertex_build_response.valid:
if vertex_build_response.next_vertices_ids:
tasks = []
for next_vertex_id in vertex_build_response.next_vertices_ids:
task = asyncio.create_task(build_vertices(next_vertex_id, graph, queue, client_consumed_queue))
tasks.append(task)
try:
await asyncio.gather(*tasks)
except asyncio.CancelledError:
for task in tasks:
task.cancel()
return

async def event_generator(queue: asyncio.Queue, client_consumed_queue: asyncio.Queue) -> None:
if not data:
# using another thread since the DB query is I/O bound
vertices_task = asyncio.create_task(await asyncio.to_thread(build_graph_and_get_order))
try:
await vertices_task
except asyncio.CancelledError:
vertices_task.cancel()
return

ids, vertices_to_run, graph = vertices_task.result()
else:
ids, vertices_to_run, graph = await build_graph_and_get_order()
send_event("vertices_sorted", {"ids": ids, "to_run": vertices_to_run}, queue)
await client_consumed_queue.get()

tasks = []
for vertex_id in ids:
task = asyncio.create_task(build_vertices(vertex_id, graph, queue, client_consumed_queue))
tasks.append(task)
try:
await asyncio.gather(*tasks)
except asyncio.CancelledError:
for task in tasks:
task.cancel()
return
send_event("end", {}, queue)
await queue.put((None, None, time.time))

async def consume_and_yield(queue: asyncio.Queue, client_consumed_queue: asyncio.Queue) -> typing.AsyncGenerator:
while True:
event_id, value, put_time = await queue.get()
if value is None:
break
get_time = time.time()
yield value
get_time_yield = time.time()
client_consumed_queue.put_nowait(event_id)
logger.debug(
f"consumed event {str(event_id)} (time in queue, {get_time - put_time:.4f}, client {get_time_yield - get_time:.4f})"
)

asyncio_queue: asyncio.Queue = asyncio.Queue()
asyncio_queue_client_consumed: asyncio.Queue = asyncio.Queue()
main_task = asyncio.create_task(event_generator(asyncio_queue, asyncio_queue_client_consumed))

def on_disconnect():
logger.debug("Client disconnected, closing tasks")
main_task.cancel()

return DisconnectHandlerStreamingResponse(
consume_and_yield(asyncio_queue, asyncio_queue_client_consumed),
media_type="application/x-ndjson",
on_disconnect=on_disconnect,
)


class DisconnectHandlerStreamingResponse(StreamingResponse):
def __init__(
self,
content: ContentStream,
status_code: int = 200,
headers: typing.Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
on_disconnect: Optional[typing.Callable] = None,
):
super().__init__(content, status_code, headers, media_type, background)
self.on_disconnect = on_disconnect

async def listen_for_disconnect(self, receive: Receive) -> None:
while True:
message = await receive()
if message["type"] == "http.disconnect":
if self.on_disconnect:
await self.on_disconnect()
break


@router.post("/build/{flow_id}/vertices/{vertex_id}")
async def build_vertex(
flow_id: uuid.UUID,
Expand Down
Loading
Loading