Skip to content

Commit

Permalink
improve usage and remove cache
Browse files Browse the repository at this point in the history
  • Loading branch information
nicoloboschi committed Aug 1, 2024
1 parent bc79582 commit b78719f
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 64 deletions.
7 changes: 6 additions & 1 deletion src/backend/base/langflow/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def format_elapsed_time(elapsed_time: float) -> str:
return f"{minutes} {minutes_unit}, {seconds} {seconds_unit}"


async def build_graph_from_db(flow_id: str, session: Session, chat_service: "ChatService"):
async def build_graph_from_db_no_cache(flow_id: str, session: Session):
"""Build and cache the graph."""
flow: Optional[Flow] = session.get(Flow, flow_id)
if not flow or not flow.data:
Expand All @@ -139,6 +139,11 @@ async def build_graph_from_db(flow_id: str, session: Session, chat_service: "Cha
graph.set_run_id(run_id)
graph.set_run_name()
await graph.initialize_run()

return graph

async def build_graph_from_db(flow_id: str, session: Session, chat_service: "ChatService"):
graph = await build_graph_from_db_no_cache(flow_id, session)
await chat_service.set_cache(flow_id, graph)
return graph

Expand Down
90 changes: 43 additions & 47 deletions src/backend/base/langflow/api/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import traceback
import typing
import uuid
from asyncio import QueueEmpty
from typing import TYPE_CHECKING, Annotated, Optional

from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException, Request
Expand All @@ -19,15 +20,15 @@
format_elapsed_time,
format_exception_message,
get_top_level_vertices,
parse_exception,
parse_exception, build_graph_from_db_no_cache,
)
from langflow.api.v1.schemas import (
FlowDataRequest,
InputValueRequest,
ResultDataResponse,
StreamData,
VertexBuildResponse,
VerticesOrderResponse,
VerticesOrderResponse
)
from langflow.exceptions.component import ComponentBuildException
from langflow.graph.graph.base import Graph
Expand Down Expand Up @@ -148,9 +149,9 @@ async def retrieve_vertices_order(
@router.post("/build/{flow_id}/flow")
async def build_flow(
background_tasks: BackgroundTasks,
request: Request,
flow_id: uuid.UUID,
inputs: Annotated[Optional[InputValueRequest], Body(embed=True)] = None,
data: Annotated[Optional[FlowDataRequest], Body(embed=True)] = None,
files: Optional[list[str]] = None,
stop_component_id: Optional[str] = None,
start_component_id: Optional[str] = None,
Expand All @@ -161,12 +162,15 @@ async def build_flow(

):

async def get_vertices_order() -> VerticesOrderResponse:
async def build_graph_and_get_order() -> tuple[list[str], list[str], "Graph"]:
start_time = time.perf_counter()
components_count = None
try:
flow_id_str = str(flow_id)
graph = await build_graph_from_db(flow_id=flow_id_str, session=session, chat_service=chat_service)
if not data:
graph = await build_graph_from_db_no_cache(flow_id=flow_id_str, session=session)
else:
graph = Graph.from_payload(data.model_dump(), flow_id_str)
graph.validate_stream()
if stop_component_id or start_component_id:
try:
Expand All @@ -185,7 +189,6 @@ async def get_vertices_order() -> VerticesOrderResponse:
# and return the same structure but only with the ids
components_count = len(graph.vertices)
vertices_to_run = list(graph.vertices_to_run.union(get_top_level_vertices(graph, graph.vertices_to_run)))
await chat_service.set_cache(str(flow_id), graph)
background_tasks.add_task(
telemetry_service.log_package_playground,
PlaygroundPayload(
Expand All @@ -194,7 +197,7 @@ async def get_vertices_order() -> VerticesOrderResponse:
playgroundSuccess=True,
),
)
return VerticesOrderResponse(ids=first_layer, run_id=graph._run_id, vertices_to_run=vertices_to_run)
return first_layer, vertices_to_run, graph
except Exception as exc:
background_tasks.add_task(
telemetry_service.log_package_playground,
Expand All @@ -212,26 +215,15 @@ async def get_vertices_order() -> VerticesOrderResponse:
raise HTTPException(status_code=500, detail=str(exc)) from exc


async def _build_vertex(vertex_id: str) -> VertexBuildResponse:
async def _build_vertex(vertex_id: str, graph: "Graph") -> VertexBuildResponse:
flow_id_str = str(flow_id)

next_runnable_vertices = []
top_level_vertices = []
start_time = time.perf_counter()
error_message = None
try:
cache = await chat_service.get_cache(flow_id_str)
if not cache:
# If there's no cache
logger.warning(f"No cache found for {flow_id_str}. Building graph starting at {vertex_id}")
graph: "Graph" = await build_graph_from_db(
flow_id=flow_id_str, session=next(get_session()), chat_service=chat_service
)
else:
graph = cache.get("result")
await graph.initialize_run()
vertex = graph.get_vertex(vertex_id)

try:
lock = chat_service._async_cache_locks[flow_id_str]
(
Expand All @@ -241,7 +233,7 @@ async def _build_vertex(vertex_id: str) -> VertexBuildResponse:
artifacts,
vertex,
) = await graph.build_vertex(
chat_service=chat_service,
chat_service=None,
vertex_id=vertex_id,
user_id=current_user.id,
inputs_dict=inputs.model_dump() if inputs else {},
Expand All @@ -267,9 +259,6 @@ async def _build_vertex(vertex_id: str) -> VertexBuildResponse:
result_data_response = ResultDataResponse(results={}, outputs=outputs)
artifacts = {}
background_tasks.add_task(graph.end_all_traces, error=exc)
# If there's an error building the vertex
# we need to clear the cache
await chat_service.clear_cache(flow_id_str)

result_data_response.message = artifacts

Expand All @@ -293,9 +282,6 @@ async def _build_vertex(vertex_id: str) -> VertexBuildResponse:
inactivated_vertices = list(graph.inactivated_vertices)
graph.reset_inactivated_vertices()
graph.reset_activated_vertices()

await chat_service.set_cache(flow_id_str, graph)

# graph.stop_vertex tells us if the user asked
# to stop the build of the graph at a certain vertex
# if it is in next_vertices_ids, we need to remove other
Expand Down Expand Up @@ -341,16 +327,17 @@ async def _build_vertex(vertex_id: str) -> VertexBuildResponse:
raise HTTPException(status_code=500, detail=message) from exc


async def send_event(event_type: str, value: dict, queue: asyncio.Queue) -> None:
def send_event(event_type: str, value: dict, queue: asyncio.Queue) -> None:
json_data = {
"event": event_type,
"data": value
}
logger.debug(f"sending event {event_type}")
await queue.put(json.dumps(json_data))
str_data = json.dumps(json_data)
queue.put_nowait((str_data, time.time()))

async def build_vertices(vertex_id: str, queue: asyncio.Queue) -> None:
build_task = asyncio.create_task(_build_vertex(vertex_id))
async def build_vertices(vertex_id: str, graph: "Graph", queue: asyncio.Queue) -> None:
build_task = asyncio.create_task(await asyncio.to_thread(_build_vertex, vertex_id, graph))
try:
await build_task
except asyncio.CancelledError:
Expand All @@ -359,14 +346,14 @@ async def build_vertices(vertex_id: str, queue: asyncio.Queue) -> None:

vertex_build_response: VertexBuildResponse = build_task.result()
# send built event or error event
await send_event("end_vertex", {
send_event("end_vertex", {
"build_data": json.loads(vertex_build_response.model_dump_json())},
queue)
if vertex_build_response.valid:
if vertex_build_response.next_vertices_ids:
tasks = []
for next_vertex_id in vertex_build_response.next_vertices_ids:
task = asyncio.create_task(build_vertices(next_vertex_id, queue))
task = asyncio.create_task(build_vertices(next_vertex_id, graph, queue))
tasks.append(task)
try:
await asyncio.gather(*tasks)
Expand All @@ -377,43 +364,52 @@ async def build_vertices(vertex_id: str, queue: asyncio.Queue) -> None:


async def event_generator(queue: asyncio.Queue) -> None:
logger.debug("Starting event generator")
order_response = await get_vertices_order()
await send_event("vertices_sorted",
if not data:
# using another thread since the DB query is I/O bound
vertices_task = asyncio.create_task(await asyncio.to_thread(build_graph_and_get_order))
try:
await vertices_task
except asyncio.CancelledError:
vertices_task.cancel()
return

ids, vertices_to_run, graph = vertices_task.result()
else:
ids, vertices_to_run, graph = await build_graph_and_get_order()
send_event("vertices_sorted",
{
"ids": order_response.ids,
"to_run": order_response.vertices_to_run
"ids": ids,
"to_run": vertices_to_run
}, queue)
to_build_ids = order_response.ids
tasks = []
for vertex_id in to_build_ids:
task = asyncio.create_task(build_vertices(vertex_id, queue))
for vertex_id in ids:
task = asyncio.create_task(build_vertices(vertex_id, graph, queue))
tasks.append(task)
try:
await asyncio.gather(*tasks)
except asyncio.CancelledError:
for task in tasks:
task.cancel()
return
await send_event("end", {}, queue)
await queue.put(None)
send_event("end", {}, queue)
await queue.put((None, time.time))

async def consume_and_yield(queue: asyncio.Queue) -> None:
while True:
value = await queue.get()
value, put_time = await queue.get()
if value is None:
break
get_time = time.time()
delay = get_time - put_time
logger.debug(f"consumed event with delay: {delay:.4f}")
yield value + "\n\n"

asyncio_queue = asyncio.Queue()
main_task = asyncio.create_task(event_generator(asyncio_queue))
def on_disconnect():
logger.debug("Client disconnected, closing tasks")
main_task.cancel()

return DisconnectHandlerStreamingResponse(consume_and_yield(asyncio_queue),
media_type="application/x-ndjson",
on_disconnect=on_disconnect)
return DisconnectHandlerStreamingResponse(consume_and_yield(asyncio_queue), media_type="application/x-ndjson", on_disconnect=on_disconnect)

class DisconnectHandlerStreamingResponse(StreamingResponse):

Expand Down
1 change: 1 addition & 0 deletions src/backend/base/langflow/api/v1/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ class InputValueRequest(BaseModel):
)



class SimplifiedAPIRequest(BaseModel):
input_value: Optional[str] = Field(default=None, description="The input value")
input_type: Optional[InputType] = Field(default="chat", description="The input type")
Expand Down
10 changes: 6 additions & 4 deletions src/backend/base/langflow/graph/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,7 @@ def get_root_of_group_node(self, vertex_id: str) -> Vertex:

async def build_vertex(
self,
chat_service: ChatService,
chat_service: Optional[ChatService],
vertex_id: str,
inputs_dict: Optional[Dict[str, str]] = None,
files: Optional[list[str]] = None,
Expand Down Expand Up @@ -881,12 +881,13 @@ async def build_vertex(
params = ""
if vertex.frozen:
# Check the cache for the vertex
cached_result = await chat_service.get_cache(key=vertex.id)
cached_result = await chat_service.get_cache(key=vertex.id) if chat_service else CacheMiss
if isinstance(cached_result, CacheMiss):
await vertex.build(
user_id=user_id, inputs=inputs_dict, fallback_to_env_vars=fallback_to_env_vars, files=files
)
await chat_service.set_cache(key=vertex.id, data=vertex)
if chat_service:
await chat_service.set_cache(key=vertex.id, data=vertex)
else:
cached_vertex = cached_result["result"]
# Now set update the vertex with the cached vertex
Expand All @@ -903,7 +904,8 @@ async def build_vertex(
await vertex.build(
user_id=user_id, inputs=inputs_dict, fallback_to_env_vars=fallback_to_env_vars, files=files
)
await chat_service.set_cache(key=vertex.id, data=vertex)
if chat_service:
await chat_service.set_cache(key=vertex.id, data=vertex)

if vertex.result is not None:
params = f"{vertex._built_object_repr()}{params}"
Expand Down
6 changes: 2 additions & 4 deletions src/frontend/src/stores/flowStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@ import {
} from "../types/flow";
import { FlowStoreType, VertexLayerElementType } from "../types/zustand/flow";
import {
buildFlowVertices,
buildFlowVerticesWithFallback,
buildVertices,
} from "../utils/buildUtils";
import {
checkChatInput,
Expand Down Expand Up @@ -665,8 +663,8 @@ const useFlowStore = create<FlowStoreType>((set, get) => ({
useFlowStore.getState().updateBuildStatus(idList, BuildStatus.BUILDING);
},
onValidateNodes: validateSubgraph,
nodes: !get().onFlowPage ? get().nodes : undefined,
edges: !get().onFlowPage ? get().edges : undefined,
nodes: get().onFlowPage ? get().nodes : undefined,
edges: get().onFlowPage ? get().edges : undefined,
});
get().setIsBuilding(false);
get().setLockChat(false);
Expand Down
24 changes: 16 additions & 8 deletions src/frontend/src/utils/buildUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -154,31 +154,37 @@ export async function buildFlowVertices({
if (stopNodeId) {
url = `${url}&stop_component_id=${stopNodeId}`;
}
const postData = {};
const postData = {}
if (typeof input_value !== "undefined") {
postData["inputs"] = { input_value: input_value };
}
if (files) {
postData["files"] = files;
}
if (nodes) {
postData["data"] = {
nodes,
edges,
};
}

const buildResults: Array<boolean> = [];

const verticesStartTimeMs: Map<string, number> = new Map();

const onEvent = async (type, data): Promise<boolean> => {
const onStartVertex = (id: string) => {
verticesStartTimeMs.set(id, Date.now());
useFlowStore.getState().updateBuildStatus([id], BuildStatus.TO_BUILD);

if (onBuildStart) onBuildStart([{ id: id, reference: id }]);
const onStartVertices = (ids: Array<string>) => {
useFlowStore.getState().updateBuildStatus(ids, BuildStatus.TO_BUILD);
if (onBuildStart) onBuildStart(ids.map(id => ({ id: id, reference: id })));
ids.forEach((id) => verticesStartTimeMs.set(id, Date.now()));
};
switch (type) {
case "vertices_sorted": {
console.log("got vertices_sorted event");
const verticesToRun = data.to_run;
const verticesIds = data.ids;

verticesIds.forEach(onStartVertex);
onStartVertices(verticesIds)

let verticesLayers: Array<Array<VertexLayerElementType>> =
verticesIds.map((id: string) => {
Expand Down Expand Up @@ -246,7 +252,9 @@ export async function buildFlowVertices({
buildResults.push(true);
}
}
buildData.next_vertices_ids?.forEach(onStartVertex);
if (buildData.next_vertices_ids) {
onStartVertices(buildData.next_vertices_ids)
}
return true;
}
case "end": {
Expand Down

0 comments on commit b78719f

Please sign in to comment.