mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2026-05-05 08:33:31 +08:00
finally
works
This commit is contained in:
parent
9aab678a9e
commit
b563994afc
@ -73,6 +73,7 @@ class CogVideoXAttnProcessor2_0:
|
|||||||
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||||
|
|
||||||
@torch.compiler.disable()
|
@torch.compiler.disable()
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
attn: Attention,
|
attn: Attention,
|
||||||
@ -115,14 +116,16 @@ class CogVideoXAttnProcessor2_0:
|
|||||||
|
|
||||||
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
||||||
if not attn.is_cross_attention:
|
if not attn.is_cross_attention:
|
||||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||||
|
|
||||||
if SAGEATTN_IS_AVAILABLE:
|
#if SAGEATTN_IS_AVAILABLE:
|
||||||
hidden_states = sageattn(query, key, value, is_causal=False)
|
# hidden_states = sageattn(query, key, value, is_causal=False)
|
||||||
else:
|
#else:
|
||||||
hidden_states = F.scaled_dot_product_attention(
|
hidden_states = F.scaled_dot_product_attention(
|
||||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||||
)
|
)
|
||||||
|
if torch.isinf(hidden_states).any():
|
||||||
|
raise ValueError(f"hidden_states after dot product has inf")
|
||||||
|
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
|
|
||||||
@ -136,81 +139,6 @@ class CogVideoXAttnProcessor2_0:
|
|||||||
)
|
)
|
||||||
return hidden_states, encoder_hidden_states
|
return hidden_states, encoder_hidden_states
|
||||||
|
|
||||||
|
|
||||||
# class FusedCogVideoXAttnProcessor2_0:
|
|
||||||
# r"""
|
|
||||||
# Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
|
||||||
# query and key vectors, but does not include spatial normalization.
|
|
||||||
# """
|
|
||||||
|
|
||||||
# def __init__(self):
|
|
||||||
# if not hasattr(F, "scaled_dot_product_attention"):
|
|
||||||
# raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
|
||||||
# @torch.compiler.disable()
|
|
||||||
# def __call__(
|
|
||||||
# self,
|
|
||||||
# attn: Attention,
|
|
||||||
# hidden_states: torch.Tensor,
|
|
||||||
# encoder_hidden_states: torch.Tensor,
|
|
||||||
# attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
# image_rotary_emb: Optional[torch.Tensor] = None,
|
|
||||||
# ) -> torch.Tensor:
|
|
||||||
# print("FusedCogVideoXAttnProcessor2_0")
|
|
||||||
# text_seq_length = encoder_hidden_states.size(1)
|
|
||||||
|
|
||||||
# hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
|
||||||
|
|
||||||
# batch_size, sequence_length, _ = (
|
|
||||||
# hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
||||||
# )
|
|
||||||
|
|
||||||
# if attention_mask is not None:
|
|
||||||
# attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
||||||
# attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
|
||||||
|
|
||||||
# qkv = attn.to_qkv(hidden_states)
|
|
||||||
# split_size = qkv.shape[-1] // 3
|
|
||||||
# query, key, value = torch.split(qkv, split_size, dim=-1)
|
|
||||||
|
|
||||||
# inner_dim = key.shape[-1]
|
|
||||||
# head_dim = inner_dim // attn.heads
|
|
||||||
|
|
||||||
# query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
# key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
# value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
# if attn.norm_q is not None:
|
|
||||||
# query = attn.norm_q(query)
|
|
||||||
# if attn.norm_k is not None:
|
|
||||||
# key = attn.norm_k(key)
|
|
||||||
|
|
||||||
# # Apply RoPE if needed
|
|
||||||
# if image_rotary_emb is not None:
|
|
||||||
# from diffusers.models.embeddings import apply_rotary_emb
|
|
||||||
|
|
||||||
# query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
|
||||||
# if not attn.is_cross_attention:
|
|
||||||
# key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
|
||||||
|
|
||||||
# if SAGEATTN_IS_AVAILABLE:
|
|
||||||
# hidden_states = sageattn(query, key, value, is_causal=False)
|
|
||||||
# else:
|
|
||||||
# hidden_states = F.scaled_dot_product_attention(
|
|
||||||
# query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
|
||||||
# )
|
|
||||||
|
|
||||||
# hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
|
||||||
|
|
||||||
# # linear proj
|
|
||||||
# hidden_states = attn.to_out[0](hidden_states)
|
|
||||||
# # dropout
|
|
||||||
# hidden_states = attn.to_out[1](hidden_states)
|
|
||||||
|
|
||||||
# encoder_hidden_states, hidden_states = hidden_states.split(
|
|
||||||
# [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
|
||||||
# )
|
|
||||||
# return hidden_states, encoder_hidden_states
|
|
||||||
|
|
||||||
#region Blocks
|
#region Blocks
|
||||||
@maybe_allow_in_graph
|
@maybe_allow_in_graph
|
||||||
class CogVideoXBlock(nn.Module):
|
class CogVideoXBlock(nn.Module):
|
||||||
@ -270,6 +198,7 @@ class CogVideoXBlock(nn.Module):
|
|||||||
|
|
||||||
# 1. Self Attention
|
# 1. Self Attention
|
||||||
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
||||||
|
|
||||||
|
|
||||||
self.attn1 = Attention(
|
self.attn1 = Attention(
|
||||||
query_dim=dim,
|
query_dim=dim,
|
||||||
@ -308,11 +237,14 @@ class CogVideoXBlock(nn.Module):
|
|||||||
fastercache_start_step=15,
|
fastercache_start_step=15,
|
||||||
fastercache_device="cuda:0"
|
fastercache_device="cuda:0"
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
text_seq_length = encoder_hidden_states.size(1)
|
text_seq_length = encoder_hidden_states.size(1)
|
||||||
|
|
||||||
# norm & modulate
|
# norm & modulate
|
||||||
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
||||||
hidden_states, encoder_hidden_states, temb
|
hidden_states, encoder_hidden_states, temb
|
||||||
)
|
)
|
||||||
|
|
||||||
# Tora Motion-guidance Fuser
|
# Tora Motion-guidance Fuser
|
||||||
if video_flow_feature is not None:
|
if video_flow_feature is not None:
|
||||||
H, W = video_flow_feature.shape[-2:]
|
H, W = video_flow_feature.shape[-2:]
|
||||||
@ -347,19 +279,12 @@ class CogVideoXBlock(nn.Module):
|
|||||||
elif fastercache_counter > fastercache_start_step:
|
elif fastercache_counter > fastercache_start_step:
|
||||||
self.cached_hidden_states[-1].copy_(attn_hidden_states.to(fastercache_device))
|
self.cached_hidden_states[-1].copy_(attn_hidden_states.to(fastercache_device))
|
||||||
self.cached_encoder_hidden_states[-1].copy_(attn_encoder_hidden_states.to(fastercache_device))
|
self.cached_encoder_hidden_states[-1].copy_(attn_encoder_hidden_states.to(fastercache_device))
|
||||||
|
|
||||||
|
|
||||||
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
||||||
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
|
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
|
||||||
|
|
||||||
# has_nan = torch.isnan(hidden_states).any()
|
|
||||||
# if has_nan:
|
|
||||||
# raise ValueError(f"hs before norm2 has nan: {has_nan}")
|
|
||||||
# has_inf = torch.isinf(hidden_states).any()
|
|
||||||
# if has_inf:
|
|
||||||
# raise ValueError(f"hs before norm2 has inf: {has_inf}")
|
|
||||||
|
|
||||||
# norm & modulate
|
# norm & modulate
|
||||||
|
|
||||||
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
|
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
|
||||||
hidden_states, encoder_hidden_states, temb
|
hidden_states, encoder_hidden_states, temb
|
||||||
)
|
)
|
||||||
@ -604,45 +529,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|||||||
for name, module in self.named_children():
|
for name, module in self.named_children():
|
||||||
fn_recursive_attn_processor(name, module, processor)
|
fn_recursive_attn_processor(name, module, processor)
|
||||||
|
|
||||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
|
|
||||||
# def fuse_qkv_projections(self):
|
|
||||||
# """
|
|
||||||
# Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
|
||||||
# are fused. For cross-attention modules, key and value projection matrices are fused.
|
|
||||||
|
|
||||||
# <Tip warning={true}>
|
|
||||||
|
|
||||||
# This API is 🧪 experimental.
|
|
||||||
|
|
||||||
# </Tip>
|
|
||||||
# """
|
|
||||||
# self.original_attn_processors = None
|
|
||||||
|
|
||||||
# for _, attn_processor in self.attn_processors.items():
|
|
||||||
# if "Added" in str(attn_processor.__class__.__name__):
|
|
||||||
# raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
|
||||||
|
|
||||||
# self.original_attn_processors = self.attn_processors
|
|
||||||
|
|
||||||
# for module in self.modules():
|
|
||||||
# if isinstance(module, Attention):
|
|
||||||
# module.fuse_projections(fuse=True)
|
|
||||||
|
|
||||||
# self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
|
|
||||||
|
|
||||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
|
||||||
# def unfuse_qkv_projections(self):
|
|
||||||
# """Disables the fused QKV projection if enabled.
|
|
||||||
|
|
||||||
# <Tip warning={true}>
|
|
||||||
|
|
||||||
# This API is 🧪 experimental.
|
|
||||||
|
|
||||||
# </Tip>
|
|
||||||
|
|
||||||
# """
|
|
||||||
# if self.original_attn_processors is not None:
|
|
||||||
# self.set_attn_processor(self.original_attn_processors)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -679,8 +565,10 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|||||||
first_frame = hidden_states[:, :1].repeat(1, 1 + remaining_frames, 1, 1, 1)
|
first_frame = hidden_states[:, :1].repeat(1, 1 + remaining_frames, 1, 1, 1)
|
||||||
hidden_states = torch.cat([first_frame, hidden_states[:, 1:]], dim=1)
|
hidden_states = torch.cat([first_frame, hidden_states[:, 1:]], dim=1)
|
||||||
|
|
||||||
|
|
||||||
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
||||||
hidden_states = self.embedding_dropout(hidden_states)
|
hidden_states = self.embedding_dropout(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
text_seq_length = encoder_hidden_states.shape[1]
|
text_seq_length = encoder_hidden_states.shape[1]
|
||||||
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
||||||
@ -760,7 +648,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|||||||
output = torch.cat([output, recovered_uncond])
|
output = torch.cat([output, recovered_uncond])
|
||||||
else:
|
else:
|
||||||
for i, block in enumerate(self.transformer_blocks):
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
#print("block", i)
|
|
||||||
hidden_states, encoder_hidden_states = block(
|
hidden_states, encoder_hidden_states = block(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
|||||||
@ -261,7 +261,7 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
if "CogVideoXBlock" in str(block):
|
if "CogVideoXBlock" in str(block):
|
||||||
pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=False, dynamic=False, backend="inductor")
|
pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=False, dynamic=False, backend="inductor")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
elif compile == "onediff":
|
elif compile == "onediff":
|
||||||
from onediffx import compile_pipe
|
from onediffx import compile_pipe
|
||||||
@ -274,7 +274,7 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
ignores=["vae"],
|
ignores=["vae"],
|
||||||
fuse_qkv_projections=True if pab_config is None else False,
|
fuse_qkv_projections=True if pab_config is None else False,
|
||||||
)
|
)
|
||||||
|
|
||||||
pipeline = {
|
pipeline = {
|
||||||
"pipe": pipe,
|
"pipe": pipe,
|
||||||
"dtype": dtype,
|
"dtype": dtype,
|
||||||
@ -453,6 +453,8 @@ class DownloadAndLoadCogVideoGGUFModel:
|
|||||||
if enable_sequential_cpu_offload:
|
if enable_sequential_cpu_offload:
|
||||||
pipe.enable_sequential_cpu_offload()
|
pipe.enable_sequential_cpu_offload()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
pipeline = {
|
pipeline = {
|
||||||
"pipe": pipe,
|
"pipe": pipe,
|
||||||
"dtype": vae_dtype,
|
"dtype": vae_dtype,
|
||||||
|
|||||||
1
nodes.py
1
nodes.py
@ -861,6 +861,7 @@ class CogVideoSampler:
|
|||||||
pipe.transformer.fastercache_counter = 0
|
pipe.transformer.fastercache_counter = 0
|
||||||
|
|
||||||
autocastcondition = not pipeline["onediff"] or not dtype == torch.float32
|
autocastcondition = not pipeline["onediff"] or not dtype == torch.float32
|
||||||
|
autocastcondition = False ##todo
|
||||||
autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext()
|
autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext()
|
||||||
with autocast_context:
|
with autocast_context:
|
||||||
latents = pipeline["pipe"](
|
latents = pipeline["pipe"](
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user