Skip to content

Commit

Permalink
update wdtagger
Browse files Browse the repository at this point in the history
  • Loading branch information
sdbds committed Mar 1, 2025
1 parent a119303 commit 137135a
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 26 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ cython_debug/
.pypirc
datasets/
wd14_tagger_model/
data/

# windsurf rules
.windsurfrules
95 changes: 69 additions & 26 deletions utils/wdtagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import csv
import lance
import pyarrow as pa
import pyarrow.compute as pc
from rich.console import Console
from rich.progress import (
Progress,
Expand All @@ -23,7 +22,7 @@
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from module.lanceImport import transform2lance
import io
import concurrent.futures

console = Console()

Expand All @@ -40,29 +39,51 @@ def preprocess_image(image):
image = np.array(image)
image = image[:, :, ::-1] # RGB->BGR

# pad to square
size = max(image.shape[0:2])
pad_x = size - image.shape[1]
pad_y = size - image.shape[0]
pad_l = pad_x // 2
pad_t = pad_y // 2
image = np.pad(
image,
((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)),
mode="constant",
constant_values=255,
# 使用更高效的方式计算填充
h, w = image.shape[:2]
size = max(h, w)

# 使用单个 pad 操作替代多次计算
pad_y, pad_x = size - h, size - w
pad_t, pad_l = pad_y // 2, pad_x // 2
pad_b, pad_r = pad_y - pad_t, pad_x - pad_l

# 使用更高效的填充
image = cv2.copyMakeBorder(
image, pad_t, pad_b, pad_l, pad_r, cv2.BORDER_CONSTANT, value=[255, 255, 255]
)

if size > IMAGE_SIZE:
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), cv2.INTER_AREA)
else:
image = Image.fromarray(image)
image = image.resize((IMAGE_SIZE, IMAGE_SIZE), Image.LANCZOS)
image = np.array(image)

image = image.astype(np.float32)
return image


def load_and_preprocess_batch(uris):
"""并行加载和预处理一批图像"""

def load_single_image(uri):
try:
return preprocess_image(Image.open(uri).convert("RGB"))
except Exception as e:
console.print(f"[red]Error processing {uri}: {str(e)}[/red]")
return None

with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
batch_images = list(executor.map(load_single_image, uris))

# 过滤掉加载失败的图像
valid_images = [(i, img) for i, img in enumerate(batch_images) if img is not None]
images = [img for _, img in valid_images]

return images


def process_batch(images, session, input_name):
"""处理图像批次"""
try:
Expand Down Expand Up @@ -157,6 +178,7 @@ def load_model_and_tags(args):
provider_options = {
"cudnn_conv_algo_search": "EXHAUSTIVE",
"do_copy_in_default_stream": True,
"arena_extend_strategy": "kNextPowerOfTwo",
}
providers_with_options = [
("CUDAExecutionProvider", provider_options),
Expand All @@ -175,17 +197,23 @@ def load_model_and_tags(args):


def main(args):
global console

# 初始化 Lance 数据集
if not isinstance(args.train_data_dir, lance.LanceDataset):
if args.train_data_dir.endswith(".lance"):
dataset = lance.dataset(args.train_data_dir)
elif any(file.suffix == '.lance' for file in Path(args.train_data_dir).glob('*')):
lance_file = next(file for file in Path(args.train_data_dir).glob('*') if file.suffix == '.lance')
dataset = lance.dataset(str(lance_file))
else:
console.print("[yellow]Converting dataset to Lance format...[/yellow]")
dataset = transform2lance(
args.train_data_dir,
output_name="dataset",
save_binary=False,
not_save_disk=False,
tag="WDtagger",
)
console.print("[green]Dataset converted to Lance format[/green]")

Expand Down Expand Up @@ -279,17 +307,30 @@ def main(args):
tag_freq = {}

# 先计算图片总数
total_images = len(dataset.to_table(columns=["mime"], filter="mime LIKE 'image/%'"))
total_images = len(
dataset.to_table(
columns=["mime", "captions"],
filter=(
"mime LIKE 'image/%'"
if args.overwrite
else "mime LIKE 'image/%' and (captions IS NULL OR array_length(captions) = 0)"
),
)
)

# 然后创建带columns的scanner处理数据
scanner = dataset.scanner(
columns=["uris", "mime"],
filter="mime LIKE 'image/%'",
columns=["uris", "mime", "captions"],
filter=(
"mime LIKE 'image/%'"
if args.overwrite
else "mime LIKE 'image/%' and (captions IS NULL OR array_length(captions) = 0)"
),
scan_in_order=True,
batch_size=args.batch_size,
batch_readahead=8,
fragment_readahead=2,
io_buffer_size=8 * 1024 * 1024, # 8MB buffer
batch_readahead=16,
fragment_readahead=4,
io_buffer_size=32 * 1024 * 1024, # 32MB buffer
late_materialization=True,
with_row_id=True,
)
Expand All @@ -313,20 +354,22 @@ def main(args):
) as progress:
task = progress.add_task("[bold cyan]Processing images...", total=total_images)

global console
console = progress.console

for batch in scanner.to_batches():
uris = batch["uris"].to_pylist() # 获取文件路径

batch_images = [
preprocess_image(Image.open(uri).convert("RGB")) for uri in uris
]
# 使用并行处理加载和预处理图像
batch_images = load_and_preprocess_batch(uris)

if not batch_images:
progress.update(task, advance=len(uris))
continue

# 处理批次
probs = process_batch(batch_images, ort_sess, input_name)
if probs is not None:
for path, prob in zip(batch["uris"].to_pylist(), probs):
for path, prob in zip(uris, probs):
# 获取高置信度的标签
found_tags = []
general_confidence = args.general_threshold or args.thresh
Expand Down Expand Up @@ -467,9 +510,9 @@ def setup_parser() -> argparse.ArgumentParser:
help="Threshold for character category tags (defaults to --thresh)",
)
parser.add_argument(
"--recursive",
"--overwrite",
action="store_true",
help="Search for images in subfolders recursively",
help="Skip processing images in subfolders",
)
parser.add_argument(
"--remove_underscore",
Expand Down

0 comments on commit 137135a

Please sign in to comment.