mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-17 01:54:36 +08:00
Support Z Image alibaba pai fun controlnets. (#11062)
These are not actual controlnets so put it in the models/model_patches folder and use the ModelPatchLoader + QwenImageDiffsynthControlnet node to use it.
This commit is contained in:
parent
277237ccc1
commit
b94d394a64
113
comfy/ldm/lumina/controlnet.py
Normal file
113
comfy/ldm/lumina/controlnet.py
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from .model import JointTransformerBlock
|
||||||
|
|
||||||
|
class ZImageControlTransformerBlock(JointTransformerBlock):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_id: int,
|
||||||
|
dim: int,
|
||||||
|
n_heads: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
multiple_of: int,
|
||||||
|
ffn_dim_multiplier: float,
|
||||||
|
norm_eps: float,
|
||||||
|
qk_norm: bool,
|
||||||
|
modulation=True,
|
||||||
|
block_id=0,
|
||||||
|
operation_settings=None,
|
||||||
|
):
|
||||||
|
super().__init__(layer_id, dim, n_heads, n_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, qk_norm, modulation, z_image_modulation=True, operation_settings=operation_settings)
|
||||||
|
self.block_id = block_id
|
||||||
|
if block_id == 0:
|
||||||
|
self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
|
||||||
|
def forward(self, c, x, **kwargs):
|
||||||
|
if self.block_id == 0:
|
||||||
|
c = self.before_proj(c) + x
|
||||||
|
c = super().forward(c, **kwargs)
|
||||||
|
c_skip = self.after_proj(c)
|
||||||
|
return c_skip, c
|
||||||
|
|
||||||
|
class ZImage_Control(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int = 3840,
|
||||||
|
n_heads: int = 30,
|
||||||
|
n_kv_heads: int = 30,
|
||||||
|
multiple_of: int = 256,
|
||||||
|
ffn_dim_multiplier: float = (8.0 / 3.0),
|
||||||
|
norm_eps: float = 1e-5,
|
||||||
|
qk_norm: bool = True,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||||
|
|
||||||
|
self.additional_in_dim = 0
|
||||||
|
self.control_in_dim = 16
|
||||||
|
n_refiner_layers = 2
|
||||||
|
self.n_control_layers = 6
|
||||||
|
self.control_layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ZImageControlTransformerBlock(
|
||||||
|
i,
|
||||||
|
dim,
|
||||||
|
n_heads,
|
||||||
|
n_kv_heads,
|
||||||
|
multiple_of,
|
||||||
|
ffn_dim_multiplier,
|
||||||
|
norm_eps,
|
||||||
|
qk_norm,
|
||||||
|
block_id=i,
|
||||||
|
operation_settings=operation_settings,
|
||||||
|
)
|
||||||
|
for i in range(self.n_control_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
all_x_embedder = {}
|
||||||
|
patch_size = 2
|
||||||
|
f_patch_size = 1
|
||||||
|
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True, device=device, dtype=dtype)
|
||||||
|
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
|
||||||
|
|
||||||
|
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
|
||||||
|
self.control_noise_refiner = nn.ModuleList(
|
||||||
|
[
|
||||||
|
JointTransformerBlock(
|
||||||
|
layer_id,
|
||||||
|
dim,
|
||||||
|
n_heads,
|
||||||
|
n_kv_heads,
|
||||||
|
multiple_of,
|
||||||
|
ffn_dim_multiplier,
|
||||||
|
norm_eps,
|
||||||
|
qk_norm,
|
||||||
|
modulation=True,
|
||||||
|
z_image_modulation=True,
|
||||||
|
operation_settings=operation_settings,
|
||||||
|
)
|
||||||
|
for layer_id in range(n_refiner_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
|
||||||
|
patch_size = 2
|
||||||
|
f_patch_size = 1
|
||||||
|
pH = pW = patch_size
|
||||||
|
B, C, H, W = control_context.shape
|
||||||
|
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
|
||||||
|
|
||||||
|
x_attn_mask = None
|
||||||
|
for layer in self.control_noise_refiner:
|
||||||
|
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
|
||||||
|
return control_context
|
||||||
|
|
||||||
|
def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
|
||||||
|
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||||
@ -568,7 +568,7 @@ class NextDiT(nn.Module):
|
|||||||
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
|
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
|
||||||
|
|
||||||
# def forward(self, x, t, cap_feats, cap_mask):
|
# def forward(self, x, t, cap_feats, cap_mask):
|
||||||
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs):
|
||||||
t = 1.0 - timesteps
|
t = 1.0 - timesteps
|
||||||
cap_feats = context
|
cap_feats = context
|
||||||
cap_mask = attention_mask
|
cap_mask = attention_mask
|
||||||
@ -585,16 +585,24 @@ class NextDiT(nn.Module):
|
|||||||
|
|
||||||
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
||||||
|
|
||||||
|
patches = transformer_options.get("patches", {})
|
||||||
transformer_options = kwargs.get("transformer_options", {})
|
transformer_options = kwargs.get("transformer_options", {})
|
||||||
x_is_tensor = isinstance(x, torch.Tensor)
|
x_is_tensor = isinstance(x, torch.Tensor)
|
||||||
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
|
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
|
||||||
freqs_cis = freqs_cis.to(x.device)
|
freqs_cis = freqs_cis.to(img.device)
|
||||||
|
|
||||||
for layer in self.layers:
|
for i, layer in enumerate(self.layers):
|
||||||
x = layer(x, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
||||||
|
if "double_block" in patches:
|
||||||
|
for p in patches["double_block"]:
|
||||||
|
out = p({"img": img[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
|
||||||
|
if "img" in out:
|
||||||
|
img[:, cap_size[0]:] = out["img"]
|
||||||
|
if "txt" in out:
|
||||||
|
img[:, :cap_size[0]] = out["txt"]
|
||||||
|
|
||||||
x = self.final_layer(x, adaln_input)
|
img = self.final_layer(img, adaln_input)
|
||||||
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]
|
img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
|
||||||
|
|
||||||
return -x
|
return -img
|
||||||
|
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import comfy.ops
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
import comfy.latent_formats
|
import comfy.latent_formats
|
||||||
|
import comfy.ldm.lumina.controlnet
|
||||||
|
|
||||||
|
|
||||||
class BlockWiseControlBlock(torch.nn.Module):
|
class BlockWiseControlBlock(torch.nn.Module):
|
||||||
@ -189,6 +190,35 @@ class SigLIPMultiFeatProjModel(torch.nn.Module):
|
|||||||
|
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
|
def z_image_convert(sd):
|
||||||
|
replace_keys = {".attention.to_out.0.bias": ".attention.out.bias",
|
||||||
|
".attention.norm_k.weight": ".attention.k_norm.weight",
|
||||||
|
".attention.norm_q.weight": ".attention.q_norm.weight",
|
||||||
|
".attention.to_out.0.weight": ".attention.out.weight"
|
||||||
|
}
|
||||||
|
|
||||||
|
out_sd = {}
|
||||||
|
for k in sorted(sd.keys()):
|
||||||
|
w = sd[k]
|
||||||
|
|
||||||
|
k_out = k
|
||||||
|
if k_out.endswith(".attention.to_k.weight"):
|
||||||
|
cc = [w]
|
||||||
|
continue
|
||||||
|
if k_out.endswith(".attention.to_q.weight"):
|
||||||
|
cc = [w] + cc
|
||||||
|
continue
|
||||||
|
if k_out.endswith(".attention.to_v.weight"):
|
||||||
|
cc = cc + [w]
|
||||||
|
w = torch.cat(cc, dim=0)
|
||||||
|
k_out = k_out.replace(".attention.to_v.weight", ".attention.qkv.weight")
|
||||||
|
|
||||||
|
for r, rr in replace_keys.items():
|
||||||
|
k_out = k_out.replace(r, rr)
|
||||||
|
out_sd[k_out] = w
|
||||||
|
|
||||||
|
return out_sd
|
||||||
|
|
||||||
class ModelPatchLoader:
|
class ModelPatchLoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -211,6 +241,9 @@ class ModelPatchLoader:
|
|||||||
elif 'feature_embedder.mid_layer_norm.bias' in sd:
|
elif 'feature_embedder.mid_layer_norm.bias' in sd:
|
||||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"feature_embedder.": ""}, filter_keys=True)
|
sd = comfy.utils.state_dict_prefix_replace(sd, {"feature_embedder.": ""}, filter_keys=True)
|
||||||
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
||||||
|
elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet
|
||||||
|
sd = z_image_convert(sd)
|
||||||
|
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
||||||
|
|
||||||
model.load_state_dict(sd)
|
model.load_state_dict(sd)
|
||||||
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||||
@ -263,6 +296,69 @@ class DiffSynthCnetPatch:
|
|||||||
def models(self):
|
def models(self):
|
||||||
return [self.model_patch]
|
return [self.model_patch]
|
||||||
|
|
||||||
|
class ZImageControlPatch:
|
||||||
|
def __init__(self, model_patch, vae, image, strength):
|
||||||
|
self.model_patch = model_patch
|
||||||
|
self.vae = vae
|
||||||
|
self.image = image
|
||||||
|
self.strength = strength
|
||||||
|
self.encoded_image = self.encode_latent_cond(image)
|
||||||
|
self.encoded_image_size = (image.shape[1], image.shape[2])
|
||||||
|
self.temp_data = None
|
||||||
|
|
||||||
|
def encode_latent_cond(self, image):
|
||||||
|
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(image))
|
||||||
|
return latent_image
|
||||||
|
|
||||||
|
def __call__(self, kwargs):
|
||||||
|
x = kwargs.get("x")
|
||||||
|
img = kwargs.get("img")
|
||||||
|
txt = kwargs.get("txt")
|
||||||
|
pe = kwargs.get("pe")
|
||||||
|
vec = kwargs.get("vec")
|
||||||
|
block_index = kwargs.get("block_index")
|
||||||
|
spacial_compression = self.vae.spacial_compression_encode()
|
||||||
|
if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
|
||||||
|
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
|
||||||
|
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||||
|
self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1))
|
||||||
|
self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1])
|
||||||
|
comfy.model_management.load_models_gpu(loaded_models)
|
||||||
|
|
||||||
|
cnet_index = (block_index // 5)
|
||||||
|
cnet_index_float = (block_index / 5)
|
||||||
|
|
||||||
|
kwargs.pop("img") # we do ops in place
|
||||||
|
kwargs.pop("txt")
|
||||||
|
|
||||||
|
cnet_blocks = self.model_patch.model.n_control_layers
|
||||||
|
if cnet_index_float > (cnet_blocks - 1):
|
||||||
|
self.temp_data = None
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
if self.temp_data is None or self.temp_data[0] > cnet_index:
|
||||||
|
self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
|
||||||
|
|
||||||
|
while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks:
|
||||||
|
next_layer = self.temp_data[0] + 1
|
||||||
|
self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
|
||||||
|
|
||||||
|
if cnet_index_float == self.temp_data[0]:
|
||||||
|
img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
|
||||||
|
if cnet_blocks == self.temp_data[0] + 1:
|
||||||
|
self.temp_data = None
|
||||||
|
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
def to(self, device_or_dtype):
|
||||||
|
if isinstance(device_or_dtype, torch.device):
|
||||||
|
self.encoded_image = self.encoded_image.to(device_or_dtype)
|
||||||
|
self.temp_data = None
|
||||||
|
return self
|
||||||
|
|
||||||
|
def models(self):
|
||||||
|
return [self.model_patch]
|
||||||
|
|
||||||
class QwenImageDiffsynthControlnet:
|
class QwenImageDiffsynthControlnet:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -289,7 +385,10 @@ class QwenImageDiffsynthControlnet:
|
|||||||
mask = mask.unsqueeze(2)
|
mask = mask.unsqueeze(2)
|
||||||
mask = 1.0 - mask
|
mask = 1.0 - mask
|
||||||
|
|
||||||
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
|
if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control):
|
||||||
|
model_patched.set_model_double_block_patch(ZImageControlPatch(model_patch, vae, image, strength))
|
||||||
|
else:
|
||||||
|
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
|
||||||
return (model_patched,)
|
return (model_patched,)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user