mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-09 21:04:23 +08:00
fun fixes
This commit is contained in:
parent
7ac2224ec2
commit
34b650c785
@ -37,11 +37,9 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
from einops import rearrange
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
SAGEATTN_IS_AVAVILABLE = True
|
||||
logger.info("Using sageattn")
|
||||
SAGEATTN_IS_AVAILABLE = True
|
||||
except:
|
||||
logger.info("sageattn not found, using sdpa")
|
||||
SAGEATTN_IS_AVAVILABLE = False
|
||||
SAGEATTN_IS_AVAILABLE = False
|
||||
|
||||
def fft(tensor):
|
||||
tensor_fft = torch.fft.fft2(tensor)
|
||||
@ -77,6 +75,7 @@ class CogVideoXAttnProcessor2_0:
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
attention_mode: Optional[str] = None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
@ -114,82 +113,11 @@ class CogVideoXAttnProcessor2_0:
|
||||
if not attn.is_cross_attention:
|
||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||
|
||||
if SAGEATTN_IS_AVAVILABLE:
|
||||
hidden_states = sageattn(query, key, value, is_causal=False)
|
||||
else:
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
||||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
||||
)
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class FusedCogVideoXAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
||||
query and key vectors, but does not include spatial normalization.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
qkv = attn.to_qkv(hidden_states)
|
||||
split_size = qkv.shape[-1] // 3
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if image_rotary_emb is not None:
|
||||
from diffusers.models.embeddings import apply_rotary_emb
|
||||
|
||||
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
||||
if not attn.is_cross_attention:
|
||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||
|
||||
if SAGEATTN_IS_AVAVILABLE:
|
||||
hidden_states = sageattn(query, key, value, is_causal=False)
|
||||
if attention_mode == "sageattn":
|
||||
if SAGEATTN_IS_AVAILABLE:
|
||||
hidden_states = sageattn(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False)
|
||||
else:
|
||||
raise ImportError("sageattn not found")
|
||||
else:
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
@ -298,6 +226,7 @@ class CogVideoXBlock(nn.Module):
|
||||
ff_inner_dim: Optional[int] = None,
|
||||
ff_bias: bool = True,
|
||||
attention_out_bias: bool = True,
|
||||
attention_mode: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -326,6 +255,9 @@ class CogVideoXBlock(nn.Module):
|
||||
inner_dim=ff_inner_dim,
|
||||
bias=ff_bias,
|
||||
)
|
||||
self.cached_hidden_states = []
|
||||
self.cached_encoder_hidden_states = []
|
||||
self.attention_mode = attention_mode
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -335,6 +267,7 @@ class CogVideoXBlock(nn.Module):
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
video_flow_feature: Optional[torch.Tensor] = None,
|
||||
fuser=None,
|
||||
block_use_fastercache=False,
|
||||
fastercache_counter=0,
|
||||
fastercache_start_step=15,
|
||||
fastercache_device="cuda:0",
|
||||
@ -353,31 +286,41 @@ class CogVideoXBlock(nn.Module):
|
||||
h = fuser(h, video_flow_feature.to(h), T=T)
|
||||
norm_hidden_states = rearrange(h, "(B T) C H W -> B (T H W) C", T=T)
|
||||
del h, fuser
|
||||
#fastercache
|
||||
B = norm_hidden_states.shape[0]
|
||||
if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_hidden_states[-1].shape[0] >= B:
|
||||
attn_hidden_states = (
|
||||
self.cached_hidden_states[1][:B] +
|
||||
(self.cached_hidden_states[1][:B] - self.cached_hidden_states[0][:B])
|
||||
* 0.3
|
||||
).to(norm_hidden_states.device, non_blocking=True)
|
||||
attn_encoder_hidden_states = (
|
||||
self.cached_encoder_hidden_states[1][:B] +
|
||||
(self.cached_encoder_hidden_states[1][:B] - self.cached_encoder_hidden_states[0][:B])
|
||||
* 0.3
|
||||
).to(norm_hidden_states.device, non_blocking=True)
|
||||
|
||||
#region fastercache
|
||||
if block_use_fastercache:
|
||||
B = norm_hidden_states.shape[0]
|
||||
if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_hidden_states[-1].shape[0] >= B:
|
||||
attn_hidden_states = (
|
||||
self.cached_hidden_states[1][:B] +
|
||||
(self.cached_hidden_states[1][:B] - self.cached_hidden_states[0][:B])
|
||||
* 0.3
|
||||
).to(norm_hidden_states.device, non_blocking=True)
|
||||
attn_encoder_hidden_states = (
|
||||
self.cached_encoder_hidden_states[1][:B] +
|
||||
(self.cached_encoder_hidden_states[1][:B] - self.cached_encoder_hidden_states[0][:B])
|
||||
* 0.3
|
||||
).to(norm_hidden_states.device, non_blocking=True)
|
||||
else:
|
||||
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mode=self.attention_mode,
|
||||
)
|
||||
if fastercache_counter == fastercache_start_step:
|
||||
self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)]
|
||||
self.cached_encoder_hidden_states = [attn_encoder_hidden_states.to(fastercache_device), attn_encoder_hidden_states.to(fastercache_device)]
|
||||
elif fastercache_counter > fastercache_start_step:
|
||||
self.cached_hidden_states[-1].copy_(attn_hidden_states.to(fastercache_device))
|
||||
self.cached_encoder_hidden_states[-1].copy_(attn_encoder_hidden_states.to(fastercache_device))
|
||||
else:
|
||||
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mode=self.attention_mode,
|
||||
)
|
||||
if fastercache_counter == fastercache_start_step:
|
||||
self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)]
|
||||
self.cached_encoder_hidden_states = [attn_encoder_hidden_states.to(fastercache_device), attn_encoder_hidden_states.to(fastercache_device)]
|
||||
elif fastercache_counter > fastercache_start_step:
|
||||
self.cached_hidden_states[-1].copy_(attn_hidden_states.to(fastercache_device))
|
||||
self.cached_encoder_hidden_states[-1].copy_(attn_encoder_hidden_states.to(fastercache_device))
|
||||
|
||||
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
|
||||
@ -481,6 +424,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
temporal_interpolation_scale: float = 1.0,
|
||||
use_rotary_positional_embeddings: bool = False,
|
||||
add_noise_in_inpaint_model: bool = False,
|
||||
attention_mode: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
@ -554,6 +498,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
self.fastercache_lf_step = 40
|
||||
self.fastercache_hf_step = 30
|
||||
self.fastercache_device = "cuda"
|
||||
self.fastercache_num_blocks_to_cache = len(self.transformer_blocks)
|
||||
self.attention_mode = attention_mode
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
self.gradient_checkpointing = value
|
||||
@ -720,6 +666,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
video_flow_feature=video_flow_features[i][:1] if video_flow_features is not None else None,
|
||||
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
|
||||
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
|
||||
fastercache_start_step = self.fastercache_start_step,
|
||||
fastercache_counter = self.fastercache_counter,
|
||||
fastercache_device = self.fastercache_device
|
||||
)
|
||||
@ -770,7 +718,9 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None,
|
||||
fuser = self.fuser_list[i] if self.fuser_list is not None else None,
|
||||
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
|
||||
fastercache_counter = self.fastercache_counter,
|
||||
fastercache_start_step = self.fastercache_start_step,
|
||||
fastercache_device = self.fastercache_device
|
||||
)
|
||||
|
||||
|
||||
4
nodes.py
4
nodes.py
@ -1180,6 +1180,8 @@ class CogVideoXFunSampler:
|
||||
pipe.transformer.fastercache_lf_step = fastercache["lf_step"]
|
||||
pipe.transformer.fastercache_hf_step = fastercache["hf_step"]
|
||||
pipe.transformer.fastercache_device = fastercache["cache_device"]
|
||||
pipe.transformer.fastercache_num_blocks_to_cache = fastercache["num_blocks_to_cache"]
|
||||
log.info(f"FasterCache enabled for {pipe.transformer.fastercache_num_blocks_to_cache} blocks out of {len(pipe.transformer.transformer_blocks)}")
|
||||
else:
|
||||
pipe.transformer.use_fastercache = False
|
||||
pipe.transformer.fastercache_counter = 0
|
||||
@ -1187,7 +1189,7 @@ class CogVideoXFunSampler:
|
||||
generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed)
|
||||
|
||||
autocastcondition = not pipeline["onediff"] or not dtype == torch.float32
|
||||
autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext()
|
||||
autocast_context = torch.autocast(mm.get_autocast_device(device), dtype=dtype) if autocastcondition else nullcontext()
|
||||
with autocast_context:
|
||||
video_length = int((video_length - 1) // pipe.vae.config.temporal_compression_ratio * pipe.vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
|
||||
if vid2vid_images is not None:
|
||||
|
||||
@ -472,8 +472,15 @@ class CogVideoXPipeline(VideoSysPipeline, CogVideoXLoraLoaderMixin):
|
||||
# 5. Prepare latents.
|
||||
latent_channels = self.vae.config.latent_channels
|
||||
latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
|
||||
# For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
|
||||
patch_size_t = self.transformer.config.patch_size_t
|
||||
patch_size_t = getattr(self.transformer.config, "patch_size_t", None)
|
||||
if patch_size_t is None:
|
||||
self.transformer.config.patch_size_t = None
|
||||
ofs_embed_dim = getattr(self.transformer.config, "ofs_embed_dim", None)
|
||||
if ofs_embed_dim is None:
|
||||
self.transformer.config.ofs_embed_dim = None
|
||||
|
||||
self.additional_frames = 0
|
||||
if patch_size_t is not None and latent_frames % patch_size_t != 0:
|
||||
self.additional_frames = patch_size_t - latent_frames % patch_size_t
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user