mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-09 05:54:24 +08:00
Kandinsky5 model support (#10988)
* Add Kandinsky5 model support lite and pro T2V tested to work * Update kandinsky5.py * Fix fp8 * Fix fp8_scaled text encoder * Add transformer_options for attention * Code cleanup, optimizations, use fp32 for all layers originally at fp32 * ImageToVideo -node * Fix I2V, add necessary latent post process nodes * Support text to image model * Support block replace patches (SLG mostly) * Support official LoRAs * Don't scale RoPE for lite model as that just doesn't work... * Update supported_models.py * Rever RoPE scaling to simpler one * Fix typo * Handle latent dim difference for image model in the VAE instead * Add node to use different prompts for clip_l and qwen25_7b * Reduce peak VRAM usage a bit * Further reduce peak VRAM consumption by chunking ffn * Update chunking * Update memory_usage_factor * Code cleanup, don't force the fp32 layers as it has minimal effect * Allow for stronger changes with first frames normalization Default values are too weak for any meaningful changes, these should probably be exposed as advanced node options when that's available. * Add image model's own chat template, remove unused image2video template * Remove hard error in ReplaceVideoLatentFrames -node * Update kandinsky5.py * Update supported_models.py * Fix typos in prompt template They were now fixed in the original repository as well * Update ReplaceVideoLatentFrames Add tooltips Make source optional Better handle negative index * Rename NormalizeVideoLatentFrames -node For bit better clarity what it does * Fix NormalizeVideoLatentStart node out on non-op
This commit is contained in:
parent
bed12674a1
commit
fd109325db
407
comfy/ldm/kandinsky5/model.py
Normal file
407
comfy/ldm/kandinsky5/model.py
Normal file
@ -0,0 +1,407 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import math
|
||||
|
||||
import comfy.ldm.common_dit
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
|
||||
def attention(q, k, v, heads, transformer_options={}):
|
||||
return optimized_attention(
|
||||
q.transpose(1, 2),
|
||||
k.transpose(1, 2),
|
||||
v.transpose(1, 2),
|
||||
heads=heads,
|
||||
skip_reshape=True,
|
||||
transformer_options=transformer_options
|
||||
)
|
||||
|
||||
def apply_scale_shift_norm(norm, x, scale, shift):
|
||||
return torch.addcmul(shift, norm(x), scale + 1.0)
|
||||
|
||||
def apply_gate_sum(x, out, gate):
|
||||
return torch.addcmul(x, gate, out)
|
||||
|
||||
def get_shift_scale_gate(params):
|
||||
shift, scale, gate = torch.chunk(params, 3, dim=-1)
|
||||
return tuple(x.unsqueeze(1) for x in (shift, scale, gate))
|
||||
|
||||
def get_freqs(dim, max_period=10000.0):
|
||||
return torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim)
|
||||
|
||||
|
||||
class TimeEmbeddings(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, max_period=10000.0, operation_settings=None):
|
||||
super().__init__()
|
||||
assert model_dim % 2 == 0
|
||||
self.model_dim = model_dim
|
||||
self.max_period = max_period
|
||||
self.register_buffer("freqs", get_freqs(model_dim // 2, max_period), persistent=False)
|
||||
operations = operation_settings.get("operations")
|
||||
self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.activation = nn.SiLU()
|
||||
self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, timestep, dtype):
|
||||
args = torch.outer(timestep, self.freqs.to(device=timestep.device))
|
||||
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype)
|
||||
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
|
||||
return time_embed
|
||||
|
||||
|
||||
class TextEmbeddings(nn.Module):
|
||||
def __init__(self, text_dim, model_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
operations = operation_settings.get("operations")
|
||||
self.in_layer = operations.Linear(text_dim, model_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.norm = operations.LayerNorm(model_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, text_embed):
|
||||
text_embed = self.in_layer(text_embed)
|
||||
return self.norm(text_embed).type_as(text_embed)
|
||||
|
||||
|
||||
class VisualEmbeddings(nn.Module):
|
||||
def __init__(self, visual_dim, model_dim, patch_size, operation_settings=None):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
operations = operation_settings.get("operations")
|
||||
self.in_layer = operations.Linear(visual_dim, model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, x):
|
||||
x = x.movedim(1, -1) # B C T H W -> B T H W C
|
||||
B, T, H, W, dim = x.shape
|
||||
pt, ph, pw = self.patch_size
|
||||
|
||||
x = x.view(
|
||||
B,
|
||||
T // pt, pt,
|
||||
H // ph, ph,
|
||||
W // pw, pw,
|
||||
dim,
|
||||
).permute(0, 1, 3, 5, 2, 4, 6, 7).flatten(4, 7)
|
||||
|
||||
return self.in_layer(x)
|
||||
|
||||
|
||||
class Modulation(nn.Module):
|
||||
def __init__(self, time_dim, model_dim, num_params, operation_settings=None):
|
||||
super().__init__()
|
||||
self.activation = nn.SiLU()
|
||||
self.out_layer = operation_settings.get("operations").Linear(time_dim, num_params * model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, x):
|
||||
return self.out_layer(self.activation(x))
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, num_channels, head_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
assert num_channels % head_dim == 0
|
||||
self.num_heads = num_channels // head_dim
|
||||
self.head_dim = head_dim
|
||||
|
||||
operations = operation_settings.get("operations")
|
||||
self.to_query = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.to_key = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.to_value = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.query_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.key_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.num_chunks = 2
|
||||
|
||||
def _compute_qk(self, x, freqs, proj_fn, norm_fn):
|
||||
result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
return apply_rope1(norm_fn(result), freqs)
|
||||
|
||||
def _forward(self, x, freqs, transformer_options={}):
|
||||
q = self._compute_qk(x, freqs, self.to_query, self.query_norm)
|
||||
k = self._compute_qk(x, freqs, self.to_key, self.key_norm)
|
||||
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
|
||||
return self.out_layer(out)
|
||||
|
||||
def _forward_chunked(self, x, freqs, transformer_options={}):
|
||||
def process_chunks(proj_fn, norm_fn):
|
||||
x_chunks = torch.chunk(x, self.num_chunks, dim=1)
|
||||
freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1)
|
||||
chunks = []
|
||||
for x_chunk, freqs_chunk in zip(x_chunks, freqs_chunks):
|
||||
chunks.append(self._compute_qk(x_chunk, freqs_chunk, proj_fn, norm_fn))
|
||||
return torch.cat(chunks, dim=1)
|
||||
|
||||
q = process_chunks(self.to_query, self.query_norm)
|
||||
k = process_chunks(self.to_key, self.key_norm)
|
||||
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
|
||||
return self.out_layer(out)
|
||||
|
||||
def forward(self, x, freqs, transformer_options={}):
|
||||
if x.shape[1] > 8192:
|
||||
return self._forward_chunked(x, freqs, transformer_options=transformer_options)
|
||||
else:
|
||||
return self._forward(x, freqs, transformer_options=transformer_options)
|
||||
|
||||
|
||||
class CrossAttention(SelfAttention):
|
||||
def get_qkv(self, x, context):
|
||||
q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
k = self.to_key(context).view(*context.shape[:-1], self.num_heads, -1)
|
||||
v = self.to_value(context).view(*context.shape[:-1], self.num_heads, -1)
|
||||
return q, k, v
|
||||
|
||||
def forward(self, x, context, transformer_options={}):
|
||||
q, k, v = self.get_qkv(x, context)
|
||||
out = attention(self.query_norm(q), self.key_norm(k), v, self.num_heads, transformer_options=transformer_options)
|
||||
return self.out_layer(out)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, ff_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
operations = operation_settings.get("operations")
|
||||
self.in_layer = operations.Linear(dim, ff_dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.activation = nn.GELU()
|
||||
self.out_layer = operations.Linear(ff_dim, dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.num_chunks = 4
|
||||
|
||||
def _forward(self, x):
|
||||
return self.out_layer(self.activation(self.in_layer(x)))
|
||||
|
||||
def _forward_chunked(self, x):
|
||||
chunks = torch.chunk(x, self.num_chunks, dim=1)
|
||||
output_chunks = []
|
||||
for chunk in chunks:
|
||||
output_chunks.append(self._forward(chunk))
|
||||
return torch.cat(output_chunks, dim=1)
|
||||
|
||||
def forward(self, x):
|
||||
if x.shape[1] > 8192:
|
||||
return self._forward_chunked(x)
|
||||
else:
|
||||
return self._forward(x)
|
||||
|
||||
|
||||
class OutLayer(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, visual_dim, patch_size, operation_settings=None):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.modulation = Modulation(time_dim, model_dim, 2, operation_settings=operation_settings)
|
||||
operations = operation_settings.get("operations")
|
||||
self.norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.out_layer = operations.Linear(model_dim, math.prod(patch_size) * visual_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, visual_embed, time_embed):
|
||||
B, T, H, W, _ = visual_embed.shape
|
||||
shift, scale = torch.chunk(self.modulation(time_embed), 2, dim=-1)
|
||||
scale = scale[:, None, None, None, :]
|
||||
shift = shift[:, None, None, None, :]
|
||||
visual_embed = apply_scale_shift_norm(self.norm, visual_embed, scale, shift)
|
||||
x = self.out_layer(visual_embed)
|
||||
|
||||
out_dim = x.shape[-1] // (self.patch_size[0] * self.patch_size[1] * self.patch_size[2])
|
||||
x = x.view(
|
||||
B, T, H, W,
|
||||
out_dim,
|
||||
self.patch_size[0], self.patch_size[1], self.patch_size[2]
|
||||
)
|
||||
return x.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(2, 3).flatten(3, 4).flatten(4, 5)
|
||||
|
||||
|
||||
class TransformerEncoderBlock(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
self.text_modulation = Modulation(time_dim, model_dim, 6, operation_settings=operation_settings)
|
||||
operations = operation_settings.get("operations")
|
||||
|
||||
self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings)
|
||||
|
||||
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
|
||||
|
||||
def forward(self, x, time_embed, freqs, transformer_options={}):
|
||||
self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1)
|
||||
shift, scale, gate = get_shift_scale_gate(self_attn_params)
|
||||
out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift)
|
||||
out = self.self_attention(out, freqs, transformer_options=transformer_options)
|
||||
x = apply_gate_sum(x, out, gate)
|
||||
|
||||
shift, scale, gate = get_shift_scale_gate(ff_params)
|
||||
out = apply_scale_shift_norm(self.feed_forward_norm, x, scale, shift)
|
||||
out = self.feed_forward(out)
|
||||
x = apply_gate_sum(x, out, gate)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerDecoderBlock(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
self.visual_modulation = Modulation(time_dim, model_dim, 9, operation_settings=operation_settings)
|
||||
|
||||
operations = operation_settings.get("operations")
|
||||
self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings)
|
||||
|
||||
self.cross_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.cross_attention = CrossAttention(model_dim, head_dim, operation_settings=operation_settings)
|
||||
|
||||
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
|
||||
|
||||
def forward(self, visual_embed, text_embed, time_embed, freqs, transformer_options={}):
|
||||
self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1)
|
||||
# self attention
|
||||
shift, scale, gate = get_shift_scale_gate(self_attn_params)
|
||||
visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift)
|
||||
visual_out = self.self_attention(visual_out, freqs, transformer_options=transformer_options)
|
||||
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
# cross attention
|
||||
shift, scale, gate = get_shift_scale_gate(cross_attn_params)
|
||||
visual_out = apply_scale_shift_norm(self.cross_attention_norm, visual_embed, scale, shift)
|
||||
visual_out = self.cross_attention(visual_out, text_embed, transformer_options=transformer_options)
|
||||
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
# feed forward
|
||||
shift, scale, gate = get_shift_scale_gate(ff_params)
|
||||
visual_out = apply_scale_shift_norm(self.feed_forward_norm, visual_embed, scale, shift)
|
||||
visual_out = self.feed_forward(visual_out)
|
||||
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
return visual_embed
|
||||
|
||||
|
||||
class Kandinsky5(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_visual_dim=16, out_visual_dim=16, in_text_dim=3584, in_text_dim2=768, time_dim=512,
|
||||
model_dim=1792, ff_dim=7168, visual_embed_dim=132, patch_size=(1, 2, 2), num_text_blocks=2, num_visual_blocks=32,
|
||||
axes_dims=(16, 24, 24), rope_scale_factor=(1.0, 2.0, 2.0),
|
||||
dtype=None, device=None, operations=None, **kwargs
|
||||
):
|
||||
super().__init__()
|
||||
head_dim = sum(axes_dims)
|
||||
self.rope_scale_factor = rope_scale_factor
|
||||
self.in_visual_dim = in_visual_dim
|
||||
self.model_dim = model_dim
|
||||
self.patch_size = patch_size
|
||||
self.visual_embed_dim = visual_embed_dim
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
self.time_embeddings = TimeEmbeddings(model_dim, time_dim, operation_settings=operation_settings)
|
||||
self.text_embeddings = TextEmbeddings(in_text_dim, model_dim, operation_settings=operation_settings)
|
||||
self.pooled_text_embeddings = TextEmbeddings(in_text_dim2, time_dim, operation_settings=operation_settings)
|
||||
self.visual_embeddings = VisualEmbeddings(visual_embed_dim, model_dim, patch_size, operation_settings=operation_settings)
|
||||
|
||||
self.text_transformer_blocks = nn.ModuleList(
|
||||
[TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_text_blocks)]
|
||||
)
|
||||
|
||||
self.visual_transformer_blocks = nn.ModuleList(
|
||||
[TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_visual_blocks)]
|
||||
)
|
||||
|
||||
self.out_layer = OutLayer(model_dim, time_dim, out_visual_dim, patch_size, operation_settings=operation_settings)
|
||||
|
||||
self.rope_embedder_3d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=axes_dims)
|
||||
self.rope_embedder_1d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=[head_dim])
|
||||
|
||||
def rope_encode_1d(self, seq_len, seq_start=0, steps=None, device=None, dtype=None, transformer_options={}):
|
||||
steps = seq_len if steps is None else steps
|
||||
seq_ids = torch.linspace(seq_start, seq_start + (seq_len - 1), steps=steps, device=device, dtype=dtype)
|
||||
seq_ids = seq_ids.reshape(-1, 1).unsqueeze(0) # Shape: (1, steps, 1)
|
||||
freqs = self.rope_embedder_1d(seq_ids).movedim(1, 2)
|
||||
return freqs
|
||||
|
||||
def rope_encode_3d(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
|
||||
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
||||
|
||||
if steps_t is None:
|
||||
steps_t = t_len
|
||||
if steps_h is None:
|
||||
steps_h = h_len
|
||||
if steps_w is None:
|
||||
steps_w = w_len
|
||||
|
||||
h_start = 0
|
||||
w_start = 0
|
||||
rope_options = transformer_options.get("rope_options", None)
|
||||
if rope_options is not None:
|
||||
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
|
||||
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
|
||||
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
|
||||
|
||||
t_start += rope_options.get("shift_t", 0.0)
|
||||
h_start += rope_options.get("shift_y", 0.0)
|
||||
w_start += rope_options.get("shift_x", 0.0)
|
||||
else:
|
||||
rope_scale_factor = self.rope_scale_factor
|
||||
if self.model_dim == 4096: # pro video model uses different rope scaling at higher resolutions
|
||||
if h * w >= 14080:
|
||||
rope_scale_factor = (1.0, 3.16, 3.16)
|
||||
|
||||
t_len = (t_len - 1.0) / rope_scale_factor[0] + 1.0
|
||||
h_len = (h_len - 1.0) / rope_scale_factor[1] + 1.0
|
||||
w_len = (w_len - 1.0) / rope_scale_factor[2] + 1.0
|
||||
|
||||
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
|
||||
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
|
||||
|
||||
freqs = self.rope_embedder_3d(img_ids).movedim(1, 2)
|
||||
return freqs
|
||||
|
||||
def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_options={}, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
context = self.text_embeddings(context)
|
||||
time_embed = self.time_embeddings(timestep, x.dtype) + self.pooled_text_embeddings(y)
|
||||
|
||||
for block in self.text_transformer_blocks:
|
||||
context = block(context, time_embed, freqs_text, transformer_options=transformer_options)
|
||||
|
||||
visual_embed = self.visual_embeddings(x)
|
||||
visual_shape = visual_embed.shape[:-1]
|
||||
visual_embed = visual_embed.flatten(1, -2)
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.visual_transformer_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.visual_transformer_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options"))
|
||||
visual_embed = blocks_replace[("double_block", i)]({"x": visual_embed, "context": context, "time_embed": time_embed, "freqs": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})["x"]
|
||||
else:
|
||||
visual_embed = block(visual_embed, context, time_embed, freqs=freqs, transformer_options=transformer_options)
|
||||
|
||||
visual_embed = visual_embed.reshape(*visual_shape, -1)
|
||||
return self.out_layer(visual_embed, time_embed)
|
||||
|
||||
def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
|
||||
bs, c, t_len, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
|
||||
if time_dim_replace is not None:
|
||||
time_dim_replace = comfy.ldm.common_dit.pad_to_patch_size(time_dim_replace, self.patch_size)
|
||||
x[:, :time_dim_replace.shape[1], :time_dim_replace.shape[2]] = time_dim_replace
|
||||
|
||||
freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||
freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||
|
||||
return self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs)
|
||||
|
||||
def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, y, time_dim_replace=time_dim_replace, transformer_options=transformer_options, **kwargs)
|
||||
@ -322,6 +322,13 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map["diffusion_model.{}".format(key_lora)] = to
|
||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
|
||||
|
||||
if isinstance(model, comfy.model_base.Kandinsky5):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
key_map["{}".format(key_lora)] = k
|
||||
key_map["transformer.{}".format(key_lora)] = k
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
|
||||
@ -47,6 +47,7 @@ import comfy.ldm.chroma_radiance.model
|
||||
import comfy.ldm.ace.model
|
||||
import comfy.ldm.omnigen.omnigen2
|
||||
import comfy.ldm.qwen_image.model
|
||||
import comfy.ldm.kandinsky5.model
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@ -1630,3 +1631,49 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo15):
|
||||
out = super().extra_conds(**kwargs)
|
||||
out['disable_time_r'] = comfy.conds.CONDConstant(False)
|
||||
return out
|
||||
|
||||
class Kandinsky5(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.kandinsky5.model.Kandinsky5)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return kwargs["pooled_output"]
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
noise = kwargs.get("noise", None)
|
||||
device = kwargs["device"]
|
||||
image = torch.zeros_like(noise)
|
||||
|
||||
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
if mask is None:
|
||||
mask = torch.zeros_like(noise)[:, :1]
|
||||
else:
|
||||
mask = 1.0 - mask
|
||||
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
if mask.shape[-3] < noise.shape[-3]:
|
||||
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
|
||||
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||
|
||||
return torch.cat((image, mask), dim=1)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
time_dim_replace = kwargs.get("time_dim_replace", None)
|
||||
if time_dim_replace is not None:
|
||||
out['time_dim_replace'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_replace))
|
||||
|
||||
return out
|
||||
|
||||
class Kandinsky5Image(Kandinsky5):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
return None
|
||||
|
||||
@ -611,6 +611,24 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
return dit_config
|
||||
|
||||
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
|
||||
dit_config = {}
|
||||
model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
|
||||
dit_config["model_dim"] = model_dim
|
||||
if model_dim in [4096, 2560]: # pro video and lite image
|
||||
dit_config["axes_dims"] = (32, 48, 48)
|
||||
if model_dim == 2560: # lite image
|
||||
dit_config["rope_scale_factor"] = (1.0, 1.0, 1.0)
|
||||
elif model_dim == 1792: # lite video
|
||||
dit_config["axes_dims"] = (16, 24, 24)
|
||||
dit_config["time_dim"] = state_dict['{}time_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
|
||||
dit_config["image_model"] = "kandinsky5"
|
||||
dit_config["ff_dim"] = state_dict['{}visual_transformer_blocks.0.feed_forward.in_layer.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["visual_embed_dim"] = state_dict['{}visual_embeddings.in_layer.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["num_text_blocks"] = count_blocks(state_dict_keys, '{}text_transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
return None
|
||||
|
||||
|
||||
11
comfy/sd.py
11
comfy/sd.py
@ -54,6 +54,7 @@ import comfy.text_encoders.qwen_image
|
||||
import comfy.text_encoders.hunyuan_image
|
||||
import comfy.text_encoders.z_image
|
||||
import comfy.text_encoders.ovis
|
||||
import comfy.text_encoders.kandinsky5
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@ -766,6 +767,8 @@ class VAE:
|
||||
self.throw_exception_if_invalid()
|
||||
pixel_samples = None
|
||||
do_tile = False
|
||||
if self.latent_dim == 2 and samples_in.ndim == 5:
|
||||
samples_in = samples_in[:, :, 0]
|
||||
try:
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
@ -983,6 +986,8 @@ class CLIPType(Enum):
|
||||
HUNYUAN_IMAGE = 19
|
||||
HUNYUAN_VIDEO_15 = 20
|
||||
OVIS = 21
|
||||
KANDINSKY5 = 22
|
||||
KANDINSKY5_IMAGE = 23
|
||||
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
@ -1231,6 +1236,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
elif clip_type == CLIPType.HUNYUAN_VIDEO_15:
|
||||
clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer
|
||||
elif clip_type == CLIPType.KANDINSKY5:
|
||||
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5Tokenizer
|
||||
elif clip_type == CLIPType.KANDINSKY5_IMAGE:
|
||||
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
|
||||
@ -21,6 +21,7 @@ import comfy.text_encoders.ace
|
||||
import comfy.text_encoders.omnigen2
|
||||
import comfy.text_encoders.qwen_image
|
||||
import comfy.text_encoders.hunyuan_image
|
||||
import comfy.text_encoders.kandinsky5
|
||||
import comfy.text_encoders.z_image
|
||||
|
||||
from . import supported_models_base
|
||||
@ -1474,7 +1475,60 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo):
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2]
|
||||
|
||||
class Kandinsky5(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "kandinsky5",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 10.0,
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.HunyuanVideo
|
||||
|
||||
memory_usage_factor = 1.1 #TODO
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Kandinsky5(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5Tokenizer, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
|
||||
|
||||
|
||||
class Kandinsky5Image(Kandinsky5):
|
||||
unet_config = {
|
||||
"image_model": "kandinsky5",
|
||||
"model_dim": 2560,
|
||||
"visual_embed_dim": 64,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 3.0,
|
||||
}
|
||||
|
||||
latent_format = latent_formats.Flux
|
||||
memory_usage_factor = 1.1 #TODO
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Kandinsky5Image(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Kandinsky5Image, Kandinsky5]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
68
comfy/text_encoders/kandinsky5.py
Normal file
68
comfy/text_encoders/kandinsky5.py
Normal file
@ -0,0 +1,68 @@
|
||||
from comfy import sd1_clip
|
||||
from .qwen_image import QwenImageTokenizer, QwenImageTEModel
|
||||
from .llama import Qwen25_7BVLI
|
||||
|
||||
|
||||
class Kandinsky5Tokenizer(QwenImageTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.llama_template = "<|im_start|>system\nYou are a prompt engineer. Describe the video in detail.\nDescribe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.\nDescribe the location of the video, main characters or objects and their action.\nDescribe the dynamism of the video and presented actions.\nName the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or screen content.\nDescribe the visual effects, postprocessing and transitions if they are presented in the video.\nPay attention to the order of key actions shown in the scene.<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
|
||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = super().tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Kandinsky5TokenizerImage(Kandinsky5Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.llama_template = "<|im_start|>system\nYou are a promt engineer. Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
|
||||
|
||||
|
||||
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, attention_mask=True, model_options={}):
|
||||
llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None)
|
||||
if llama_scaled_fp8 is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
|
||||
class Kandinsky5TEModel(QwenImageTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super(QwenImageTEModel, self).__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
|
||||
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
cond, p, extra = super().encode_token_weights(token_weight_pairs, template_end=-1)
|
||||
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs["l"])
|
||||
|
||||
return cond, l_pooled, extra
|
||||
|
||||
def set_clip_options(self, options):
|
||||
super().set_clip_options(options)
|
||||
self.clip_l.set_clip_options(options)
|
||||
|
||||
def reset_clip_options(self):
|
||||
super().reset_clip_options()
|
||||
self.clip_l.reset_clip_options()
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
|
||||
return self.clip_l.load_sd(sd)
|
||||
else:
|
||||
return super().load_sd(sd)
|
||||
|
||||
def te(dtype_llama=None, llama_scaled_fp8=None):
|
||||
class Kandinsky5TEModel_(Kandinsky5TEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["qwen_scaled_fp8"] = llama_scaled_fp8
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return Kandinsky5TEModel_
|
||||
@ -568,6 +568,8 @@ class Conditioning(ComfyTypeIO):
|
||||
'''Used by WAN Camera.'''
|
||||
time_dim_concat: NotRequired[torch.Tensor]
|
||||
'''Used by WAN Phantom Subject.'''
|
||||
time_dim_replace: NotRequired[torch.Tensor]
|
||||
'''Used by Kandinsky5 I2V.'''
|
||||
|
||||
CondList = list[tuple[torch.Tensor, PooledDict]]
|
||||
Type = CondList
|
||||
|
||||
136
comfy_extras/nodes_kandinsky5.py
Normal file
136
comfy_extras/nodes_kandinsky5.py
Normal file
@ -0,0 +1,136 @@
|
||||
import nodes
|
||||
import node_helpers
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class Kandinsky5ImageToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Kandinsky5ImageToVideo",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=768, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=512, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.Image.Input("start_image", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent", tooltip="Empty video latent"),
|
||||
io.Latent.Output(display_name="cond_latent", tooltip="Clean encoded start images, used to replace the noisy start of the model output latents"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
cond_latent_out = {}
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
encoded = vae.encode(start_image[:, :, :, :3])
|
||||
cond_latent_out["samples"] = encoded
|
||||
|
||||
mask = torch.ones((1, 1, latent.shape[2], latent.shape[-2], latent.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
||||
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"time_dim_replace": encoded, "concat_mask": mask})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"time_dim_replace": encoded, "concat_mask": mask})
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return io.NodeOutput(positive, negative, out_latent, cond_latent_out)
|
||||
|
||||
|
||||
def adaptive_mean_std_normalization(source, reference, clump_mean_low=0.3, clump_mean_high=0.35, clump_std_low=0.35, clump_std_high=0.5):
|
||||
source_mean = source.mean(dim=(1, 3, 4), keepdim=True) # mean over C, H, W
|
||||
source_std = source.std(dim=(1, 3, 4), keepdim=True) # std over C, H, W
|
||||
|
||||
reference_mean = torch.clamp(reference.mean(), source_mean - clump_mean_low, source_mean + clump_mean_high)
|
||||
reference_std = torch.clamp(reference.std(), source_std - clump_std_low, source_std + clump_std_high)
|
||||
|
||||
# normalization
|
||||
normalized = (source - source_mean) / (source_std + 1e-8)
|
||||
normalized = normalized * reference_std + reference_mean
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
class NormalizeVideoLatentStart(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="NormalizeVideoLatentStart",
|
||||
category="conditioning/video_models",
|
||||
description="Normalizes the initial frames of a video latent to match the mean and standard deviation of subsequent reference frames. Helps reduce differences between the starting frames and the rest of the video.",
|
||||
inputs=[
|
||||
io.Latent.Input("latent"),
|
||||
io.Int.Input("start_frame_count", default=4, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of latent frames to normalize, counted from the start"),
|
||||
io.Int.Input("reference_frame_count", default=5, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of latent frames after the start frames to use as reference"),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, latent, start_frame_count, reference_frame_count) -> io.NodeOutput:
|
||||
if latent["samples"].shape[2] <= 1:
|
||||
return io.NodeOutput(latent)
|
||||
s = latent.copy()
|
||||
samples = latent["samples"].clone()
|
||||
|
||||
first_frames = samples[:, :, :start_frame_count]
|
||||
reference_frames_data = samples[:, :, start_frame_count:start_frame_count+min(reference_frame_count, samples.shape[2]-1)]
|
||||
normalized_first_frames = adaptive_mean_std_normalization(first_frames, reference_frames_data)
|
||||
|
||||
samples[:, :, :start_frame_count] = normalized_first_frames
|
||||
s["samples"] = samples
|
||||
return io.NodeOutput(s)
|
||||
|
||||
|
||||
class CLIPTextEncodeKandinsky5(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CLIPTextEncodeKandinsky5",
|
||||
category="advanced/conditioning/kandinsky5",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
|
||||
io.String.Input("qwen25_7b", multiline=True, dynamic_prompts=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, clip_l, qwen25_7b) -> io.NodeOutput:
|
||||
tokens = clip.tokenize(clip_l)
|
||||
tokens["qwen25_7b"] = clip.tokenize(qwen25_7b)["qwen25_7b"]
|
||||
|
||||
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
|
||||
|
||||
|
||||
class Kandinsky5Extension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
Kandinsky5ImageToVideo,
|
||||
NormalizeVideoLatentStart,
|
||||
CLIPTextEncodeKandinsky5,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> Kandinsky5Extension:
|
||||
return Kandinsky5Extension()
|
||||
@ -4,7 +4,7 @@ import torch
|
||||
import nodes
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
import logging
|
||||
|
||||
def reshape_latent_to(target_shape, latent, repeat_batch=True):
|
||||
if latent.shape[1:] != target_shape[1:]:
|
||||
@ -388,6 +388,42 @@ class LatentOperationSharpen(io.ComfyNode):
|
||||
return luminance * sharpened
|
||||
return io.NodeOutput(sharpen)
|
||||
|
||||
class ReplaceVideoLatentFrames(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ReplaceVideoLatentFrames",
|
||||
category="latent/batch",
|
||||
inputs=[
|
||||
io.Latent.Input("destination", tooltip="The destination latent where frames will be replaced."),
|
||||
io.Latent.Input("source", optional=True, tooltip="The source latent providing frames to insert into the destination latent. If not provided, the destination latent is returned unchanged."),
|
||||
io.Int.Input("index", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1, tooltip="The starting latent frame index in the destination latent where the source latent frames will be placed. Negative values count from the end."),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, destination, index, source=None) -> io.NodeOutput:
|
||||
if source is None:
|
||||
return io.NodeOutput(destination)
|
||||
dest_frames = destination["samples"].shape[2]
|
||||
source_frames = source["samples"].shape[2]
|
||||
if index < 0:
|
||||
index = dest_frames + index
|
||||
if index > dest_frames:
|
||||
logging.warning(f"ReplaceVideoLatentFrames: Index {index} is out of bounds for destination latent frames {dest_frames}.")
|
||||
return io.NodeOutput(destination)
|
||||
if index + source_frames > dest_frames:
|
||||
logging.warning(f"ReplaceVideoLatentFrames: Source latent frames {source_frames} do not fit within destination latent frames {dest_frames} at the specified index {index}.")
|
||||
return io.NodeOutput(destination)
|
||||
s = source.copy()
|
||||
s_source = source["samples"]
|
||||
s_destination = destination["samples"].clone()
|
||||
s_destination[:, :, index:index + s_source.shape[2]] = s_source
|
||||
s["samples"] = s_destination
|
||||
return io.NodeOutput(s)
|
||||
|
||||
class LatentExtension(ComfyExtension):
|
||||
@override
|
||||
@ -405,6 +441,7 @@ class LatentExtension(ComfyExtension):
|
||||
LatentApplyOperationCFG,
|
||||
LatentOperationTonemapReinhard,
|
||||
LatentOperationSharpen,
|
||||
ReplaceVideoLatentFrames
|
||||
]
|
||||
|
||||
|
||||
|
||||
3
nodes.py
3
nodes.py
@ -970,7 +970,7 @@ class DualCLIPLoader:
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15"], ),
|
||||
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image"], ),
|
||||
},
|
||||
"optional": {
|
||||
"device": (["default", "cpu"], {"advanced": True}),
|
||||
@ -2357,6 +2357,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_rope.py",
|
||||
"nodes_logic.py",
|
||||
"nodes_nop.py",
|
||||
"nodes_kandinsky5.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user