Skip to content

Commit

Permalink
Add broadcasts for sequences (#2971)
Browse files Browse the repository at this point in the history
* start of allowing generators to broadcast length data

* allow output lengths to be set

* Fixed typing and broadcasting issue for sequence types (#2972)

* Fixed typing and broadcasting issue for sequence types

* Fixed errors

* Sequence item types for Load Video (#2974)

* Sequence item types for Load Video

* Ignore type error

* Don't cache individually run generators and delete any caches on new run

---------

Co-authored-by: Michael Schmidt <mitchi5000.ms@googlemail.com>
  • Loading branch information
joeyballentine and RunDevelopment authored Jul 15, 2024
1 parent fb7b8da commit a2135a0
Show file tree
Hide file tree
Showing 21 changed files with 436 additions and 200 deletions.
1 change: 1 addition & 0 deletions backend/src/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .api import *
from .group import *
from .input import *
from .iter import *
from .lazy import *
from .node_context import *
from .node_data import *
Expand Down
64 changes: 0 additions & 64 deletions backend/src/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
Any,
Awaitable,
Callable,
Generic,
Iterable,
TypeVar,
)
Expand Down Expand Up @@ -504,66 +503,3 @@ def add_package(
dependencies=dependencies or [],
)
)


I = TypeVar("I")
L = TypeVar("L")


@dataclass
class Generator(Generic[I]):
supplier: Callable[[], Iterable[I | Exception]]
expected_length: int
fail_fast: bool = True

def with_fail_fast(self, fail_fast: bool) -> Generator[I]:
return Generator(self.supplier, self.expected_length, fail_fast=fail_fast)

@staticmethod
def from_iter(
supplier: Callable[[], Iterable[I | Exception]], expected_length: int
) -> Generator[I]:
return Generator(supplier, expected_length)

@staticmethod
def from_list(l: list[L], map_fn: Callable[[L, int], I]) -> Generator[I]:
"""
Creates a new generator from a list that is mapped using the given
function. The iterable will be equivalent to `map(map_fn, l)`.
"""

def supplier():
for i, x in enumerate(l):
try:
yield map_fn(x, i)
except Exception as e:
yield e

return Generator(supplier, len(l))

@staticmethod
def from_range(count: int, map_fn: Callable[[int], I]) -> Generator[I]:
"""
Creates a new generator the given number of items where each item is
lazily evaluated. The iterable will be equivalent to `map(map_fn, range(count))`.
"""
assert count >= 0

def supplier():
for i in range(count):
try:
yield map_fn(i)
except Exception as e:
yield e

return Generator(supplier, count)


N = TypeVar("N")
R = TypeVar("R")


@dataclass
class Collector(Generic[N, R]):
on_iterate: Callable[[N], None]
on_complete: Callable[[], R]
72 changes: 72 additions & 0 deletions backend/src/api/iter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Callable, Generic, Iterable, TypeVar

I = TypeVar("I")
L = TypeVar("L")


@dataclass
class Generator(Generic[I]):
supplier: Callable[[], Iterable[I | Exception]]
expected_length: int
fail_fast: bool = True
metadata: object | None = None

def with_fail_fast(self, fail_fast: bool):
self.fail_fast = fail_fast
return self

def with_metadata(self, metadata: object):
self.metadata = metadata
return self

@staticmethod
def from_iter(
supplier: Callable[[], Iterable[I | Exception]], expected_length: int
) -> Generator[I]:
return Generator(supplier, expected_length)

@staticmethod
def from_list(l: list[L], map_fn: Callable[[L, int], I]) -> Generator[I]:
"""
Creates a new generator from a list that is mapped using the given
function. The iterable will be equivalent to `map(map_fn, l)`.
"""

def supplier():
for i, x in enumerate(l):
try:
yield map_fn(x, i)
except Exception as e:
yield e

return Generator(supplier, len(l))

@staticmethod
def from_range(count: int, map_fn: Callable[[int], I]) -> Generator[I]:
"""
Creates a new generator the given number of items where each item is
lazily evaluated. The iterable will be equivalent to `map(map_fn, range(count))`.
"""
assert count >= 0

def supplier():
for i in range(count):
try:
yield map_fn(i)
except Exception as e:
yield e

return Generator(supplier, count)


N = TypeVar("N")
R = TypeVar("R")


@dataclass
class Collector(Generic[N, R]):
on_iterate: Callable[[N], None]
on_complete: Callable[[], R]
36 changes: 35 additions & 1 deletion backend/src/api/node_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

from dataclasses import dataclass
from enum import Enum
from typing import Any, Mapping
from typing import Any, Callable, Generic, Mapping, Protocol, TypeVar

import navi

from .group import NestedIdGroup
from .input import BaseInput
from .iter import Generator
from .output import BaseOutput
from .types import (
FeatureId,
Expand Down Expand Up @@ -42,6 +43,13 @@ def to_dict(self):
}


M_co = TypeVar("M_co", covariant=True)


class AnyConstructor(Protocol, Generic[M_co]):
def __call__(self, *args: Any, **kwargs: Any) -> M_co: ...


class IteratorOutputInfo:
def __init__(
self,
Expand All @@ -56,13 +64,39 @@ def __init__(
)
self.length_type: navi.ExpressionJson = length_type

self._metadata_constructor: Any | None = None
self._item_types_fn: (
Callable[[Any], Mapping[OutputId, navi.ExpressionJson]] | None
) = None

def with_item_types(
self,
class_: AnyConstructor[M_co],
fn: Callable[[M_co], Mapping[OutputId, navi.ExpressionJson]],
):
self._metadata_constructor = class_
self._item_types_fn = fn
return self

def to_dict(self):
return {
"id": self.id,
"outputs": self.outputs,
"sequenceType": navi.named("Sequence", {"length": self.length_type}),
}

def get_broadcast_sequence_type(self, generator: Generator) -> navi.ExpressionJson:
return navi.named("Sequence", {"length": generator.expected_length})

def get_broadcast_item_types(
self, generator: Generator
) -> Mapping[OutputId, navi.ExpressionJson]:
if self._item_types_fn is not None and self._metadata_constructor is not None:
metadata = generator.metadata
if isinstance(metadata, self._metadata_constructor):
return self._item_types_fn(metadata)
return {}


class KeyInfo:
def __init__(self, data: dict[str, Any]) -> None:
Expand Down
8 changes: 5 additions & 3 deletions backend/src/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from dataclasses import dataclass
from typing import Dict, Literal, TypedDict, Union

from api import ErrorValue, InputId, NodeId, OutputId
import navi
from api import BroadcastData, ErrorValue, InputId, IterOutputId, NodeId, OutputId

# General events

Expand Down Expand Up @@ -87,8 +88,9 @@ class NodeProgressUpdateEvent(TypedDict):

class NodeBroadcastData(TypedDict):
nodeId: NodeId
data: dict[OutputId, object]
types: dict[OutputId, object]
data: dict[OutputId, BroadcastData | None]
types: dict[OutputId, navi.ExpressionJson | None]
sequenceTypes: dict[IterOutputId, navi.ExpressionJson] | None


class NodeBroadcastEvent(TypedDict):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,14 @@ def list_glob(directory: Path, globexpr: str, ext_filter: list[str]) -> list[Pat
DirectoryOutput("Directory", output_type="Input0"),
TextOutput("Subdirectory Path"),
TextOutput("Name"),
NumberOutput(
"Index",
output_type="if Input4 { min(uint, Input5 - 1) } else { uint }",
),
NumberOutput("Index", output_type="min(uint, max(0, IterOutput0.length - 1))"),
],
iterator_outputs=IteratorOutputInfo(
outputs=[0, 2, 3, 4],
length_type="if Input4 { min(uint, Input5) } else { uint }",
),
kind="generator",
side_effects=True,
)
def load_images_node(
directory: Path,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

import numpy as np

from api import Generator, IteratorOutputInfo, NodeContext
import navi
from api import Generator, IteratorOutputInfo, NodeContext, OutputId
from nodes.groups import Condition, if_group
from nodes.impl.ffmpeg import FFMpegEnv
from nodes.impl.video import VideoLoader
from nodes.impl.video import VideoLoader, VideoMetadata
from nodes.properties.inputs import BoolInput, NumberInput, VideoFileInput
from nodes.properties.outputs import (
AudioStreamOutput,
Expand All @@ -22,6 +23,14 @@
from .. import video_frames_group


def get_item_types(metadata: VideoMetadata):
return {
OutputId(0): navi.Image(
width=metadata.width, height=metadata.height, channels=3
),
}


@video_frames_group.register(
schema_id="chainner:image:load_video",
name="Load Video",
Expand All @@ -46,8 +55,8 @@
outputs=[
ImageOutput("Frame", channels=3),
NumberOutput(
"Frame Index",
output_type="if Input1 { min(uint, Input2 - 1) } else { uint }",
"Index",
output_type="min(uint, max(0, IterOutput0.length - 1))",
).with_docs("A counter that starts at 0 and increments by 1 for each frame."),
DirectoryOutput("Video Directory", of_input=0),
FileNameOutput("Name", of_input=0),
Expand All @@ -56,8 +65,9 @@
],
iterator_outputs=IteratorOutputInfo(
outputs=[0, 1], length_type="if Input1 { min(uint, Input2) } else { uint }"
),
).with_item_types(VideoMetadata, get_item_types), # type: ignore
node_context=True,
side_effects=True,
kind="generator",
)
def load_video_node(
Expand All @@ -83,7 +93,9 @@ def iterator():
break

return (
Generator.from_iter(supplier=iterator, expected_length=frame_count),
Generator.from_iter(
supplier=iterator, expected_length=frame_count
).with_metadata(loader.metadata),
video_dir,
video_name,
loader.metadata.fps,
Expand Down
Loading

0 comments on commit a2135a0

Please sign in to comment.