Skip to content

Commit e9521cf

Browse files
Feature/qwen edit plus (#180)
* support qwen edit 2509 * fix attn type * revert loader && ruff * apply suggestions --------- Co-authored-by: zhuguoxuan.zgx <zhuguoxuan.zgx@alibaba-inc.com>
1 parent 0acf4cb commit e9521cf

File tree

7 files changed

+94
-20
lines changed

7 files changed

+94
-20
lines changed

diffsynth_engine/models/qwen_image/qwen_image_dit.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ def forward(
449449
cfg_parallel(
450450
(
451451
image,
452-
edit,
452+
*(edit if edit is not None else ()),
453453
timestep,
454454
text,
455455
text_seq_lens,
@@ -472,10 +472,12 @@ def forward(
472472
image = torch.cat([image, context_latents], dim=1)
473473
video_fhw += [(1, h // 2, w // 2)]
474474
if edit is not None:
475-
edit = edit.to(dtype=image.dtype)
476-
edit = self.patchify(edit)
477-
image = torch.cat([image, edit], dim=1)
478-
video_fhw += [(1, h // 2, w // 2)]
475+
for img in edit:
476+
img = img.to(dtype=image.dtype)
477+
edit_h, edit_w = img.shape[-2:]
478+
img = self.patchify(img)
479+
image = torch.cat([image, img], dim=1)
480+
video_fhw += [(1, edit_h // 2, edit_w // 2)]
479481

480482
rotary_emb = self.pos_embed(video_fhw, text_seq_len, image.device)
481483

diffsynth_engine/pipelines/qwen_image.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,14 @@ def __init__(
107107
dtype=config.model_dtype,
108108
)
109109
self.config = config
110+
# qwen image
110111
self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
111112
self.prompt_template_encode_start_idx = 34
112-
113+
# qwen image edit
113114
self.edit_prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
115+
# qwen image edit plus
116+
self.edit_plus_prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
117+
114118
self.edit_prompt_template_encode_start_idx = 64
115119

116120
# sampler
@@ -282,7 +286,7 @@ def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, sav
282286

283287
def unload_loras(self):
284288
self.dit.unload_loras()
285-
self.noise_scheduler.restore_scheduler_config()
289+
self.noise_scheduler.restore_config()
286290

287291
def apply_scheduler_config(self, scheduler_config: Dict):
288292
self.noise_scheduler.update_config(scheduler_config)
@@ -339,16 +343,27 @@ def encode_prompt(
339343
def encode_prompt_with_image(
340344
self,
341345
prompt: Union[str, List[str]],
342-
image: torch.Tensor,
346+
vae_image: List[torch.Tensor],
347+
condition_image: List[torch.Tensor], # edit plus
343348
num_images_per_prompt: int = 1,
344349
max_sequence_length: int = 1024,
350+
is_edit_plus: bool = True,
345351
):
346352
prompt = [prompt] if isinstance(prompt, str) else prompt
347353

348354
batch_size = len(prompt)
349355
template = self.edit_prompt_template_encode
350356
drop_idx = self.edit_prompt_template_encode_start_idx
351-
texts = [template.format(txt) for txt in prompt]
357+
if not is_edit_plus:
358+
template = self.edit_prompt_template_encode
359+
texts = [template.format(txt) for txt in prompt]
360+
image = vae_image
361+
else:
362+
template = self.edit_plus_prompt_template_encode
363+
img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"
364+
img_prompt = "".join([img_prompt_template.format(i + 1) for i in range(len(condition_image))])
365+
texts = [template.format(img_prompt + e) for e in prompt]
366+
image = condition_image
352367

353368
model_inputs = self.processor(text=texts, images=image, max_length=max_sequence_length + drop_idx)
354369
input_ids, attention_mask, pixel_values, image_grid_thw = (
@@ -454,7 +469,7 @@ def predict_noise_with_cfg(
454469
entity_masks = [torch.cat([mask, mask], dim=0) for mask in entity_masks]
455470
latents = torch.cat([latents, latents], dim=0)
456471
if image_latents is not None:
457-
image_latents = torch.cat([image_latents, image_latents], dim=0)
472+
image_latents = [torch.cat([image_latent, image_latent], dim=0) for image_latent in image_latents]
458473
if context_latents is not None:
459474
context_latents = torch.cat([context_latents, context_latents], dim=0)
460475
timestep = torch.cat([timestep, timestep], dim=0)
@@ -543,7 +558,8 @@ def __call__(
543558
self,
544559
prompt: str,
545560
negative_prompt: str = "",
546-
input_image: Image.Image | None = None, # use for img2img
561+
# single image for edit, list for edit plus(QwenImageEdit2509)
562+
input_image: List[Image.Image] | Image.Image | None = None,
547563
cfg_scale: float = 4.0, # true cfg
548564
height: int = 1328,
549565
width: int = 1328,
@@ -555,10 +571,20 @@ def __call__(
555571
entity_prompts: Optional[List[str]] = None,
556572
entity_masks: Optional[List[Image.Image]] = None,
557573
):
574+
is_edit_plus = isinstance(input_image, list)
558575
if input_image is not None:
559-
width, height = input_image.size
560-
width, height = self.calculate_dimensions(1024 * 1024, width / height)
561-
input_image = input_image.resize((width, height), Image.LANCZOS)
576+
if not isinstance(input_image, list):
577+
input_image = [input_image]
578+
condition_images = []
579+
vae_images = []
580+
for img in input_image:
581+
img_width, img_height = img.size
582+
condition_width, condition_height = self.calculate_dimensions(384 * 384, img_width / img_height)
583+
vae_width, vae_height = self.calculate_dimensions(1024 * 1024, img_width / img_height)
584+
condition_images.append(img.resize((condition_width, condition_height), Image.LANCZOS))
585+
vae_images.append(img.resize((vae_width, vae_height), Image.LANCZOS))
586+
587+
width, height = vae_images[-1].size
562588

563589
self.validate_image_size(height, width, minimum=64, multiple_of=16)
564590

@@ -567,7 +593,7 @@ def __call__(
567593

568594
context_latents = None
569595
for param in controlnet_params:
570-
self.load_lora(param.model, param.scale, fused=True, save_original_weight=False)
596+
self.load_lora(param.model, param.scale, fused=False, save_original_weight=False)
571597
if param.control_type == QwenImageControlType.in_context:
572598
width, height = param.image.size
573599
self.validate_image_size(height, width, minimum=64, multiple_of=16)
@@ -585,16 +611,18 @@ def __call__(
585611

586612
self.load_models_to_device(["vae"])
587613
if input_image:
588-
image_latents = self.prepare_image_latents(input_image)
614+
image_latents = [self.prepare_image_latents(img) for img in vae_images]
589615
else:
590616
image_latents = None
591617

592618
self.load_models_to_device(["encoder"])
593619
if image_latents is not None:
594-
prompt_emb, prompt_emb_mask = self.encode_prompt_with_image(prompt, input_image, 1, 4096)
620+
prompt_emb, prompt_emb_mask = self.encode_prompt_with_image(
621+
prompt, vae_images, condition_images, 1, 4096, is_edit_plus
622+
)
595623
if cfg_scale > 1.0 and negative_prompt != "":
596624
negative_prompt_emb, negative_prompt_emb_mask = self.encode_prompt_with_image(
597-
negative_prompt, input_image, 1, 4096
625+
negative_prompt, vae_images, condition_images, 1, 4096, is_edit_plus
598626
)
599627
else:
600628
negative_prompt_emb, negative_prompt_emb_mask = None, None
1.09 MB
Loading

tests/data/input/qwen_1.png

421 KB
Loading

tests/data/input/qwen_2.png

466 KB
Loading

tests/test_pipelines/test_qwen_image_controlnet.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ def test_incontext_canny(self):
3838
seed=42,
3939
controlnet_params=param,
4040
)
41-
self.assertImageEqualAndSaveFailed(image, "qwen_image/qwen_image_canny.png", threshold=0.99)
41+
self.assertImageEqualAndSaveFailed(image, "qwen_image/qwen_image_canny.png", threshold=0.95)
42+
self.pipe.unload_loras()
4243

4344
def test_incontext_depth(self):
4445
param = QwenImageControlNetParams(
@@ -54,7 +55,8 @@ def test_incontext_depth(self):
5455
seed=42,
5556
controlnet_params=param,
5657
)
57-
self.assertImageEqualAndSaveFailed(image, "qwen_image/qwen_image_depth.png", threshold=0.99)
58+
self.assertImageEqualAndSaveFailed(image, "qwen_image/qwen_image_depth.png", threshold=0.95)
59+
self.pipe.unload_loras()
5860

5961

6062
if __name__ == "__main__":
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import unittest
2+
import torch
3+
4+
from diffsynth_engine import QwenImagePipelineConfig
5+
from diffsynth_engine.pipelines import QwenImagePipeline
6+
from diffsynth_engine.utils.download import fetch_model
7+
from tests.common.test_case import ImageTestCase
8+
9+
10+
class TestQwenImageEditPlusPipeline(ImageTestCase):
11+
@classmethod
12+
def setUpClass(cls):
13+
config = QwenImagePipelineConfig(
14+
model_path=fetch_model("Qwen/Qwen-Image-Edit-2509", path="transformer/*.safetensors"),
15+
encoder_path=fetch_model("Qwen/Qwen-Image-Edit-2509", path="text_encoder/*.safetensors"),
16+
vae_path=fetch_model("Qwen/Qwen-Image-Edit-2509", path="vae/*.safetensors"),
17+
model_dtype=torch.bfloat16,
18+
encoder_dtype=torch.bfloat16,
19+
vae_dtype=torch.float32,
20+
)
21+
cls.pipe = QwenImagePipeline.from_pretrained(config)
22+
23+
@classmethod
24+
def tearDownClass(cls):
25+
del cls.pipe
26+
27+
def test_txt2img(self):
28+
image = self.pipe(
29+
prompt="根据这图1中女性和图2中的男性,生成一组结婚照,并遵循以下描述:新郎穿着红色的中式马褂,新娘穿着精致的秀禾服,头戴金色凤冠。他们并肩站立在古老的朱红色宫墙前,背景是雕花的木窗。光线明亮柔和,构图对称,氛围喜庆而庄重。",
30+
input_image=[self.get_input_image("qwen_1.png"), self.get_input_image("qwen_2.png")],
31+
negative_prompt=" ",
32+
cfg_scale=4.0,
33+
width=1328,
34+
height=1328,
35+
num_inference_steps=40,
36+
seed=42,
37+
)
38+
self.assertImageEqualAndSaveFailed(image, "qwen_image/qwen_image_edit_plus.png", threshold=0.95)
39+
40+
41+
if __name__ == "__main__":
42+
unittest.main()

0 commit comments

Comments
 (0)