mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-09 22:14:34 +08:00
* Add Kandinsky5 model support lite and pro T2V tested to work * Update kandinsky5.py * Fix fp8 * Fix fp8_scaled text encoder * Add transformer_options for attention * Code cleanup, optimizations, use fp32 for all layers originally at fp32 * ImageToVideo -node * Fix I2V, add necessary latent post process nodes * Support text to image model * Support block replace patches (SLG mostly) * Support official LoRAs * Don't scale RoPE for lite model as that just doesn't work... * Update supported_models.py * Rever RoPE scaling to simpler one * Fix typo * Handle latent dim difference for image model in the VAE instead * Add node to use different prompts for clip_l and qwen25_7b * Reduce peak VRAM usage a bit * Further reduce peak VRAM consumption by chunking ffn * Update chunking * Update memory_usage_factor * Code cleanup, don't force the fp32 layers as it has minimal effect * Allow for stronger changes with first frames normalization Default values are too weak for any meaningful changes, these should probably be exposed as advanced node options when that's available. * Add image model's own chat template, remove unused image2video template * Remove hard error in ReplaceVideoLatentFrames -node * Update kandinsky5.py * Update supported_models.py * Fix typos in prompt template They were now fixed in the original repository as well * Update ReplaceVideoLatentFrames Add tooltips Make source optional Better handle negative index * Rename NormalizeVideoLatentFrames -node For bit better clarity what it does * Fix NormalizeVideoLatentStart node out on non-op
408 lines
21 KiB
Python
408 lines
21 KiB
Python
import torch
|
|
from torch import nn
|
|
import math
|
|
|
|
import comfy.ldm.common_dit
|
|
from comfy.ldm.modules.attention import optimized_attention
|
|
from comfy.ldm.flux.math import apply_rope1
|
|
from comfy.ldm.flux.layers import EmbedND
|
|
|
|
def attention(q, k, v, heads, transformer_options={}):
|
|
return optimized_attention(
|
|
q.transpose(1, 2),
|
|
k.transpose(1, 2),
|
|
v.transpose(1, 2),
|
|
heads=heads,
|
|
skip_reshape=True,
|
|
transformer_options=transformer_options
|
|
)
|
|
|
|
def apply_scale_shift_norm(norm, x, scale, shift):
|
|
return torch.addcmul(shift, norm(x), scale + 1.0)
|
|
|
|
def apply_gate_sum(x, out, gate):
|
|
return torch.addcmul(x, gate, out)
|
|
|
|
def get_shift_scale_gate(params):
|
|
shift, scale, gate = torch.chunk(params, 3, dim=-1)
|
|
return tuple(x.unsqueeze(1) for x in (shift, scale, gate))
|
|
|
|
def get_freqs(dim, max_period=10000.0):
|
|
return torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim)
|
|
|
|
|
|
class TimeEmbeddings(nn.Module):
|
|
def __init__(self, model_dim, time_dim, max_period=10000.0, operation_settings=None):
|
|
super().__init__()
|
|
assert model_dim % 2 == 0
|
|
self.model_dim = model_dim
|
|
self.max_period = max_period
|
|
self.register_buffer("freqs", get_freqs(model_dim // 2, max_period), persistent=False)
|
|
operations = operation_settings.get("operations")
|
|
self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
self.activation = nn.SiLU()
|
|
self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
|
|
def forward(self, timestep, dtype):
|
|
args = torch.outer(timestep, self.freqs.to(device=timestep.device))
|
|
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype)
|
|
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
|
|
return time_embed
|
|
|
|
|
|
class TextEmbeddings(nn.Module):
|
|
def __init__(self, text_dim, model_dim, operation_settings=None):
|
|
super().__init__()
|
|
operations = operation_settings.get("operations")
|
|
self.in_layer = operations.Linear(text_dim, model_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
self.norm = operations.LayerNorm(model_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
|
|
def forward(self, text_embed):
|
|
text_embed = self.in_layer(text_embed)
|
|
return self.norm(text_embed).type_as(text_embed)
|
|
|
|
|
|
class VisualEmbeddings(nn.Module):
|
|
def __init__(self, visual_dim, model_dim, patch_size, operation_settings=None):
|
|
super().__init__()
|
|
self.patch_size = patch_size
|
|
operations = operation_settings.get("operations")
|
|
self.in_layer = operations.Linear(visual_dim, model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
|
|
def forward(self, x):
|
|
x = x.movedim(1, -1) # B C T H W -> B T H W C
|
|
B, T, H, W, dim = x.shape
|
|
pt, ph, pw = self.patch_size
|
|
|
|
x = x.view(
|
|
B,
|
|
T // pt, pt,
|
|
H // ph, ph,
|
|
W // pw, pw,
|
|
dim,
|
|
).permute(0, 1, 3, 5, 2, 4, 6, 7).flatten(4, 7)
|
|
|
|
return self.in_layer(x)
|
|
|
|
|
|
class Modulation(nn.Module):
|
|
def __init__(self, time_dim, model_dim, num_params, operation_settings=None):
|
|
super().__init__()
|
|
self.activation = nn.SiLU()
|
|
self.out_layer = operation_settings.get("operations").Linear(time_dim, num_params * model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
|
|
def forward(self, x):
|
|
return self.out_layer(self.activation(x))
|
|
|
|
|
|
class SelfAttention(nn.Module):
|
|
def __init__(self, num_channels, head_dim, operation_settings=None):
|
|
super().__init__()
|
|
assert num_channels % head_dim == 0
|
|
self.num_heads = num_channels // head_dim
|
|
self.head_dim = head_dim
|
|
|
|
operations = operation_settings.get("operations")
|
|
self.to_query = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
self.to_key = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
self.to_value = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
self.query_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
self.key_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
|
|
self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
self.num_chunks = 2
|
|
|
|
def _compute_qk(self, x, freqs, proj_fn, norm_fn):
|
|
result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1)
|
|
return apply_rope1(norm_fn(result), freqs)
|
|
|
|
def _forward(self, x, freqs, transformer_options={}):
|
|
q = self._compute_qk(x, freqs, self.to_query, self.query_norm)
|
|
k = self._compute_qk(x, freqs, self.to_key, self.key_norm)
|
|
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
|
|
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
|
|
return self.out_layer(out)
|
|
|
|
def _forward_chunked(self, x, freqs, transformer_options={}):
|
|
def process_chunks(proj_fn, norm_fn):
|
|
x_chunks = torch.chunk(x, self.num_chunks, dim=1)
|
|
freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1)
|
|
chunks = []
|
|
for x_chunk, freqs_chunk in zip(x_chunks, freqs_chunks):
|
|
chunks.append(self._compute_qk(x_chunk, freqs_chunk, proj_fn, norm_fn))
|
|
return torch.cat(chunks, dim=1)
|
|
|
|
q = process_chunks(self.to_query, self.query_norm)
|
|
k = process_chunks(self.to_key, self.key_norm)
|
|
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
|
|
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
|
|
return self.out_layer(out)
|
|
|
|
def forward(self, x, freqs, transformer_options={}):
|
|
if x.shape[1] > 8192:
|
|
return self._forward_chunked(x, freqs, transformer_options=transformer_options)
|
|
else:
|
|
return self._forward(x, freqs, transformer_options=transformer_options)
|
|
|
|
|
|
class CrossAttention(SelfAttention):
|
|
def get_qkv(self, x, context):
|
|
q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1)
|
|
k = self.to_key(context).view(*context.shape[:-1], self.num_heads, -1)
|
|
v = self.to_value(context).view(*context.shape[:-1], self.num_heads, -1)
|
|
return q, k, v
|
|
|
|
def forward(self, x, context, transformer_options={}):
|
|
q, k, v = self.get_qkv(x, context)
|
|
out = attention(self.query_norm(q), self.key_norm(k), v, self.num_heads, transformer_options=transformer_options)
|
|
return self.out_layer(out)
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
def __init__(self, dim, ff_dim, operation_settings=None):
|
|
super().__init__()
|
|
operations = operation_settings.get("operations")
|
|
self.in_layer = operations.Linear(dim, ff_dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
self.activation = nn.GELU()
|
|
self.out_layer = operations.Linear(ff_dim, dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
self.num_chunks = 4
|
|
|
|
def _forward(self, x):
|
|
return self.out_layer(self.activation(self.in_layer(x)))
|
|
|
|
def _forward_chunked(self, x):
|
|
chunks = torch.chunk(x, self.num_chunks, dim=1)
|
|
output_chunks = []
|
|
for chunk in chunks:
|
|
output_chunks.append(self._forward(chunk))
|
|
return torch.cat(output_chunks, dim=1)
|
|
|
|
def forward(self, x):
|
|
if x.shape[1] > 8192:
|
|
return self._forward_chunked(x)
|
|
else:
|
|
return self._forward(x)
|
|
|
|
|
|
class OutLayer(nn.Module):
|
|
def __init__(self, model_dim, time_dim, visual_dim, patch_size, operation_settings=None):
|
|
super().__init__()
|
|
self.patch_size = patch_size
|
|
self.modulation = Modulation(time_dim, model_dim, 2, operation_settings=operation_settings)
|
|
operations = operation_settings.get("operations")
|
|
self.norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
self.out_layer = operations.Linear(model_dim, math.prod(patch_size) * visual_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
|
|
def forward(self, visual_embed, time_embed):
|
|
B, T, H, W, _ = visual_embed.shape
|
|
shift, scale = torch.chunk(self.modulation(time_embed), 2, dim=-1)
|
|
scale = scale[:, None, None, None, :]
|
|
shift = shift[:, None, None, None, :]
|
|
visual_embed = apply_scale_shift_norm(self.norm, visual_embed, scale, shift)
|
|
x = self.out_layer(visual_embed)
|
|
|
|
out_dim = x.shape[-1] // (self.patch_size[0] * self.patch_size[1] * self.patch_size[2])
|
|
x = x.view(
|
|
B, T, H, W,
|
|
out_dim,
|
|
self.patch_size[0], self.patch_size[1], self.patch_size[2]
|
|
)
|
|
return x.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(2, 3).flatten(3, 4).flatten(4, 5)
|
|
|
|
|
|
class TransformerEncoderBlock(nn.Module):
|
|
def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
|
|
super().__init__()
|
|
self.text_modulation = Modulation(time_dim, model_dim, 6, operation_settings=operation_settings)
|
|
operations = operation_settings.get("operations")
|
|
|
|
self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings)
|
|
|
|
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
|
|
|
|
def forward(self, x, time_embed, freqs, transformer_options={}):
|
|
self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1)
|
|
shift, scale, gate = get_shift_scale_gate(self_attn_params)
|
|
out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift)
|
|
out = self.self_attention(out, freqs, transformer_options=transformer_options)
|
|
x = apply_gate_sum(x, out, gate)
|
|
|
|
shift, scale, gate = get_shift_scale_gate(ff_params)
|
|
out = apply_scale_shift_norm(self.feed_forward_norm, x, scale, shift)
|
|
out = self.feed_forward(out)
|
|
x = apply_gate_sum(x, out, gate)
|
|
return x
|
|
|
|
|
|
class TransformerDecoderBlock(nn.Module):
|
|
def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
|
|
super().__init__()
|
|
self.visual_modulation = Modulation(time_dim, model_dim, 9, operation_settings=operation_settings)
|
|
|
|
operations = operation_settings.get("operations")
|
|
self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings)
|
|
|
|
self.cross_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
self.cross_attention = CrossAttention(model_dim, head_dim, operation_settings=operation_settings)
|
|
|
|
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
|
|
|
|
def forward(self, visual_embed, text_embed, time_embed, freqs, transformer_options={}):
|
|
self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1)
|
|
# self attention
|
|
shift, scale, gate = get_shift_scale_gate(self_attn_params)
|
|
visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift)
|
|
visual_out = self.self_attention(visual_out, freqs, transformer_options=transformer_options)
|
|
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
|
# cross attention
|
|
shift, scale, gate = get_shift_scale_gate(cross_attn_params)
|
|
visual_out = apply_scale_shift_norm(self.cross_attention_norm, visual_embed, scale, shift)
|
|
visual_out = self.cross_attention(visual_out, text_embed, transformer_options=transformer_options)
|
|
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
|
# feed forward
|
|
shift, scale, gate = get_shift_scale_gate(ff_params)
|
|
visual_out = apply_scale_shift_norm(self.feed_forward_norm, visual_embed, scale, shift)
|
|
visual_out = self.feed_forward(visual_out)
|
|
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
|
return visual_embed
|
|
|
|
|
|
class Kandinsky5(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_visual_dim=16, out_visual_dim=16, in_text_dim=3584, in_text_dim2=768, time_dim=512,
|
|
model_dim=1792, ff_dim=7168, visual_embed_dim=132, patch_size=(1, 2, 2), num_text_blocks=2, num_visual_blocks=32,
|
|
axes_dims=(16, 24, 24), rope_scale_factor=(1.0, 2.0, 2.0),
|
|
dtype=None, device=None, operations=None, **kwargs
|
|
):
|
|
super().__init__()
|
|
head_dim = sum(axes_dims)
|
|
self.rope_scale_factor = rope_scale_factor
|
|
self.in_visual_dim = in_visual_dim
|
|
self.model_dim = model_dim
|
|
self.patch_size = patch_size
|
|
self.visual_embed_dim = visual_embed_dim
|
|
self.dtype = dtype
|
|
self.device = device
|
|
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
|
|
|
self.time_embeddings = TimeEmbeddings(model_dim, time_dim, operation_settings=operation_settings)
|
|
self.text_embeddings = TextEmbeddings(in_text_dim, model_dim, operation_settings=operation_settings)
|
|
self.pooled_text_embeddings = TextEmbeddings(in_text_dim2, time_dim, operation_settings=operation_settings)
|
|
self.visual_embeddings = VisualEmbeddings(visual_embed_dim, model_dim, patch_size, operation_settings=operation_settings)
|
|
|
|
self.text_transformer_blocks = nn.ModuleList(
|
|
[TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_text_blocks)]
|
|
)
|
|
|
|
self.visual_transformer_blocks = nn.ModuleList(
|
|
[TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_visual_blocks)]
|
|
)
|
|
|
|
self.out_layer = OutLayer(model_dim, time_dim, out_visual_dim, patch_size, operation_settings=operation_settings)
|
|
|
|
self.rope_embedder_3d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=axes_dims)
|
|
self.rope_embedder_1d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=[head_dim])
|
|
|
|
def rope_encode_1d(self, seq_len, seq_start=0, steps=None, device=None, dtype=None, transformer_options={}):
|
|
steps = seq_len if steps is None else steps
|
|
seq_ids = torch.linspace(seq_start, seq_start + (seq_len - 1), steps=steps, device=device, dtype=dtype)
|
|
seq_ids = seq_ids.reshape(-1, 1).unsqueeze(0) # Shape: (1, steps, 1)
|
|
freqs = self.rope_embedder_1d(seq_ids).movedim(1, 2)
|
|
return freqs
|
|
|
|
def rope_encode_3d(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
|
|
|
|
patch_size = self.patch_size
|
|
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
|
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
|
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
|
|
|
if steps_t is None:
|
|
steps_t = t_len
|
|
if steps_h is None:
|
|
steps_h = h_len
|
|
if steps_w is None:
|
|
steps_w = w_len
|
|
|
|
h_start = 0
|
|
w_start = 0
|
|
rope_options = transformer_options.get("rope_options", None)
|
|
if rope_options is not None:
|
|
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
|
|
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
|
|
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
|
|
|
|
t_start += rope_options.get("shift_t", 0.0)
|
|
h_start += rope_options.get("shift_y", 0.0)
|
|
w_start += rope_options.get("shift_x", 0.0)
|
|
else:
|
|
rope_scale_factor = self.rope_scale_factor
|
|
if self.model_dim == 4096: # pro video model uses different rope scaling at higher resolutions
|
|
if h * w >= 14080:
|
|
rope_scale_factor = (1.0, 3.16, 3.16)
|
|
|
|
t_len = (t_len - 1.0) / rope_scale_factor[0] + 1.0
|
|
h_len = (h_len - 1.0) / rope_scale_factor[1] + 1.0
|
|
w_len = (w_len - 1.0) / rope_scale_factor[2] + 1.0
|
|
|
|
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
|
|
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
|
|
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
|
|
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
|
|
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
|
|
|
|
freqs = self.rope_embedder_3d(img_ids).movedim(1, 2)
|
|
return freqs
|
|
|
|
def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_options={}, **kwargs):
|
|
patches_replace = transformer_options.get("patches_replace", {})
|
|
context = self.text_embeddings(context)
|
|
time_embed = self.time_embeddings(timestep, x.dtype) + self.pooled_text_embeddings(y)
|
|
|
|
for block in self.text_transformer_blocks:
|
|
context = block(context, time_embed, freqs_text, transformer_options=transformer_options)
|
|
|
|
visual_embed = self.visual_embeddings(x)
|
|
visual_shape = visual_embed.shape[:-1]
|
|
visual_embed = visual_embed.flatten(1, -2)
|
|
|
|
blocks_replace = patches_replace.get("dit", {})
|
|
transformer_options["total_blocks"] = len(self.visual_transformer_blocks)
|
|
transformer_options["block_type"] = "double"
|
|
for i, block in enumerate(self.visual_transformer_blocks):
|
|
transformer_options["block_index"] = i
|
|
if ("double_block", i) in blocks_replace:
|
|
def block_wrap(args):
|
|
return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options"))
|
|
visual_embed = blocks_replace[("double_block", i)]({"x": visual_embed, "context": context, "time_embed": time_embed, "freqs": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})["x"]
|
|
else:
|
|
visual_embed = block(visual_embed, context, time_embed, freqs=freqs, transformer_options=transformer_options)
|
|
|
|
visual_embed = visual_embed.reshape(*visual_shape, -1)
|
|
return self.out_layer(visual_embed, time_embed)
|
|
|
|
def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
|
|
bs, c, t_len, h, w = x.shape
|
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
|
|
|
if time_dim_replace is not None:
|
|
time_dim_replace = comfy.ldm.common_dit.pad_to_patch_size(time_dim_replace, self.patch_size)
|
|
x[:, :time_dim_replace.shape[1], :time_dim_replace.shape[2]] = time_dim_replace
|
|
|
|
freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
|
freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
|
|
|
return self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs)
|
|
|
|
def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
|
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
|
self._forward,
|
|
self,
|
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
|
).execute(x, timestep, context, y, time_dim_replace=time_dim_replace, transformer_options=transformer_options, **kwargs)
|