Merge 65935d512f980695306c82b75a6b067bdb77dd85 into 56fa7dbe380cb5591c5542f8aa51ce2fc26beedf

This commit is contained in:
nolan4 2025-12-07 17:01:52 +01:00 committed by GitHub
commit b2ec09319a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 582 additions and 14 deletions

View File

@ -2,8 +2,9 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import logging
from typing import Optional, Tuple from typing import Optional, Tuple
from einops import repeat from einops import repeat, rearrange
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.modules.attention import optimized_attention_masked
@ -54,7 +55,6 @@ class FeedForward(nn.Module):
def apply_rotary_emb(x, freqs_cis): def apply_rotary_emb(x, freqs_cis):
if x.shape[1] == 0: if x.shape[1] == 0:
return x return x
t_ = x.reshape(*x.shape[:-1], -1, 1, 2) t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1] t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
return t_out.reshape(*x.shape) return t_out.reshape(*x.shape)
@ -229,6 +229,7 @@ class QwenImageTransformerBlock(nn.Module):
encoder_hidden_states_mask: torch.Tensor, encoder_hidden_states_mask: torch.Tensor,
temb: torch.Tensor, temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
transformer_options={}, transformer_options={},
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod_params = self.img_mod(temb) img_mod_params = self.img_mod(temb)
@ -245,6 +246,7 @@ class QwenImageTransformerBlock(nn.Module):
hidden_states=img_modulated, hidden_states=img_modulated,
encoder_hidden_states=txt_modulated, encoder_hidden_states=txt_modulated,
encoder_hidden_states_mask=encoder_hidden_states_mask, encoder_hidden_states_mask=encoder_hidden_states_mask,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
transformer_options=transformer_options, transformer_options=transformer_options,
) )
@ -288,8 +290,12 @@ class LastLayer(nn.Module):
x = torch.addcmul(shift[:, None, :], self.norm(x), (1 + scale)[:, None, :]) x = torch.addcmul(shift[:, None, :], self.norm(x), (1 + scale)[:, None, :])
return x return x
class QwenImageTransformer2DModel(nn.Module): class QwenImageTransformer2DModel(nn.Module):
# Constants for EliGen processing
LATENT_TO_PIXEL_RATIO = 8 # Latents are 8x downsampled from pixel space
PATCH_TO_LATENT_RATIO = 2 # 2x2 patches in latent space
PATCH_TO_PIXEL_RATIO = 16 # Combined: 2x2 patches on 8x downsampled latents = 16x in pixel space
def __init__( def __init__(
self, self,
patch_size: int = 2, patch_size: int = 2,
@ -316,7 +322,6 @@ class QwenImageTransformer2DModel(nn.Module):
self.inner_dim = num_attention_heads * attention_head_dim self.inner_dim = num_attention_heads * attention_head_dim
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope)) self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
self.time_text_embed = QwenTimestepProjEmbeddings( self.time_text_embed = QwenTimestepProjEmbeddings(
embedding_dim=self.inner_dim, embedding_dim=self.inner_dim,
pooled_projection_dim=pooled_projection_dim, pooled_projection_dim=pooled_projection_dim,
@ -365,6 +370,214 @@ class QwenImageTransformer2DModel(nn.Module):
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2) img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
def process_entity_masks(self, latents, prompt_emb, prompt_emb_mask, entity_prompt_emb,
entity_prompt_emb_mask, entity_masks, height, width, image,
cond_or_uncond=None, batch_size=None):
"""
Process entity masks and build spatial attention mask for EliGen.
Concatenates entity+global prompts, builds RoPE embeddings, creates attention mask
enforcing spatial restrictions, and handles CFG batching with separate masks.
Based on: https://github.com/modelscope/DiffSynth-Studio
"""
num_entities = len(entity_prompt_emb)
actual_batch_size = latents.shape[0]
has_positive = cond_or_uncond and 0 in cond_or_uncond
has_negative = cond_or_uncond and 1 in cond_or_uncond
is_cfg_batched = has_positive and has_negative
logging.debug(
f"[EliGen Model] Processing {num_entities} entities for {height}x{width}px, "
f"batch_size={actual_batch_size}, CFG_batched={is_cfg_batched}"
)
# Concatenate entity + global prompts
all_prompt_emb = entity_prompt_emb + [prompt_emb]
all_prompt_emb = [self.txt_in(self.txt_norm(local_prompt_emb)) for local_prompt_emb in all_prompt_emb]
all_prompt_emb = torch.cat(all_prompt_emb, dim=1)
# Build RoPE embeddings
patch_h = height // self.PATCH_TO_PIXEL_RATIO
patch_w = width // self.PATCH_TO_PIXEL_RATIO
entity_seq_lens = [int(mask.sum(dim=1)[0].item()) for mask in entity_prompt_emb_mask]
if prompt_emb_mask is not None:
global_seq_len = int(prompt_emb_mask.sum(dim=1)[0].item())
else:
global_seq_len = int(prompt_emb.shape[1])
max_vid_index = max(patch_h // 2, patch_w // 2)
# Generate per-entity text RoPE (each entity starts from same offset)
entity_txt_embs = []
for entity_seq_len in entity_seq_lens:
entity_ids = torch.arange(
max_vid_index,
max_vid_index + entity_seq_len,
device=latents.device
).reshape(1, -1, 1).repeat(1, 1, 3)
entity_rope = self.pe_embedder(entity_ids) # Keep shape [1, 1, seq, dim, 2, 2]
entity_txt_embs.append(entity_rope)
# Generate global text RoPE
global_ids = torch.arange(
max_vid_index,
max_vid_index + global_seq_len,
device=latents.device
).reshape(1, -1, 1).repeat(1, 1, 3)
global_rope = self.pe_embedder(global_ids) # Keep shape [1, 1, seq, dim, 2, 2]
txt_rotary_emb = torch.cat(entity_txt_embs + [global_rope], dim=2) # Concatenate on sequence dimension
h_coords = torch.arange(-(patch_h - patch_h // 2), patch_h // 2, device=latents.device)
w_coords = torch.arange(-(patch_w - patch_w // 2), patch_w // 2, device=latents.device)
img_ids = torch.zeros((patch_h, patch_w, 3), device=latents.device)
img_ids[:, :, 0] = 0
img_ids[:, :, 1] = h_coords.unsqueeze(1)
img_ids[:, :, 2] = w_coords.unsqueeze(0)
img_ids = img_ids.reshape(1, -1, 3)
img_rope = self.pe_embedder(img_ids) # Keep shape [1, 1, seq, dim, 2, 2]
logging.debug(f"[EliGen Model] RoPE shapes - img: {img_rope.shape}, txt: {txt_rotary_emb.shape}")
# Concatenate text and image RoPE embeddings on sequence dimension
# Shape will be [1, 1, total_seq, dim, 2, 2] where total_seq = txt_seq + img_seq
image_rotary_emb = torch.cat([txt_rotary_emb, img_rope], dim=2).to(dtype=latents.dtype)
# Prepare spatial masks
repeat_dim = latents.shape[1]
max_masks = entity_masks.shape[1]
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
padded_h = height // self.LATENT_TO_PIXEL_RATIO
padded_w = width // self.LATENT_TO_PIXEL_RATIO
if entity_masks.shape[3] != padded_h or entity_masks.shape[4] != padded_w:
pad_h = padded_h - entity_masks.shape[3]
pad_w = padded_w - entity_masks.shape[4]
logging.debug(f"[EliGen Model] Padding masks by ({pad_h}, {pad_w})")
entity_masks = torch.nn.functional.pad(entity_masks, (0, pad_w, 0, pad_h), mode='constant', value=0)
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
global_mask = torch.ones((entity_masks[0].shape[0], entity_masks[0].shape[1], padded_h, padded_w),
device=latents.device, dtype=latents.dtype)
entity_masks = entity_masks + [global_mask]
# Patchify masks
N = len(entity_masks)
batch_size = int(entity_masks[0].shape[0])
seq_lens = entity_seq_lens + [global_seq_len]
total_seq_len = int(sum(seq_lens) + image.shape[1])
logging.debug(f"[EliGen Model] total_seq={total_seq_len}")
patched_masks = []
for i in range(N):
patched_mask = rearrange(
entity_masks[i],
"B C (H P) (W Q) -> B (H W) (C P Q)",
H=height // self.PATCH_TO_PIXEL_RATIO,
W=width // self.PATCH_TO_PIXEL_RATIO,
P=self.PATCH_TO_LATENT_RATIO,
Q=self.PATCH_TO_LATENT_RATIO
)
patched_masks.append(patched_mask)
# Build attention mask matrix
attention_mask = torch.ones(
(batch_size, total_seq_len, total_seq_len),
dtype=torch.bool
).to(device=entity_masks[0].device)
# Calculate positions
image_start = int(sum(seq_lens))
image_end = int(total_seq_len)
cumsum = [0]
single_image_seq = int(image_end - image_start)
for length in seq_lens:
cumsum.append(cumsum[-1] + length)
# Spatial restriction (prompt <-> image)
for i in range(N):
prompt_start = cumsum[i]
prompt_end = cumsum[i+1]
# Create binary mask for which image patches this entity can attend to
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
image_mask = image_mask.unsqueeze(1).repeat(1, seq_lens[i], 1)
# Always repeat mask to match image sequence length
repeat_time = single_image_seq // image_mask.shape[-1]
image_mask = image_mask.repeat(1, 1, repeat_time)
# Bidirectional restriction:
# - Entity prompt can only attend to its masked image regions
attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
# - Image patches can only be updated by prompts that own them
attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
# Entity isolation
for i in range(N):
for j in range(N):
if i == j:
continue
start_i, end_i = cumsum[i], cumsum[i+1]
start_j, end_j = cumsum[j], cumsum[j+1]
attention_mask[:, start_i:end_i, start_j:end_j] = False
# Convert to additive bias and handle CFG batching
attention_mask = attention_mask.float()
num_valid_connections = (attention_mask == 1).sum().item()
attention_mask[attention_mask == 0] = float('-inf')
attention_mask[attention_mask == 1] = 0
attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype)
# Handle CFG batching: Create separate masks for positive and negative
if is_cfg_batched and actual_batch_size > 1:
# CFG batch: [positive, negative] - need different masks for each
# Positive gets entity constraints, negative gets standard attention (all zeros)
logging.debug(
"[EliGen Model] CFG batched detected - creating separate masks. "
"Positive (index 0) gets entity mask, Negative (index 1) gets standard mask"
)
# Create standard attention mask (all zeros = no constraints)
standard_mask = torch.zeros_like(attention_mask)
# Stack masks according to cond_or_uncond order
mask_list = []
for cond_type in cond_or_uncond:
if cond_type == 0: # Positive - use entity mask
mask_list.append(attention_mask[0:1]) # Take first (and only) entity mask
else: # Negative - use standard mask
mask_list.append(standard_mask[0:1])
# Concatenate masks to match batch
attention_mask = torch.cat(mask_list, dim=0)
logging.debug(
f"[EliGen Model] Created {len(mask_list)} masks for CFG batch. "
f"Final shape: {attention_mask.shape}"
)
# Add head dimension: [B, 1, seq, seq]
attention_mask = attention_mask.unsqueeze(1)
logging.debug(
f"[EliGen Model] Attention mask created: shape={attention_mask.shape}, "
f"valid_connections={num_valid_connections}/{total_seq_len * total_seq_len}"
)
return all_prompt_emb, image_rotary_emb, attention_mask
def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs): def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor( return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward, self._forward,
@ -416,15 +629,60 @@ class QwenImageTransformer2DModel(nn.Module):
hidden_states = torch.cat([hidden_states, kontext], dim=1) hidden_states = torch.cat([hidden_states, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1)
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) # Initialize attention mask (None for standard generation)
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) eligen_attention_mask = None
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
del ids, txt_ids, img_ids
hidden_states = self.img_in(hidden_states) # Extract EliGen entity data
encoder_hidden_states = self.txt_norm(encoder_hidden_states) entity_prompt_emb = kwargs.get("entity_prompt_emb", None)
encoder_hidden_states = self.txt_in(encoder_hidden_states) entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None)
entity_masks = kwargs.get("entity_masks", None)
# Detect batch composition for CFG handling
cond_or_uncond = transformer_options.get("cond_or_uncond", []) if transformer_options else []
is_positive_cond = 0 in cond_or_uncond
is_negative_cond = 1 in cond_or_uncond
batch_size = x.shape[0]
if entity_prompt_emb is not None:
logging.debug(
f"[EliGen Forward] batch_size={batch_size}, cond_or_uncond={cond_or_uncond}, "
f"has_positive={is_positive_cond}, has_negative={is_negative_cond}"
)
if entity_prompt_emb is not None and entity_masks is not None and entity_prompt_emb_mask is not None and is_positive_cond:
# EliGen path
height = int(orig_shape[-2] * self.LATENT_TO_PIXEL_RATIO)
width = int(orig_shape[-1] * self.LATENT_TO_PIXEL_RATIO)
encoder_hidden_states, image_rotary_emb, eligen_attention_mask = self.process_entity_masks(
latents=x,
prompt_emb=encoder_hidden_states,
prompt_emb_mask=encoder_hidden_states_mask,
entity_prompt_emb=entity_prompt_emb,
entity_prompt_emb_mask=entity_prompt_emb_mask,
entity_masks=entity_masks,
height=height,
width=width,
image=hidden_states,
cond_or_uncond=cond_or_uncond,
batch_size=batch_size
)
hidden_states = self.img_in(hidden_states)
del img_ids
else:
# Standard path
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
del ids, txt_ids, img_ids
hidden_states = self.img_in(hidden_states)
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)
if guidance is not None: if guidance is not None:
guidance = guidance * 1000 guidance = guidance * 1000
@ -446,9 +704,25 @@ class QwenImageTransformer2DModel(nn.Module):
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"]) out["txt"], out["img"] = block(
hidden_states=args["img"],
encoder_hidden_states=args["txt"],
encoder_hidden_states_mask=args.get("encoder_hidden_states_mask"),
temb=args["vec"],
image_rotary_emb=args["pe"],
attention_mask=args.get("attention_mask"),
transformer_options=args["transformer_options"]
)
return out return out
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap}) out = blocks_replace[("double_block", i)]({
"img": hidden_states,
"txt": encoder_hidden_states,
"encoder_hidden_states_mask": encoder_hidden_states_mask,
"attention_mask": eligen_attention_mask,
"vec": temb,
"pe": image_rotary_emb,
"transformer_options": transformer_options
}, {"original_block": block_wrap})
hidden_states = out["img"] hidden_states = out["img"]
encoder_hidden_states = out["txt"] encoder_hidden_states = out["txt"]
else: else:
@ -458,6 +732,7 @@ class QwenImageTransformer2DModel(nn.Module):
encoder_hidden_states_mask=encoder_hidden_states_mask, encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb, temb=temb,
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
attention_mask=eligen_attention_mask,
transformer_options=transformer_options, transformer_options=transformer_options,
) )

View File

@ -1484,6 +1484,19 @@ class QwenImage(BaseModel):
ref_latents_method = kwargs.get("reference_latents_method", None) ref_latents_method = kwargs.get("reference_latents_method", None)
if ref_latents_method is not None: if ref_latents_method is not None:
out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method) out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method)
# Handle EliGen entity data
entity_prompt_emb = kwargs.get("entity_prompt_emb", None)
if entity_prompt_emb is not None:
out['entity_prompt_emb'] = comfy.conds.CONDList(entity_prompt_emb)
entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None)
if entity_prompt_emb_mask is not None:
out['entity_prompt_emb_mask'] = comfy.conds.CONDList(entity_prompt_emb_mask)
entity_masks = kwargs.get("entity_masks", None)
if entity_masks is not None:
out['entity_masks'] = comfy.conds.CONDRegular(entity_masks)
return out return out
def extra_conds_shapes(self, **kwargs): def extra_conds_shapes(self, **kwargs):

View File

@ -1,6 +1,10 @@
import node_helpers import node_helpers
import comfy.utils import comfy.utils
import comfy.conds
import math import math
import torch
import logging
from typing import Optional
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, io from comfy_api.latest import ComfyExtension, io
@ -103,6 +107,281 @@ class TextEncodeQwenImageEditPlus(io.ComfyNode):
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True) conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True)
return io.NodeOutput(conditioning) return io.NodeOutput(conditioning)
class TextEncodeQwenImageEliGen(io.ComfyNode):
"""
Entity-Level Image Generation (EliGen) conditioning node for Qwen Image model.
Allows specifying different prompts for different spatial regions using masks.
Each entity (mask + prompt pair) will only influence its masked region through
spatial attention masking.
Features:
- Supports up to 8 entities per generation
- Spatial attention masks prevent cross-entity contamination
- Separate RoPE embeddings per entity (research-accurate)
- Falls back to standard generation if no entities provided
Usage:
1. Create spatial masks using LoadImageMask (white=entity, black=background)
2. Use 'red', 'green', or 'blue' channel (NOT 'alpha' - it gets inverted)
3. Provide entity-specific prompts for each masked region
Based on DiffSynth Studio: https://github.com/modelscope/DiffSynth-Studio
"""
# Qwen Image model uses 2x2 patches on latents (which are 8x downsampled from pixels)
PATCH_SIZE = 2
@classmethod
def define_schema(cls):
return io.Schema(
node_id="TextEncodeQwenImageEliGen",
category="advanced/conditioning",
inputs=[
io.Clip.Input("clip"),
io.Conditioning.Input("global_conditioning"),
io.Latent.Input("latent"),
io.Mask.Input("entity_mask_1", optional=True),
io.String.Input("entity_prompt_1", multiline=True, dynamic_prompts=True, default=""),
io.Mask.Input("entity_mask_2", optional=True),
io.String.Input("entity_prompt_2", multiline=True, dynamic_prompts=True, default=""),
io.Mask.Input("entity_mask_3", optional=True),
io.String.Input("entity_prompt_3", multiline=True, dynamic_prompts=True, default=""),
io.Mask.Input("entity_mask_4", optional=True),
io.String.Input("entity_prompt_4", multiline=True, dynamic_prompts=True, default=""),
io.Mask.Input("entity_mask_5", optional=True),
io.String.Input("entity_prompt_5", multiline=True, dynamic_prompts=True, default=""),
io.Mask.Input("entity_mask_6", optional=True),
io.String.Input("entity_prompt_6", multiline=True, dynamic_prompts=True, default=""),
io.Mask.Input("entity_mask_7", optional=True),
io.String.Input("entity_prompt_7", multiline=True, dynamic_prompts=True, default=""),
io.Mask.Input("entity_mask_8", optional=True),
io.String.Input("entity_prompt_8", multiline=True, dynamic_prompts=True, default=""),
],
outputs=[
io.Conditioning.Output(),
],
)
@classmethod
def execute(
cls,
clip,
global_conditioning,
latent,
entity_prompt_1: str = "",
entity_mask_1: Optional[torch.Tensor] = None,
entity_prompt_2: str = "",
entity_mask_2: Optional[torch.Tensor] = None,
entity_prompt_3: str = "",
entity_mask_3: Optional[torch.Tensor] = None,
entity_prompt_4: str = "",
entity_mask_4: Optional[torch.Tensor] = None,
entity_prompt_5: str = "",
entity_mask_5: Optional[torch.Tensor] = None,
entity_prompt_6: str = "",
entity_mask_6: Optional[torch.Tensor] = None,
entity_prompt_7: str = "",
entity_mask_7: Optional[torch.Tensor] = None,
entity_prompt_8: str = "",
entity_mask_8: Optional[torch.Tensor] = None
) -> io.NodeOutput:
# Extract dimensions from latent tensor
# latent["samples"] shape: [batch, channels, latent_h, latent_w]
latent_samples = latent["samples"]
unpadded_latent_height = latent_samples.shape[2] # Unpadded latent space
unpadded_latent_width = latent_samples.shape[3] # Unpadded latent space
# Calculate padded dimensions (same logic as model's pad_to_patch_size with patch_size=2)
# The model pads latents to be multiples of PATCH_SIZE
pad_h = (cls.PATCH_SIZE - unpadded_latent_height % cls.PATCH_SIZE) % cls.PATCH_SIZE
pad_w = (cls.PATCH_SIZE - unpadded_latent_width % cls.PATCH_SIZE) % cls.PATCH_SIZE
latent_height = unpadded_latent_height + pad_h # Padded latent dimensions
latent_width = unpadded_latent_width + pad_w # Padded latent dimensions
height = latent_height * 8 # Convert to pixel space for logging
width = latent_width * 8
if pad_h > 0 or pad_w > 0:
logging.debug(f"[EliGen] Latent padding detected: {unpadded_latent_height}x{unpadded_latent_width}{latent_height}x{latent_width}")
logging.debug(f"[EliGen] Target generation dimensions: {height}x{width} pixels ({latent_height}x{latent_width} latent)")
# Collect entity prompts and masks
entity_prompts = [entity_prompt_1, entity_prompt_2, entity_prompt_3, entity_prompt_4, entity_prompt_5, entity_prompt_6, entity_prompt_7, entity_prompt_8]
entity_masks_raw = [entity_mask_1, entity_mask_2, entity_mask_3, entity_mask_4, entity_mask_5, entity_mask_6, entity_mask_7, entity_mask_8]
# Filter out entities with empty prompts or missing masks
valid_entities = []
for prompt, mask in zip(entity_prompts, entity_masks_raw):
if prompt.strip() and mask is not None:
valid_entities.append((prompt, mask))
# Log warning if some entities were skipped
total_prompts_provided = len([p for p in entity_prompts if p.strip()])
if len(valid_entities) < total_prompts_provided:
logging.warning(f"[EliGen] Only {len(valid_entities)} of {total_prompts_provided} entity prompts have valid masks")
# If no valid entities, return standard conditioning
if len(valid_entities) == 0:
return io.NodeOutput(global_conditioning)
# Encode each entity prompt separately
entity_prompt_emb_list = []
entity_prompt_emb_mask_list = []
for entity_prompt, _ in valid_entities: # mask not used at this point
entity_tokens = clip.tokenize(entity_prompt)
entity_cond_dict = clip.encode_from_tokens(entity_tokens, return_pooled=True, return_dict=True)
entity_prompt_emb = entity_cond_dict["cond"]
entity_prompt_emb_mask = entity_cond_dict.get("attention_mask", None)
# If no attention mask in extra_dict, create one (all True)
if entity_prompt_emb_mask is None:
seq_len = entity_prompt_emb.shape[1]
entity_prompt_emb_mask = torch.ones((entity_prompt_emb.shape[0], seq_len),
dtype=torch.bool, device=entity_prompt_emb.device)
entity_prompt_emb_list.append(entity_prompt_emb)
entity_prompt_emb_mask_list.append(entity_prompt_emb_mask)
# Process spatial masks to latent space
processed_entity_masks = []
for i, (_, mask) in enumerate(valid_entities):
# MASK type format: [batch, height, width] (no channel dimension)
# This is different from IMAGE type which is [batch, height, width, channels]
mask_tensor = mask
# Validate mask dtype
if mask_tensor.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
raise TypeError(
f"Entity {i+1} mask has invalid dtype {mask_tensor.dtype}. "
f"Expected float32, float16, or bfloat16. "
f"Ensure you're using LoadImageMask node, not LoadImage."
)
# Log original mask statistics
logging.debug(
f"[EliGen] Entity {i+1} input mask: shape={mask_tensor.shape}, "
f"dtype={mask_tensor.dtype}, min={mask_tensor.min():.4f}, max={mask_tensor.max():.4f}"
)
# Check for all-zero masks (common error when wrong channel selected)
if mask_tensor.max() == 0.0:
raise ValueError(
f"Entity {i+1} mask is all zeros! This usually means:\n"
f" 1. Wrong channel selected in LoadImageMask (use 'red', 'green', or 'blue', NOT 'alpha')\n"
f" 2. Your mask image is completely black\n"
f" 3. The mask file failed to load"
)
# Check for constant masks (no variation)
if mask_tensor.min() == mask_tensor.max() and mask_tensor.max() > 0:
logging.warning(
f"[EliGen] Entity {i+1} mask has no variation (all pixels = {mask_tensor.min():.4f}). "
f"This entity will affect the entire image."
)
# Extract original dimensions
original_shape = mask_tensor.shape
if len(original_shape) == 2:
# [height, width] - single mask without batch
orig_h, orig_w = original_shape[0], original_shape[1]
# Add batch dimension: [1, height, width]
mask_tensor = mask_tensor.unsqueeze(0)
elif len(original_shape) == 3:
# [batch, height, width] - standard MASK format
orig_h, orig_w = original_shape[1], original_shape[2]
else:
raise ValueError(
f"Entity {i+1} has unexpected mask shape: {original_shape}. "
f"Expected [H, W] or [B, H, W]. Got {len(original_shape)} dimensions."
)
# Log size mismatch if mask doesn't match expected latent dimensions
expected_h, expected_w = latent_height * 8, latent_width * 8
if orig_h != expected_h or orig_w != expected_w:
logging.info(
f"[EliGen] Entity {i+1} mask size mismatch: {orig_h}x{orig_w} vs expected {expected_h}x{expected_w}. "
f"Will resize to {latent_height}x{latent_width} latent space."
)
else:
logging.debug(f"[EliGen] Entity {i+1} mask: {orig_h}x{orig_w} → will resize to {latent_height}x{latent_width} latent")
# Convert MASK format [batch, height, width] to [batch, 1, height, width] for common_upscale
# common_upscale expects [batch, channels, height, width]
mask_tensor = mask_tensor.unsqueeze(1) # Add channel dimension: [batch, 1, height, width]
# Resize to latent space dimensions using nearest neighbor
resized_mask = comfy.utils.common_upscale(
mask_tensor,
latent_width,
latent_height,
upscale_method="nearest-exact",
crop="disabled"
)
# Threshold to binary (0 or 1)
# Use > 0 instead of > 0.5 to preserve edge pixels from nearest-neighbor downsampling
resized_mask = (resized_mask > 0).float()
# Log how many pixels are active in the mask
active_pixels = (resized_mask > 0).sum().item()
total_pixels = resized_mask.numel()
coverage_pct = 100 * active_pixels / total_pixels if total_pixels > 0 else 0
if active_pixels == 0:
raise ValueError(
f"Entity {i+1} mask has no active pixels after resizing to latent space! "
f"Original mask may have been too small or all black."
)
logging.debug(
f"[EliGen] Entity {i+1} mask coverage: {active_pixels}/{total_pixels} pixels ({coverage_pct:.1f}%)"
)
processed_entity_masks.append(resized_mask)
# Stack masks: [batch, num_entities, 1, latent_height, latent_width]
# Each item in processed_entity_masks has shape [1, 1, H, W] (batch=1, channel=1)
# We need to remove batch dim, stack, then add it back
processed_entity_masks_no_batch = [m.squeeze(0) for m in processed_entity_masks] # Each: [1, H, W]
entity_masks_tensor = torch.stack(processed_entity_masks_no_batch, dim=0) # [num_entities, 1, H, W]
entity_masks_tensor = entity_masks_tensor.unsqueeze(0) # [1, num_entities, 1, H, W]
logging.debug(
f"[EliGen] Stacked {len(valid_entities)} entity masks into tensor: "
f"shape={entity_masks_tensor.shape} (expected: [1, {len(valid_entities)}, 1, {latent_height}, {latent_width}])"
)
# Extract global prompt embedding and mask from conditioning
# Conditioning format: [[cond_tensor, extra_dict]]
global_prompt_emb = global_conditioning[0][0] # The embedding tensor directly
global_extra_dict = global_conditioning[0][1] # Metadata dict
global_prompt_emb_mask = global_extra_dict.get("attention_mask", None)
# If no attention mask, create one (all True)
if global_prompt_emb_mask is None:
global_prompt_emb_mask = torch.ones((global_prompt_emb.shape[0], global_prompt_emb.shape[1]),
dtype=torch.bool, device=global_prompt_emb.device)
# Attach entity data to conditioning using conditioning_set_values
entity_data = {
"entity_prompt_emb": entity_prompt_emb_list,
"entity_prompt_emb_mask": entity_prompt_emb_mask_list,
"entity_masks": entity_masks_tensor,
}
conditioning_with_entities = node_helpers.conditioning_set_values(
global_conditioning,
entity_data,
append=True
)
return io.NodeOutput(conditioning_with_entities)
class QwenExtension(ComfyExtension): class QwenExtension(ComfyExtension):
@override @override
@ -110,6 +389,7 @@ class QwenExtension(ComfyExtension):
return [ return [
TextEncodeQwenImageEdit, TextEncodeQwenImageEdit,
TextEncodeQwenImageEditPlus, TextEncodeQwenImageEditPlus,
TextEncodeQwenImageEliGen,
] ]