From 3810320f6e945936a431507690700690697d196e Mon Sep 17 00:00:00 2001 From: Antonio Cordero Balcazar Date: Sun, 20 Oct 2024 13:39:52 +0200 Subject: [PATCH] * Simpler versioning * Cache now takes wildcard definitions into consideration. * ComfyUI: node is not cached if wildcard processing is active. * Add a space with prefix and suffix if needed. * Checks requirements before installing them. --- README.md | 6 +- __init__.py | 10 -- install.py | 12 ++- ppp.py | 105 ++++++++++++++++++--- ppp_cache.py | 4 +- ppp_comfyui.py | 5 +- ppp_wildcards.py | 187 +++++++++++++++++++++++++++++++------- scripts/ppp_script.py | 13 ++- tests/tests.py | 7 +- tests/wildcards/test.yaml | 14 ++- 10 files changed, 281 insertions(+), 82 deletions(-) diff --git a/README.md b/README.md index a8d2eb5..ea183ae 100644 --- a/README.md +++ b/README.md @@ -3,9 +3,9 @@ The Prompt PostProcessor (PPP), formerly known as "sd-webui-sendtonegative", is an extension designed to process the prompt, possibly after other extensions have modified it. This extension is compatible with: * [AUTOMATIC1111 Stable Diffusion WebUI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) -* [SD.Next](https://github.com/vladmandic/automatic). * [Forge](https://github.com/lllyasviel/stable-diffusion-webui-forge) * [reForge](https://github.com/Panchovix/stable-diffusion-webui-reForge) +* [SD.Next](https://github.com/vladmandic/automatic). * ...and probably other forks * [ComfyUI](https://github.com/comfyanonymous/ComfyUI) @@ -49,6 +49,8 @@ On A1111 compatible webuis: 3. Click the Install button 4. Restart the webui +On SD.Next I recommend you disable the native wildcard processing. + On ComfyUI: 1. Go to Manager > Custom Nodes Manager @@ -130,7 +132,7 @@ The parameters, the filter, and the setting of a variable are optional. The para The wildcard identifier can contain globbing formatting, to read multiple wildcards and merge their choices. Note that if there are no parameters specified, the globbing will use the ones from the first wildcard that matches and have parameters (sorted by keys), so if you don't want that you might want to specify them. Also note that, unlike with Dynamic Prompts, the wildcard name has to be specified with its full path (unless you use globbing). -The filter can be used to filter specific choices from the wildcard. The filtering works before applying the choice conditions (if any). The surrounding quotes can be single or double. The filter is a comma separated list of an integer (positional choice index) or choice label. You can also compound them with "+". That is, the comma separated items act as an OR and the "+" inside them as an AND. Using labels can simplify the definitions of complex wildcards where you want to have direct access to specific choices on occasion (you don't need to create wildcards for each individual choice). There are some additional formats when using filters. You can specify "^wildcard" as a filter to use the filter of a previous wildcard in the chain. You can start the filter (regular or inherited) with "#" and it will not be applied to the current wildcard choices, but the filter will remain in memory to use by other descendant wildcards. You use "#" and "^" when you want to pass a filter to inner wildcards (see the test files). +The filter can be used to filter specific choices from the wildcard. The filtering works before applying the choice conditions (if any). The surrounding quotes can be single or double. The filter is a comma separated list of an integer (positional choice index; zero-based) or choice label. You can also compound them with "+". That is, the comma separated items act as an OR and the "+" inside them as an AND. Using labels can simplify the definitions of complex wildcards where you want to have direct access to specific choices on occasion (you don't need to create wildcards for each individual choice). There are some additional formats when using filters. You can specify "^wildcard" as a filter to use the filter of a previous wildcard in the chain. You can start the filter (regular or inherited) with "#" and it will not be applied to the current wildcard choices, but the filter will remain in memory to use by other descendant wildcards. You use "#" and "^" when you want to pass a filter to inner wildcards (see the test files). The variable value only applies during the evaluation of the selected choices and is discarded afterward (the variable keeps its original value if there was one). diff --git a/__init__.py b/__init__.py index 9fe3f43..fbcec4f 100644 --- a/__init__.py +++ b/__init__.py @@ -13,16 +13,6 @@ from .ppp_comfyui import PromptPostProcessorComfyUINode NODE_CLASS_MAPPINGS = {"ACBPromptPostProcessor": PromptPostProcessorComfyUINode} - NODE_DISPLAY_NAME_MAPPINGS = {"ACBPromptPostProcessor": "ACB Prompt Post Processor"} -MANIFEST = { - "name": "ACB Prompt Post Processor", - "version": PromptPostProcessorComfyUINode.VERSION, - "author": "ACB", - "project": "https://github.com/acorderob/sd-webui-prompt-postprocessor", - "description": "Node for processing prompts", - "license": "MIT", -} - __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] diff --git a/install.py b/install.py index 5fe51f0..abdd67c 100644 --- a/install.py +++ b/install.py @@ -1,5 +1,13 @@ import os -import launch requirements_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements.txt") -launch.run_pip(f'install -r "{requirements_filename}"', "requirements for Prompt Post-Processor") + +try: + from modules.launch_utils import requirements_met, run_pip # A1111 + + if not requirements_met(requirements_filename): + run_pip(f'install -r "{requirements_filename}"', "requirements for Prompt Post-Processor") +except ImportError: + import launch + + launch.run_pip(f'install -r "{requirements_filename}"', "requirements for Prompt Post-Processor") diff --git a/ppp.py b/ppp.py index 42ecf77..03cdb44 100644 --- a/ppp.py +++ b/ppp.py @@ -23,12 +23,12 @@ class PromptPostProcessor: # pylint: disable=too-few-public-methods,too-many-in """ @staticmethod - def get_version_from_pyproject() -> tuple: + def get_version_from_pyproject() -> str: """ Reads the version from the pyproject.toml file. Returns: - tuple: A tuple containing the version numbers. + str: The version string. """ version_str = "0.0.0" try: @@ -40,7 +40,7 @@ def get_version_from_pyproject() -> tuple: break except Exception as e: # pylint: disable=broad-exception-caught logging.getLogger().exception(e) - return tuple(map(int, version_str.split("."))) + return version_str NAME = "Prompt Post-Processor" VERSION = get_version_from_pyproject() @@ -175,6 +175,9 @@ def isComfyUI(self) -> bool: return self.env_info.get("app", "") == "comfyui" def __init_sysvars(self): + """ + Initializes the system variables. + """ self.system_variables = {} sdchecks = { "sd1": self.env_info.get("is_sd1", False), @@ -384,6 +387,16 @@ def __cleanup(self, text: str) -> str: return text def __processprompts(self, prompt, negative_prompt): + """ + Process the prompt and negative prompt. + + Args: + prompt (str): The prompt. + negative_prompt (str): The negative prompt. + + Returns: + tuple: A tuple containing the processed prompt and negative prompt. + """ self.user_variables = {} # Process prompt @@ -444,7 +457,7 @@ def process_prompt( seed: int = 0, ): """ - Process the prompt and negative prompt by moving content to the negative prompt, and cleaning up. + Initializes the random number generator and processes the prompt and negative prompt. Args: original_prompt (str): The original prompt. @@ -486,6 +499,18 @@ def process_prompt( return original_prompt, original_negative_prompt def parse_prompt(self, prompt_description: str, prompt: str, parser: lark.Lark, raise_parsing_error: bool = False): + """ + Parses a prompt using the specified parser. + + Args: + prompt_description (str): The description of the prompt. + prompt (str): The prompt to be parsed. + parser (lark.Lark): The parser to be used. + raise_parsing_error (bool): Whether to raise a parsing error. + + Returns: + Tree: The parsed prompt. + """ t1 = time.time() try: if self.debug_level == DEBUG_LEVEL.full: @@ -605,6 +630,16 @@ def __visit( return added_result def __get_original_node_content(self, node: lark.Tree | lark.Token, default=None) -> str: + """ + Get the original content of a node. + + Args: + node (Tree|Token): The node to get the content from. + default: The default value to return if the content is not found. + + Returns: + str: The original content of the node. + """ return ( node.meta.content if hasattr(node, "meta") and node.meta is not None and not node.meta.empty @@ -637,13 +672,35 @@ def __get_user_variable_value(self, name: str, evaluate=True, visit=False) -> st return v def __set_user_variable_value(self, name: str, value: str): + """ + Set the value of a user variable. + + Args: + name (str): The name of the user variable. + value (str): The value to be set. + """ self.__ppp.user_variables[name] = value def __remove_user_variable(self, name: str): + """ + Remove a user variable. + + Args: + name (str): The name of the user variable. + """ if name in self.__ppp.user_variables: del self.__ppp.user_variables[name] def __debug_end(self, construct: str, start_result: str, duration: float, info=None): + """ + Log the end of a construct processing. + + Args: + construct (str): The name of the construct. + start_result (str): The initial result. + duration (float): The duration of the processing. + info: Additional information to log. + """ if self.__ppp.debug_level == DEBUG_LEVEL.full: info = f"({info}) " if info is not None and info != "" else "" output = self.result[len(start_result) :] @@ -1208,6 +1265,8 @@ def __get_choices( if options.get("prefix", None) is not None else "" ) + if prefix != "" and re.match(r"\w", prefix[-1]): + prefix += " " for i, c in enumerate(selected_choices): t1 = time.time() choice_content_obj = c.get("content", c.get("text", None)) @@ -1227,12 +1286,23 @@ def __get_choices( if options.get("suffix", None) is not None else "" ) + if suffix != "" and re.match(r"\w", suffix[0]): + suffix = " " + suffix # remove comments results = [re.sub(r"\s*#[^\n]*(?:\n|$)", "", r, flags=re.DOTALL) for r in selected_choices_text] return prefix + separator.join(results) + suffix return "" def __convert_choices_options(self, options: Optional[lark.Tree]) -> dict: + """ + Convert the choices options to a dictionary. + + Args: + options (Tree): The choices options tree. + + Returns: + dict: The converted choices options. + """ if options is None: return None the_options = {} @@ -1263,6 +1333,15 @@ def __convert_choices_options(self, options: Optional[lark.Tree]) -> dict: return the_options def __convert_choice(self, choice: lark.Tree) -> dict: + """ + Convert the choice to a dictionary. + + Args: + choice (Tree): The choice tree. + + Returns: + dict: The converted choice. + """ the_choice = {} c_label_obj = choice.children[0] the_choice["labels"] = ( @@ -1276,6 +1355,12 @@ def __convert_choice(self, choice: lark.Tree) -> dict: return the_choice def __check_wildcard_initialization(self, wildcard: PPPWildcard): + """ + Initializes a wildcard if it hasn't been yet. + + Args: + wildcard (PPPWildcard): The wildcard to check. + """ choice_values = wildcard.choices options = wildcard.options if choice_values is None: @@ -1284,10 +1369,7 @@ def __check_wildcard_initialization(self, wildcard: PPPWildcard): n = 0 # we check the first choice to see if it is actually options if isinstance(wildcard.unprocessed_choices[0], dict): - if all( - k in ["sampler", "repeating", "count", "from", "to", "prefix", "suffix", "separator"] - for k in wildcard.unprocessed_choices[0].keys() - ): + if self.__ppp.wildcard_obj.is_dict_choices_options(wildcard.unprocessed_choices[0]): options = wildcard.unprocessed_choices[0] prefix = options.get("prefix", None) if prefix is not None and isinstance(prefix, str): @@ -1331,7 +1413,7 @@ def __check_wildcard_initialization(self, wildcard: PPPWildcard): # we process the choices for cv in wildcard.unprocessed_choices[n:]: if isinstance(cv, dict): - if all(k in ["labels", "weight", "if", "content", "text"] for k in cv.keys()): + if self.__ppp.wildcard_obj.is_dict_choice_options(cv): theif = cv.get("if", None) if theif is not None and isinstance(theif, str): try: @@ -1379,12 +1461,9 @@ def __check_wildcard_initialization(self, wildcard: PPPWildcard): wildcard.choices = choice_values t2 = time.time() if self.__ppp.debug_level == DEBUG_LEVEL.full: - self.__ppp.logger.debug( - f"Processed choices for wildcard '{wildcard.key}' ({t2-t1:.3f} seconds)" - ) + self.__ppp.logger.debug(f"Processed choices for wildcard '{wildcard.key}' ({t2-t1:.3f} seconds)") return (options, choice_values) - def wildcard(self, tree: lark.Tree): """ Process a wildcard construct in the tree. diff --git a/ppp_cache.py b/ppp_cache.py index d88473c..a77a87b 100644 --- a/ppp_cache.py +++ b/ppp_cache.py @@ -4,8 +4,8 @@ class PPPLRUCache: - ProcessInput = Tuple[int, str, str] - ProcessResult = Tuple[str, str] + ProcessInput = Tuple[int, int, str, str] # (seed, wildcards_hash, positive_prompt, negative_prompt) + ProcessResult = Tuple[str, str] # (positive_prompt, negative_prompt) def __init__(self, capacity: int): self.cache = OrderedDict() diff --git a/ppp_comfyui.py b/ppp_comfyui.py index e38245d..6a1a8fd 100644 --- a/ppp_comfyui.py +++ b/ppp_comfyui.py @@ -16,8 +16,6 @@ class PromptPostProcessorComfyUINode: - VERSION = PromptPostProcessor.VERSION - logger = None def __init__(self): @@ -27,6 +25,7 @@ def __init__(self): with open(grammar_filename, "r", encoding="utf-8") as file: self.grammar_content = file.read() self.wildcards_obj = PPPWildcards(lf.log) + self.logger.info(f"{PromptPostProcessor.NAME} {PromptPostProcessor.VERSION} initialized") class SmartType(str): def __ne__(self, other): @@ -354,6 +353,8 @@ def IS_CHANGED( cleanup_merge_attention, remove_extranetwork_tags, ): + if wc_process_wildcards: + return float("NaN") # since we can't detect changes in wildcards we assume they are always changed when enabled new_run = { # everything except debug_level "model": model, "modelname": modelname, diff --git a/ppp_wildcards.py b/ppp_wildcards.py index 3c018e0..57112d9 100644 --- a/ppp_wildcards.py +++ b/ppp_wildcards.py @@ -8,7 +8,37 @@ from ppp_logging import DEBUG_LEVEL +def deep_freeze(obj): + """ + Deep freeze an object. + + Args: + obj (object): The object to freeze. + + Returns: + object: The frozen object. + """ + if isinstance(obj, dict): + return tuple((k, deep_freeze(v)) for k, v in sorted(obj.items())) + elif isinstance(obj, list): + return tuple(deep_freeze(i) for i in obj) + elif isinstance(obj, set): + return tuple(deep_freeze(i) for i in sorted(obj)) + else: + return obj + + class PPPWildcard: + """ + A wildcard object. + + Attributes: + key (str): The key of the wildcard. + file (str): The path to the file where the wildcard is defined. + unprocessed_choices (list[str]): The unprocessed choices of the wildcard. + choices (list[dict]): The processed choices of the wildcard. + options (dict): The options of the wildcard. + """ def __init__(self, fullpath: str, key: str, choices: list[str]): self.key: str = key @@ -17,45 +47,67 @@ def __init__(self, fullpath: str, key: str, choices: list[str]): self.choices: list[dict] = None self.options: dict = None + def __hash__(self) -> int: + t = (self.key, deep_freeze(self.unprocessed_choices)) + return hash(t) + class PPPWildcards: + """ + A class to manage wildcards. + + Attributes: + wildcards (dict[str, PPPWildcard]): The wildcards. + """ DEFAULT_WILDCARDS_FOLDER = "wildcards" def __init__(self, logger): - self.logger: logging.Logger = logger - self.debug_level = DEBUG_LEVEL.none - self.wildcards_folders = [] + self.__logger: logging.Logger = logger + self.__debug_level = DEBUG_LEVEL.none + self.__wildcards_folders = [] + self.__wildcard_files = {} self.wildcards: dict[str, PPPWildcard] = {} - self.wildcard_files = {} + + def __hash__(self) -> int: + return hash(deep_freeze(self.wildcards)) def refresh_wildcards(self, debug_level: DEBUG_LEVEL, wildcards_folders: Optional[list[str]]): """ Initialize the wildcards. """ - self.debug_level = debug_level - self.wildcards_folders = wildcards_folders + self.__debug_level = debug_level + self.__wildcards_folders = wildcards_folders if wildcards_folders is not None: # if self.debug_level != DEBUG_LEVEL.none: # self.logger.info("Initializing wildcards...") # t1 = time.time() - for fullpath in list(self.wildcard_files.keys()): + for fullpath in list(self.__wildcard_files.keys()): path = os.path.dirname(fullpath) if not os.path.exists(fullpath) or not any( - os.path.commonpath([path, folder]) == folder for folder in self.wildcards_folders + os.path.commonpath([path, folder]) == folder for folder in self.__wildcards_folders ): self.__remove_wildcards_from_file(fullpath) - for f in self.wildcards_folders: + for f in self.__wildcards_folders: self.__get_wildcards_in_directory(f, f) # t2 = time.time() # if self.debug_level != DEBUG_LEVEL.none: # self.logger.info(f"Wildcards init time: {t2 - t1:.3f} seconds") else: - self.wildcards_folders = [] + self.__wildcards_folders = [] self.wildcards = {} - self.wildcard_files = {} + self.__wildcard_files = {} def get_wildcards(self, key: str) -> list[PPPWildcard]: + """ + Get all wildcards that match a key. + + Args: + key (str): The key to match. + + Returns: + list: A list of all wildcards that match the key. + """ keys = sorted(fnmatch.filter(self.wildcards.keys(), key)) return [self.wildcards[k] for k in keys] @@ -105,11 +157,11 @@ def __remove_wildcards_from_file(self, full_path: str, debug=True): full_path (str): The path to the file. debug (bool): Whether to print debug messages or not. """ - last_modified_cached = self.wildcard_files.get(full_path, None) - if debug and last_modified_cached is not None and self.debug_level != DEBUG_LEVEL.none: - self.logger.debug(f"Removing wildcards from file: {full_path}") - if full_path in self.wildcard_files.keys(): - del self.wildcard_files[full_path] + last_modified_cached = self.__wildcard_files.get(full_path, None) + if debug and last_modified_cached is not None and self.__debug_level != DEBUG_LEVEL.none: + self.__logger.debug(f"Removing wildcards from file: {full_path}") + if full_path in self.__wildcard_files.keys(): + del self.__wildcard_files[full_path] for key in list(self.wildcards.keys()): if self.wildcards[key].file == full_path: del self.wildcards[key] @@ -123,23 +175,60 @@ def __get_wildcards_in_file(self, base, full_path: str): full_path (str): The path to the file. """ last_modified = os.path.getmtime(full_path) - last_modified_cached = self.wildcard_files.get(full_path, None) - if last_modified_cached is not None and last_modified == self.wildcard_files[full_path]: + last_modified_cached = self.__wildcard_files.get(full_path, None) + if last_modified_cached is not None and last_modified == self.__wildcard_files[full_path]: return filename = os.path.basename(full_path) _, extension = os.path.splitext(filename) if extension not in (".txt", ".json", ".yaml", ".yml"): return self.__remove_wildcards_from_file(full_path, False) - if last_modified_cached is not None and self.debug_level != DEBUG_LEVEL.none: - self.logger.debug(f"Updating wildcards from file: {full_path}") + if last_modified_cached is not None and self.__debug_level != DEBUG_LEVEL.none: + self.__logger.debug(f"Updating wildcards from file: {full_path}") if extension == ".txt": self.__get_wildcards_in_text_file(full_path, base) elif extension in (".json", ".yaml", ".yml"): self.__get_wildcards_in_structured_file(full_path, base) - self.wildcard_files[full_path] = last_modified + self.__wildcard_files[full_path] = last_modified + + def is_dict_choices_options(self, d: dict) -> bool: + """ + Check if a dictionary is a valid choices options dictionary. + + Args: + d (dict): The dictionary to check. + + Returns: + bool: Whether the dictionary is a valid choices options dictionary or not. + """ + return all( + k in ["sampler", "repeating", "count", "from", "to", "prefix", "suffix", "separator"] for k in d.keys() + ) + + def is_dict_choice_options(self, d: dict) -> bool: + """ + Check if a dictionary is a valid choice options dictionary. + + Args: + d (dict): The dictionary to check. + + Returns: + bool: Whether the dictionary is a valid choice options dictionary or not. + """ + return all(k in ["labels", "weight", "if", "content", "text"] for k in d.keys()) - def __get_choices(self, obj: object, full_path: str, key_parts: list[str]) -> list[dict]: + def __get_choices(self, obj: object, full_path: str, key_parts: list[str]) -> list: + """ + We process the choices in the object and return them as a list. + + Args: + obj (object): the value of a wildcard + full_path (str): path to the file where the wildcard is defined + key_parts (list[str]): parts of the key for the wildcard + + Returns: + list: list of choices + """ choices = None if obj is not None: if isinstance(obj, (str, dict)): @@ -158,10 +247,7 @@ def __get_choices(self, obj: object, full_path: str, key_parts: list[str]) -> li # we create an anonymous wildcard choice = self.__create_anonymous_wildcard(full_path, key_parts, i, c) elif isinstance(c, dict): - if all( - k in ["sampler", "repeating", "count", "from", "to", "prefix", "suffix", "separator"] - for k in c.keys() - ) or all(k in ["labels", "weight", "if", "content", "text"] for k in c.keys()): + if self.is_dict_choices_options(c) or self.is_dict_choice_options(c): # we assume it is a choice or wildcard parameters in object format choice = c choice_content = choice.get("content", choice.get("text", None)) @@ -181,7 +267,7 @@ def __get_choices(self, obj: object, full_path: str, key_parts: list[str]) -> li else: invalid_choice = True if invalid_choice: - self.logger.warning( + self.__logger.warning( f"Invalid choice {i+1} in wildcard '{'/'.join(key_parts)}' in file '{full_path}'!" ) else: @@ -189,6 +275,19 @@ def __get_choices(self, obj: object, full_path: str, key_parts: list[str]) -> li return choices def __create_anonymous_wildcard(self, full_path, key_parts, i, content, options=None): + """ + Create an anonymous wildcard. + + Args: + full_path (str): The path to the file that contains it. + key_parts (list[str]): The parts of the key. + i (int): The index of the wildcard. + content (object): The content of the wildcard. + options (str): The options for the choice where the wildcard is defined. + + Returns: + str: The resulting value for the choice. + """ new_parts = key_parts + [f"#ANON_{i}"] self.__add_wildcard(content, full_path, new_parts) value = f"__{'/'.join(new_parts)}__" @@ -197,6 +296,14 @@ def __create_anonymous_wildcard(self, full_path, key_parts, i, content, options= return value def __add_wildcard(self, content: object, full_path: str, external_key_parts: list[str]): + """ + Add a wildcard to the wildcards dictionary. + + Args: + content (object): The content of the wildcard. + full_path (str): The path to the file that contains it. + external_key_parts (list[str]): The parts of the key. + """ key_parts = external_key_parts.copy() if isinstance(content, dict): key_parts.pop() @@ -206,14 +313,14 @@ def __add_wildcard(self, content: object, full_path: str, external_key_parts: li tmp_key_parts.extend(key.split("/")) fullkey = "/".join(tmp_key_parts) if self.wildcards.get(fullkey, None) is not None: - self.logger.warning( + self.__logger.warning( f"Duplicate wildcard '{fullkey}' in file '{full_path}' and '{self.wildcards[fullkey].file}'!" ) else: obj = self.__get_nested(content, key) choices = self.__get_choices(obj, full_path, tmp_key_parts) if choices is None: - self.logger.warning(f"Invalid wildcard '{fullkey}' in file '{full_path}'!") + self.__logger.warning(f"Invalid wildcard '{fullkey}' in file '{full_path}'!") else: self.wildcards[fullkey] = PPPWildcard(full_path, fullkey, choices) return @@ -222,21 +329,28 @@ def __add_wildcard(self, content: object, full_path: str, external_key_parts: li elif isinstance(content, (int, float, bool)): content = [str(content)] if not isinstance(content, list): - self.logger.warning(f"Invalid wildcard in file '{full_path}'!") + self.__logger.warning(f"Invalid wildcard in file '{full_path}'!") return fullkey = "/".join(key_parts) if self.wildcards.get(fullkey, None) is not None: - self.logger.warning( + self.__logger.warning( f"Duplicate wildcard '{fullkey}' in file '{full_path}' and '{self.wildcards[fullkey].file}'!" ) else: choices = self.__get_choices(content, full_path, key_parts) if choices is None: - self.logger.warning(f"Invalid wildcard '{fullkey}' in file '{full_path}'!") + self.__logger.warning(f"Invalid wildcard '{fullkey}' in file '{full_path}'!") else: self.wildcards[fullkey] = PPPWildcard(full_path, fullkey, choices) def __get_wildcards_in_structured_file(self, full_path, base): + """ + Get all wildcards in a structured file. + + Args: + full_path (str): The path to the file. + base (str): The base path for the wildcards. + """ external_key: str = os.path.relpath(os.path.splitext(full_path)[0], base) external_key_parts = external_key.split(os.sep) _, extension = os.path.splitext(full_path) @@ -248,6 +362,13 @@ def __get_wildcards_in_structured_file(self, full_path, base): self.__add_wildcard(content, full_path, external_key_parts) def __get_wildcards_in_text_file(self, full_path, base): + """ + Get all wildcards in a text file. + + Args: + full_path (str): The path to the file. + base (str): The base path for the wildcards. + """ external_key: str = os.path.relpath(os.path.splitext(full_path)[0], base) external_key_parts = external_key.split(os.sep) with open(full_path, "r", encoding="utf-8") as file: @@ -265,7 +386,7 @@ def __get_wildcards_in_directory(self, base: str, directory: str): directory (str): The path to the directory. """ if not os.path.exists(directory): - self.logger.warning(f"Wildcard directory '{directory}' does not exist!") + self.__logger.warning(f"Wildcard directory '{directory}' does not exist!") return for filename in os.listdir(directory): full_path = os.path.abspath(os.path.join(directory, filename)) diff --git a/scripts/ppp_script.py b/scripts/ppp_script.py index 9d459b9..e77c891 100644 --- a/scripts/ppp_script.py +++ b/scripts/ppp_script.py @@ -37,8 +37,6 @@ class PromptPostProcessorA1111Script(scripts.Script): __on_ui_settings(): Callback function for UI settings. """ - VERSION = PromptPostProcessor.VERSION - def __init__(self): """ Initializes the PromptPostProcessor object. @@ -60,6 +58,7 @@ def __init__(self): with open(grammar_filename, "r", encoding="utf-8") as file: self.grammar_content = file.read() self.wildcards_obj = PPPWildcards(lf.log) + self.ppp_logger.info(f"{PromptPostProcessor.NAME} {PromptPostProcessor.VERSION} initialized") def title(self): """ @@ -285,24 +284,24 @@ def process( for i, (prompttype, seed, prompt, negative_prompt) in enumerate(prompts_list): if self.ppp_debug_level != DEBUG_LEVEL.none: self.ppp_logger.info(f"processing prompts[{i+1}] ({prompttype})") - if self.lru_cache.get((seed, prompt, negative_prompt)) is None: + if self.lru_cache.get((seed, hash(self.wildcards_obj), prompt, negative_prompt)) is None: posp, negp = ppp.process_prompt(prompt, negative_prompt, seed) - self.lru_cache.put((seed, prompt, negative_prompt), (posp, negp)) + self.lru_cache.put((seed, hash(self.wildcards_obj), prompt, negative_prompt), (posp, negp)) # adds also the result so i2i doesn't process it unnecessarily - self.lru_cache.put((seed, posp, negp), (posp, negp)) + self.lru_cache.put((seed, hash(self.wildcards_obj), posp, negp), (posp, negp)) elif self.ppp_debug_level != DEBUG_LEVEL.none: self.ppp_logger.info("result already in cache") # updates the prompts if rpr is not None and rnr is not None: for i, (seed, prompt, negative_prompt) in enumerate(zip(calculated_seeds, rpr, rnr)): - found = self.lru_cache.get((seed, prompt, negative_prompt)) + found = self.lru_cache.get((seed, hash(self.wildcards_obj), prompt, negative_prompt)) if found is not None: rpr[i] = found[0] rnr[i] = found[1] if rph is not None and rnh is not None: for i, (seed, prompt, negative_prompt) in enumerate(zip(calculated_seeds, rph, rnh)): - found = self.lru_cache.get((seed, prompt, negative_prompt)) + found = self.lru_cache.get((seed, hash(self.wildcards_obj), prompt, negative_prompt)) if found is not None: rph[i] = found[0] rnh[i] = found[1] diff --git a/tests/tests.py b/tests/tests.py index b3e91c4..8245523 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -152,11 +152,6 @@ def __process( self.assertEqual(result_negative_prompt, eo.negative_prompt, "Incorrect negative prompt") seed += 1 - # Other tests - - def test_version(self): # reading ppp version - self.assertNotEqual(PromptPostProcessor.VERSION, (0, 0, 0), "Incorrect version") - # Send To Negative tests def test_stn_simple(self): # negtags with different parameters and separations @@ -824,7 +819,7 @@ def test_wc_wildcardPS_yaml(self): # yaml wildcard with object formatted choice def test_wc_anonymouswildcard_yaml(self): # yaml anonymous wildcard self.__process( - PromptPair("the choices are: __yaml/nesteddict__", ""), + PromptPair("the choices are: __yaml/anonwildcards__", ""), PromptPair("the choices are: six", ""), ppp=self.__nocupppp, ) diff --git a/tests/wildcards/test.yaml b/tests/wildcards/test.yaml index d1952dc..0062638 100644 --- a/tests/wildcards/test.yaml +++ b/tests/wildcards/test.yaml @@ -53,13 +53,17 @@ yaml: - one - two - nesteddict: + anonwildcards: - one - two - - # anonymous wildcard without options + - # choice without options and anonymous wildcard - three - four - - 3 if _is_sdxl: # anonymous wildcard with options - - five - - six + - 3 if _is_sdxl: # choice with options and anonymous wildcard + - five + - six - { weight: 1, text: [seven, eight] } # anonymous wildcard used in a choice in object format + - # choice without options and anonymous wildcard with parameters + - { count: 2, prefix: "#" } + - nine + - ten