mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-10 05:14:22 +08:00
expose sageattn 2.0.0 functions
_cuda versions seem to be required on RTX 30xx -series GPUs for sageattn + CogVideoX 1.5
This commit is contained in:
parent
411791c748
commit
729a6485ea
@ -46,9 +46,40 @@ except:
|
|||||||
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
@torch.compiler.disable()
|
|
||||||
def sageattn_func(query, key, value, attn_mask=None, dropout_p=0.0,is_causal=False):
|
def set_attention_func(attention_mode, heads):
|
||||||
return sageattn(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p,is_causal=is_causal)
|
if attention_mode == "sdpa" or attention_mode == "fused_sdpa":
|
||||||
|
def func(q, k, v, is_causal=False, attn_mask=None):
|
||||||
|
return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=is_causal)
|
||||||
|
return func
|
||||||
|
elif attention_mode == "comfy":
|
||||||
|
def func(q, k, v, is_causal=False, attn_mask=None):
|
||||||
|
return optimized_attention(q, k, v, mask=attn_mask, heads=heads, skip_reshape=True)
|
||||||
|
return func
|
||||||
|
|
||||||
|
elif attention_mode == "sageattn" or attention_mode == "fused_sageattn":
|
||||||
|
@torch.compiler.disable()
|
||||||
|
def func(q, k, v, is_causal=False, attn_mask=None):
|
||||||
|
return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask)
|
||||||
|
return func
|
||||||
|
elif attention_mode == "sageattn_qk_int8_pv_fp16_cuda":
|
||||||
|
from sageattention import sageattn_qk_int8_pv_fp16_cuda
|
||||||
|
@torch.compiler.disable()
|
||||||
|
def func(q, k, v, is_causal=False, attn_mask=None):
|
||||||
|
return sageattn_qk_int8_pv_fp16_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32")
|
||||||
|
return func
|
||||||
|
elif attention_mode == "sageattn_qk_int8_pv_fp16_triton":
|
||||||
|
from sageattention import sageattn_qk_int8_pv_fp16_triton
|
||||||
|
@torch.compiler.disable()
|
||||||
|
def func(q, k, v, is_causal=False, attn_mask=None):
|
||||||
|
return sageattn_qk_int8_pv_fp16_triton(q, k, v, is_causal=is_causal, attn_mask=attn_mask)
|
||||||
|
return func
|
||||||
|
elif attention_mode == "sageattn_qk_int8_pv_fp8_cuda":
|
||||||
|
from sageattention import sageattn_qk_int8_pv_fp8_cuda
|
||||||
|
@torch.compiler.disable()
|
||||||
|
def func(q, k, v, is_causal=False, attn_mask=None):
|
||||||
|
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32")
|
||||||
|
return func
|
||||||
|
|
||||||
def fft(tensor):
|
def fft(tensor):
|
||||||
tensor_fft = torch.fft.fft2(tensor)
|
tensor_fft = torch.fft.fft2(tensor)
|
||||||
@ -67,16 +98,18 @@ def fft(tensor):
|
|||||||
|
|
||||||
return low_freq_fft, high_freq_fft
|
return low_freq_fft, high_freq_fft
|
||||||
|
|
||||||
|
#region Attention
|
||||||
class CogVideoXAttnProcessor2_0:
|
class CogVideoXAttnProcessor2_0:
|
||||||
r"""
|
r"""
|
||||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
||||||
query and key vectors, but does not include spatial normalization.
|
query and key vectors, but does not include spatial normalization.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, attn_func, attention_mode: Optional[str] = None):
|
||||||
if not hasattr(F, "scaled_dot_product_attention"):
|
if not hasattr(F, "scaled_dot_product_attention"):
|
||||||
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||||
|
self.attention_mode = attention_mode
|
||||||
|
self.attn_func = attn_func
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
attn: Attention,
|
attn: Attention,
|
||||||
@ -84,7 +117,6 @@ class CogVideoXAttnProcessor2_0:
|
|||||||
encoder_hidden_states: torch.Tensor,
|
encoder_hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||||
attention_mode: Optional[str] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
text_seq_length = encoder_hidden_states.size(1)
|
text_seq_length = encoder_hidden_states.size(1)
|
||||||
|
|
||||||
@ -101,7 +133,7 @@ class CogVideoXAttnProcessor2_0:
|
|||||||
if attn.to_q.weight.dtype == torch.float16 or attn.to_q.weight.dtype == torch.bfloat16:
|
if attn.to_q.weight.dtype == torch.float16 or attn.to_q.weight.dtype == torch.bfloat16:
|
||||||
hidden_states = hidden_states.to(attn.to_q.weight.dtype)
|
hidden_states = hidden_states.to(attn.to_q.weight.dtype)
|
||||||
|
|
||||||
if attention_mode != "fused_sdpa" or attention_mode != "fused_sageattn":
|
if not "fused" in self.attention_mode:
|
||||||
query = attn.to_q(hidden_states)
|
query = attn.to_q(hidden_states)
|
||||||
key = attn.to_k(hidden_states)
|
key = attn.to_k(hidden_states)
|
||||||
value = attn.to_v(hidden_states)
|
value = attn.to_v(hidden_states)
|
||||||
@ -128,16 +160,10 @@ class CogVideoXAttnProcessor2_0:
|
|||||||
if not attn.is_cross_attention:
|
if not attn.is_cross_attention:
|
||||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||||
|
|
||||||
if attention_mode == "sageattn" or attention_mode == "fused_sageattn":
|
hidden_states = self.attn_func(query, key, value, attn_mask=attention_mask, is_causal=False)
|
||||||
hidden_states = sageattn_func(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False)
|
|
||||||
|
if self.attention_mode != "comfy":
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
elif attention_mode == "sdpa" or attention_mode == "fused_sdpa":
|
|
||||||
hidden_states = F.scaled_dot_product_attention(
|
|
||||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
|
||||||
)
|
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
|
||||||
elif attention_mode == "comfy":
|
|
||||||
hidden_states = optimized_attention(query, key, value, mask=attention_mask, heads=attn.heads, skip_reshape=True)
|
|
||||||
|
|
||||||
# linear proj
|
# linear proj
|
||||||
hidden_states = attn.to_out[0](hidden_states)
|
hidden_states = attn.to_out[0](hidden_states)
|
||||||
@ -203,12 +229,14 @@ class CogVideoXBlock(nn.Module):
|
|||||||
ff_inner_dim: Optional[int] = None,
|
ff_inner_dim: Optional[int] = None,
|
||||||
ff_bias: bool = True,
|
ff_bias: bool = True,
|
||||||
attention_out_bias: bool = True,
|
attention_out_bias: bool = True,
|
||||||
|
attention_mode: Optional[str] = "sdpa",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# 1. Self Attention
|
# 1. Self Attention
|
||||||
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
||||||
|
|
||||||
|
attn_func = set_attention_func(attention_mode, num_attention_heads)
|
||||||
|
|
||||||
self.attn1 = Attention(
|
self.attn1 = Attention(
|
||||||
query_dim=dim,
|
query_dim=dim,
|
||||||
@ -218,7 +246,7 @@ class CogVideoXBlock(nn.Module):
|
|||||||
eps=1e-6,
|
eps=1e-6,
|
||||||
bias=attention_bias,
|
bias=attention_bias,
|
||||||
out_bias=attention_out_bias,
|
out_bias=attention_out_bias,
|
||||||
processor=CogVideoXAttnProcessor2_0(),
|
processor=CogVideoXAttnProcessor2_0(attn_func, attention_mode=attention_mode),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. Feed Forward
|
# 2. Feed Forward
|
||||||
@ -247,7 +275,6 @@ class CogVideoXBlock(nn.Module):
|
|||||||
fastercache_counter=0,
|
fastercache_counter=0,
|
||||||
fastercache_start_step=15,
|
fastercache_start_step=15,
|
||||||
fastercache_device="cuda:0",
|
fastercache_device="cuda:0",
|
||||||
attention_mode="sdpa",
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
#print("hidden_states in block: ", hidden_states.shape) #1.5: torch.Size([2, 3200, 3072]) 10.: torch.Size([2, 6400, 3072])
|
#print("hidden_states in block: ", hidden_states.shape) #1.5: torch.Size([2, 3200, 3072]) 10.: torch.Size([2, 6400, 3072])
|
||||||
text_seq_length = encoder_hidden_states.size(1)
|
text_seq_length = encoder_hidden_states.size(1)
|
||||||
@ -286,7 +313,6 @@ class CogVideoXBlock(nn.Module):
|
|||||||
hidden_states=norm_hidden_states,
|
hidden_states=norm_hidden_states,
|
||||||
encoder_hidden_states=norm_encoder_hidden_states,
|
encoder_hidden_states=norm_encoder_hidden_states,
|
||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=image_rotary_emb,
|
||||||
attention_mode=attention_mode,
|
|
||||||
)
|
)
|
||||||
if fastercache_counter == fastercache_start_step:
|
if fastercache_counter == fastercache_start_step:
|
||||||
self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)]
|
self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)]
|
||||||
@ -298,8 +324,7 @@ class CogVideoXBlock(nn.Module):
|
|||||||
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
||||||
hidden_states=norm_hidden_states,
|
hidden_states=norm_hidden_states,
|
||||||
encoder_hidden_states=norm_encoder_hidden_states,
|
encoder_hidden_states=norm_encoder_hidden_states,
|
||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=image_rotary_emb
|
||||||
attention_mode=attention_mode,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
||||||
@ -408,6 +433,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|||||||
use_rotary_positional_embeddings: bool = False,
|
use_rotary_positional_embeddings: bool = False,
|
||||||
use_learned_positional_embeddings: bool = False,
|
use_learned_positional_embeddings: bool = False,
|
||||||
patch_bias: bool = True,
|
patch_bias: bool = True,
|
||||||
|
attention_mode: Optional[str] = "sdpa",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = num_attention_heads * attention_head_dim
|
inner_dim = num_attention_heads * attention_head_dim
|
||||||
@ -461,6 +487,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
activation_fn=activation_fn,
|
activation_fn=activation_fn,
|
||||||
attention_bias=attention_bias,
|
attention_bias=attention_bias,
|
||||||
|
attention_mode=attention_mode,
|
||||||
norm_elementwise_affine=norm_elementwise_affine,
|
norm_elementwise_affine=norm_elementwise_affine,
|
||||||
norm_eps=norm_eps,
|
norm_eps=norm_eps,
|
||||||
)
|
)
|
||||||
@ -496,73 +523,12 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|||||||
self.fastercache_hf_step = 30
|
self.fastercache_hf_step = 30
|
||||||
self.fastercache_device = "cuda"
|
self.fastercache_device = "cuda"
|
||||||
self.fastercache_num_blocks_to_cache = len(self.transformer_blocks)
|
self.fastercache_num_blocks_to_cache = len(self.transformer_blocks)
|
||||||
self.attention_mode = "sdpa"
|
self.attention_mode = attention_mode
|
||||||
|
|
||||||
|
|
||||||
def _set_gradient_checkpointing(self, module, value=False):
|
def _set_gradient_checkpointing(self, module, value=False):
|
||||||
self.gradient_checkpointing = value
|
self.gradient_checkpointing = value
|
||||||
|
#region forward
|
||||||
@property
|
|
||||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
|
||||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
|
||||||
r"""
|
|
||||||
Returns:
|
|
||||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
|
||||||
indexed by its weight name.
|
|
||||||
"""
|
|
||||||
# set recursively
|
|
||||||
processors = {}
|
|
||||||
|
|
||||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
|
||||||
if hasattr(module, "get_processor"):
|
|
||||||
processors[f"{name}.processor"] = module.get_processor()
|
|
||||||
|
|
||||||
for sub_name, child in module.named_children():
|
|
||||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
|
||||||
|
|
||||||
return processors
|
|
||||||
|
|
||||||
for name, module in self.named_children():
|
|
||||||
fn_recursive_add_processors(name, module, processors)
|
|
||||||
|
|
||||||
return processors
|
|
||||||
|
|
||||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
|
||||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
|
||||||
r"""
|
|
||||||
Sets the attention processor to use to compute attention.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
|
||||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
|
||||||
for **all** `Attention` layers.
|
|
||||||
|
|
||||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
|
||||||
processor. This is strongly recommended when setting trainable attention processors.
|
|
||||||
|
|
||||||
"""
|
|
||||||
count = len(self.attn_processors.keys())
|
|
||||||
|
|
||||||
if isinstance(processor, dict) and len(processor) != count:
|
|
||||||
raise ValueError(
|
|
||||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
|
||||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
|
||||||
)
|
|
||||||
|
|
||||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
|
||||||
if hasattr(module, "set_processor"):
|
|
||||||
if not isinstance(processor, dict):
|
|
||||||
module.set_processor(processor)
|
|
||||||
else:
|
|
||||||
module.set_processor(processor.pop(f"{name}.processor"))
|
|
||||||
|
|
||||||
for sub_name, child in module.named_children():
|
|
||||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
|
||||||
|
|
||||||
for name, module in self.named_children():
|
|
||||||
fn_recursive_attn_processor(name, module, processor)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@ -624,8 +590,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|||||||
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
|
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
|
||||||
fastercache_counter = self.fastercache_counter,
|
fastercache_counter = self.fastercache_counter,
|
||||||
fastercache_start_step = self.fastercache_start_step,
|
fastercache_start_step = self.fastercache_start_step,
|
||||||
fastercache_device = self.fastercache_device,
|
fastercache_device = self.fastercache_device
|
||||||
attention_mode = self.attention_mode
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if (controlnet_states is not None) and (i < len(controlnet_states)):
|
if (controlnet_states is not None) and (i < len(controlnet_states)):
|
||||||
@ -695,8 +660,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|||||||
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
|
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
|
||||||
fastercache_counter = self.fastercache_counter,
|
fastercache_counter = self.fastercache_counter,
|
||||||
fastercache_start_step = self.fastercache_start_step,
|
fastercache_start_step = self.fastercache_start_step,
|
||||||
fastercache_device = self.fastercache_device,
|
fastercache_device = self.fastercache_device
|
||||||
attention_mode = self.attention_mode
|
|
||||||
)
|
)
|
||||||
#has_nan = torch.isnan(hidden_states).any()
|
#has_nan = torch.isnan(hidden_states).any()
|
||||||
#if has_nan:
|
#if has_nan:
|
||||||
|
|||||||
@ -70,6 +70,7 @@ class CogVideoLoraSelect:
|
|||||||
RETURN_NAMES = ("lora", )
|
RETURN_NAMES = ("lora", )
|
||||||
FUNCTION = "getlorapath"
|
FUNCTION = "getlorapath"
|
||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
DESCRIPTION = "Select a LoRA model from ComfyUI/models/CogVideo/loras"
|
||||||
|
|
||||||
def getlorapath(self, lora, strength, prev_lora=None, fuse_lora=False):
|
def getlorapath(self, lora, strength, prev_lora=None, fuse_lora=False):
|
||||||
cog_loras_list = []
|
cog_loras_list = []
|
||||||
@ -93,7 +94,7 @@ class CogVideoLoraSelectComfy:
|
|||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"lora": (folder_paths.get_filename_list("loras"),
|
"lora": (folder_paths.get_filename_list("loras"),
|
||||||
{"tooltip": "LORA models are expected to be in ComfyUI/models/CogVideo/loras with .safetensors extension"}),
|
{"tooltip": "LORA models are expected to be in ComfyUI/models/loras with .safetensors extension"}),
|
||||||
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.0001, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}),
|
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.0001, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
@ -106,6 +107,7 @@ class CogVideoLoraSelectComfy:
|
|||||||
RETURN_NAMES = ("lora", )
|
RETURN_NAMES = ("lora", )
|
||||||
FUNCTION = "getlorapath"
|
FUNCTION = "getlorapath"
|
||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
DESCRIPTION = "Select a LoRA model from ComfyUI/models/loras"
|
||||||
|
|
||||||
def getlorapath(self, lora, strength, prev_lora=None, fuse_lora=False):
|
def getlorapath(self, lora, strength, prev_lora=None, fuse_lora=False):
|
||||||
cog_loras_list = []
|
cog_loras_list = []
|
||||||
@ -160,7 +162,19 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
|
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
|
||||||
"lora": ("COGLORA", {"default": None}),
|
"lora": ("COGLORA", {"default": None}),
|
||||||
"compile_args":("COMPILEARGS", ),
|
"compile_args":("COMPILEARGS", ),
|
||||||
"attention_mode": (["sdpa", "sageattn", "fused_sdpa", "fused_sageattn", "comfy"], {"default": "sdpa"}),
|
"attention_mode": ([
|
||||||
|
"sdpa",
|
||||||
|
"fused_sdpa",
|
||||||
|
"sageattn",
|
||||||
|
"fused_sageattn",
|
||||||
|
"sageattn_qk_int8_pv_fp8_cuda",
|
||||||
|
"sageattn_qk_int8_pv_fp16_cuda",
|
||||||
|
"sageattn_qk_int8_pv_fp16_triton",
|
||||||
|
"fused_sageattn_qk_int8_pv_fp8_cuda",
|
||||||
|
"fused_sageattn_qk_int8_pv_fp16_cuda",
|
||||||
|
"fused_sageattn_qk_int8_pv_fp16_triton",
|
||||||
|
"comfy"
|
||||||
|
], {"default": "sdpa"}),
|
||||||
"load_device": (["main_device", "offload_device"], {"default": "main_device"}),
|
"load_device": (["main_device", "offload_device"], {"default": "main_device"}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -175,11 +189,18 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
enable_sequential_cpu_offload=False, block_edit=None, lora=None, compile_args=None,
|
enable_sequential_cpu_offload=False, block_edit=None, lora=None, compile_args=None,
|
||||||
attention_mode="sdpa", load_device="main_device"):
|
attention_mode="sdpa", load_device="main_device"):
|
||||||
|
|
||||||
|
transformer = None
|
||||||
|
|
||||||
if "sage" in attention_mode:
|
if "sage" in attention_mode:
|
||||||
try:
|
try:
|
||||||
from sageattention import sageattn
|
from sageattention import sageattn
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Can't import SageAttention: {str(e)}")
|
raise ValueError(f"Can't import SageAttention: {str(e)}")
|
||||||
|
if "qk_int8" in attention_mode:
|
||||||
|
try:
|
||||||
|
from sageattention import sageattn_qk_int8_pv_fp16_cuda
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Can't import SageAttention 2.0.0: {str(e)}")
|
||||||
|
|
||||||
if precision == "fp16" and "1.5" in model:
|
if precision == "fp16" and "1.5" in model:
|
||||||
raise ValueError("1.5 models do not currently work in fp16")
|
raise ValueError("1.5 models do not currently work in fp16")
|
||||||
@ -254,7 +275,7 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
local_dir_use_symlinks=False,
|
local_dir_use_symlinks=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder)
|
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder, attention_mode=attention_mode)
|
||||||
transformer = transformer.to(dtype).to(transformer_load_device)
|
transformer = transformer.to(dtype).to(transformer_load_device)
|
||||||
|
|
||||||
if "1.5" in model:
|
if "1.5" in model:
|
||||||
@ -327,7 +348,6 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
for module in pipe.transformer.modules():
|
for module in pipe.transformer.modules():
|
||||||
if isinstance(module, Attention):
|
if isinstance(module, Attention):
|
||||||
module.fuse_projections(fuse=True)
|
module.fuse_projections(fuse=True)
|
||||||
pipe.transformer.attention_mode = attention_mode
|
|
||||||
|
|
||||||
if compile_args is not None:
|
if compile_args is not None:
|
||||||
pipe.transformer.to(memory_format=torch.channels_last)
|
pipe.transformer.to(memory_format=torch.channels_last)
|
||||||
@ -551,7 +571,7 @@ class DownloadAndLoadCogVideoGGUFModel:
|
|||||||
else:
|
else:
|
||||||
transformer_config["in_channels"] = 16
|
transformer_config["in_channels"] = 16
|
||||||
|
|
||||||
transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
|
transformer = CogVideoXTransformer3DModel.from_config(transformer_config, attention_mode=attention_mode)
|
||||||
cast_dtype = vae_dtype
|
cast_dtype = vae_dtype
|
||||||
params_to_keep = {"patch_embed", "pos_embedding", "time_embedding"}
|
params_to_keep = {"patch_embed", "pos_embedding", "time_embedding"}
|
||||||
if "2b" in model:
|
if "2b" in model:
|
||||||
@ -655,7 +675,19 @@ class CogVideoXModelLoader:
|
|||||||
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
|
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
|
||||||
"lora": ("COGLORA", {"default": None}),
|
"lora": ("COGLORA", {"default": None}),
|
||||||
"compile_args":("COMPILEARGS", ),
|
"compile_args":("COMPILEARGS", ),
|
||||||
"attention_mode": (["sdpa", "sageattn", "fused_sdpa", "fused_sageattn"], {"default": "sdpa"}),
|
"attention_mode": ([
|
||||||
|
"sdpa",
|
||||||
|
"fused_sdpa",
|
||||||
|
"sageattn",
|
||||||
|
"fused_sageattn",
|
||||||
|
"sageattn_qk_int8_pv_fp8_cuda",
|
||||||
|
"sageattn_qk_int8_pv_fp16_cuda",
|
||||||
|
"sageattn_qk_int8_pv_fp16_triton",
|
||||||
|
"fused_sageattn_qk_int8_pv_fp8_cuda",
|
||||||
|
"fused_sageattn_qk_int8_pv_fp16_cuda",
|
||||||
|
"fused_sageattn_qk_int8_pv_fp16_triton",
|
||||||
|
"comfy"
|
||||||
|
], {"default": "sdpa"}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -666,7 +698,7 @@ class CogVideoXModelLoader:
|
|||||||
|
|
||||||
def loadmodel(self, model, base_precision, load_device, enable_sequential_cpu_offload,
|
def loadmodel(self, model, base_precision, load_device, enable_sequential_cpu_offload,
|
||||||
block_edit=None, compile_args=None, lora=None, attention_mode="sdpa", quantization="disabled"):
|
block_edit=None, compile_args=None, lora=None, attention_mode="sdpa", quantization="disabled"):
|
||||||
|
transformer = None
|
||||||
if "sage" in attention_mode:
|
if "sage" in attention_mode:
|
||||||
try:
|
try:
|
||||||
from sageattention import sageattn
|
from sageattention import sageattn
|
||||||
@ -732,7 +764,7 @@ class CogVideoXModelLoader:
|
|||||||
transformer_config["sample_width"] = 300
|
transformer_config["sample_width"] = 300
|
||||||
|
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
|
transformer = CogVideoXTransformer3DModel.from_config(transformer_config, attention_mode=attention_mode)
|
||||||
|
|
||||||
#load weights
|
#load weights
|
||||||
#params_to_keep = {}
|
#params_to_keep = {}
|
||||||
@ -1084,7 +1116,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"CogVideoLoraSelect": CogVideoLoraSelect,
|
"CogVideoLoraSelect": CogVideoLoraSelect,
|
||||||
"CogVideoXVAELoader": CogVideoXVAELoader,
|
"CogVideoXVAELoader": CogVideoXVAELoader,
|
||||||
"CogVideoXModelLoader": CogVideoXModelLoader,
|
"CogVideoXModelLoader": CogVideoXModelLoader,
|
||||||
"CogVideoLoraSelectComfy": CogVideoLoraSelectComfy,
|
"CogVideoLoraSelectComfy": CogVideoLoraSelectComfy
|
||||||
}
|
}
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
|
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
|
||||||
@ -1094,5 +1126,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"CogVideoLoraSelect": "CogVideo LoraSelect",
|
"CogVideoLoraSelect": "CogVideo LoraSelect",
|
||||||
"CogVideoXVAELoader": "CogVideoX VAE Loader",
|
"CogVideoXVAELoader": "CogVideoX VAE Loader",
|
||||||
"CogVideoXModelLoader": "CogVideoX Model Loader",
|
"CogVideoXModelLoader": "CogVideoX Model Loader",
|
||||||
"CogVideoLoraSelectComfy": "CogVideo LoraSelect Comfy",
|
"CogVideoLoraSelectComfy": "CogVideo LoraSelect Comfy"
|
||||||
}
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user