works
This commit is contained in:
Jukka Seppänen 2024-11-09 04:02:36 +02:00
parent 9aab678a9e
commit b563994afc
3 changed files with 22 additions and 132 deletions

View File

@ -73,6 +73,7 @@ class CogVideoXAttnProcessor2_0:
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
@torch.compiler.disable()
def __call__(
self,
attn: Attention,
@ -115,14 +116,16 @@ class CogVideoXAttnProcessor2_0:
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
if SAGEATTN_IS_AVAILABLE:
hidden_states = sageattn(query, key, value, is_causal=False)
else:
hidden_states = F.scaled_dot_product_attention(
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
#if SAGEATTN_IS_AVAILABLE:
# hidden_states = sageattn(query, key, value, is_causal=False)
#else:
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
if torch.isinf(hidden_states).any():
raise ValueError(f"hidden_states after dot product has inf")
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
@ -136,81 +139,6 @@ class CogVideoXAttnProcessor2_0:
)
return hidden_states, encoder_hidden_states
# class FusedCogVideoXAttnProcessor2_0:
# r"""
# Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
# query and key vectors, but does not include spatial normalization.
# """
# def __init__(self):
# if not hasattr(F, "scaled_dot_product_attention"):
# raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
# @torch.compiler.disable()
# def __call__(
# self,
# attn: Attention,
# hidden_states: torch.Tensor,
# encoder_hidden_states: torch.Tensor,
# attention_mask: Optional[torch.Tensor] = None,
# image_rotary_emb: Optional[torch.Tensor] = None,
# ) -> torch.Tensor:
# print("FusedCogVideoXAttnProcessor2_0")
# text_seq_length = encoder_hidden_states.size(1)
# hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
# batch_size, sequence_length, _ = (
# hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# )
# if attention_mask is not None:
# attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
# qkv = attn.to_qkv(hidden_states)
# split_size = qkv.shape[-1] // 3
# query, key, value = torch.split(qkv, split_size, dim=-1)
# inner_dim = key.shape[-1]
# head_dim = inner_dim // attn.heads
# query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# if attn.norm_q is not None:
# query = attn.norm_q(query)
# if attn.norm_k is not None:
# key = attn.norm_k(key)
# # Apply RoPE if needed
# if image_rotary_emb is not None:
# from diffusers.models.embeddings import apply_rotary_emb
# query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
# if not attn.is_cross_attention:
# key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
# if SAGEATTN_IS_AVAILABLE:
# hidden_states = sageattn(query, key, value, is_causal=False)
# else:
# hidden_states = F.scaled_dot_product_attention(
# query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
# )
# hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# # linear proj
# hidden_states = attn.to_out[0](hidden_states)
# # dropout
# hidden_states = attn.to_out[1](hidden_states)
# encoder_hidden_states, hidden_states = hidden_states.split(
# [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
# )
# return hidden_states, encoder_hidden_states
#region Blocks
@maybe_allow_in_graph
class CogVideoXBlock(nn.Module):
@ -270,6 +198,7 @@ class CogVideoXBlock(nn.Module):
# 1. Self Attention
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
self.attn1 = Attention(
query_dim=dim,
@ -308,11 +237,14 @@ class CogVideoXBlock(nn.Module):
fastercache_start_step=15,
fastercache_device="cuda:0"
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
hidden_states, encoder_hidden_states, temb
)
# Tora Motion-guidance Fuser
if video_flow_feature is not None:
H, W = video_flow_feature.shape[-2:]
@ -347,19 +279,12 @@ class CogVideoXBlock(nn.Module):
elif fastercache_counter > fastercache_start_step:
self.cached_hidden_states[-1].copy_(attn_hidden_states.to(fastercache_device))
self.cached_encoder_hidden_states[-1].copy_(attn_encoder_hidden_states.to(fastercache_device))
hidden_states = hidden_states + gate_msa * attn_hidden_states
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
# has_nan = torch.isnan(hidden_states).any()
# if has_nan:
# raise ValueError(f"hs before norm2 has nan: {has_nan}")
# has_inf = torch.isinf(hidden_states).any()
# if has_inf:
# raise ValueError(f"hs before norm2 has inf: {has_inf}")
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
hidden_states, encoder_hidden_states, temb
)
@ -604,45 +529,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
# def fuse_qkv_projections(self):
# """
# Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
# are fused. For cross-attention modules, key and value projection matrices are fused.
# <Tip warning={true}>
# This API is 🧪 experimental.
# </Tip>
# """
# self.original_attn_processors = None
# for _, attn_processor in self.attn_processors.items():
# if "Added" in str(attn_processor.__class__.__name__):
# raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
# self.original_attn_processors = self.attn_processors
# for module in self.modules():
# if isinstance(module, Attention):
# module.fuse_projections(fuse=True)
# self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
# def unfuse_qkv_projections(self):
# """Disables the fused QKV projection if enabled.
# <Tip warning={true}>
# This API is 🧪 experimental.
# </Tip>
# """
# if self.original_attn_processors is not None:
# self.set_attn_processor(self.original_attn_processors)
def forward(
self,
@ -679,8 +565,10 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
first_frame = hidden_states[:, :1].repeat(1, 1 + remaining_frames, 1, 1, 1)
hidden_states = torch.cat([first_frame, hidden_states[:, 1:]], dim=1)
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
hidden_states = self.embedding_dropout(hidden_states)
text_seq_length = encoder_hidden_states.shape[1]
encoder_hidden_states = hidden_states[:, :text_seq_length]
@ -760,7 +648,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
output = torch.cat([output, recovered_uncond])
else:
for i, block in enumerate(self.transformer_blocks):
#print("block", i)
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,

View File

@ -261,7 +261,7 @@ class DownloadAndLoadCogVideoModel:
if "CogVideoXBlock" in str(block):
pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=False, dynamic=False, backend="inductor")
elif compile == "onediff":
from onediffx import compile_pipe
@ -274,7 +274,7 @@ class DownloadAndLoadCogVideoModel:
ignores=["vae"],
fuse_qkv_projections=True if pab_config is None else False,
)
pipeline = {
"pipe": pipe,
"dtype": dtype,
@ -453,6 +453,8 @@ class DownloadAndLoadCogVideoGGUFModel:
if enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload()
pipeline = {
"pipe": pipe,
"dtype": vae_dtype,

View File

@ -861,6 +861,7 @@ class CogVideoSampler:
pipe.transformer.fastercache_counter = 0
autocastcondition = not pipeline["onediff"] or not dtype == torch.float32
autocastcondition = False ##todo
autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext()
with autocast_context:
latents = pipeline["pipe"](