Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved pipeline validation #1045

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
15 changes: 12 additions & 3 deletions apps/pipelines/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
class PipelineBuildError(Exception):
pass
"""Exception to raise for errors detected at build time."""

def __init__(self, message: str, node_id: str = None, edge_ids: list[str] = None):
self.message = message
self.node_id = node_id
self.edge_ids = edge_ids

class PipelineNodeAttributeError(Exception):
pass
def to_json(self):
if self.node_id:
return {"node": {self.node_id: {"root": self.message}}, "edge": self.edge_ids}
return {"pipeline": self.message, "edge": self.edge_ids}


class PipelineNodeBuildError(Exception):
"""Exception to raise for errors related to bad parameters or
missing attributes that are detected during at runtime"""

pass


Expand Down
28 changes: 24 additions & 4 deletions apps/pipelines/graph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections import defaultdict
from collections import Counter, defaultdict
from functools import cached_property, partial
from typing import Self

Expand Down Expand Up @@ -104,6 +104,7 @@ def build_runnable(self) -> CompiledStateGraph:
raise PipelineBuildError("There are no nodes in the graph")

self._validate_start_end_nodes()
self._validate_no_parallel_nodes()
if self._check_for_cycles():
raise PipelineBuildError("A cycle was detected")

Expand All @@ -116,10 +117,28 @@ def build_runnable(self) -> CompiledStateGraph:
self._add_nodes_to_graph(state_graph, reachable_nodes)
self._add_edges_to_graph(state_graph, reachable_nodes)

compiled_graph = state_graph.compile()
# compiled_graph.get_graph().print_ascii()
try:
compiled_graph = state_graph.compile()
except ValueError as e:
raise PipelineBuildError(str(e))
return compiled_graph

def _validate_no_parallel_nodes(self):
"""This is a simple check to ensure that no two edges are connected to the same output
which serves as a proxy for parallel nodes."""
outgoing_edges = defaultdict(list)
for edge in self.edges:
outgoing_edges[edge.source].append(edge)

for source, edges in outgoing_edges.items():
handles = Counter(edge.sourceHandle for edge in edges)
handle, count = handles.most_common(1)[0]
if count > 1:
edge_ids = [edge.id for edge in edges if edge.sourceHandle == handle]
raise PipelineBuildError(
"Multiple edges connected to the same output", node_id=source, edge_ids=edge_ids
)

def _check_for_cycles(self):
"""Detect cycles in a directed graph."""
adjacency_list = defaultdict(list)
Expand Down Expand Up @@ -162,7 +181,8 @@ def _add_nodes_to_graph(self, state_graph: StateGraph, nodes: list[Node]):
if self.end_node not in nodes:
raise PipelineBuildError(
f"{EndNode.model_config['json_schema_extra'].label} node is not reachable "
f"from {StartNode.model_config['json_schema_extra'].label} node"
f"from {StartNode.model_config['json_schema_extra'].label} node",
node_id=self.end_node.id,
)

for node in nodes:
Expand Down
12 changes: 11 additions & 1 deletion apps/pipelines/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from apps.custom_actions.form_utils import set_custom_actions
from apps.custom_actions.mixins import CustomActionOperationMixin
from apps.experiments.models import ExperimentSession, VersionsMixin, VersionsObjectManagerMixin
from apps.pipelines.exceptions import PipelineBuildError
from apps.pipelines.executor import patch_executor
from apps.pipelines.flow import Flow, FlowNode, FlowNodeData
from apps.pipelines.logging import PipelineLoggingCallbackHandler
Expand Down Expand Up @@ -132,6 +133,7 @@ def update_nodes_from_data(self) -> None:

def validate(self) -> dict:
"""Validate the pipeline nodes and return a dictionary of errors"""
from apps.pipelines.graph import PipelineGraph
from apps.pipelines.nodes import nodes as pipeline_nodes

errors = {}
Expand All @@ -142,7 +144,15 @@ def validate(self) -> dict:
node_class.model_validate(node.params)
except pydantic.ValidationError as e:
errors[node.flow_id] = {err["loc"][0]: err["msg"] for err in e.errors()}
return errors
if errors:
return {"node": errors}

try:
PipelineGraph.build_runnable_from_pipeline(self)
except PipelineBuildError as e:
return e.to_json()

return {}

@cached_property
def flow_data(self) -> dict:
Expand Down
6 changes: 5 additions & 1 deletion apps/pipelines/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from django.conf import settings
from django.core.mail import send_mail

from apps.pipelines.exceptions import PipelineBuildError
from apps.pipelines.models import Pipeline


Expand All @@ -20,4 +21,7 @@ def send_email_from_pipeline(recipient_list, subject, message):
@shared_task
def get_response_for_pipeline_test_message(pipeline_id: int, message_text: str, user_id: int):
pipeline = Pipeline.objects.get(id=pipeline_id)
return pipeline.simple_invoke(message_text, user_id=user_id)
try:
return pipeline.simple_invoke(message_text, user_id)
except PipelineBuildError as e:
return {"error": str(e)}
19 changes: 16 additions & 3 deletions assets/javascript/apps/pipeline/BoundaryNode.tsx
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import React, { ReactNode } from "react";

import { NodeProps, Position } from "reactflow";
import { NodeProps, NodeToolbar, Position } from "reactflow";
import { NodeData } from "./types/nodeParams";
import { nodeBorderClass } from "./utils";
import usePipelineManagerStore from "./stores/pipelineManagerStore";
import { BaseHandle } from "./nodes/BaseHandle";
import { HelpContent } from "./panel/ComponentHelp";

function BoundaryNode({
nodeProps,
Expand All @@ -16,10 +17,22 @@ function BoundaryNode({
children: ReactNode;
}) {
const { id, selected } = nodeProps;
const nodeErrors = usePipelineManagerStore((state) => state.errors[id]);
const nodeError = usePipelineManagerStore((state) => state.getNodeFieldError(id, "root"));
return (
<>
<div className={nodeBorderClass(nodeErrors, selected)}>
<NodeToolbar position={Position.Top} isVisible={!!nodeError}>
<div className="border border-primary join">
{nodeError && (
<div className="dropdown dropdown-top">
<button tabIndex={0} role="button" className="btn btn-xs join-item">
<i className="fa-solid fa-exclamation-triangle text-warning"></i>
</button>
<HelpContent><p>{nodeError}</p></HelpContent>
</div>
)}
</div>
</NodeToolbar>
<div className={nodeBorderClass(!!nodeError, selected)}>
<div className="px-4">
<div className="m-1 text-lg font-bold text-center">{label}</div>
</div>
Expand Down
67 changes: 36 additions & 31 deletions assets/javascript/apps/pipeline/Page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ export default function Page() {
const savePipeline = usePipelineManagerStore((state) => state.savePipeline);
const dirty = usePipelineManagerStore((state) => state.dirty);
const isSaving = usePipelineManagerStore((state) => state.isSaving);
const error = usePipelineManagerStore((state) => state.getPipelineError());
const [name, setName] = useState(currentPipeline?.name);
const [editingName, setEditingName] = useState(false);
const handleNameChange = (event: ChangeEvent<HTMLInputElement>) => {
Expand All @@ -28,39 +29,43 @@ export default function Page() {
<div className="flex h-full overflow-hidden">
<div className="flex flex-1">
<div className="h-full w-full">
<div className="grid grid-cols-2">
<div className="flex gap-2">
{editingName ? (
<>
<input
type="text"
value={name}
onChange={handleNameChange}
className="input input-bordered input-sm"
placeholder="Edit pipeline name"
/>
<button className="btn btn-sm btn-primary" onClick={onClickSave}>
<i className="fa fa-check"></i>
</button>
</>
) : (
<>
<div className="text-lg font-bold">{name}</div>
<button className="btn btn-sm btn-ghost" onClick={() => setEditingName(true)}>
<i className="fa fa-pencil"></i>
</button>
</>
)}
<div className="tooltip tooltip-right" data-tip={dirty ? (isSaving ? "Saving ..." : "Preparing to Save") : "Saved"}>
<button className="btn btn-sm btn-circle no-animation self-center">
{dirty ?
(isSaving ? <div className="loader loader-sm ml-2"></div> :
<i className="fa fa-cloud-upload"></i>)
: <i className="fa fa-check"></i>
}
<div className="flex gap-2">
{editingName ? (
<>
<input
type="text"
value={name}
onChange={handleNameChange}
className="input input-bordered input-sm"
placeholder="Edit pipeline name"
/>
<button className="btn btn-sm btn-primary" onClick={onClickSave}>
<i className="fa fa-check"></i>
</button>
</div>
</>
) : (
<>
<div className="text-lg font-bold">{name}</div>
<button className="btn btn-sm btn-ghost" onClick={() => setEditingName(true)}>
<i className="fa fa-pencil"></i>
</button>
</>
)}
<div className="tooltip tooltip-right" data-tip={dirty ? (isSaving ? "Saving ..." : "Preparing to Save") : "Saved"}>
<button className="btn btn-sm btn-circle no-animation self-center">
{dirty ?
(isSaving ? <span className="loading loading-spinner loading-xs"></span> :
<i className="fa fa-cloud-upload"></i>)
: <i className="fa fa-check"></i>
}
</button>
</div>
{!isSaving && error && (
<div className="content-center">
<i className="fa fa-exclamation-triangle text-red-500 mr-2"></i>
<small className="text-red-500">{error}</small>
</div>
)}
</div>
<div id="react-flow-id" className="relative h-full w-full">
<Pipeline />
Expand Down
15 changes: 12 additions & 3 deletions assets/javascript/apps/pipeline/PipelineNode.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ export function PipelineNode(nodeProps: NodeProps<NodeData>) {
const openEditorForNode = useEditorStore((state) => state.openEditorForNode)
const setNode = usePipelineStore((state) => state.setNode);
const deleteNode = usePipelineStore((state) => state.deleteNode);
const nodeErrors = usePipelineManagerStore((state) => state.errors[id]);
const hasErrors = usePipelineManagerStore((state) => state.nodeHasErrors(id));
const nodeError = usePipelineManagerStore((state) => state.getNodeFieldError(id, "root"));
const {nodeSchemas} = getCachedData();
const nodeSchema = nodeSchemas.get(data.type)!;
const schemaProperties = Object.getOwnPropertyNames(nodeSchema.properties);
Expand Down Expand Up @@ -45,7 +46,7 @@ export function PipelineNode(nodeProps: NodeProps<NodeData>) {

return (
<>
<NodeToolbar position={Position.Top}>
<NodeToolbar position={Position.Top} isVisible={hasErrors}>
<div className="border border-primary join">
<button
className="btn btn-xs join-item"
Expand All @@ -65,9 +66,17 @@ export function PipelineNode(nodeProps: NodeProps<NodeData>) {
<HelpContent><p>{nodeSchema.description}</p></HelpContent>
</div>
)}
{nodeError && (
<div className="dropdown dropdown-top">
<button tabIndex={0} role="button" className="btn btn-xs join-item">
<i className="fa-solid fa-exclamation-triangle text-warning"></i>
</button>
<HelpContent><p>{nodeError}</p></HelpContent>
</div>
)}
</div>
</NodeToolbar>
<div className={nodeBorderClass(nodeErrors, selected)}>
<div className={nodeBorderClass(hasErrors, selected)}>
<div className="m-1 text-lg font-bold text-center">{nodeSchema["ui:label"]}</div>

<NodeInput />
Expand Down
6 changes: 3 additions & 3 deletions assets/javascript/apps/pipeline/nodes/GetInputWidget.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,14 @@ export const getInputWidget = (params: InputWidgetParams) => {
return
}

const getFieldError = usePipelineManagerStore((state) => state.getFieldError);
const getNodeFieldError = usePipelineManagerStore((state) => state.getNodeFieldError);
const widgetOrType = params.schema["ui:widget"] || params.schema.type;
if (widgetOrType == 'none') {
return <></>;
}

const Widget = getWidget(widgetOrType)
let fieldError = getFieldError(params.id, params.name);
let fieldError = getNodeFieldError(params.id, params.name);
const paramValue = params.params[params.name];
if (params.required && (paramValue === null || paramValue === undefined)) {
fieldError = "This field is required";
Expand All @@ -82,7 +82,7 @@ export const getInputWidget = (params: InputWidgetParams) => {
schema={params.schema}
nodeParams={params.params}
required={params.required}
getFieldError={getFieldError}
getNodeFieldError={getNodeFieldError}
/>
)
};
4 changes: 2 additions & 2 deletions assets/javascript/apps/pipeline/nodes/widgets.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ interface WidgetParams {
schema: PropertySchema
nodeParams: NodeParams
required: boolean,
getFieldError: (nodeId: string, fieldName: string) => string | undefined;
getNodeFieldError: (nodeId: string, fieldName: string) => string | undefined;
}

function DefaultWidget(props: WidgetParams) {
Expand Down Expand Up @@ -596,7 +596,7 @@ export function HistoryTypeWidget(props: WidgetParams) {
const options = getSelectOptions(props.schema);
const historyType = concatenate(props.paramValue);
const historyName = concatenate(props.nodeParams["history_name"]);
const historyNameError = props.getFieldError(props.nodeId, "history_name");
const historyNameError = props.getNodeFieldError(props.nodeId, "history_name");
return (
<>
<div className="flex join">
Expand Down
12 changes: 8 additions & 4 deletions assets/javascript/apps/pipeline/panel/TestMessageBox.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,15 @@ export default function TestMessageBox({
response.result &&
typeof response.result !== "string"
) {
// The task finished succesfully and we receive the response
// The task finished successfully and we receive the response
const result = response.result;
setResponseMessage(result.messages[result.messages.length - 1]);
for (const [nodeId, nodeOutput] of Object.entries(result.outputs)) {
setEdgeLabel(nodeId, nodeOutput.output_handle, nodeOutput.message);
if (result.error) {
setErrorMessage(result.error);
} else {
setResponseMessage(result.messages[result.messages.length - 1]);
for (const [nodeId, nodeOutput] of Object.entries(result.outputs)) {
setEdgeLabel(nodeId, nodeOutput.output_handle, nodeOutput.message);
}
}
setLoading(false);
polling = false;
Expand Down
Loading
Loading