mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-08 20:34:23 +08:00
sageattn
This commit is contained in:
parent
9a797229f2
commit
634c22db50
@ -32,6 +32,7 @@ from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
|||||||
from diffusers.models.modeling_utils import ModelMixin
|
from diffusers.models.modeling_utils import ModelMixin
|
||||||
from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
|
from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
|
||||||
from diffusers.loaders import PeftAdapterMixin
|
from diffusers.loaders import PeftAdapterMixin
|
||||||
|
from diffusers.models.embeddings import apply_rotary_emb
|
||||||
from .embeddings import CogVideoXPatchEmbed
|
from .embeddings import CogVideoXPatchEmbed
|
||||||
|
|
||||||
|
|
||||||
@ -40,9 +41,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|||||||
try:
|
try:
|
||||||
from sageattention import sageattn
|
from sageattention import sageattn
|
||||||
SAGEATTN_IS_AVAILABLE = True
|
SAGEATTN_IS_AVAILABLE = True
|
||||||
logger.info("Using sageattn")
|
|
||||||
except:
|
except:
|
||||||
logger.info("sageattn not found, using sdpa")
|
|
||||||
SAGEATTN_IS_AVAILABLE = False
|
SAGEATTN_IS_AVAILABLE = False
|
||||||
|
|
||||||
def fft(tensor):
|
def fft(tensor):
|
||||||
@ -73,7 +72,6 @@ class CogVideoXAttnProcessor2_0:
|
|||||||
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||||
|
|
||||||
@torch.compiler.disable()
|
@torch.compiler.disable()
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
attn: Attention,
|
attn: Attention,
|
||||||
@ -81,6 +79,7 @@ class CogVideoXAttnProcessor2_0:
|
|||||||
encoder_hidden_states: torch.Tensor,
|
encoder_hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||||
|
attention_mode: Optional[str] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
text_seq_length = encoder_hidden_states.size(1)
|
text_seq_length = encoder_hidden_states.size(1)
|
||||||
|
|
||||||
@ -112,20 +111,21 @@ class CogVideoXAttnProcessor2_0:
|
|||||||
|
|
||||||
# Apply RoPE if needed
|
# Apply RoPE if needed
|
||||||
if image_rotary_emb is not None:
|
if image_rotary_emb is not None:
|
||||||
from diffusers.models.embeddings import apply_rotary_emb
|
|
||||||
|
|
||||||
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
||||||
if not attn.is_cross_attention:
|
if not attn.is_cross_attention:
|
||||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||||
|
|
||||||
#if SAGEATTN_IS_AVAILABLE:
|
if attention_mode == "sageattn":
|
||||||
# hidden_states = sageattn(query, key, value, is_causal=False)
|
if SAGEATTN_IS_AVAILABLE:
|
||||||
#else:
|
hidden_states = sageattn(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False)
|
||||||
hidden_states = F.scaled_dot_product_attention(
|
else:
|
||||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
raise ImportError("sageattn not found")
|
||||||
)
|
else:
|
||||||
if torch.isinf(hidden_states).any():
|
hidden_states = F.scaled_dot_product_attention(
|
||||||
raise ValueError(f"hidden_states after dot product has inf")
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||||
|
)
|
||||||
|
#if torch.isinf(hidden_states).any():
|
||||||
|
# raise ValueError(f"hidden_states after dot product has inf")
|
||||||
|
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
|
|
||||||
@ -193,6 +193,7 @@ class CogVideoXBlock(nn.Module):
|
|||||||
ff_inner_dim: Optional[int] = None,
|
ff_inner_dim: Optional[int] = None,
|
||||||
ff_bias: bool = True,
|
ff_bias: bool = True,
|
||||||
attention_out_bias: bool = True,
|
attention_out_bias: bool = True,
|
||||||
|
attention_mode: Optional[str] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -224,6 +225,7 @@ class CogVideoXBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
self.cached_hidden_states = []
|
self.cached_hidden_states = []
|
||||||
self.cached_encoder_hidden_states = []
|
self.cached_encoder_hidden_states = []
|
||||||
|
self.attention_mode = attention_mode
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -235,7 +237,7 @@ class CogVideoXBlock(nn.Module):
|
|||||||
fuser=None,
|
fuser=None,
|
||||||
fastercache_counter=0,
|
fastercache_counter=0,
|
||||||
fastercache_start_step=15,
|
fastercache_start_step=15,
|
||||||
fastercache_device="cuda:0"
|
fastercache_device="cuda:0",
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
text_seq_length = encoder_hidden_states.size(1)
|
text_seq_length = encoder_hidden_states.size(1)
|
||||||
@ -271,7 +273,8 @@ class CogVideoXBlock(nn.Module):
|
|||||||
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
||||||
hidden_states=norm_hidden_states,
|
hidden_states=norm_hidden_states,
|
||||||
encoder_hidden_states=norm_encoder_hidden_states,
|
encoder_hidden_states=norm_encoder_hidden_states,
|
||||||
image_rotary_emb=image_rotary_emb
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
attention_mode=self.attention_mode,
|
||||||
)
|
)
|
||||||
if fastercache_counter == fastercache_start_step:
|
if fastercache_counter == fastercache_start_step:
|
||||||
self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)]
|
self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)]
|
||||||
@ -386,6 +389,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|||||||
use_rotary_positional_embeddings: bool = False,
|
use_rotary_positional_embeddings: bool = False,
|
||||||
use_learned_positional_embeddings: bool = False,
|
use_learned_positional_embeddings: bool = False,
|
||||||
patch_bias: bool = True,
|
patch_bias: bool = True,
|
||||||
|
attention_mode: Optional[str] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = num_attention_heads * attention_head_dim
|
inner_dim = num_attention_heads * attention_head_dim
|
||||||
@ -471,6 +475,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|||||||
self.fastercache_lf_step = 40
|
self.fastercache_lf_step = 40
|
||||||
self.fastercache_hf_step = 30
|
self.fastercache_hf_step = 30
|
||||||
self.fastercache_device = "cuda"
|
self.fastercache_device = "cuda"
|
||||||
|
self.attention_mode = attention_mode
|
||||||
|
|
||||||
def _set_gradient_checkpointing(self, module, value=False):
|
def _set_gradient_checkpointing(self, module, value=False):
|
||||||
self.gradient_checkpointing = value
|
self.gradient_checkpointing = value
|
||||||
@ -667,9 +672,9 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|||||||
fastercache_counter = self.fastercache_counter,
|
fastercache_counter = self.fastercache_counter,
|
||||||
fastercache_device = self.fastercache_device
|
fastercache_device = self.fastercache_device
|
||||||
)
|
)
|
||||||
has_nan = torch.isnan(hidden_states).any()
|
#has_nan = torch.isnan(hidden_states).any()
|
||||||
if has_nan:
|
#if has_nan:
|
||||||
raise ValueError(f"block output hidden_states has nan: {has_nan}")
|
# raise ValueError(f"block output hidden_states has nan: {has_nan}")
|
||||||
|
|
||||||
if (controlnet_states is not None) and (i < len(controlnet_states)):
|
if (controlnet_states is not None) and (i < len(controlnet_states)):
|
||||||
controlnet_states_block = controlnet_states[i]
|
controlnet_states_block = controlnet_states[i]
|
||||||
|
|||||||
@ -39,11 +39,9 @@ def fp8_linear_forward(cls, original_dtype, input):
|
|||||||
def convert_fp8_linear(module, original_dtype, params_to_keep={}):
|
def convert_fp8_linear(module, original_dtype, params_to_keep={}):
|
||||||
setattr(module, "fp8_matmul_enabled", True)
|
setattr(module, "fp8_matmul_enabled", True)
|
||||||
|
|
||||||
|
|
||||||
for name, module in module.named_modules():
|
for name, module in module.named_modules():
|
||||||
if not any(keyword in name for keyword in params_to_keep):
|
if not any(keyword in name for keyword in params_to_keep):
|
||||||
if isinstance(module, nn.Linear):
|
if isinstance(module, nn.Linear):
|
||||||
print(name)
|
|
||||||
original_forward = module.forward
|
original_forward = module.forward
|
||||||
setattr(module, "original_forward", original_forward)
|
setattr(module, "original_forward", original_forward)
|
||||||
setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input))
|
setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input))
|
||||||
|
|||||||
@ -60,6 +60,7 @@ class CogVideoLoraSelect:
|
|||||||
cog_loras_list.append(cog_lora)
|
cog_loras_list.append(cog_lora)
|
||||||
print(cog_loras_list)
|
print(cog_loras_list)
|
||||||
return (cog_loras_list,)
|
return (cog_loras_list,)
|
||||||
|
|
||||||
#region DownloadAndLoadCogVideoModel
|
#region DownloadAndLoadCogVideoModel
|
||||||
class DownloadAndLoadCogVideoModel:
|
class DownloadAndLoadCogVideoModel:
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -98,6 +99,7 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
|
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
|
||||||
"lora": ("COGLORA", {"default": None}),
|
"lora": ("COGLORA", {"default": None}),
|
||||||
"compile_args":("COMPILEARGS", ),
|
"compile_args":("COMPILEARGS", ),
|
||||||
|
"attention_mode": (["sdpa", "sageattn"], {"default": "sdpa"}),
|
||||||
"load_device": (["main_device", "offload_device"], {"default": "main_device"}),
|
"load_device": (["main_device", "offload_device"], {"default": "main_device"}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -108,9 +110,9 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
DESCRIPTION = "Downloads and loads the selected CogVideo model from Huggingface to 'ComfyUI/models/CogVideo'"
|
DESCRIPTION = "Downloads and loads the selected CogVideo model from Huggingface to 'ComfyUI/models/CogVideo'"
|
||||||
|
|
||||||
def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled", enable_sequential_cpu_offload=False, pab_config=None, block_edit=None, lora=None, compile_args=None, load_device="main_device"):
|
def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled",
|
||||||
|
enable_sequential_cpu_offload=False, pab_config=None, block_edit=None, lora=None, compile_args=None,
|
||||||
check_diffusers_version()
|
attention_mode="sdpa", load_device="main_device"):
|
||||||
|
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
@ -195,6 +197,8 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
|
|
||||||
transformer = transformer.to(dtype).to(transformer_load_device)
|
transformer = transformer.to(dtype).to(transformer_load_device)
|
||||||
|
|
||||||
|
transformer.attention_mode = attention_mode
|
||||||
|
|
||||||
if block_edit is not None:
|
if block_edit is not None:
|
||||||
transformer = remove_specific_blocks(transformer, block_edit)
|
transformer = remove_specific_blocks(transformer, block_edit)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user