mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-09 05:54:24 +08:00
Merge 65935d512f980695306c82b75a6b067bdb77dd85 into 56fa7dbe380cb5591c5542f8aa51ce2fc26beedf
This commit is contained in:
commit
b2ec09319a
@ -2,8 +2,9 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
from einops import repeat
|
||||
from einops import repeat, rearrange
|
||||
|
||||
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
@ -54,7 +55,6 @@ class FeedForward(nn.Module):
|
||||
def apply_rotary_emb(x, freqs_cis):
|
||||
if x.shape[1] == 0:
|
||||
return x
|
||||
|
||||
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
|
||||
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
|
||||
return t_out.reshape(*x.shape)
|
||||
@ -229,6 +229,7 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
encoder_hidden_states_mask: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
transformer_options={},
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
img_mod_params = self.img_mod(temb)
|
||||
@ -245,6 +246,7 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
hidden_states=img_modulated,
|
||||
encoder_hidden_states=txt_modulated,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
@ -288,8 +290,12 @@ class LastLayer(nn.Module):
|
||||
x = torch.addcmul(shift[:, None, :], self.norm(x), (1 + scale)[:, None, :])
|
||||
return x
|
||||
|
||||
|
||||
class QwenImageTransformer2DModel(nn.Module):
|
||||
# Constants for EliGen processing
|
||||
LATENT_TO_PIXEL_RATIO = 8 # Latents are 8x downsampled from pixel space
|
||||
PATCH_TO_LATENT_RATIO = 2 # 2x2 patches in latent space
|
||||
PATCH_TO_PIXEL_RATIO = 16 # Combined: 2x2 patches on 8x downsampled latents = 16x in pixel space
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
@ -316,7 +322,6 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
|
||||
|
||||
self.time_text_embed = QwenTimestepProjEmbeddings(
|
||||
embedding_dim=self.inner_dim,
|
||||
pooled_projection_dim=pooled_projection_dim,
|
||||
@ -365,6 +370,214 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
|
||||
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
|
||||
|
||||
def process_entity_masks(self, latents, prompt_emb, prompt_emb_mask, entity_prompt_emb,
|
||||
entity_prompt_emb_mask, entity_masks, height, width, image,
|
||||
cond_or_uncond=None, batch_size=None):
|
||||
"""
|
||||
Process entity masks and build spatial attention mask for EliGen.
|
||||
|
||||
Concatenates entity+global prompts, builds RoPE embeddings, creates attention mask
|
||||
enforcing spatial restrictions, and handles CFG batching with separate masks.
|
||||
|
||||
Based on: https://github.com/modelscope/DiffSynth-Studio
|
||||
"""
|
||||
num_entities = len(entity_prompt_emb)
|
||||
actual_batch_size = latents.shape[0]
|
||||
|
||||
has_positive = cond_or_uncond and 0 in cond_or_uncond
|
||||
has_negative = cond_or_uncond and 1 in cond_or_uncond
|
||||
is_cfg_batched = has_positive and has_negative
|
||||
|
||||
logging.debug(
|
||||
f"[EliGen Model] Processing {num_entities} entities for {height}x{width}px, "
|
||||
f"batch_size={actual_batch_size}, CFG_batched={is_cfg_batched}"
|
||||
)
|
||||
|
||||
# Concatenate entity + global prompts
|
||||
all_prompt_emb = entity_prompt_emb + [prompt_emb]
|
||||
all_prompt_emb = [self.txt_in(self.txt_norm(local_prompt_emb)) for local_prompt_emb in all_prompt_emb]
|
||||
all_prompt_emb = torch.cat(all_prompt_emb, dim=1)
|
||||
|
||||
# Build RoPE embeddings
|
||||
patch_h = height // self.PATCH_TO_PIXEL_RATIO
|
||||
patch_w = width // self.PATCH_TO_PIXEL_RATIO
|
||||
|
||||
entity_seq_lens = [int(mask.sum(dim=1)[0].item()) for mask in entity_prompt_emb_mask]
|
||||
|
||||
if prompt_emb_mask is not None:
|
||||
global_seq_len = int(prompt_emb_mask.sum(dim=1)[0].item())
|
||||
else:
|
||||
global_seq_len = int(prompt_emb.shape[1])
|
||||
|
||||
max_vid_index = max(patch_h // 2, patch_w // 2)
|
||||
|
||||
# Generate per-entity text RoPE (each entity starts from same offset)
|
||||
entity_txt_embs = []
|
||||
for entity_seq_len in entity_seq_lens:
|
||||
entity_ids = torch.arange(
|
||||
max_vid_index,
|
||||
max_vid_index + entity_seq_len,
|
||||
device=latents.device
|
||||
).reshape(1, -1, 1).repeat(1, 1, 3)
|
||||
|
||||
entity_rope = self.pe_embedder(entity_ids) # Keep shape [1, 1, seq, dim, 2, 2]
|
||||
entity_txt_embs.append(entity_rope)
|
||||
|
||||
# Generate global text RoPE
|
||||
global_ids = torch.arange(
|
||||
max_vid_index,
|
||||
max_vid_index + global_seq_len,
|
||||
device=latents.device
|
||||
).reshape(1, -1, 1).repeat(1, 1, 3)
|
||||
global_rope = self.pe_embedder(global_ids) # Keep shape [1, 1, seq, dim, 2, 2]
|
||||
|
||||
txt_rotary_emb = torch.cat(entity_txt_embs + [global_rope], dim=2) # Concatenate on sequence dimension
|
||||
|
||||
h_coords = torch.arange(-(patch_h - patch_h // 2), patch_h // 2, device=latents.device)
|
||||
w_coords = torch.arange(-(patch_w - patch_w // 2), patch_w // 2, device=latents.device)
|
||||
|
||||
img_ids = torch.zeros((patch_h, patch_w, 3), device=latents.device)
|
||||
img_ids[:, :, 0] = 0
|
||||
img_ids[:, :, 1] = h_coords.unsqueeze(1)
|
||||
img_ids[:, :, 2] = w_coords.unsqueeze(0)
|
||||
img_ids = img_ids.reshape(1, -1, 3)
|
||||
|
||||
img_rope = self.pe_embedder(img_ids) # Keep shape [1, 1, seq, dim, 2, 2]
|
||||
|
||||
logging.debug(f"[EliGen Model] RoPE shapes - img: {img_rope.shape}, txt: {txt_rotary_emb.shape}")
|
||||
|
||||
# Concatenate text and image RoPE embeddings on sequence dimension
|
||||
# Shape will be [1, 1, total_seq, dim, 2, 2] where total_seq = txt_seq + img_seq
|
||||
image_rotary_emb = torch.cat([txt_rotary_emb, img_rope], dim=2).to(dtype=latents.dtype)
|
||||
|
||||
# Prepare spatial masks
|
||||
repeat_dim = latents.shape[1]
|
||||
max_masks = entity_masks.shape[1]
|
||||
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
|
||||
|
||||
padded_h = height // self.LATENT_TO_PIXEL_RATIO
|
||||
padded_w = width // self.LATENT_TO_PIXEL_RATIO
|
||||
if entity_masks.shape[3] != padded_h or entity_masks.shape[4] != padded_w:
|
||||
pad_h = padded_h - entity_masks.shape[3]
|
||||
pad_w = padded_w - entity_masks.shape[4]
|
||||
logging.debug(f"[EliGen Model] Padding masks by ({pad_h}, {pad_w})")
|
||||
entity_masks = torch.nn.functional.pad(entity_masks, (0, pad_w, 0, pad_h), mode='constant', value=0)
|
||||
|
||||
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
|
||||
|
||||
global_mask = torch.ones((entity_masks[0].shape[0], entity_masks[0].shape[1], padded_h, padded_w),
|
||||
device=latents.device, dtype=latents.dtype)
|
||||
entity_masks = entity_masks + [global_mask]
|
||||
|
||||
# Patchify masks
|
||||
N = len(entity_masks)
|
||||
batch_size = int(entity_masks[0].shape[0])
|
||||
seq_lens = entity_seq_lens + [global_seq_len]
|
||||
total_seq_len = int(sum(seq_lens) + image.shape[1])
|
||||
|
||||
logging.debug(f"[EliGen Model] total_seq={total_seq_len}")
|
||||
|
||||
patched_masks = []
|
||||
for i in range(N):
|
||||
patched_mask = rearrange(
|
||||
entity_masks[i],
|
||||
"B C (H P) (W Q) -> B (H W) (C P Q)",
|
||||
H=height // self.PATCH_TO_PIXEL_RATIO,
|
||||
W=width // self.PATCH_TO_PIXEL_RATIO,
|
||||
P=self.PATCH_TO_LATENT_RATIO,
|
||||
Q=self.PATCH_TO_LATENT_RATIO
|
||||
)
|
||||
patched_masks.append(patched_mask)
|
||||
|
||||
# Build attention mask matrix
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, total_seq_len, total_seq_len),
|
||||
dtype=torch.bool
|
||||
).to(device=entity_masks[0].device)
|
||||
|
||||
# Calculate positions
|
||||
image_start = int(sum(seq_lens))
|
||||
image_end = int(total_seq_len)
|
||||
cumsum = [0]
|
||||
single_image_seq = int(image_end - image_start)
|
||||
|
||||
for length in seq_lens:
|
||||
cumsum.append(cumsum[-1] + length)
|
||||
|
||||
# Spatial restriction (prompt <-> image)
|
||||
for i in range(N):
|
||||
prompt_start = cumsum[i]
|
||||
prompt_end = cumsum[i+1]
|
||||
|
||||
# Create binary mask for which image patches this entity can attend to
|
||||
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
|
||||
image_mask = image_mask.unsqueeze(1).repeat(1, seq_lens[i], 1)
|
||||
|
||||
# Always repeat mask to match image sequence length
|
||||
repeat_time = single_image_seq // image_mask.shape[-1]
|
||||
image_mask = image_mask.repeat(1, 1, repeat_time)
|
||||
|
||||
# Bidirectional restriction:
|
||||
# - Entity prompt can only attend to its masked image regions
|
||||
attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
|
||||
# - Image patches can only be updated by prompts that own them
|
||||
attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
|
||||
|
||||
# Entity isolation
|
||||
for i in range(N):
|
||||
for j in range(N):
|
||||
if i == j:
|
||||
continue
|
||||
start_i, end_i = cumsum[i], cumsum[i+1]
|
||||
start_j, end_j = cumsum[j], cumsum[j+1]
|
||||
attention_mask[:, start_i:end_i, start_j:end_j] = False
|
||||
|
||||
# Convert to additive bias and handle CFG batching
|
||||
attention_mask = attention_mask.float()
|
||||
num_valid_connections = (attention_mask == 1).sum().item()
|
||||
attention_mask[attention_mask == 0] = float('-inf')
|
||||
attention_mask[attention_mask == 1] = 0
|
||||
attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
# Handle CFG batching: Create separate masks for positive and negative
|
||||
if is_cfg_batched and actual_batch_size > 1:
|
||||
# CFG batch: [positive, negative] - need different masks for each
|
||||
# Positive gets entity constraints, negative gets standard attention (all zeros)
|
||||
|
||||
logging.debug(
|
||||
"[EliGen Model] CFG batched detected - creating separate masks. "
|
||||
"Positive (index 0) gets entity mask, Negative (index 1) gets standard mask"
|
||||
)
|
||||
|
||||
# Create standard attention mask (all zeros = no constraints)
|
||||
standard_mask = torch.zeros_like(attention_mask)
|
||||
|
||||
# Stack masks according to cond_or_uncond order
|
||||
mask_list = []
|
||||
for cond_type in cond_or_uncond:
|
||||
if cond_type == 0: # Positive - use entity mask
|
||||
mask_list.append(attention_mask[0:1]) # Take first (and only) entity mask
|
||||
else: # Negative - use standard mask
|
||||
mask_list.append(standard_mask[0:1])
|
||||
|
||||
# Concatenate masks to match batch
|
||||
attention_mask = torch.cat(mask_list, dim=0)
|
||||
|
||||
logging.debug(
|
||||
f"[EliGen Model] Created {len(mask_list)} masks for CFG batch. "
|
||||
f"Final shape: {attention_mask.shape}"
|
||||
)
|
||||
|
||||
# Add head dimension: [B, 1, seq, seq]
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
logging.debug(
|
||||
f"[EliGen Model] Attention mask created: shape={attention_mask.shape}, "
|
||||
f"valid_connections={num_valid_connections}/{total_seq_len * total_seq_len}"
|
||||
)
|
||||
|
||||
return all_prompt_emb, image_rotary_emb, attention_mask
|
||||
|
||||
def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
@ -416,15 +629,60 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||
|
||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
|
||||
del ids, txt_ids, img_ids
|
||||
# Initialize attention mask (None for standard generation)
|
||||
eligen_attention_mask = None
|
||||
|
||||
hidden_states = self.img_in(hidden_states)
|
||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
||||
# Extract EliGen entity data
|
||||
entity_prompt_emb = kwargs.get("entity_prompt_emb", None)
|
||||
entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None)
|
||||
entity_masks = kwargs.get("entity_masks", None)
|
||||
|
||||
# Detect batch composition for CFG handling
|
||||
cond_or_uncond = transformer_options.get("cond_or_uncond", []) if transformer_options else []
|
||||
is_positive_cond = 0 in cond_or_uncond
|
||||
is_negative_cond = 1 in cond_or_uncond
|
||||
batch_size = x.shape[0]
|
||||
|
||||
if entity_prompt_emb is not None:
|
||||
logging.debug(
|
||||
f"[EliGen Forward] batch_size={batch_size}, cond_or_uncond={cond_or_uncond}, "
|
||||
f"has_positive={is_positive_cond}, has_negative={is_negative_cond}"
|
||||
)
|
||||
|
||||
if entity_prompt_emb is not None and entity_masks is not None and entity_prompt_emb_mask is not None and is_positive_cond:
|
||||
# EliGen path
|
||||
height = int(orig_shape[-2] * self.LATENT_TO_PIXEL_RATIO)
|
||||
width = int(orig_shape[-1] * self.LATENT_TO_PIXEL_RATIO)
|
||||
|
||||
encoder_hidden_states, image_rotary_emb, eligen_attention_mask = self.process_entity_masks(
|
||||
latents=x,
|
||||
prompt_emb=encoder_hidden_states,
|
||||
prompt_emb_mask=encoder_hidden_states_mask,
|
||||
entity_prompt_emb=entity_prompt_emb,
|
||||
entity_prompt_emb_mask=entity_prompt_emb_mask,
|
||||
entity_masks=entity_masks,
|
||||
height=height,
|
||||
width=width,
|
||||
image=hidden_states,
|
||||
cond_or_uncond=cond_or_uncond,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
hidden_states = self.img_in(hidden_states)
|
||||
|
||||
del img_ids
|
||||
|
||||
else:
|
||||
# Standard path
|
||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
|
||||
del ids, txt_ids, img_ids
|
||||
|
||||
hidden_states = self.img_in(hidden_states)
|
||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
||||
|
||||
if guidance is not None:
|
||||
guidance = guidance * 1000
|
||||
@ -446,9 +704,25 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
|
||||
out["txt"], out["img"] = block(
|
||||
hidden_states=args["img"],
|
||||
encoder_hidden_states=args["txt"],
|
||||
encoder_hidden_states_mask=args.get("encoder_hidden_states_mask"),
|
||||
temb=args["vec"],
|
||||
image_rotary_emb=args["pe"],
|
||||
attention_mask=args.get("attention_mask"),
|
||||
transformer_options=args["transformer_options"]
|
||||
)
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({
|
||||
"img": hidden_states,
|
||||
"txt": encoder_hidden_states,
|
||||
"encoder_hidden_states_mask": encoder_hidden_states_mask,
|
||||
"attention_mask": eligen_attention_mask,
|
||||
"vec": temb,
|
||||
"pe": image_rotary_emb,
|
||||
"transformer_options": transformer_options
|
||||
}, {"original_block": block_wrap})
|
||||
hidden_states = out["img"]
|
||||
encoder_hidden_states = out["txt"]
|
||||
else:
|
||||
@ -458,6 +732,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=eligen_attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
|
||||
@ -1484,6 +1484,19 @@ class QwenImage(BaseModel):
|
||||
ref_latents_method = kwargs.get("reference_latents_method", None)
|
||||
if ref_latents_method is not None:
|
||||
out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method)
|
||||
|
||||
# Handle EliGen entity data
|
||||
entity_prompt_emb = kwargs.get("entity_prompt_emb", None)
|
||||
if entity_prompt_emb is not None:
|
||||
out['entity_prompt_emb'] = comfy.conds.CONDList(entity_prompt_emb)
|
||||
|
||||
entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None)
|
||||
if entity_prompt_emb_mask is not None:
|
||||
out['entity_prompt_emb_mask'] = comfy.conds.CONDList(entity_prompt_emb_mask)
|
||||
|
||||
entity_masks = kwargs.get("entity_masks", None)
|
||||
if entity_masks is not None:
|
||||
out['entity_masks'] = comfy.conds.CONDRegular(entity_masks)
|
||||
return out
|
||||
|
||||
def extra_conds_shapes(self, **kwargs):
|
||||
|
||||
@ -1,6 +1,10 @@
|
||||
import node_helpers
|
||||
import comfy.utils
|
||||
import comfy.conds
|
||||
import math
|
||||
import torch
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
@ -103,6 +107,281 @@ class TextEncodeQwenImageEditPlus(io.ComfyNode):
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True)
|
||||
return io.NodeOutput(conditioning)
|
||||
|
||||
class TextEncodeQwenImageEliGen(io.ComfyNode):
|
||||
"""
|
||||
Entity-Level Image Generation (EliGen) conditioning node for Qwen Image model.
|
||||
|
||||
Allows specifying different prompts for different spatial regions using masks.
|
||||
Each entity (mask + prompt pair) will only influence its masked region through
|
||||
spatial attention masking.
|
||||
|
||||
Features:
|
||||
- Supports up to 8 entities per generation
|
||||
- Spatial attention masks prevent cross-entity contamination
|
||||
- Separate RoPE embeddings per entity (research-accurate)
|
||||
- Falls back to standard generation if no entities provided
|
||||
|
||||
Usage:
|
||||
1. Create spatial masks using LoadImageMask (white=entity, black=background)
|
||||
2. Use 'red', 'green', or 'blue' channel (NOT 'alpha' - it gets inverted)
|
||||
3. Provide entity-specific prompts for each masked region
|
||||
|
||||
Based on DiffSynth Studio: https://github.com/modelscope/DiffSynth-Studio
|
||||
"""
|
||||
|
||||
# Qwen Image model uses 2x2 patches on latents (which are 8x downsampled from pixels)
|
||||
PATCH_SIZE = 2
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="TextEncodeQwenImageEliGen",
|
||||
category="advanced/conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.Conditioning.Input("global_conditioning"),
|
||||
io.Latent.Input("latent"),
|
||||
io.Mask.Input("entity_mask_1", optional=True),
|
||||
io.String.Input("entity_prompt_1", multiline=True, dynamic_prompts=True, default=""),
|
||||
io.Mask.Input("entity_mask_2", optional=True),
|
||||
io.String.Input("entity_prompt_2", multiline=True, dynamic_prompts=True, default=""),
|
||||
io.Mask.Input("entity_mask_3", optional=True),
|
||||
io.String.Input("entity_prompt_3", multiline=True, dynamic_prompts=True, default=""),
|
||||
io.Mask.Input("entity_mask_4", optional=True),
|
||||
io.String.Input("entity_prompt_4", multiline=True, dynamic_prompts=True, default=""),
|
||||
io.Mask.Input("entity_mask_5", optional=True),
|
||||
io.String.Input("entity_prompt_5", multiline=True, dynamic_prompts=True, default=""),
|
||||
io.Mask.Input("entity_mask_6", optional=True),
|
||||
io.String.Input("entity_prompt_6", multiline=True, dynamic_prompts=True, default=""),
|
||||
io.Mask.Input("entity_mask_7", optional=True),
|
||||
io.String.Input("entity_prompt_7", multiline=True, dynamic_prompts=True, default=""),
|
||||
io.Mask.Input("entity_mask_8", optional=True),
|
||||
io.String.Input("entity_prompt_8", multiline=True, dynamic_prompts=True, default=""),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
clip,
|
||||
global_conditioning,
|
||||
latent,
|
||||
entity_prompt_1: str = "",
|
||||
entity_mask_1: Optional[torch.Tensor] = None,
|
||||
entity_prompt_2: str = "",
|
||||
entity_mask_2: Optional[torch.Tensor] = None,
|
||||
entity_prompt_3: str = "",
|
||||
entity_mask_3: Optional[torch.Tensor] = None,
|
||||
entity_prompt_4: str = "",
|
||||
entity_mask_4: Optional[torch.Tensor] = None,
|
||||
entity_prompt_5: str = "",
|
||||
entity_mask_5: Optional[torch.Tensor] = None,
|
||||
entity_prompt_6: str = "",
|
||||
entity_mask_6: Optional[torch.Tensor] = None,
|
||||
entity_prompt_7: str = "",
|
||||
entity_mask_7: Optional[torch.Tensor] = None,
|
||||
entity_prompt_8: str = "",
|
||||
entity_mask_8: Optional[torch.Tensor] = None
|
||||
) -> io.NodeOutput:
|
||||
|
||||
# Extract dimensions from latent tensor
|
||||
# latent["samples"] shape: [batch, channels, latent_h, latent_w]
|
||||
latent_samples = latent["samples"]
|
||||
unpadded_latent_height = latent_samples.shape[2] # Unpadded latent space
|
||||
unpadded_latent_width = latent_samples.shape[3] # Unpadded latent space
|
||||
|
||||
# Calculate padded dimensions (same logic as model's pad_to_patch_size with patch_size=2)
|
||||
# The model pads latents to be multiples of PATCH_SIZE
|
||||
pad_h = (cls.PATCH_SIZE - unpadded_latent_height % cls.PATCH_SIZE) % cls.PATCH_SIZE
|
||||
pad_w = (cls.PATCH_SIZE - unpadded_latent_width % cls.PATCH_SIZE) % cls.PATCH_SIZE
|
||||
latent_height = unpadded_latent_height + pad_h # Padded latent dimensions
|
||||
latent_width = unpadded_latent_width + pad_w # Padded latent dimensions
|
||||
|
||||
height = latent_height * 8 # Convert to pixel space for logging
|
||||
width = latent_width * 8
|
||||
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
logging.debug(f"[EliGen] Latent padding detected: {unpadded_latent_height}x{unpadded_latent_width} → {latent_height}x{latent_width}")
|
||||
logging.debug(f"[EliGen] Target generation dimensions: {height}x{width} pixels ({latent_height}x{latent_width} latent)")
|
||||
|
||||
# Collect entity prompts and masks
|
||||
entity_prompts = [entity_prompt_1, entity_prompt_2, entity_prompt_3, entity_prompt_4, entity_prompt_5, entity_prompt_6, entity_prompt_7, entity_prompt_8]
|
||||
entity_masks_raw = [entity_mask_1, entity_mask_2, entity_mask_3, entity_mask_4, entity_mask_5, entity_mask_6, entity_mask_7, entity_mask_8]
|
||||
|
||||
# Filter out entities with empty prompts or missing masks
|
||||
valid_entities = []
|
||||
for prompt, mask in zip(entity_prompts, entity_masks_raw):
|
||||
if prompt.strip() and mask is not None:
|
||||
valid_entities.append((prompt, mask))
|
||||
|
||||
# Log warning if some entities were skipped
|
||||
total_prompts_provided = len([p for p in entity_prompts if p.strip()])
|
||||
if len(valid_entities) < total_prompts_provided:
|
||||
logging.warning(f"[EliGen] Only {len(valid_entities)} of {total_prompts_provided} entity prompts have valid masks")
|
||||
|
||||
# If no valid entities, return standard conditioning
|
||||
if len(valid_entities) == 0:
|
||||
return io.NodeOutput(global_conditioning)
|
||||
|
||||
# Encode each entity prompt separately
|
||||
entity_prompt_emb_list = []
|
||||
entity_prompt_emb_mask_list = []
|
||||
|
||||
for entity_prompt, _ in valid_entities: # mask not used at this point
|
||||
entity_tokens = clip.tokenize(entity_prompt)
|
||||
entity_cond_dict = clip.encode_from_tokens(entity_tokens, return_pooled=True, return_dict=True)
|
||||
entity_prompt_emb = entity_cond_dict["cond"]
|
||||
entity_prompt_emb_mask = entity_cond_dict.get("attention_mask", None)
|
||||
|
||||
# If no attention mask in extra_dict, create one (all True)
|
||||
if entity_prompt_emb_mask is None:
|
||||
seq_len = entity_prompt_emb.shape[1]
|
||||
entity_prompt_emb_mask = torch.ones((entity_prompt_emb.shape[0], seq_len),
|
||||
dtype=torch.bool, device=entity_prompt_emb.device)
|
||||
|
||||
|
||||
entity_prompt_emb_list.append(entity_prompt_emb)
|
||||
entity_prompt_emb_mask_list.append(entity_prompt_emb_mask)
|
||||
|
||||
# Process spatial masks to latent space
|
||||
processed_entity_masks = []
|
||||
for i, (_, mask) in enumerate(valid_entities):
|
||||
# MASK type format: [batch, height, width] (no channel dimension)
|
||||
# This is different from IMAGE type which is [batch, height, width, channels]
|
||||
mask_tensor = mask
|
||||
|
||||
# Validate mask dtype
|
||||
if mask_tensor.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
raise TypeError(
|
||||
f"Entity {i+1} mask has invalid dtype {mask_tensor.dtype}. "
|
||||
f"Expected float32, float16, or bfloat16. "
|
||||
f"Ensure you're using LoadImageMask node, not LoadImage."
|
||||
)
|
||||
|
||||
# Log original mask statistics
|
||||
logging.debug(
|
||||
f"[EliGen] Entity {i+1} input mask: shape={mask_tensor.shape}, "
|
||||
f"dtype={mask_tensor.dtype}, min={mask_tensor.min():.4f}, max={mask_tensor.max():.4f}"
|
||||
)
|
||||
|
||||
# Check for all-zero masks (common error when wrong channel selected)
|
||||
if mask_tensor.max() == 0.0:
|
||||
raise ValueError(
|
||||
f"Entity {i+1} mask is all zeros! This usually means:\n"
|
||||
f" 1. Wrong channel selected in LoadImageMask (use 'red', 'green', or 'blue', NOT 'alpha')\n"
|
||||
f" 2. Your mask image is completely black\n"
|
||||
f" 3. The mask file failed to load"
|
||||
)
|
||||
|
||||
# Check for constant masks (no variation)
|
||||
if mask_tensor.min() == mask_tensor.max() and mask_tensor.max() > 0:
|
||||
logging.warning(
|
||||
f"[EliGen] Entity {i+1} mask has no variation (all pixels = {mask_tensor.min():.4f}). "
|
||||
f"This entity will affect the entire image."
|
||||
)
|
||||
|
||||
# Extract original dimensions
|
||||
original_shape = mask_tensor.shape
|
||||
if len(original_shape) == 2:
|
||||
# [height, width] - single mask without batch
|
||||
orig_h, orig_w = original_shape[0], original_shape[1]
|
||||
# Add batch dimension: [1, height, width]
|
||||
mask_tensor = mask_tensor.unsqueeze(0)
|
||||
elif len(original_shape) == 3:
|
||||
# [batch, height, width] - standard MASK format
|
||||
orig_h, orig_w = original_shape[1], original_shape[2]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Entity {i+1} has unexpected mask shape: {original_shape}. "
|
||||
f"Expected [H, W] or [B, H, W]. Got {len(original_shape)} dimensions."
|
||||
)
|
||||
|
||||
# Log size mismatch if mask doesn't match expected latent dimensions
|
||||
expected_h, expected_w = latent_height * 8, latent_width * 8
|
||||
if orig_h != expected_h or orig_w != expected_w:
|
||||
logging.info(
|
||||
f"[EliGen] Entity {i+1} mask size mismatch: {orig_h}x{orig_w} vs expected {expected_h}x{expected_w}. "
|
||||
f"Will resize to {latent_height}x{latent_width} latent space."
|
||||
)
|
||||
else:
|
||||
logging.debug(f"[EliGen] Entity {i+1} mask: {orig_h}x{orig_w} → will resize to {latent_height}x{latent_width} latent")
|
||||
|
||||
# Convert MASK format [batch, height, width] to [batch, 1, height, width] for common_upscale
|
||||
# common_upscale expects [batch, channels, height, width]
|
||||
mask_tensor = mask_tensor.unsqueeze(1) # Add channel dimension: [batch, 1, height, width]
|
||||
|
||||
# Resize to latent space dimensions using nearest neighbor
|
||||
resized_mask = comfy.utils.common_upscale(
|
||||
mask_tensor,
|
||||
latent_width,
|
||||
latent_height,
|
||||
upscale_method="nearest-exact",
|
||||
crop="disabled"
|
||||
)
|
||||
|
||||
# Threshold to binary (0 or 1)
|
||||
# Use > 0 instead of > 0.5 to preserve edge pixels from nearest-neighbor downsampling
|
||||
resized_mask = (resized_mask > 0).float()
|
||||
|
||||
# Log how many pixels are active in the mask
|
||||
active_pixels = (resized_mask > 0).sum().item()
|
||||
total_pixels = resized_mask.numel()
|
||||
coverage_pct = 100 * active_pixels / total_pixels if total_pixels > 0 else 0
|
||||
|
||||
if active_pixels == 0:
|
||||
raise ValueError(
|
||||
f"Entity {i+1} mask has no active pixels after resizing to latent space! "
|
||||
f"Original mask may have been too small or all black."
|
||||
)
|
||||
|
||||
logging.debug(
|
||||
f"[EliGen] Entity {i+1} mask coverage: {active_pixels}/{total_pixels} pixels ({coverage_pct:.1f}%)"
|
||||
)
|
||||
|
||||
processed_entity_masks.append(resized_mask)
|
||||
|
||||
# Stack masks: [batch, num_entities, 1, latent_height, latent_width]
|
||||
# Each item in processed_entity_masks has shape [1, 1, H, W] (batch=1, channel=1)
|
||||
# We need to remove batch dim, stack, then add it back
|
||||
processed_entity_masks_no_batch = [m.squeeze(0) for m in processed_entity_masks] # Each: [1, H, W]
|
||||
entity_masks_tensor = torch.stack(processed_entity_masks_no_batch, dim=0) # [num_entities, 1, H, W]
|
||||
entity_masks_tensor = entity_masks_tensor.unsqueeze(0) # [1, num_entities, 1, H, W]
|
||||
|
||||
logging.debug(
|
||||
f"[EliGen] Stacked {len(valid_entities)} entity masks into tensor: "
|
||||
f"shape={entity_masks_tensor.shape} (expected: [1, {len(valid_entities)}, 1, {latent_height}, {latent_width}])"
|
||||
)
|
||||
|
||||
# Extract global prompt embedding and mask from conditioning
|
||||
# Conditioning format: [[cond_tensor, extra_dict]]
|
||||
global_prompt_emb = global_conditioning[0][0] # The embedding tensor directly
|
||||
global_extra_dict = global_conditioning[0][1] # Metadata dict
|
||||
|
||||
global_prompt_emb_mask = global_extra_dict.get("attention_mask", None)
|
||||
|
||||
# If no attention mask, create one (all True)
|
||||
if global_prompt_emb_mask is None:
|
||||
global_prompt_emb_mask = torch.ones((global_prompt_emb.shape[0], global_prompt_emb.shape[1]),
|
||||
dtype=torch.bool, device=global_prompt_emb.device)
|
||||
|
||||
# Attach entity data to conditioning using conditioning_set_values
|
||||
entity_data = {
|
||||
"entity_prompt_emb": entity_prompt_emb_list,
|
||||
"entity_prompt_emb_mask": entity_prompt_emb_mask_list,
|
||||
"entity_masks": entity_masks_tensor,
|
||||
}
|
||||
|
||||
conditioning_with_entities = node_helpers.conditioning_set_values(
|
||||
global_conditioning,
|
||||
entity_data,
|
||||
append=True
|
||||
)
|
||||
|
||||
return io.NodeOutput(conditioning_with_entities)
|
||||
|
||||
|
||||
class QwenExtension(ComfyExtension):
|
||||
@override
|
||||
@ -110,6 +389,7 @@ class QwenExtension(ComfyExtension):
|
||||
return [
|
||||
TextEncodeQwenImageEdit,
|
||||
TextEncodeQwenImageEditPlus,
|
||||
TextEncodeQwenImageEliGen,
|
||||
]
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user