Skip to content

Commit

Permalink
Remove code pertaining to old iterators (#2267)
Browse files Browse the repository at this point in the history
* No more iterator jank

* remove some more stuff

* Some more things i missed

* update tests

* lint

* Remove Z indexes

* update tests

---------

Co-authored-by: Michael Schmidt <mitchi5000.ms@googlemail.com>
  • Loading branch information
joeyballentine and RunDevelopment authored Oct 16, 2023
1 parent cfd0e14 commit 91eae95
Show file tree
Hide file tree
Showing 48 changed files with 190 additions and 4,014 deletions.
8 changes: 2 additions & 6 deletions backend/src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,11 @@ def inner_wrapper(wrapped_func: T) -> T:

run_check(
TYPE_CHECK_LEVEL,
lambda _: check_schema_types(
wrapped_func, node_type, p_inputs, p_outputs
),
lambda _: check_schema_types(wrapped_func, p_inputs, p_outputs),
)
run_check(
NAME_CHECK_LEVEL,
lambda fix: check_naming_conventions(
wrapped_func, node_type, name, fix
),
lambda fix: check_naming_conventions(wrapped_func, name, fix),
)

if decorators is not None:
Expand Down
12 changes: 2 additions & 10 deletions backend/src/chain/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,8 @@ def get_cache_strategies(chain: Chain) -> Dict[NodeId, CacheStrategy]:

for node in chain.nodes.values():
out_edges = chain.edges_from(node.id)
connected_to_child_node = any(
chain.nodes[e.target.id].parent for e in out_edges
)

strategy: CacheStrategy
if node.parent is None and connected_to_child_node:
# free nodes that are connected to child nodes need to live as the execution
strategy = StaticCaching
else:
strategy = CacheStrategy(len(out_edges))

strategy: CacheStrategy = CacheStrategy(len(out_edges))

result[node.id] = strategy

Expand Down
35 changes: 1 addition & 34 deletions backend/src/chain/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ class FunctionNode:
def __init__(self, node_id: NodeId, schema_id: str):
self.id: NodeId = node_id
self.schema_id: str = schema_id
self.parent: Union[NodeId, None] = None
self.is_helper: bool = False

def get_node(self) -> NodeData:
return registry.get_node(self.schema_id)
Expand All @@ -29,21 +27,6 @@ def has_side_effects(self) -> bool:
return self.get_node().side_effects


class IteratorNode:
def __init__(self, node_id: NodeId, schema_id: str):
self.id: NodeId = node_id
self.schema_id: str = schema_id
self.parent: None = None
self.__node = None

def get_node(self) -> NodeData:
if self.__node is None:
node = registry.get_node(self.schema_id)
assert node.type == "iterator", "Invalid iterator node"
self.__node = node
return self.__node


class NewIteratorNode:
def __init__(self, node_id: NodeId, schema_id: str):
self.id: NodeId = node_id
Expand Down Expand Up @@ -82,7 +65,7 @@ def has_side_effects(self) -> bool:
return self.get_node().side_effects


Node = Union[FunctionNode, IteratorNode, NewIteratorNode, CollectorNode]
Node = Union[FunctionNode, NewIteratorNode, CollectorNode]


class EdgeSource:
Expand Down Expand Up @@ -137,19 +120,3 @@ def remove_node(self, node_id: NodeId):
self.__edges_by_target[e.target.id].remove(e)
for e in self.__edges_by_target.pop(node_id, []):
self.__edges_by_source[e.source.id].remove(e)

if isinstance(node, IteratorNode):
# remove all child nodes
for n in list(self.nodes.values()):
if n.parent == node_id:
self.remove_node(n.id)


class SubChain:
def __init__(self, chain: Chain, iterator_id: NodeId):
self.nodes: Dict[NodeId, FunctionNode] = {}
self.iterator_id = iterator_id

for node in chain.nodes.values():
if node.parent is not None and node.parent == iterator_id:
self.nodes[node.id] = node
7 changes: 1 addition & 6 deletions backend/src/chain/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
EdgeSource,
EdgeTarget,
FunctionNode,
IteratorNode,
NewIteratorNode,
)
from .input import EdgeInput, Input, InputMap, ValueInput
Expand Down Expand Up @@ -54,16 +53,12 @@ def parse_json(json: List[JsonNode]) -> Tuple[Chain, InputMap]:
index_edges: List[IndexEdge] = []

for json_node in json:
if json_node["nodeType"] == "iterator":
node = IteratorNode(json_node["id"], json_node["schemaId"])
elif json_node["nodeType"] == "newIterator":
if json_node["nodeType"] == "newIterator":
node = NewIteratorNode(json_node["id"], json_node["schemaId"])
elif json_node["nodeType"] == "collector":
node = CollectorNode(json_node["id"], json_node["schemaId"])
else:
node = FunctionNode(json_node["id"], json_node["schemaId"])
node.parent = json_node["parent"]
node.is_helper = json_node["nodeType"] == "iteratorHelper"
chain.add_node(node)

inputs: List[Input] = []
Expand Down
35 changes: 2 additions & 33 deletions backend/src/chain/optimize.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,12 @@
from sanic.log import logger

from .chain import Chain, EdgeSource, IteratorNode, Node
from .chain import Chain, Node


def __has_side_effects(node: Node) -> bool:
if isinstance(node, IteratorNode) or node.is_helper:
# we assume that both iterators and their helper nodes always have side effects
return True
return node.has_side_effects()


def __outline_child_nodes(chain: Chain) -> bool:
"""
If a child node of an iterator is not downstream of any iterator helper node,
then this child node can be lifted out of the iterator (outlined) to be a free node.
"""
changed = False

for node in chain.nodes.values():
# we try to outline child nodes that are not iterator helper nodes
if node.parent is not None and not node.is_helper:

def has_no_parent(source: EdgeSource) -> bool:
n = chain.nodes.get(source.id)
assert n is not None
return n.parent is None

# we can only outline if all of its inputs are independent of the iterator
can_outline = all(has_no_parent(n.source) for n in chain.edges_to(node.id))
if can_outline:
node.parent = None
changed = True
logger.debug(
f"Chain optimization: Outlined {node.schema_id} node {node.id}"
)

return changed


def __removed_dead_nodes(chain: Chain) -> bool:
"""
If a node does not have side effects and has no downstream nodes, then it can be removed.
Expand All @@ -57,4 +26,4 @@ def __removed_dead_nodes(chain: Chain) -> bool:
def optimize(chain: Chain):
changed = True
while changed:
changed = __removed_dead_nodes(chain) or __outline_child_nodes(chain)
changed = __removed_dead_nodes(chain)
4 changes: 1 addition & 3 deletions backend/src/custom_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

RunFn = Callable[..., Any]

NodeType = Literal[
"regularNode", "iterator", "iteratorHelper", "newIterator", "collector"
]
NodeType = Literal["regularNode", "newIterator", "collector"]

UpdateProgressFn = Callable[[str, float, Union[float, None]], Awaitable[None]]
17 changes: 1 addition & 16 deletions backend/src/events.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from typing import Dict, List, Literal, Optional, TypedDict, Union
from typing import Dict, Literal, Optional, TypedDict, Union

from base_types import InputId, NodeId, OutputId
from nodes.base_input import ErrorValue
Expand Down Expand Up @@ -36,15 +36,6 @@ class NodeStartData(TypedDict):
nodeId: NodeId


class IteratorProgressUpdateData(TypedDict):
percent: float
index: int
total: int
eta: float
iteratorId: NodeId
running: Optional[List[NodeId]]


class NodeProgressUpdateData(TypedDict):
percent: float
index: int
Expand Down Expand Up @@ -79,11 +70,6 @@ class NodeStartEvent(TypedDict):
data: NodeStartData


class IteratorProgressUpdateEvent(TypedDict):
event: Literal["iterator-progress-update"]
data: IteratorProgressUpdateData


class NodeProgressUpdateEvent(TypedDict):
event: Literal["node-progress-update"]
data: NodeProgressUpdateData
Expand All @@ -104,7 +90,6 @@ class BackendStateEvent(TypedDict):
ExecutionErrorEvent,
NodeFinishEvent,
NodeStartEvent,
IteratorProgressUpdateEvent,
NodeProgressUpdateEvent,
BackendStatusEvent,
BackendStateEvent,
Expand Down
22 changes: 1 addition & 21 deletions backend/src/node_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from enum import Enum
from typing import Any, Callable, Dict, List, NewType, Set, Union, cast, get_args

from custom_types import NodeType
from nodes.base_input import BaseInput
from nodes.base_output import BaseOutput

Expand Down Expand Up @@ -174,7 +173,6 @@ def validate_return_type(return_type: _Ty, outputs: list[BaseOutput]):

def check_schema_types(
wrapped_func: Callable,
node_type: NodeType,
inputs: list[BaseInput],
outputs: list[BaseOutput],
):
Expand All @@ -195,19 +193,6 @@ def check_schema_types(
if not arg in ann:
raise CheckFailedError(f"Missing type annotation for '{arg}'")

if node_type == "iteratorHelper":
# iterator helpers have inputs that do not describe the arguments of the function, so we can't check them
return

if node_type == "iterator":
# the last argument of an iterator is the iterator context, so we have to account for that
context = [*ann.keys()][-1]
context_type = ann.pop(context)
if str(context_type) != "<class 'process.IteratorContext'>":
raise CheckFailedError(
f"Last argument of an iterator must be an IteratorContext, not '{context_type}'"
)

if arg_spec.varargs is not None:
if not arg_spec.varargs in ann:
raise CheckFailedError(f"Missing type annotation for '{arg_spec.varargs}'")
Expand Down Expand Up @@ -255,23 +240,18 @@ def check_schema_types(

def check_naming_conventions(
wrapped_func: Callable,
node_type: NodeType,
name: str,
fix: bool,
):
expected_name = (
name.lower()
.replace(" (iterator)", "")
.replace(" ", "_")
.replace("-", "_")
.replace("(", "")
.replace(")", "")
.replace("&", "and")
)

if node_type == "iteratorHelper":
expected_name = "iterator_helper_" + expected_name

func_name = wrapped_func.__name__
file_path = pathlib.Path(inspect.getfile(wrapped_func))
file_name = file_path.stem
Expand All @@ -289,7 +269,7 @@ def check_naming_conventions(
file_path.write_text(fixed_code, encoding="utf-8")

# check file name
if node_type != "iteratorHelper" and file_name != expected_name:
if file_name != expected_name:
if not fix:
raise CheckFailedError(
f"File name is '{file_name}.py', but it should be '{expected_name}.py'"
Expand Down
5 changes: 0 additions & 5 deletions backend/src/nodes/properties/inputs/generic_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,11 +408,6 @@ def make_optional(self):
raise ValueError("ColorInput cannot be made optional")


def IteratorInput():
"""Input for showing that an iterator automatically handles the input"""
return BaseInput("IteratorAuto", "Auto (Iterator)", has_handle=False)


class VideoContainer(Enum):
MKV = "mkv"
MP4 = "mp4"
Expand Down
2 changes: 1 addition & 1 deletion backend/src/packages/chaiNNer_ncnn/ncnn/io/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
FileNameOutput("Name", of_input=0).with_id(1),
],
see_also=[
"chainner:ncnn:model_file_iterator",
"chainner:ncnn:load_models",
],
)
def load_model_node(
Expand Down
2 changes: 1 addition & 1 deletion backend/src/packages/chaiNNer_onnx/onnx/io/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
FileNameOutput("Name", of_input=0).with_id(1),
],
see_also=[
"chainner:onnx:model_file_iterator",
"chainner:onnx:load_models",
],
)
def load_model_node(path: str) -> Tuple[OnnxModel, str, str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def parse_ckpt_state_dict(checkpoint: dict):
FileNameOutput("Name", of_input=0).with_id(1),
],
see_also=[
"chainner:pytorch:model_file_iterator",
"chainner:pytorch:load_models",
],
)
def load_model_node(path: str) -> Tuple[PyTorchModel, str, str]:
Expand Down
Loading

0 comments on commit 91eae95

Please sign in to comment.