Skip to content

Commit

Permalink
Merge pull request #14528 from AUTOMATIC1111/mass-file-lister
Browse files Browse the repository at this point in the history
mass file lister as an attempt to tackle #14507
  • Loading branch information
AUTOMATIC1111 authored Jan 4, 2024
2 parents 04a005f + 320a217 commit 149c9d2
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 9 deletions.
5 changes: 3 additions & 2 deletions modules/extra_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def parse_prompts(prompts):
return res, extra_data


def get_user_metadata(filename):
def get_user_metadata(filename, lister=None):
if filename is None:
return {}

Expand All @@ -215,7 +215,8 @@ def get_user_metadata(filename):

metadata = {}
try:
if os.path.isfile(metadata_filename):
exists = lister.exists(metadata_filename) if lister else os.path.exists(metadata_filename)
if exists:
with open(metadata_filename, "r", encoding="utf8") as file:
metadata = json.load(file)
except Exception as e:
Expand Down
20 changes: 13 additions & 7 deletions modules/ui_extra_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import urllib.parse
from pathlib import Path

from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks
from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks, util
from modules.images import read_info_from_image, save_image_with_geninfo
import gradio as gr
import json
Expand Down Expand Up @@ -107,13 +107,14 @@ def __init__(self, title):
self.allow_negative_prompt = False
self.metadata = {}
self.items = {}
self.lister = util.MassFileLister()

def refresh(self):
pass

def read_user_metadata(self, item):
filename = item.get("filename", None)
metadata = extra_networks.get_user_metadata(filename)
metadata = extra_networks.get_user_metadata(filename, lister=self.lister)

desc = metadata.get("description", None)
if desc is not None:
Expand All @@ -123,7 +124,7 @@ def read_user_metadata(self, item):

def link_preview(self, filename):
quoted_filename = urllib.parse.quote(filename.replace('\\', '/'))
mtime = os.path.getmtime(filename)
mtime, _ = self.lister.mctime(filename)
return f"./sd_extra_networks/thumb?filename={quoted_filename}&mtime={mtime}"

def search_terms_from_path(self, filename, possible_directories=None):
Expand All @@ -137,6 +138,8 @@ def search_terms_from_path(self, filename, possible_directories=None):
return ""

def create_html(self, tabname):
self.lister.reset()

items_html = ''

self.metadata = {}
Expand Down Expand Up @@ -282,10 +285,10 @@ def get_sort_keys(self, path):
List of default keys used for sorting in the UI.
"""
pth = Path(path)
stat = pth.stat()
mtime, ctime = self.lister.mctime(path)
return {
"date_created": int(stat.st_ctime or 0),
"date_modified": int(stat.st_mtime or 0),
"date_created": int(mtime),
"date_modified": int(ctime),
"name": pth.name.lower(),
"path": str(pth.parent).lower(),
}
Expand All @@ -298,7 +301,7 @@ def find_preview(self, path):
potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in allowed_preview_extensions()], [])

for file in potential_files:
if os.path.isfile(file):
if self.lister.exists(file):
return self.link_preview(file)

return None
Expand All @@ -308,6 +311,9 @@ def find_description(self, path):
Find and read a description file for a given path (without extension).
"""
for file in [f"{path}.txt", f"{path}.description.txt"]:
if not self.lister.exists(file):
continue

try:
with open(file, "r", encoding="utf-8", errors="replace") as f:
return f.read()
Expand Down
70 changes: 70 additions & 0 deletions modules/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,73 @@ def truncate_path(target_path, base_path=cwd):
except ValueError:
pass
return abs_target


class MassFileListerCachedDir:
"""A class that caches file metadata for a specific directory."""

def __init__(self, dirname):
self.files = None
self.files_cased = None
self.dirname = dirname

stats = ((x.name, x.stat(follow_symlinks=False)) for x in os.scandir(self.dirname))
files = [(n, s.st_mtime, s.st_ctime) for n, s in stats]
self.files = {x[0].lower(): x for x in files}
self.files_cased = {x[0]: x for x in files}


class MassFileLister:
"""A class that provides a way to check for the existence and mtime/ctile of files without doing more than one stat call per file."""

def __init__(self):
self.cached_dirs = {}

def find(self, path):
"""
Find the metadata for a file at the given path.
Returns:
tuple or None: A tuple of (name, mtime, ctime) if the file exists, or None if it does not.
"""

dirname, filename = os.path.split(path)

cached_dir = self.cached_dirs.get(dirname)
if cached_dir is None:
cached_dir = MassFileListerCachedDir(dirname)
self.cached_dirs[dirname] = cached_dir

stats = cached_dir.files_cased.get(filename)
if stats is not None:
return stats

stats = cached_dir.files.get(filename.lower())
if stats is None:
return None

try:
os_stats = os.stat(path, follow_symlinks=False)
return filename, os_stats.st_mtime, os_stats.st_ctime
except Exception:
return None

def exists(self, path):
"""Check if a file exists at the given path."""

return self.find(path) is not None

def mctime(self, path):
"""
Get the modification and creation times for a file at the given path.
Returns:
tuple: A tuple of (mtime, ctime) if the file exists, or (0, 0) if it does not.
"""

stats = self.find(path)
return (0, 0) if stats is None else stats[1:3]

def reset(self):
"""Clear the cache of all directories."""
self.cached_dirs.clear()

0 comments on commit 149c9d2

Please sign in to comment.