mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-08 20:34:23 +08:00
sageattn
This commit is contained in:
parent
9a797229f2
commit
634c22db50
@ -32,6 +32,7 @@ from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
|
||||
from diffusers.loaders import PeftAdapterMixin
|
||||
from diffusers.models.embeddings import apply_rotary_emb
|
||||
from .embeddings import CogVideoXPatchEmbed
|
||||
|
||||
|
||||
@ -40,9 +41,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
SAGEATTN_IS_AVAILABLE = True
|
||||
logger.info("Using sageattn")
|
||||
except:
|
||||
logger.info("sageattn not found, using sdpa")
|
||||
SAGEATTN_IS_AVAILABLE = False
|
||||
|
||||
def fft(tensor):
|
||||
@ -73,7 +72,6 @@ class CogVideoXAttnProcessor2_0:
|
||||
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
@torch.compiler.disable()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
@ -81,6 +79,7 @@ class CogVideoXAttnProcessor2_0:
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
attention_mode: Optional[str] = None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
@ -112,20 +111,21 @@ class CogVideoXAttnProcessor2_0:
|
||||
|
||||
# Apply RoPE if needed
|
||||
if image_rotary_emb is not None:
|
||||
from diffusers.models.embeddings import apply_rotary_emb
|
||||
|
||||
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
||||
if not attn.is_cross_attention:
|
||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||
|
||||
#if SAGEATTN_IS_AVAILABLE:
|
||||
# hidden_states = sageattn(query, key, value, is_causal=False)
|
||||
#else:
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
if torch.isinf(hidden_states).any():
|
||||
raise ValueError(f"hidden_states after dot product has inf")
|
||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||
|
||||
if attention_mode == "sageattn":
|
||||
if SAGEATTN_IS_AVAILABLE:
|
||||
hidden_states = sageattn(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False)
|
||||
else:
|
||||
raise ImportError("sageattn not found")
|
||||
else:
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
#if torch.isinf(hidden_states).any():
|
||||
# raise ValueError(f"hidden_states after dot product has inf")
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
|
||||
@ -193,6 +193,7 @@ class CogVideoXBlock(nn.Module):
|
||||
ff_inner_dim: Optional[int] = None,
|
||||
ff_bias: bool = True,
|
||||
attention_out_bias: bool = True,
|
||||
attention_mode: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -224,6 +225,7 @@ class CogVideoXBlock(nn.Module):
|
||||
)
|
||||
self.cached_hidden_states = []
|
||||
self.cached_encoder_hidden_states = []
|
||||
self.attention_mode = attention_mode
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -235,7 +237,7 @@ class CogVideoXBlock(nn.Module):
|
||||
fuser=None,
|
||||
fastercache_counter=0,
|
||||
fastercache_start_step=15,
|
||||
fastercache_device="cuda:0"
|
||||
fastercache_device="cuda:0",
|
||||
) -> torch.Tensor:
|
||||
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
@ -271,7 +273,8 @@ class CogVideoXBlock(nn.Module):
|
||||
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mode=self.attention_mode,
|
||||
)
|
||||
if fastercache_counter == fastercache_start_step:
|
||||
self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)]
|
||||
@ -386,6 +389,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
use_rotary_positional_embeddings: bool = False,
|
||||
use_learned_positional_embeddings: bool = False,
|
||||
patch_bias: bool = True,
|
||||
attention_mode: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
@ -471,6 +475,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
self.fastercache_lf_step = 40
|
||||
self.fastercache_hf_step = 30
|
||||
self.fastercache_device = "cuda"
|
||||
self.attention_mode = attention_mode
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
self.gradient_checkpointing = value
|
||||
@ -667,9 +672,9 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
fastercache_counter = self.fastercache_counter,
|
||||
fastercache_device = self.fastercache_device
|
||||
)
|
||||
has_nan = torch.isnan(hidden_states).any()
|
||||
if has_nan:
|
||||
raise ValueError(f"block output hidden_states has nan: {has_nan}")
|
||||
#has_nan = torch.isnan(hidden_states).any()
|
||||
#if has_nan:
|
||||
# raise ValueError(f"block output hidden_states has nan: {has_nan}")
|
||||
|
||||
if (controlnet_states is not None) and (i < len(controlnet_states)):
|
||||
controlnet_states_block = controlnet_states[i]
|
||||
|
||||
@ -39,11 +39,9 @@ def fp8_linear_forward(cls, original_dtype, input):
|
||||
def convert_fp8_linear(module, original_dtype, params_to_keep={}):
|
||||
setattr(module, "fp8_matmul_enabled", True)
|
||||
|
||||
|
||||
for name, module in module.named_modules():
|
||||
if not any(keyword in name for keyword in params_to_keep):
|
||||
if isinstance(module, nn.Linear):
|
||||
print(name)
|
||||
original_forward = module.forward
|
||||
setattr(module, "original_forward", original_forward)
|
||||
setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input))
|
||||
|
||||
@ -60,6 +60,7 @@ class CogVideoLoraSelect:
|
||||
cog_loras_list.append(cog_lora)
|
||||
print(cog_loras_list)
|
||||
return (cog_loras_list,)
|
||||
|
||||
#region DownloadAndLoadCogVideoModel
|
||||
class DownloadAndLoadCogVideoModel:
|
||||
@classmethod
|
||||
@ -98,6 +99,7 @@ class DownloadAndLoadCogVideoModel:
|
||||
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
|
||||
"lora": ("COGLORA", {"default": None}),
|
||||
"compile_args":("COMPILEARGS", ),
|
||||
"attention_mode": (["sdpa", "sageattn"], {"default": "sdpa"}),
|
||||
"load_device": (["main_device", "offload_device"], {"default": "main_device"}),
|
||||
}
|
||||
}
|
||||
@ -108,9 +110,9 @@ class DownloadAndLoadCogVideoModel:
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
DESCRIPTION = "Downloads and loads the selected CogVideo model from Huggingface to 'ComfyUI/models/CogVideo'"
|
||||
|
||||
def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled", enable_sequential_cpu_offload=False, pab_config=None, block_edit=None, lora=None, compile_args=None, load_device="main_device"):
|
||||
|
||||
check_diffusers_version()
|
||||
def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled",
|
||||
enable_sequential_cpu_offload=False, pab_config=None, block_edit=None, lora=None, compile_args=None,
|
||||
attention_mode="sdpa", load_device="main_device"):
|
||||
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
@ -195,6 +197,8 @@ class DownloadAndLoadCogVideoModel:
|
||||
|
||||
transformer = transformer.to(dtype).to(transformer_load_device)
|
||||
|
||||
transformer.attention_mode = attention_mode
|
||||
|
||||
if block_edit is not None:
|
||||
transformer = remove_specific_blocks(transformer, block_edit)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user