From f56c1e0c485a8db14831f972ae03172d2ba4608a Mon Sep 17 00:00:00 2001 From: Rhys Yang Date: Fri, 24 Nov 2023 00:07:20 +0800 Subject: [PATCH] Add `Batch Loader` node #13 --- js/loaderDisplay.js | 36 ++++++++++++++++ nodes.py | 103 +++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 132 insertions(+), 7 deletions(-) create mode 100644 js/loaderDisplay.js diff --git a/js/loaderDisplay.js b/js/loaderDisplay.js new file mode 100644 index 0000000..fe77f22 --- /dev/null +++ b/js/loaderDisplay.js @@ -0,0 +1,36 @@ +import {app} from "../../scripts/app.js"; +import {ComfyWidgets} from "../../scripts/widgets.js"; + +// Create a read-only string widget with opacity set +function createWidget(app, node, widgetName) { + const widget = ComfyWidgets["STRING"](node, widgetName, ["STRING", {multiline: true}], app).widget; + widget.inputEl.readOnly = true; + widget.inputEl.style.opacity = 0.7; + return widget; +} + +// Displays file list on the node +app.registerExtension({ + name: "sd_prompt_reader.loaderDisplay", + async beforeRegisterNodeDef(nodeType, nodeData, app) { + if (nodeData.name === "SDBatchLoader") { + const onNodeCreated = nodeType.prototype.onNodeCreated; + + nodeType.prototype.onNodeCreated = function () { + const result = onNodeCreated?.apply(this, arguments); + + // Create prompt and setting widgets + const fileList = createWidget(app, this, "fileList"); + return result; + }; + + // Update widgets + const onExecuted = nodeType.prototype.onExecuted; + nodeType.prototype.onExecuted = function (message) { + onExecuted?.apply(this, arguments); + this.widgets.find(obj => obj.name === "fileList").value = message.text[0]; + + }; + } + }, +}); \ No newline at end of file diff --git a/nodes.py b/nodes.py index 4f36c1b..f8531a4 100644 --- a/nodes.py +++ b/nodes.py @@ -53,18 +53,34 @@ def output_to_terminal(text: str): output_to_terminal("Core version: " + CORE_VERSION) +class AnyType(str): + """A special type that can be connected to any other types. Credit to pythongosssss""" + + def __ne__(self, __value: object) -> bool: + return False + + +any_type = AnyType("*") + + class SDPromptReader: + files = [] + @classmethod def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() - files = [ - f - for f in os.listdir(input_dir) - if os.path.isfile(os.path.join(input_dir, f)) - ] + SDPromptReader.files = sorted( + [ + f + for f in os.listdir(input_dir) + if os.path.isfile(os.path.join(input_dir, f)) + ] + ) return { "required": { - "image": (sorted(files), {"image_upload": True}), + "image": (SDPromptReader.files, {"image_upload": True}), + }, + "optional": { "parameter_index": ( "INT", {"default": 0, "min": 0, "max": 255, "step": 1}, @@ -104,7 +120,10 @@ def INPUT_TYPES(s): OUTPUT_NODE = True def load_image(self, image, parameter_index): - image_path = folder_paths.get_annotated_filepath(image) + if image in SDPromptReader.files: + image_path = folder_paths.get_annotated_filepath(image) + else: + image_path = image i = Image.open(image_path) i = ImageOps.exif_transpose(i) image = i.convert("RGB") @@ -778,12 +797,81 @@ def convert_string( ) +class SDBatchLoader: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "path": ("STRING", {"default": "./input/"}), + }, + "optional": { + "image_load_limit": ("INT", {"default": 0, "min": 0, "step": 1}), + "start_index": ("INT", {"default": 0, "min": 0, "step": 1}), + }, + } + + RETURN_TYPES = (any_type,) + + RETURN_NAMES = ("IMAGE",) + OUTPUT_IS_LIST = (True,) + FUNCTION = "load_path" + CATEGORY = "SD Prompt Reader" + + def load_path( + self, + path: str = "./input/", + image_load_limit: int = 0, + start_index: int = 0, + ): + if not Path(path).is_dir(): + raise FileNotFoundError(f"Invalid directory: {path}") + + files = list( + filter(lambda file: file.suffix in SUPPORTED_FORMATS, Path(path).iterdir()) + ) + + files = ( + sorted(files)[start_index : start_index + image_load_limit] + if image_load_limit > 0 + else sorted(files)[start_index:] + ) + + files_str = list(map(str, files)) + return { + "ui": { + "text": ("\n".join(files_str),), + }, + "result": (files_str,), + } + + @classmethod + def IS_CHANGED( + s, + path, + image_load_limit, + start_index, + ): + return os.listdir(path) + + @classmethod + def VALIDATE_INPUTS( + s, + path, + image_load_limit, + start_index, + ): + if not Path(path).is_dir(): + return f"Invalid directory: {path}" + return True + + NODE_CLASS_MAPPINGS = { "SDPromptReader": SDPromptReader, "SDPromptSaver": SDPromptSaver, "SDParameterGenerator": SDParameterGenerator, "SDPromptMerger": SDPromptMerger, "SDTypeConverter": SDTypeConverter, + "SDBatchLoader": SDBatchLoader, } NODE_DISPLAY_NAME_MAPPINGS = { @@ -792,4 +880,5 @@ def convert_string( "SDParameterGenerator": "SD Parameter Generator", "SDPromptMerger": "SD Prompt Merger", "SDTypeConverter": "SD Type Converter", + "SDBatchLoader": "SD Batch Loader", }