Skip to content

Commit

Permalink
[autofix.ci] apply automated fixes (attempt 2/3)
Browse files Browse the repository at this point in the history
  • Loading branch information
autofix-ci[bot] authored Aug 1, 2024
1 parent fca1425 commit ec1c871
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 31 deletions.
2 changes: 2 additions & 0 deletions src/backend/base/langflow/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,15 @@ async def build_graph_from_data(flow_id: str, payload: Dict, **kwargs):
await graph.initialize_run()
return graph


async def build_graph_from_db_no_cache(flow_id: str, session: Session):
"""Build and cache the graph."""
flow: Optional[Flow] = session.get(Flow, flow_id)
if not flow or not flow.data:
raise ValueError("Invalid flow ID")
return await build_graph_from_data(flow_id, flow.data, flow_name=flow.name, user_id=str(flow.user_id))


async def build_graph_from_db(flow_id: str, session: Session, chat_service: "ChatService"):
graph = await build_graph_from_db_no_cache(flow_id, session)
await chat_service.set_cache(flow_id, graph)
Expand Down
63 changes: 33 additions & 30 deletions src/backend/base/langflow/api/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
import traceback
import typing
import uuid
from asyncio import QueueEmpty
from typing import TYPE_CHECKING, Annotated, Optional

from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException, Request
from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException
from fastapi.responses import StreamingResponse
from loguru import logger
from starlette.background import BackgroundTask
Expand All @@ -20,15 +19,17 @@
format_elapsed_time,
format_exception_message,
get_top_level_vertices,
parse_exception, build_graph_from_db_no_cache, build_graph_from_data,
parse_exception,
build_graph_from_db_no_cache,
build_graph_from_data,
)
from langflow.api.v1.schemas import (
FlowDataRequest,
InputValueRequest,
ResultDataResponse,
StreamData,
VertexBuildResponse,
VerticesOrderResponse
VerticesOrderResponse,
)
from langflow.exceptions.component import ComponentBuildException
from langflow.graph.graph.base import Graph
Expand Down Expand Up @@ -146,6 +147,7 @@ async def retrieve_vertices_order(
logger.exception(exc)
raise HTTPException(status_code=500, detail=str(exc)) from exc


@router.post("/build/{flow_id}/flow")
async def build_flow(
background_tasks: BackgroundTasks,
Expand All @@ -159,9 +161,7 @@ async def build_flow(
current_user=Depends(get_current_active_user),
telemetry_service: "TelemetryService" = Depends(get_telemetry_service),
session=Depends(get_session),

):

async def build_graph_and_get_order() -> tuple[list[str], list[str], "Graph"]:
start_time = time.perf_counter()
components_count = None
Expand Down Expand Up @@ -214,7 +214,6 @@ async def build_graph_and_get_order() -> tuple[list[str], list[str], "Graph"]:
logger.exception(exc)
raise HTTPException(status_code=500, detail=str(exc)) from exc


async def _build_vertex(vertex_id: str, graph: "Graph") -> VertexBuildResponse:
flow_id_str = str(flow_id)

Expand Down Expand Up @@ -326,18 +325,16 @@ async def _build_vertex(vertex_id: str, graph: "Graph") -> VertexBuildResponse:
message = parse_exception(exc)
raise HTTPException(status_code=500, detail=message) from exc


def send_event(event_type: str, value: dict, queue: asyncio.Queue) -> None:
json_data = {
"event": event_type,
"data": value
}
json_data = {"event": event_type, "data": value}
event_id = uuid.uuid4()
logger.debug(f"sending event {event_id}: {event_type}")
str_data = json.dumps(json_data) + "\n\n"
queue.put_nowait((event_id, str_data.encode('utf-8'), time.time()))
queue.put_nowait((event_id, str_data.encode("utf-8"), time.time()))

async def build_vertices(vertex_id: str, graph: "Graph", queue: asyncio.Queue, client_consumed_queue: asyncio.Queue) -> None:
async def build_vertices(
vertex_id: str, graph: "Graph", queue: asyncio.Queue, client_consumed_queue: asyncio.Queue
) -> None:
build_task = asyncio.create_task(await asyncio.to_thread(_build_vertex, vertex_id, graph))
try:
await build_task
Expand All @@ -347,9 +344,7 @@ async def build_vertices(vertex_id: str, graph: "Graph", queue: asyncio.Queue, c

vertex_build_response: VertexBuildResponse = build_task.result()
# send built event or error event
send_event("end_vertex", {
"build_data": json.loads(vertex_build_response.model_dump_json())},
queue)
send_event("end_vertex", {"build_data": json.loads(vertex_build_response.model_dump_json())}, queue)
await client_consumed_queue.get()
if vertex_build_response.valid:
if vertex_build_response.next_vertices_ids:
Expand All @@ -364,7 +359,6 @@ async def build_vertices(vertex_id: str, graph: "Graph", queue: asyncio.Queue, c
task.cancel()
return


async def event_generator(queue: asyncio.Queue, client_consumed_queue: asyncio.Queue) -> None:
if not data:
# using another thread since the DB query is I/O bound
Expand All @@ -378,11 +372,7 @@ async def event_generator(queue: asyncio.Queue, client_consumed_queue: asyncio.Q
ids, vertices_to_run, graph = vertices_task.result()
else:
ids, vertices_to_run, graph = await build_graph_and_get_order()
send_event("vertices_sorted",
{
"ids": ids,
"to_run": vertices_to_run
}, queue)
send_event("vertices_sorted", {"ids": ids, "to_run": vertices_to_run}, queue)
await client_consumed_queue.get()

tasks = []
Expand All @@ -407,21 +397,35 @@ async def consume_and_yield(queue: asyncio.Queue, client_consumed_queue: asyncio
yield value
get_time_yield = time.time()
client_consumed_queue.put_nowait(event_id)
logger.debug(f"consumed event {str(event_id)} (time in queue, {get_time - put_time:.4f}, client {get_time_yield - get_time:.4f})")
logger.debug(
f"consumed event {str(event_id)} (time in queue, {get_time - put_time:.4f}, client {get_time_yield - get_time:.4f})"
)

asyncio_queue = asyncio.Queue()
asyncio_queue_client_consumed = asyncio.Queue()
main_task = asyncio.create_task(event_generator(asyncio_queue, asyncio_queue_client_consumed))

def on_disconnect():
logger.debug("Client disconnected, closing tasks")
main_task.cancel()
return DisconnectHandlerStreamingResponse(consume_and_yield(asyncio_queue, asyncio_queue_client_consumed), media_type="application/x-ndjson", on_disconnect=on_disconnect)

class DisconnectHandlerStreamingResponse(StreamingResponse):
return DisconnectHandlerStreamingResponse(
consume_and_yield(asyncio_queue, asyncio_queue_client_consumed),
media_type="application/x-ndjson",
on_disconnect=on_disconnect,
)

def __init__(self,
content: ContentStream, status_code: int = 200, headers: typing.Mapping[str, str] | None = None,
media_type: str | None = None, background: BackgroundTask | None = None, on_disconnect: Optional[typing.Callable] = None):

class DisconnectHandlerStreamingResponse(StreamingResponse):
def __init__(
self,
content: ContentStream,
status_code: int = 200,
headers: typing.Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
on_disconnect: Optional[typing.Callable] = None,
):
super().__init__(content, status_code, headers, media_type, background)
self.on_disconnect = on_disconnect

Expand All @@ -434,7 +438,6 @@ async def listen_for_disconnect(self, receive: Receive) -> None:
break



@router.post("/build/{flow_id}/vertices/{vertex_id}")
async def build_vertex(
flow_id: uuid.UUID,
Expand Down
1 change: 0 additions & 1 deletion src/backend/base/langflow/api/v1/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,6 @@ class InputValueRequest(BaseModel):
)



class SimplifiedAPIRequest(BaseModel):
input_value: Optional[str] = Field(default=None, description="The input value")
input_type: Optional[InputType] = Field(default="chat", description="The input type")
Expand Down

0 comments on commit ec1c871

Please sign in to comment.