mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-08 20:34:23 +08:00
finally
works
This commit is contained in:
parent
9aab678a9e
commit
b563994afc
@ -73,6 +73,7 @@ class CogVideoXAttnProcessor2_0:
|
||||
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
@torch.compiler.disable()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
@ -115,14 +116,16 @@ class CogVideoXAttnProcessor2_0:
|
||||
|
||||
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
||||
if not attn.is_cross_attention:
|
||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||
|
||||
if SAGEATTN_IS_AVAILABLE:
|
||||
hidden_states = sageattn(query, key, value, is_causal=False)
|
||||
else:
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||
|
||||
#if SAGEATTN_IS_AVAILABLE:
|
||||
# hidden_states = sageattn(query, key, value, is_causal=False)
|
||||
#else:
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
if torch.isinf(hidden_states).any():
|
||||
raise ValueError(f"hidden_states after dot product has inf")
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
|
||||
@ -136,81 +139,6 @@ class CogVideoXAttnProcessor2_0:
|
||||
)
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
# class FusedCogVideoXAttnProcessor2_0:
|
||||
# r"""
|
||||
# Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
||||
# query and key vectors, but does not include spatial normalization.
|
||||
# """
|
||||
|
||||
# def __init__(self):
|
||||
# if not hasattr(F, "scaled_dot_product_attention"):
|
||||
# raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
# @torch.compiler.disable()
|
||||
# def __call__(
|
||||
# self,
|
||||
# attn: Attention,
|
||||
# hidden_states: torch.Tensor,
|
||||
# encoder_hidden_states: torch.Tensor,
|
||||
# attention_mask: Optional[torch.Tensor] = None,
|
||||
# image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
# ) -> torch.Tensor:
|
||||
# print("FusedCogVideoXAttnProcessor2_0")
|
||||
# text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
# hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
# batch_size, sequence_length, _ = (
|
||||
# hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
# )
|
||||
|
||||
# if attention_mask is not None:
|
||||
# attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
# qkv = attn.to_qkv(hidden_states)
|
||||
# split_size = qkv.shape[-1] // 3
|
||||
# query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
# inner_dim = key.shape[-1]
|
||||
# head_dim = inner_dim // attn.heads
|
||||
|
||||
# query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
# key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
# value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# if attn.norm_q is not None:
|
||||
# query = attn.norm_q(query)
|
||||
# if attn.norm_k is not None:
|
||||
# key = attn.norm_k(key)
|
||||
|
||||
# # Apply RoPE if needed
|
||||
# if image_rotary_emb is not None:
|
||||
# from diffusers.models.embeddings import apply_rotary_emb
|
||||
|
||||
# query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
||||
# if not attn.is_cross_attention:
|
||||
# key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||
|
||||
# if SAGEATTN_IS_AVAILABLE:
|
||||
# hidden_states = sageattn(query, key, value, is_causal=False)
|
||||
# else:
|
||||
# hidden_states = F.scaled_dot_product_attention(
|
||||
# query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
# )
|
||||
|
||||
# hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
|
||||
# # linear proj
|
||||
# hidden_states = attn.to_out[0](hidden_states)
|
||||
# # dropout
|
||||
# hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
# encoder_hidden_states, hidden_states = hidden_states.split(
|
||||
# [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
||||
# )
|
||||
# return hidden_states, encoder_hidden_states
|
||||
|
||||
#region Blocks
|
||||
@maybe_allow_in_graph
|
||||
class CogVideoXBlock(nn.Module):
|
||||
@ -270,6 +198,7 @@ class CogVideoXBlock(nn.Module):
|
||||
|
||||
# 1. Self Attention
|
||||
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
||||
|
||||
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
@ -308,11 +237,14 @@ class CogVideoXBlock(nn.Module):
|
||||
fastercache_start_step=15,
|
||||
fastercache_device="cuda:0"
|
||||
) -> torch.Tensor:
|
||||
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
# norm & modulate
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
||||
hidden_states, encoder_hidden_states, temb
|
||||
)
|
||||
|
||||
# Tora Motion-guidance Fuser
|
||||
if video_flow_feature is not None:
|
||||
H, W = video_flow_feature.shape[-2:]
|
||||
@ -347,19 +279,12 @@ class CogVideoXBlock(nn.Module):
|
||||
elif fastercache_counter > fastercache_start_step:
|
||||
self.cached_hidden_states[-1].copy_(attn_hidden_states.to(fastercache_device))
|
||||
self.cached_encoder_hidden_states[-1].copy_(attn_encoder_hidden_states.to(fastercache_device))
|
||||
|
||||
|
||||
|
||||
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
|
||||
|
||||
# has_nan = torch.isnan(hidden_states).any()
|
||||
# if has_nan:
|
||||
# raise ValueError(f"hs before norm2 has nan: {has_nan}")
|
||||
# has_inf = torch.isinf(hidden_states).any()
|
||||
# if has_inf:
|
||||
# raise ValueError(f"hs before norm2 has inf: {has_inf}")
|
||||
|
||||
# norm & modulate
|
||||
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
|
||||
hidden_states, encoder_hidden_states, temb
|
||||
)
|
||||
@ -604,45 +529,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
|
||||
# def fuse_qkv_projections(self):
|
||||
# """
|
||||
# Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||
# are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
# <Tip warning={true}>
|
||||
|
||||
# This API is 🧪 experimental.
|
||||
|
||||
# </Tip>
|
||||
# """
|
||||
# self.original_attn_processors = None
|
||||
|
||||
# for _, attn_processor in self.attn_processors.items():
|
||||
# if "Added" in str(attn_processor.__class__.__name__):
|
||||
# raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
# self.original_attn_processors = self.attn_processors
|
||||
|
||||
# for module in self.modules():
|
||||
# if isinstance(module, Attention):
|
||||
# module.fuse_projections(fuse=True)
|
||||
|
||||
# self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
# def unfuse_qkv_projections(self):
|
||||
# """Disables the fused QKV projection if enabled.
|
||||
|
||||
# <Tip warning={true}>
|
||||
|
||||
# This API is 🧪 experimental.
|
||||
|
||||
# </Tip>
|
||||
|
||||
# """
|
||||
# if self.original_attn_processors is not None:
|
||||
# self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -679,8 +565,10 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
first_frame = hidden_states[:, :1].repeat(1, 1 + remaining_frames, 1, 1, 1)
|
||||
hidden_states = torch.cat([first_frame, hidden_states[:, 1:]], dim=1)
|
||||
|
||||
|
||||
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
||||
hidden_states = self.embedding_dropout(hidden_states)
|
||||
|
||||
|
||||
text_seq_length = encoder_hidden_states.shape[1]
|
||||
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
||||
@ -760,7 +648,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
output = torch.cat([output, recovered_uncond])
|
||||
else:
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
#print("block", i)
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
|
||||
@ -261,7 +261,7 @@ class DownloadAndLoadCogVideoModel:
|
||||
if "CogVideoXBlock" in str(block):
|
||||
pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=False, dynamic=False, backend="inductor")
|
||||
|
||||
|
||||
|
||||
|
||||
elif compile == "onediff":
|
||||
from onediffx import compile_pipe
|
||||
@ -274,7 +274,7 @@ class DownloadAndLoadCogVideoModel:
|
||||
ignores=["vae"],
|
||||
fuse_qkv_projections=True if pab_config is None else False,
|
||||
)
|
||||
|
||||
|
||||
pipeline = {
|
||||
"pipe": pipe,
|
||||
"dtype": dtype,
|
||||
@ -453,6 +453,8 @@ class DownloadAndLoadCogVideoGGUFModel:
|
||||
if enable_sequential_cpu_offload:
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
|
||||
|
||||
|
||||
pipeline = {
|
||||
"pipe": pipe,
|
||||
"dtype": vae_dtype,
|
||||
|
||||
1
nodes.py
1
nodes.py
@ -861,6 +861,7 @@ class CogVideoSampler:
|
||||
pipe.transformer.fastercache_counter = 0
|
||||
|
||||
autocastcondition = not pipeline["onediff"] or not dtype == torch.float32
|
||||
autocastcondition = False ##todo
|
||||
autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext()
|
||||
with autocast_context:
|
||||
latents = pipeline["pipe"](
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user