From 97c6aa988dd5b100a8984476403e871475a67c08 Mon Sep 17 00:00:00 2001 From: Rhys Yang Date: Mon, 11 Dec 2023 19:55:03 +0800 Subject: [PATCH] Add `Parameter Extractor` node #24 --- __init__.py | 1 + js/extractorDisplay.js | 38 +++++++++++++++ nodes.py | 103 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 142 insertions(+) create mode 100644 js/extractorDisplay.js diff --git a/__init__.py b/__init__.py index 1733ba4..3647136 100644 --- a/__init__.py +++ b/__init__.py @@ -17,6 +17,7 @@ "parameterDisplay.js", "seedGen.js", "loaderDisplay.js", + "extractorDisplay.js", ] for file in files_to_copy: diff --git a/js/extractorDisplay.js b/js/extractorDisplay.js new file mode 100644 index 0000000..4e87693 --- /dev/null +++ b/js/extractorDisplay.js @@ -0,0 +1,38 @@ +import {app} from "../../scripts/app.js"; +import {ComfyWidgets} from "../../scripts/widgets.js"; + +// Create a read-only string widget +function createWidget(app, node, widgetName, type) { + const widget = ComfyWidgets[type](node, widgetName, ["STRING", {multiline: true}], app).widget; + widget.inputEl.readOnly = true; + widget.inputEl.style.textAlign = "center"; + widget.inputEl.style.fontSize = "0.75rem"; + return widget; +} + +app.registerExtension({ + name: "sd_prompt_reader.extractorDisplay", + async beforeRegisterNodeDef(nodeType, nodeData, app) { + if (nodeData.name === "SDParameterExtractor") { + const onNodeCreated = nodeType.prototype.onNodeCreated; + + nodeType.prototype.onNodeCreated = function () { + const result = onNodeCreated?.apply(this, arguments); + + // Create widgets + const value_display = createWidget(app, this, "value_display", "STRING"); + }; + + // Update widgets + const onExecuted = nodeType.prototype.onExecuted; + nodeType.prototype.onExecuted = function (message) { + onExecuted?.apply(this, arguments); + this.widgets.find(obj => obj.name === "value_display").value = message.text[1] + this.widgets.find(obj => obj.name === "parameter").options.values = message.text[0] + if (this.widgets.find(obj => obj.name === "parameter").value === "parameters not loaded") { + this.widgets.find(obj => obj.name === "parameter").value = message.text[0][0] + } + }; + } + }, +}); \ No newline at end of file diff --git a/nodes.py b/nodes.py index 0734a22..d4376a3 100644 --- a/nodes.py +++ b/nodes.py @@ -12,6 +12,7 @@ import torch import json +import re import numpy as np from pathlib import Path from PIL import Image, ImageOps @@ -936,6 +937,106 @@ def VALIDATE_INPUTS( return True +class SDParameterExtractor: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "settings": ( + "STRING", + {"default": "", "multiline": True, "forceInput": True}, + ) + }, + "optional": { + "parameter": ( + ["parameters not loaded"], + {"default": "parameters not loaded"}, + ), + "value_type": (["STRING", "INT", "FLOAT"], {"default": "STRING"}), + "parameter_index": ( + "INT", + {"default": 0, "min": 0, "max": 255, "step": 1}, + ), + }, + } + + RETURN_TYPES = (any_type,) + + RETURN_NAMES = ("VALUE",) + FUNCTION = "extract_param" + CATEGORY = "SD Prompt Reader" + + def extract_param( + self, + settings: str = "", + parameter: str = "", + value_type: str = "STRING", + parameter_index: int = 0, + ): + setting_dict = self.parse_setting(settings) + if not settings or not parameter or parameter == "parameters not loaded": + return { + "ui": { + "text": (list(setting_dict.keys()), ""), + }, + "result": ("",), + } + + result = setting_dict.get(parameter) + + try: + if isinstance(result, tuple): + result = result[parameter_index] + if value_type == "INT": + result = int(result) + elif value_type == "FLOAT": + result = float(result) + except IndexError: + return { + "ui": { + "text": (list(setting_dict.keys()), "Parameter index out of range"), + }, + "result": ("",), + } + except (ValueError, TypeError): + return { + "ui": { + "text": ( + list(setting_dict.keys()), + f"{parameter}: {result}\n" + f"{result} is not a valid number; it will be output as STRING", + ), + }, + "result": (result,), + } + return { + "ui": { + "text": (list(setting_dict.keys()), f"{parameter}: {result}"), + }, + "result": (result,), + } + + @staticmethod + def parse_setting(settings): + pattern = re.compile(r"([^:,]+):\s*\(([^)]+)\)|([^:,]+):\s*([^,]+)") + + matches = pattern.findall(settings) + + result = {} + for match in matches: + key, value_paren, key_nonparen, value_nonparen = match + if key: + key = key.strip() + value = value_paren.strip() + value = tuple(v.strip() for v in value.split(",")) + else: + key = key_nonparen.strip() + value = value_nonparen.strip() + result[key] = value + + return result + + NODE_CLASS_MAPPINGS = { "SDPromptReader": SDPromptReader, "SDPromptSaver": SDPromptSaver, @@ -943,6 +1044,7 @@ def VALIDATE_INPUTS( "SDPromptMerger": SDPromptMerger, "SDTypeConverter": SDTypeConverter, "SDBatchLoader": SDBatchLoader, + "SDParameterExtractor": SDParameterExtractor, } NODE_DISPLAY_NAME_MAPPINGS = { @@ -952,4 +1054,5 @@ def VALIDATE_INPUTS( "SDPromptMerger": "SD Prompt Merger", "SDTypeConverter": "SD Type Converter", "SDBatchLoader": "SD Batch Loader", + "SDParameterExtractor": "SD Parameter Extractor", }