From a2135a060b52a08da46d18e04b5f0313c66d6da5 Mon Sep 17 00:00:00 2001 From: Joey Ballentine <34788790+joeyballentine@users.noreply.github.com> Date: Mon, 15 Jul 2024 16:17:50 -0400 Subject: [PATCH] Add broadcasts for sequences (#2971) * start of allowing generators to broadcast length data * allow output lengths to be set * Fixed typing and broadcasting issue for sequence types (#2972) * Fixed typing and broadcasting issue for sequence types * Fixed errors * Sequence item types for Load Video (#2974) * Sequence item types for Load Video * Ignore type error * Don't cache individually run generators and delete any caches on new run --------- Co-authored-by: Michael Schmidt --- backend/src/api/__init__.py | 1 + backend/src/api/api.py | 64 ----------- backend/src/api/iter.py | 72 +++++++++++++ backend/src/api/node_data.py | 36 ++++++- backend/src/events.py | 8 +- .../image/batch_processing/load_images.py | 6 +- .../image/video_frames/load_video.py | 24 +++-- backend/src/process.py | 102 +++++++++++++----- backend/src/server.py | 43 +++++++- src/common/Backend.ts | 2 + src/common/common-types.ts | 9 +- src/common/nodes/TypeState.ts | 45 +++++--- src/common/types/function.ts | 21 ++-- src/main/cli/run.ts | 11 +- .../NodeDocumentation/NodeExample.tsx | 1 + src/renderer/components/node/NodeOutputs.tsx | 17 ++- src/renderer/contexts/ExecutionContext.tsx | 24 ++++- src/renderer/contexts/GlobalNodeState.tsx | 76 +++++-------- src/renderer/hooks/useAutomaticFeatures.ts | 5 +- src/renderer/hooks/useOutputDataStore.ts | 9 +- src/renderer/hooks/useTypeMap.ts | 60 +++++++++++ 21 files changed, 436 insertions(+), 200 deletions(-) create mode 100644 backend/src/api/iter.py create mode 100644 src/renderer/hooks/useTypeMap.ts diff --git a/backend/src/api/__init__.py b/backend/src/api/__init__.py index 548a1a7642..89c30d4f54 100644 --- a/backend/src/api/__init__.py +++ b/backend/src/api/__init__.py @@ -1,6 +1,7 @@ from .api import * from .group import * from .input import * +from .iter import * from .lazy import * from .node_context import * from .node_data import * diff --git a/backend/src/api/api.py b/backend/src/api/api.py index a6cce94d27..08a4db4663 100644 --- a/backend/src/api/api.py +++ b/backend/src/api/api.py @@ -7,7 +7,6 @@ Any, Awaitable, Callable, - Generic, Iterable, TypeVar, ) @@ -504,66 +503,3 @@ def add_package( dependencies=dependencies or [], ) ) - - -I = TypeVar("I") -L = TypeVar("L") - - -@dataclass -class Generator(Generic[I]): - supplier: Callable[[], Iterable[I | Exception]] - expected_length: int - fail_fast: bool = True - - def with_fail_fast(self, fail_fast: bool) -> Generator[I]: - return Generator(self.supplier, self.expected_length, fail_fast=fail_fast) - - @staticmethod - def from_iter( - supplier: Callable[[], Iterable[I | Exception]], expected_length: int - ) -> Generator[I]: - return Generator(supplier, expected_length) - - @staticmethod - def from_list(l: list[L], map_fn: Callable[[L, int], I]) -> Generator[I]: - """ - Creates a new generator from a list that is mapped using the given - function. The iterable will be equivalent to `map(map_fn, l)`. - """ - - def supplier(): - for i, x in enumerate(l): - try: - yield map_fn(x, i) - except Exception as e: - yield e - - return Generator(supplier, len(l)) - - @staticmethod - def from_range(count: int, map_fn: Callable[[int], I]) -> Generator[I]: - """ - Creates a new generator the given number of items where each item is - lazily evaluated. The iterable will be equivalent to `map(map_fn, range(count))`. - """ - assert count >= 0 - - def supplier(): - for i in range(count): - try: - yield map_fn(i) - except Exception as e: - yield e - - return Generator(supplier, count) - - -N = TypeVar("N") -R = TypeVar("R") - - -@dataclass -class Collector(Generic[N, R]): - on_iterate: Callable[[N], None] - on_complete: Callable[[], R] diff --git a/backend/src/api/iter.py b/backend/src/api/iter.py new file mode 100644 index 0000000000..1add2e0ccb --- /dev/null +++ b/backend/src/api/iter.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, Generic, Iterable, TypeVar + +I = TypeVar("I") +L = TypeVar("L") + + +@dataclass +class Generator(Generic[I]): + supplier: Callable[[], Iterable[I | Exception]] + expected_length: int + fail_fast: bool = True + metadata: object | None = None + + def with_fail_fast(self, fail_fast: bool): + self.fail_fast = fail_fast + return self + + def with_metadata(self, metadata: object): + self.metadata = metadata + return self + + @staticmethod + def from_iter( + supplier: Callable[[], Iterable[I | Exception]], expected_length: int + ) -> Generator[I]: + return Generator(supplier, expected_length) + + @staticmethod + def from_list(l: list[L], map_fn: Callable[[L, int], I]) -> Generator[I]: + """ + Creates a new generator from a list that is mapped using the given + function. The iterable will be equivalent to `map(map_fn, l)`. + """ + + def supplier(): + for i, x in enumerate(l): + try: + yield map_fn(x, i) + except Exception as e: + yield e + + return Generator(supplier, len(l)) + + @staticmethod + def from_range(count: int, map_fn: Callable[[int], I]) -> Generator[I]: + """ + Creates a new generator the given number of items where each item is + lazily evaluated. The iterable will be equivalent to `map(map_fn, range(count))`. + """ + assert count >= 0 + + def supplier(): + for i in range(count): + try: + yield map_fn(i) + except Exception as e: + yield e + + return Generator(supplier, count) + + +N = TypeVar("N") +R = TypeVar("R") + + +@dataclass +class Collector(Generic[N, R]): + on_iterate: Callable[[N], None] + on_complete: Callable[[], R] diff --git a/backend/src/api/node_data.py b/backend/src/api/node_data.py index a42c706129..475ce9fd99 100644 --- a/backend/src/api/node_data.py +++ b/backend/src/api/node_data.py @@ -2,12 +2,13 @@ from dataclasses import dataclass from enum import Enum -from typing import Any, Mapping +from typing import Any, Callable, Generic, Mapping, Protocol, TypeVar import navi from .group import NestedIdGroup from .input import BaseInput +from .iter import Generator from .output import BaseOutput from .types import ( FeatureId, @@ -42,6 +43,13 @@ def to_dict(self): } +M_co = TypeVar("M_co", covariant=True) + + +class AnyConstructor(Protocol, Generic[M_co]): + def __call__(self, *args: Any, **kwargs: Any) -> M_co: ... + + class IteratorOutputInfo: def __init__( self, @@ -56,6 +64,20 @@ def __init__( ) self.length_type: navi.ExpressionJson = length_type + self._metadata_constructor: Any | None = None + self._item_types_fn: ( + Callable[[Any], Mapping[OutputId, navi.ExpressionJson]] | None + ) = None + + def with_item_types( + self, + class_: AnyConstructor[M_co], + fn: Callable[[M_co], Mapping[OutputId, navi.ExpressionJson]], + ): + self._metadata_constructor = class_ + self._item_types_fn = fn + return self + def to_dict(self): return { "id": self.id, @@ -63,6 +85,18 @@ def to_dict(self): "sequenceType": navi.named("Sequence", {"length": self.length_type}), } + def get_broadcast_sequence_type(self, generator: Generator) -> navi.ExpressionJson: + return navi.named("Sequence", {"length": generator.expected_length}) + + def get_broadcast_item_types( + self, generator: Generator + ) -> Mapping[OutputId, navi.ExpressionJson]: + if self._item_types_fn is not None and self._metadata_constructor is not None: + metadata = generator.metadata + if isinstance(metadata, self._metadata_constructor): + return self._item_types_fn(metadata) + return {} + class KeyInfo: def __init__(self, data: dict[str, Any]) -> None: diff --git a/backend/src/events.py b/backend/src/events.py index 6e978c2950..5c20302ff4 100644 --- a/backend/src/events.py +++ b/backend/src/events.py @@ -5,7 +5,8 @@ from dataclasses import dataclass from typing import Dict, Literal, TypedDict, Union -from api import ErrorValue, InputId, NodeId, OutputId +import navi +from api import BroadcastData, ErrorValue, InputId, IterOutputId, NodeId, OutputId # General events @@ -87,8 +88,9 @@ class NodeProgressUpdateEvent(TypedDict): class NodeBroadcastData(TypedDict): nodeId: NodeId - data: dict[OutputId, object] - types: dict[OutputId, object] + data: dict[OutputId, BroadcastData | None] + types: dict[OutputId, navi.ExpressionJson | None] + sequenceTypes: dict[IterOutputId, navi.ExpressionJson] | None class NodeBroadcastEvent(TypedDict): diff --git a/backend/src/packages/chaiNNer_standard/image/batch_processing/load_images.py b/backend/src/packages/chaiNNer_standard/image/batch_processing/load_images.py index 4ab28ff859..3c4c000164 100644 --- a/backend/src/packages/chaiNNer_standard/image/batch_processing/load_images.py +++ b/backend/src/packages/chaiNNer_standard/image/batch_processing/load_images.py @@ -94,16 +94,14 @@ def list_glob(directory: Path, globexpr: str, ext_filter: list[str]) -> list[Pat DirectoryOutput("Directory", output_type="Input0"), TextOutput("Subdirectory Path"), TextOutput("Name"), - NumberOutput( - "Index", - output_type="if Input4 { min(uint, Input5 - 1) } else { uint }", - ), + NumberOutput("Index", output_type="min(uint, max(0, IterOutput0.length - 1))"), ], iterator_outputs=IteratorOutputInfo( outputs=[0, 2, 3, 4], length_type="if Input4 { min(uint, Input5) } else { uint }", ), kind="generator", + side_effects=True, ) def load_images_node( directory: Path, diff --git a/backend/src/packages/chaiNNer_standard/image/video_frames/load_video.py b/backend/src/packages/chaiNNer_standard/image/video_frames/load_video.py index b546d545fb..cac1484837 100644 --- a/backend/src/packages/chaiNNer_standard/image/video_frames/load_video.py +++ b/backend/src/packages/chaiNNer_standard/image/video_frames/load_video.py @@ -5,10 +5,11 @@ import numpy as np -from api import Generator, IteratorOutputInfo, NodeContext +import navi +from api import Generator, IteratorOutputInfo, NodeContext, OutputId from nodes.groups import Condition, if_group from nodes.impl.ffmpeg import FFMpegEnv -from nodes.impl.video import VideoLoader +from nodes.impl.video import VideoLoader, VideoMetadata from nodes.properties.inputs import BoolInput, NumberInput, VideoFileInput from nodes.properties.outputs import ( AudioStreamOutput, @@ -22,6 +23,14 @@ from .. import video_frames_group +def get_item_types(metadata: VideoMetadata): + return { + OutputId(0): navi.Image( + width=metadata.width, height=metadata.height, channels=3 + ), + } + + @video_frames_group.register( schema_id="chainner:image:load_video", name="Load Video", @@ -46,8 +55,8 @@ outputs=[ ImageOutput("Frame", channels=3), NumberOutput( - "Frame Index", - output_type="if Input1 { min(uint, Input2 - 1) } else { uint }", + "Index", + output_type="min(uint, max(0, IterOutput0.length - 1))", ).with_docs("A counter that starts at 0 and increments by 1 for each frame."), DirectoryOutput("Video Directory", of_input=0), FileNameOutput("Name", of_input=0), @@ -56,8 +65,9 @@ ], iterator_outputs=IteratorOutputInfo( outputs=[0, 1], length_type="if Input1 { min(uint, Input2) } else { uint }" - ), + ).with_item_types(VideoMetadata, get_item_types), # type: ignore node_context=True, + side_effects=True, kind="generator", ) def load_video_node( @@ -83,7 +93,9 @@ def iterator(): break return ( - Generator.from_iter(supplier=iterator, expected_length=frame_count), + Generator.from_iter( + supplier=iterator, expected_length=frame_count + ).with_metadata(loader.metadata), video_dir, video_name, loader.metadata.fps, diff --git a/backend/src/process.py b/backend/src/process.py index 4b0b427a9f..8cc9dc308c 100644 --- a/backend/src/process.py +++ b/backend/src/process.py @@ -13,13 +13,17 @@ from sanic.log import logger +import navi from api import ( BaseInput, BaseOutput, + BroadcastData, Collector, ExecutionOptions, Generator, InputId, + IteratorOutputInfo, + IterOutputId, Lazy, NodeContext, NodeData, @@ -31,7 +35,7 @@ from chain.cache import CacheStrategy, OutputCache, StaticCaching, get_cache_strategies from chain.chain import Chain, CollectorNode, FunctionNode, GeneratorNode, Node from chain.input import EdgeInput, Input, InputMap -from events import EventConsumer, InputsDict +from events import EventConsumer, InputsDict, NodeBroadcastData from progress_controller import Aborted, ProgressController, ProgressToken from util import combine_sets, timed_supplier @@ -143,7 +147,11 @@ def enforce_generator_output(raw_output: object, node: NodeData) -> GeneratorOut assert isinstance( raw_output, Generator ), "Expected the output to be a generator" - return GeneratorOutput(generator=raw_output, partial_output=partial) + return GeneratorOutput( + info=generator_output, + generator=raw_output, + partial_output=partial, + ) assert l > len(generator_output.outputs) assert isinstance(raw_output, (tuple, list)) @@ -159,7 +167,11 @@ def enforce_generator_output(raw_output: object, node: NodeData) -> GeneratorOut if o.id not in generator_output.outputs: partial[i] = o.enforce(rest.pop(0)) - return GeneratorOutput(generator=iterator, partial_output=partial) + return GeneratorOutput( + info=generator_output, + generator=iterator, + partial_output=partial, + ) def run_node( @@ -287,8 +299,8 @@ def add(self): def compute_broadcast(output: Output, node_outputs: Iterable[BaseOutput]): - data: dict[OutputId, object] = {} - types: dict[OutputId, object] = {} + data: dict[OutputId, BroadcastData | None] = {} + types: dict[OutputId, navi.ExpressionJson | None] = {} for index, node_output in enumerate(node_outputs): try: value = output[index] @@ -300,6 +312,21 @@ def compute_broadcast(output: Output, node_outputs: Iterable[BaseOutput]): return data, types +def compute_sequence_broadcast( + generators: Iterable[Generator], node_iter_outputs: Iterable[IteratorOutputInfo] +): + sequence_types: dict[IterOutputId, navi.ExpressionJson] = {} + item_types: dict[OutputId, navi.ExpressionJson] = {} + for g, iter_output in zip(generators, node_iter_outputs): + try: + sequence_types[iter_output.id] = iter_output.get_broadcast_sequence_type(g) + for output_id, type in iter_output.get_broadcast_item_types(g).items(): + item_types[output_id] = type + except Exception as e: + logger.error(f"Error broadcasting output: {e}") + return sequence_types, item_types + + class NodeExecutionError(Exception): def __init__( self, @@ -321,6 +348,7 @@ class RegularOutput: @dataclass(frozen=True) class GeneratorOutput: + info: IteratorOutputInfo generator: Generator partial_output: Output @@ -424,15 +452,18 @@ def __init__( self._storage_dir = storage_dir - async def process(self, node_id: NodeId) -> NodeOutput | CollectorOutput: + async def process( + self, node_id: NodeId, perform_cache: bool = True + ) -> NodeOutput | CollectorOutput: # Return cached output value from an already-run node if that cached output exists - cached = self.node_cache.get(node_id) - if cached is not None: - return cached + if perform_cache: + cached = self.node_cache.get(node_id) + if cached is not None: + return cached node = self.chain.nodes[node_id] try: - return await self.__process(node) + return await self.__process(node, perform_cache) except Aborted: raise except NodeExecutionError: @@ -450,14 +481,16 @@ async def process_regular_node(self, node: FunctionNode) -> RegularOutput: assert isinstance(result, RegularOutput) return result - async def process_generator_node(self, node: GeneratorNode) -> GeneratorOutput: + async def process_generator_node( + self, node: GeneratorNode, perform_cache: bool = True + ) -> GeneratorOutput: """ Processes the given iterator node. This will **not** iterate the returned generator. Only `node-start` and `node-broadcast` events will be sent. """ - result = await self.process(node.id) + result = await self.process(node.id, perform_cache) assert isinstance(result, GeneratorOutput) return result @@ -571,7 +604,9 @@ def __get_node_context(self, node: Node) -> _ExecutorNodeContext: return context - async def __process(self, node: Node) -> NodeOutput | CollectorOutput: + async def __process( + self, node: Node, perform_cache: bool = True + ) -> NodeOutput | CollectorOutput: """ Process a single node. @@ -617,11 +652,15 @@ def get_lazy_evaluation_time(): await self.__send_node_broadcast(node, output.output) self.__send_node_finish(node, execution_time) elif isinstance(output, GeneratorOutput): - await self.__send_node_broadcast(node, output.partial_output) + await self.__send_node_broadcast( + node, + output.partial_output, + generators=[output.generator], + ) # TODO: execution time # Cache the output of the node - if not isinstance(output, CollectorOutput): + if perform_cache and not isinstance(output, CollectorOutput): self.node_cache.set(node.id, output, self.cache_strategy[node.id]) await self.progress.suspend() @@ -1011,12 +1050,19 @@ async def __send_node_broadcast( self, node: Node, output: Output, + generators: Iterable[Generator] | None = None, ): def compute_broadcast_data(): if self.progress.aborted: # abort the broadcast if the chain was aborted return None - return compute_broadcast(output, node.data.outputs) + foo = compute_broadcast(output, node.data.outputs) + if generators is None: + return (*foo, {}, {}) + return ( + *foo, + *compute_sequence_broadcast(generators, node.data.iterable_outputs), + ) async def send_broadcast(): # TODO: Add the time it takes to compute the broadcast data to the execution time @@ -1024,17 +1070,19 @@ async def send_broadcast(): if result is None or self.progress.aborted: return - data, types = result - self.queue.put( - { - "event": "node-broadcast", - "data": { - "nodeId": node.id, - "data": data, - "types": types, - }, - } - ) + data, types, sequence_types, item_types = result + + # assign item types + for output_id, type in item_types.items(): + types[output_id] = type + + evant_data: NodeBroadcastData = { + "nodeId": node.id, + "data": data, + "types": types, + "sequenceTypes": sequence_types, + } + self.queue.put({"event": "node-broadcast", "data": evant_data}) # Only broadcast the output if the node has outputs if self.send_broadcast_data and len(node.data.outputs) > 0: diff --git a/backend/src/server.py b/backend/src/server.py index 1a52bb2a9a..6c5148e11c 100644 --- a/backend/src/server.py +++ b/backend/src/server.py @@ -28,7 +28,7 @@ NodeId, ) from chain.cache import OutputCache -from chain.chain import Chain, FunctionNode +from chain.chain import Chain, FunctionNode, GeneratorNode from chain.json import JsonNode, parse_json from chain.optimize import optimize from dependencies.store import installed_packages @@ -207,6 +207,16 @@ async def run(request: Request): chain = parse_json(full_data["data"]) optimize(chain) + # Remove all Generator values from the cache for each new run + # Otherwise, their state will cause them to resume from where they left off + schema_data = api.registry.nodes + for node in chain.nodes.values(): + node_schema = schema_data.get(node.schema_id) + if node_schema: + node_data, _ = node_schema + if node_data.kind == "generator" and ctx.cache.get(node.id): + ctx.cache.pop(node.id) + logger.info("Running new executor...") executor = Executor( id=executor_id, @@ -287,7 +297,22 @@ async def run_individual(request: Request): node_id = full_data["id"] ctx.cache.pop(node_id, None) - node = FunctionNode(node_id, full_data["schemaId"]) + schema_data = api.registry.nodes.get(full_data["schemaId"]) + + if schema_data is None: + raise ValueError( + f"Invalid node {full_data['schemaId']} attempted to run individually" + ) + + node_data, _ = schema_data + if node_data.kind == "generator": + node = GeneratorNode(node_id, full_data["schemaId"]) + elif node_data.kind == "regularNode": + node = FunctionNode(node_id, full_data["schemaId"]) + else: + raise ValueError( + f"Invalid node kind {node_data.kind} attempted to run individually" + ) chain = Chain() chain.add_node(node) @@ -319,8 +344,18 @@ async def run_individual(request: Request): old_executor.kill() ctx.individual_executors[execution_id] = executor - output = await executor.process_regular_node(node) - ctx.cache[node_id] = output + if node_data.kind == "generator": + assert isinstance(node, GeneratorNode) + output = await executor.process_generator_node(node) + elif node_data.kind == "regularNode": + assert isinstance(node, FunctionNode) + output = await executor.process_regular_node(node) + else: + raise ValueError( + f"Invalid node kind {node_data.kind} attempted to run individually" + ) + if not isinstance(node, GeneratorNode): + ctx.cache[node_id] = output except Aborted: pass finally: diff --git a/src/common/Backend.ts b/src/common/Backend.ts index 7b37d70f2e..1c9dcafc87 100644 --- a/src/common/Backend.ts +++ b/src/common/Backend.ts @@ -6,6 +6,7 @@ import { FeatureState, InputId, InputValue, + IterOutputTypes, NodeSchema, OutputData, OutputTypes, @@ -345,6 +346,7 @@ export interface BackendEventMap { nodeId: string; data: OutputData; types: OutputTypes; + sequenceTypes?: IterOutputTypes | null; }; 'backend-status': { message: string; diff --git a/src/common/common-types.ts b/src/common/common-types.ts index ea18447aba..4322c24f81 100644 --- a/src/common/common-types.ts +++ b/src/common/common-types.ts @@ -16,8 +16,8 @@ export interface Size { export type SchemaId = string & { readonly __schemaId: never }; export type InputId = number & { readonly __inputId: never }; export type OutputId = number & { readonly __outputId: never }; -export type IteratorInputId = number & { readonly __iteratorInputId: never }; -export type IteratorOutputId = number & { readonly __iteratorOutputId: never }; +export type IterInputId = number & { readonly __iteratorInputId: never }; +export type IterOutputId = number & { readonly __iteratorOutputId: never }; export type GroupId = number & { readonly __groupId: never }; export type PackageId = string & { readonly __packageId: never }; export type FeatureId = string & { readonly __featureId: never }; @@ -279,14 +279,15 @@ export type InputHeight = Readonly>; export type OutputData = Readonly>; export type OutputHeight = Readonly>; export type OutputTypes = Readonly>>; +export type IterOutputTypes = Readonly>>; export interface IteratorInputInfo { - readonly id: IteratorInputId; + readonly id: IterInputId; readonly inputs: readonly InputId[]; readonly sequenceType: ExpressionJson; } export interface IteratorOutputInfo { - readonly id: IteratorOutputId; + readonly id: IterOutputId; readonly outputs: readonly OutputId[]; readonly sequenceType: ExpressionJson; } diff --git a/src/common/nodes/TypeState.ts b/src/common/nodes/TypeState.ts index afdb26fa48..06612f0d4c 100644 --- a/src/common/nodes/TypeState.ts +++ b/src/common/nodes/TypeState.ts @@ -1,5 +1,5 @@ import { EvaluationError, NonNeverType, Type, isSameType } from '@chainner/navi'; -import { EdgeData, InputId, NodeData, OutputId, SchemaId } from '../common-types'; +import { EdgeData, InputId, IterOutputId, NodeData, OutputId, SchemaId } from '../common-types'; import { log } from '../log'; import { PassthroughMap } from '../PassthroughMap'; import { @@ -22,23 +22,38 @@ const assignmentErrorEquals = ( isSameType(a.inputType, b.inputType) ); }; +const mapEqual = >( + a: ReadonlyMap, + b: ReadonlyMap, + eq: (a: V, b: V) => boolean +): boolean => { + if (a.size !== b.size) return false; + for (const [key, value] of a) { + const otherValue = b.get(key); + if (otherValue === undefined || !eq(value, otherValue)) return false; + } + return true; +}; +const arrayEqual = ( + a: ReadonlyArray, + b: ReadonlyArray, + eq: (a: T, b: T) => boolean +): boolean => { + if (a.length !== b.length) return false; + for (let i = 0; i < a.length; i += 1) { + if (!eq(a[i], b[i])) return false; + } + return true; +}; const instanceEqual = (a: FunctionInstance, b: FunctionInstance): boolean => { if (a.definition !== b.definition) return false; - for (const [key, value] of a.inputs) { - const otherValue = b.inputs.get(key); - if (!otherValue || !isSameType(value, otherValue)) return false; - } - - for (const [key, value] of a.outputs) { - const otherValue = b.outputs.get(key); - if (!otherValue || !isSameType(value, otherValue)) return false; - } + if (!mapEqual(a.inputs, b.inputs, isSameType)) return false; + if (!mapEqual(a.inputSequence, b.inputSequence, isSameType)) return false; + if (!mapEqual(a.outputs, b.outputs, isSameType)) return false; + if (!mapEqual(a.outputSequence, b.outputSequence, isSameType)) return false; - if (a.inputErrors.length !== b.inputErrors.length) return false; - for (let i = 0; i < a.inputErrors.length; i += 1) { - if (!assignmentErrorEquals(a.inputErrors[i], b.inputErrors[i])) return false; - } + if (!arrayEqual(a.inputErrors, b.inputErrors, assignmentErrorEquals)) return false; return true; }; @@ -66,6 +81,7 @@ export class TypeState { nodesMap: ReadonlyMap>, rawEdges: readonly Edge[], outputNarrowing: ReadonlyMap>, + sequenceOutputNarrowing: ReadonlyMap>, functionDefinitions: ReadonlyMap, passthrough?: PassthroughMap, previousTypeState?: TypeState @@ -127,6 +143,7 @@ export class TypeState { return undefined; }, outputNarrowing.get(n.id), + sequenceOutputNarrowing.get(n.id), passthroughInfo ); } catch (error) { diff --git a/src/common/types/function.ts b/src/common/types/function.ts index bc4b28e866..4c2dc9a082 100644 --- a/src/common/types/function.ts +++ b/src/common/types/function.ts @@ -19,9 +19,9 @@ import { Input, InputId, InputSchemaValue, - IteratorInputId, + IterInputId, + IterOutputId, IteratorInputInfo, - IteratorOutputId, IteratorOutputInfo, NodeSchema, Output, @@ -56,9 +56,9 @@ const getParamRefs = ( }; export const getInputParamName = (inputId: InputId) => `Input${inputId}` as const; -export const getIterInputParamName = (id: IteratorInputId) => `IterInput${id}` as const; +export const getIterInputParamName = (id: IterInputId) => `IterInput${id}` as const; export const getOutputParamName = (outputId: OutputId) => `Output${outputId}` as const; -export const getIterOutputParamName = (id: IteratorOutputId) => `IterOutput${id}` as const; +export const getIterOutputParamName = (id: IterOutputId) => `IterOutput${id}` as const; interface BaseDesc

{ readonly type: P; @@ -592,6 +592,7 @@ export class FunctionInstance { definition: FunctionDefinition, partialInputs: (inputId: InputId) => NonNeverType | undefined, outputNarrowing: ReadonlyMap = EMPTY_MAP, + sequenceOutputNarrowing: ReadonlyMap = EMPTY_MAP, passthrough?: PassthroughInfo ): FunctionInstance { const inputErrors: FunctionInputAssignmentError[] = []; @@ -689,11 +690,12 @@ export class FunctionInstance { type = item.default; } - if (item.type === 'Output') { - const narrowing = outputNarrowing.get(item.output.id); - if (narrowing) { - type = intersect(narrowing, type); - } + const narrowing = + item.type === 'Output' + ? outputNarrowing.get(item.output.id) + : sequenceOutputNarrowing.get(item.iterOutput.id); + if (narrowing) { + type = intersect(narrowing, type); } if (type.type === 'never') { @@ -705,6 +707,7 @@ export class FunctionInstance { if (item.type === 'Output') { outputs.set(item.output.id, type); + scope.assignParameter(getOutputParamName(item.output.id), type); } else { for (const id of item.iterOutput.outputs) { outputLengths.set(id, type); diff --git a/src/main/cli/run.ts b/src/main/cli/run.ts index 4b77a05dde..f33789159c 100644 --- a/src/main/cli/run.ts +++ b/src/main/cli/run.ts @@ -16,7 +16,7 @@ import { SchemaMap } from '../../common/SchemaMap'; import { ChainnerSettings } from '../../common/settings/settings'; import { FunctionDefinition } from '../../common/types/function'; import { ProgressController, ProgressMonitor, ProgressToken } from '../../common/ui/progress'; -import { assertNever, delay } from '../../common/util'; +import { EMPTY_MAP, assertNever, delay } from '../../common/util'; import { RunArguments } from '../arguments'; import { BackendProcess } from '../backend/process'; import { setupBackend } from '../backend/setup'; @@ -143,7 +143,14 @@ const ensureStaticCorrectness = ( } const byId = new Map(nodes.map((n) => [n.id, n])); - const typeState = TypeState.create(byId, edges, new Map(), functionDefinitions, passthrough); + const typeState = TypeState.create( + byId, + edges, + EMPTY_MAP, + EMPTY_MAP, + functionDefinitions, + passthrough + ); const chainLineage = new ChainLineage(schemata, nodes, edges); const invalidNodes = nodes.flatMap((node) => { diff --git a/src/renderer/components/NodeDocumentation/NodeExample.tsx b/src/renderer/components/NodeDocumentation/NodeExample.tsx index 0be7a3a336..906e36a700 100644 --- a/src/renderer/components/NodeDocumentation/NodeExample.tsx +++ b/src/renderer/components/NodeDocumentation/NodeExample.tsx @@ -119,6 +119,7 @@ export const NodeExample = memo(({ selectedSchema }: NodeExampleProps) => { new Map([[nodeId, node]]), EMPTY_ARRAY, EMPTY_MAP, + EMPTY_MAP, functionDefinitions, PassthroughMap.EMPTY ); diff --git a/src/renderer/components/node/NodeOutputs.tsx b/src/renderer/components/node/NodeOutputs.tsx index 63de6ad3bb..ef168c95ff 100644 --- a/src/renderer/components/node/NodeOutputs.tsx +++ b/src/renderer/components/node/NodeOutputs.tsx @@ -58,7 +58,7 @@ export const NodeOutputs = memo(({ nodeState, animated }: NodeOutputProps) => { } = nodeState; const { functionDefinitions } = useContext(BackendContext); - const { setManualOutputType } = useContext(GlobalContext); + const { setManualOutputType, setManualSequenceOutputType } = useContext(GlobalContext); const outputDataEntry = useContextSelector(GlobalVolatileContext, (c) => c.outputDataMap.get(id) ); @@ -80,6 +80,7 @@ export const NodeOutputs = memo(({ nodeState, animated }: NodeOutputProps) => { ); const currentTypes = stale ? undefined : outputDataEntry?.types; + const currentSequenceTypes = stale ? undefined : outputDataEntry?.sequenceTypes; const { isAutomatic } = useAutomaticFeatures(id, schemaId); @@ -89,8 +90,20 @@ export const NodeOutputs = memo(({ nodeState, animated }: NodeOutputProps) => { const type = evalExpression(currentTypes?.[output.id]); setManualOutputType(id, output.id, type); } + for (const iterOutput of schema.iteratorOutputs) { + const type = evalExpression(currentSequenceTypes?.[iterOutput.id]); + setManualSequenceOutputType(id, iterOutput.id, type); + } } - }, [id, currentTypes, schema, setManualOutputType, isAutomatic]); + }, [ + id, + currentTypes, + currentSequenceTypes, + schema, + setManualOutputType, + setManualSequenceOutputType, + isAutomatic, + ]); const isCollapsed = useIsCollapsedNode(); if (isCollapsed) { diff --git a/src/renderer/contexts/ExecutionContext.tsx b/src/renderer/contexts/ExecutionContext.tsx index 5a665578a7..828e522be9 100644 --- a/src/renderer/contexts/ExecutionContext.tsx +++ b/src/renderer/contexts/ExecutionContext.tsx @@ -167,7 +167,7 @@ export const ExecutionProvider = memo(({ children }: React.PropsWithChildren<{}> outputDataActions, getInputHash, setManualOutputType, - clearManualOutputTypes, + clearManualTypes, } = useContext(GlobalContext); const { schemata, @@ -262,6 +262,7 @@ export const ExecutionProvider = memo(({ children }: React.PropsWithChildren<{}> let broadcastData; let types; let progress; + let sequenceTypes; for (const { type, data } of events) { if (type === 'node-start') { @@ -272,6 +273,7 @@ export const ExecutionProvider = memo(({ children }: React.PropsWithChildren<{}> } else if (type === 'node-broadcast') { broadcastData = data.data; types = data.types; + sequenceTypes = data.sequenceTypes ?? undefined; } else { progress = data; } @@ -281,14 +283,26 @@ export const ExecutionProvider = memo(({ children }: React.PropsWithChildren<{}> setNodeStatus(executionStatus, [nodeId]); } - if (executionTime !== undefined || broadcastData !== undefined || types !== undefined) { + if ( + executionTime !== undefined || + broadcastData !== undefined || + types !== undefined || + sequenceTypes !== undefined + ) { // TODO: This is incorrect. The inputs of the node might have changed since // the chain started running. However, sending the then current input hashes // of the chain to the backend along with the rest of its data and then making // the backend send us those hashes is incorrect too because of iterators, I // think. const inputHash = getInputHash(nodeId); - outputDataActions.set(nodeId, executionTime, inputHash, broadcastData, types); + outputDataActions.set( + nodeId, + executionTime, + inputHash, + broadcastData, + types, + sequenceTypes + ); } if (progress) { @@ -492,7 +506,7 @@ export const ExecutionProvider = memo(({ children }: React.PropsWithChildren<{}> nodeEventBacklog.processAll(); clearNodeStatusMap(); setStatus(ExecutionStatus.READY); - clearManualOutputTypes(iteratorNodeIds); + clearManualTypes(iteratorNodeIds); } }, [ getNodes, @@ -509,7 +523,7 @@ export const ExecutionProvider = memo(({ children }: React.PropsWithChildren<{}> packageSettings, clearNodeStatusMap, nodeEventBacklog, - clearManualOutputTypes, + clearManualTypes, ]); const resume = useCallback(async () => { diff --git a/src/renderer/contexts/GlobalNodeState.tsx b/src/renderer/contexts/GlobalNodeState.tsx index 6a538e587d..42230dd74d 100644 --- a/src/renderer/contexts/GlobalNodeState.tsx +++ b/src/renderer/contexts/GlobalNodeState.tsx @@ -1,4 +1,4 @@ -import { Expression, Type, evaluate } from '@chainner/navi'; +import { Expression } from '@chainner/navi'; import { dirname, parse } from 'path'; import React, { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react'; import { @@ -18,6 +18,7 @@ import { InputId, InputKind, InputValue, + IterOutputId, Mutable, NodeData, OutputId, @@ -83,6 +84,7 @@ import { } from '../hooks/useOutputDataStore'; import { getSessionStorageOrDefault, useSessionStorage } from '../hooks/useSessionStorage'; import { useSettings } from '../hooks/useSettings'; +import { useTypeMap } from '../hooks/useTypeMap'; import { ipcRenderer } from '../safeIpc'; import { AlertBoxContext, AlertType } from './AlertBoxContext'; import { BackendContext } from './BackendContext'; @@ -138,7 +140,12 @@ interface Global { exportViewportScreenshot: () => void; exportViewportScreenshotToClipboard: () => void; setManualOutputType: (nodeId: string, outputId: OutputId, type: Expression | undefined) => void; - clearManualOutputTypes: (nodes: Iterable) => void; + setManualSequenceOutputType: ( + nodeId: string, + iterOutputId: IterOutputId, + type: Expression | undefined + ) => void; + clearManualTypes: (nodes: Iterable) => void; typeStateRef: Readonly>; chainLineageRef: Readonly>; outputDataActions: OutputDataActions; @@ -203,52 +210,22 @@ export const GlobalProvider = memo( [addEdgeChanges] ); - const [manualOutputTypes, setManualOutputTypes] = useState(() => ({ - map: new Map>(), - })); - const setManualOutputType = useCallback( - (nodeId: string, outputId: OutputId, expression: Expression | undefined): void => { - const getType = () => { - if (expression === undefined) { - return undefined; - } - - try { - return evaluate(expression, scope); - } catch (error) { - log.error(error); - return undefined; - } - }; - - setManualOutputTypes(({ map }) => { - let inner = map.get(nodeId); - const type = getType(); - if (type) { - if (!inner) { - inner = new Map(); - map.set(nodeId, inner); - } - - inner.set(outputId, type); - } else { - inner?.delete(outputId); - } - return { map }; - }); - }, - [setManualOutputTypes, scope] - ); - const clearManualOutputTypes = useCallback( - (nodes: Iterable): void => { - setManualOutputTypes(({ map }) => { - for (const nodeId of nodes) { - map.delete(nodeId); - } - return { map }; - }); + const [manualOutputTypes, setManualOutputType, clearManualOutputTypes] = useTypeMap< + string, + OutputId + >(scope); + const [ + manualSequenceOutputTypes, + setManualSequenceOutputType, + clearManualSequenceOutputTypes, + ] = useTypeMap(scope); + + const clearManualTypes = useCallback( + (nodes: Iterable) => { + clearManualOutputTypes(nodes); + clearManualSequenceOutputTypes(nodes); }, - [setManualOutputTypes] + [clearManualOutputTypes, clearManualSequenceOutputTypes] ); const [typeState, setTypeState] = useState(TypeState.empty); @@ -272,6 +249,7 @@ export const GlobalProvider = memo( nodeMap, getEdges(), manualOutputTypes.map, + manualSequenceOutputTypes.map, functionDefinitions, passthrough, typeStateRef.current @@ -288,6 +266,7 @@ export const GlobalProvider = memo( nodeChanges, edgeChanges, manualOutputTypes, + manualSequenceOutputTypes, functionDefinitions, schemata, passthrough, @@ -1321,7 +1300,8 @@ export const GlobalProvider = memo( exportViewportScreenshot, exportViewportScreenshotToClipboard, setManualOutputType, - clearManualOutputTypes, + setManualSequenceOutputType, + clearManualTypes, typeStateRef, chainLineageRef, outputDataActions, diff --git a/src/renderer/hooks/useAutomaticFeatures.ts b/src/renderer/hooks/useAutomaticFeatures.ts index bda4d31f76..6fbb53b464 100644 --- a/src/renderer/hooks/useAutomaticFeatures.ts +++ b/src/renderer/hooks/useAutomaticFeatures.ts @@ -17,16 +17,13 @@ export const useAutomaticFeatures = (id: string, schemaId: SchemaId) => { const hasIncomingConnections = thisNode && getIncomers(thisNode, getNodes(), getEdges()).length > 0; - // If the node is a generator, it should not use automatic features - const isGenerator = schema.kind === 'generator'; // Same if it has any static input values const hasStaticValueInput = schema.inputs.some((i) => i.kind === 'static'); // We should only use automatic features if the node has side effects const { hasSideEffects } = schema; return { - isAutomatic: - hasSideEffects && !hasIncomingConnections && !isGenerator && !hasStaticValueInput, + isAutomatic: hasSideEffects && !hasIncomingConnections && !hasStaticValueInput, hasIncomingConnections, }; }; diff --git a/src/renderer/hooks/useOutputDataStore.ts b/src/renderer/hooks/useOutputDataStore.ts index 81b95cb2aa..f55ac9d3b1 100644 --- a/src/renderer/hooks/useOutputDataStore.ts +++ b/src/renderer/hooks/useOutputDataStore.ts @@ -1,6 +1,6 @@ import isDeepEqual from 'fast-deep-equal/react'; import { useCallback, useState } from 'react'; -import { OutputData, OutputTypes } from '../../common/common-types'; +import { IterOutputTypes, OutputData, OutputTypes } from '../../common/common-types'; import { EMPTY_MAP } from '../../common/util'; import { useMemoObject } from './useMemo'; @@ -9,6 +9,7 @@ export interface OutputDataEntry { lastExecutionTime: number | undefined; data: OutputData | undefined; types: OutputTypes | undefined; + sequenceTypes: IterOutputTypes | undefined; } export interface OutputDataActions { @@ -17,7 +18,8 @@ export interface OutputDataActions { executionTime: number | undefined, nodeInputHash: string, data: OutputData | undefined, - types: OutputTypes | undefined + types: OutputTypes | undefined, + sequenceTypes: IterOutputTypes | undefined ): void; delete(nodeId: string): void; clear(): void; @@ -28,7 +30,7 @@ export const useOutputDataStore = () => { const actions: OutputDataActions = { set: useCallback( - (nodeId, executionTime, inputHash, data, types) => { + (nodeId, executionTime, inputHash, data, types, sequenceTypes) => { setMap((prev) => { const existingEntry = prev.get(nodeId); @@ -36,6 +38,7 @@ export const useOutputDataStore = () => { const entry: OutputDataEntry = { data: useExisting ? existingEntry.data : data, types: useExisting ? existingEntry.types : types, + sequenceTypes: useExisting ? existingEntry.sequenceTypes : sequenceTypes, inputHash: useExisting ? existingEntry.inputHash : inputHash, lastExecutionTime: executionTime ?? existingEntry?.lastExecutionTime, }; diff --git a/src/renderer/hooks/useTypeMap.ts b/src/renderer/hooks/useTypeMap.ts new file mode 100644 index 0000000000..8cd6cec2ca --- /dev/null +++ b/src/renderer/hooks/useTypeMap.ts @@ -0,0 +1,60 @@ +import { Expression, Scope, Type, evaluate } from '@chainner/navi'; +import { useCallback, useState } from 'react'; +import { log } from '../../common/log'; + +/** + * A map of types that can be used as either a ref-like object or a state-like value. + */ +export const useTypeMap = (scope: Scope) => { + const [types, setTypes] = useState(() => ({ + map: new Map>(), + })); + + const setType = useCallback( + (nodeId: N, outputId: I, expression: Expression | undefined): void => { + const getType = () => { + if (expression === undefined) { + return undefined; + } + + try { + return evaluate(expression, scope); + } catch (error) { + log.error(error); + return undefined; + } + }; + + setTypes(({ map }) => { + let inner = map.get(nodeId); + const type = getType(); + if (type) { + if (!inner) { + inner = new Map(); + map.set(nodeId, inner); + } + + inner.set(outputId, type); + } else { + inner?.delete(outputId); + } + return { map }; + }); + }, + [setTypes, scope] + ); + + const clear = useCallback( + (nodes: Iterable): void => { + setTypes(({ map }) => { + for (const nodeId of nodes) { + map.delete(nodeId); + } + return { map }; + }); + }, + [setTypes] + ); + + return [types, setType, clear] as const; +};