Skip to content

Commit

Permalink
Merge pull request #15205 from AUTOMATIC1111/callback_order
Browse files Browse the repository at this point in the history
Callback order
  • Loading branch information
AUTOMATIC1111 authored Mar 16, 2024
2 parents 9fd6932 + 1bbc8a1 commit 5bd2724
Show file tree
Hide file tree
Showing 7 changed files with 361 additions and 121 deletions.
41 changes: 41 additions & 0 deletions modules/extensions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import configparser
import dataclasses
import os
import threading
import re
Expand All @@ -22,6 +23,13 @@ def active():
return [x for x in extensions if x.enabled]


@dataclasses.dataclass
class CallbackOrderInfo:
name: str
before: list
after: list


class ExtensionMetadata:
filename = "metadata.ini"
config: configparser.ConfigParser
Expand Down Expand Up @@ -65,6 +73,22 @@ def parse_list(self, text):
# both "," and " " are accepted as separator
return [x for x in re.split(r"[,\s]+", text.strip()) if x]

def list_callback_order_instructions(self):
for section in self.config.sections():
if not section.startswith("callbacks/"):
continue

callback_name = section[10:]

if not callback_name.startswith(self.canonical_name):
errors.report(f"Callback order section for extension {self.canonical_name} is referencing the wrong extension: {section}")
continue

before = self.parse_list(self.config.get(section, 'Before', fallback=''))
after = self.parse_list(self.config.get(section, 'After', fallback=''))

yield CallbackOrderInfo(callback_name, before, after)


class Extension:
lock = threading.Lock()
Expand Down Expand Up @@ -188,6 +212,7 @@ def fetch_and_reset_hard(self, commit='origin'):

def list_extensions():
extensions.clear()
extension_paths.clear()

if shared.cmd_opts.disable_all_extensions:
print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
Expand Down Expand Up @@ -222,6 +247,7 @@ def list_extensions():
is_builtin = dirname == extensions_builtin_dir
extension = Extension(name=extension_dirname, path=path, enabled=extension_dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin, metadata=metadata)
extensions.append(extension)
extension_paths[extension.path] = extension
loaded_extensions[canonical_name] = extension

# check for requirements
Expand All @@ -240,4 +266,19 @@ def list_extensions():
continue


def find_extension(filename):
parentdir = os.path.dirname(os.path.realpath(filename))

while parentdir != filename:
extension = extension_paths.get(parentdir)
if extension is not None:
return extension

filename = parentdir
parentdir = os.path.dirname(filename)

return None


extensions: list[Extension] = []
extension_paths: dict[str, Extension] = {}
Loading

0 comments on commit 5bd2724

Please sign in to comment.