This commit is contained in:
kijai 2024-11-09 17:05:55 +02:00
parent 9a797229f2
commit 634c22db50
3 changed files with 32 additions and 25 deletions

View File

@ -32,6 +32,7 @@ from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
from diffusers.loaders import PeftAdapterMixin from diffusers.loaders import PeftAdapterMixin
from diffusers.models.embeddings import apply_rotary_emb
from .embeddings import CogVideoXPatchEmbed from .embeddings import CogVideoXPatchEmbed
@ -40,9 +41,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
try: try:
from sageattention import sageattn from sageattention import sageattn
SAGEATTN_IS_AVAILABLE = True SAGEATTN_IS_AVAILABLE = True
logger.info("Using sageattn")
except: except:
logger.info("sageattn not found, using sdpa")
SAGEATTN_IS_AVAILABLE = False SAGEATTN_IS_AVAILABLE = False
def fft(tensor): def fft(tensor):
@ -73,7 +72,6 @@ class CogVideoXAttnProcessor2_0:
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
@torch.compiler.disable() @torch.compiler.disable()
def __call__( def __call__(
self, self,
attn: Attention, attn: Attention,
@ -81,6 +79,7 @@ class CogVideoXAttnProcessor2_0:
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None,
attention_mode: Optional[str] = None,
) -> torch.Tensor: ) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1) text_seq_length = encoder_hidden_states.size(1)
@ -112,20 +111,21 @@ class CogVideoXAttnProcessor2_0:
# Apply RoPE if needed # Apply RoPE if needed
if image_rotary_emb is not None: if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
if not attn.is_cross_attention: if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
#if SAGEATTN_IS_AVAILABLE: if attention_mode == "sageattn":
# hidden_states = sageattn(query, key, value, is_causal=False) if SAGEATTN_IS_AVAILABLE:
#else: hidden_states = sageattn(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False)
hidden_states = F.scaled_dot_product_attention( else:
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False raise ImportError("sageattn not found")
) else:
if torch.isinf(hidden_states).any(): hidden_states = F.scaled_dot_product_attention(
raise ValueError(f"hidden_states after dot product has inf") query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
#if torch.isinf(hidden_states).any():
# raise ValueError(f"hidden_states after dot product has inf")
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
@ -193,6 +193,7 @@ class CogVideoXBlock(nn.Module):
ff_inner_dim: Optional[int] = None, ff_inner_dim: Optional[int] = None,
ff_bias: bool = True, ff_bias: bool = True,
attention_out_bias: bool = True, attention_out_bias: bool = True,
attention_mode: Optional[str] = None,
): ):
super().__init__() super().__init__()
@ -224,6 +225,7 @@ class CogVideoXBlock(nn.Module):
) )
self.cached_hidden_states = [] self.cached_hidden_states = []
self.cached_encoder_hidden_states = [] self.cached_encoder_hidden_states = []
self.attention_mode = attention_mode
def forward( def forward(
self, self,
@ -235,7 +237,7 @@ class CogVideoXBlock(nn.Module):
fuser=None, fuser=None,
fastercache_counter=0, fastercache_counter=0,
fastercache_start_step=15, fastercache_start_step=15,
fastercache_device="cuda:0" fastercache_device="cuda:0",
) -> torch.Tensor: ) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1) text_seq_length = encoder_hidden_states.size(1)
@ -271,7 +273,8 @@ class CogVideoXBlock(nn.Module):
attn_hidden_states, attn_encoder_hidden_states = self.attn1( attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states, hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states, encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb image_rotary_emb=image_rotary_emb,
attention_mode=self.attention_mode,
) )
if fastercache_counter == fastercache_start_step: if fastercache_counter == fastercache_start_step:
self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)] self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)]
@ -386,6 +389,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
use_rotary_positional_embeddings: bool = False, use_rotary_positional_embeddings: bool = False,
use_learned_positional_embeddings: bool = False, use_learned_positional_embeddings: bool = False,
patch_bias: bool = True, patch_bias: bool = True,
attention_mode: Optional[str] = None,
): ):
super().__init__() super().__init__()
inner_dim = num_attention_heads * attention_head_dim inner_dim = num_attention_heads * attention_head_dim
@ -471,6 +475,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
self.fastercache_lf_step = 40 self.fastercache_lf_step = 40
self.fastercache_hf_step = 30 self.fastercache_hf_step = 30
self.fastercache_device = "cuda" self.fastercache_device = "cuda"
self.attention_mode = attention_mode
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value self.gradient_checkpointing = value
@ -667,9 +672,9 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
fastercache_counter = self.fastercache_counter, fastercache_counter = self.fastercache_counter,
fastercache_device = self.fastercache_device fastercache_device = self.fastercache_device
) )
has_nan = torch.isnan(hidden_states).any() #has_nan = torch.isnan(hidden_states).any()
if has_nan: #if has_nan:
raise ValueError(f"block output hidden_states has nan: {has_nan}") # raise ValueError(f"block output hidden_states has nan: {has_nan}")
if (controlnet_states is not None) and (i < len(controlnet_states)): if (controlnet_states is not None) and (i < len(controlnet_states)):
controlnet_states_block = controlnet_states[i] controlnet_states_block = controlnet_states[i]

View File

@ -39,11 +39,9 @@ def fp8_linear_forward(cls, original_dtype, input):
def convert_fp8_linear(module, original_dtype, params_to_keep={}): def convert_fp8_linear(module, original_dtype, params_to_keep={}):
setattr(module, "fp8_matmul_enabled", True) setattr(module, "fp8_matmul_enabled", True)
for name, module in module.named_modules(): for name, module in module.named_modules():
if not any(keyword in name for keyword in params_to_keep): if not any(keyword in name for keyword in params_to_keep):
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
print(name)
original_forward = module.forward original_forward = module.forward
setattr(module, "original_forward", original_forward) setattr(module, "original_forward", original_forward)
setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input)) setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input))

View File

@ -60,6 +60,7 @@ class CogVideoLoraSelect:
cog_loras_list.append(cog_lora) cog_loras_list.append(cog_lora)
print(cog_loras_list) print(cog_loras_list)
return (cog_loras_list,) return (cog_loras_list,)
#region DownloadAndLoadCogVideoModel #region DownloadAndLoadCogVideoModel
class DownloadAndLoadCogVideoModel: class DownloadAndLoadCogVideoModel:
@classmethod @classmethod
@ -98,6 +99,7 @@ class DownloadAndLoadCogVideoModel:
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}), "block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
"lora": ("COGLORA", {"default": None}), "lora": ("COGLORA", {"default": None}),
"compile_args":("COMPILEARGS", ), "compile_args":("COMPILEARGS", ),
"attention_mode": (["sdpa", "sageattn"], {"default": "sdpa"}),
"load_device": (["main_device", "offload_device"], {"default": "main_device"}), "load_device": (["main_device", "offload_device"], {"default": "main_device"}),
} }
} }
@ -108,9 +110,9 @@ class DownloadAndLoadCogVideoModel:
CATEGORY = "CogVideoWrapper" CATEGORY = "CogVideoWrapper"
DESCRIPTION = "Downloads and loads the selected CogVideo model from Huggingface to 'ComfyUI/models/CogVideo'" DESCRIPTION = "Downloads and loads the selected CogVideo model from Huggingface to 'ComfyUI/models/CogVideo'"
def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled", enable_sequential_cpu_offload=False, pab_config=None, block_edit=None, lora=None, compile_args=None, load_device="main_device"): def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled",
enable_sequential_cpu_offload=False, pab_config=None, block_edit=None, lora=None, compile_args=None,
check_diffusers_version() attention_mode="sdpa", load_device="main_device"):
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
@ -195,6 +197,8 @@ class DownloadAndLoadCogVideoModel:
transformer = transformer.to(dtype).to(transformer_load_device) transformer = transformer.to(dtype).to(transformer_load_device)
transformer.attention_mode = attention_mode
if block_edit is not None: if block_edit is not None:
transformer = remove_specific_blocks(transformer, block_edit) transformer = remove_specific_blocks(transformer, block_edit)