This commit is contained in:
kijai 2024-11-09 17:05:55 +02:00
parent 9a797229f2
commit 634c22db50
3 changed files with 32 additions and 25 deletions

View File

@ -32,6 +32,7 @@ from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
from diffusers.loaders import PeftAdapterMixin
from diffusers.models.embeddings import apply_rotary_emb
from .embeddings import CogVideoXPatchEmbed
@ -40,9 +41,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
try:
from sageattention import sageattn
SAGEATTN_IS_AVAILABLE = True
logger.info("Using sageattn")
except:
logger.info("sageattn not found, using sdpa")
SAGEATTN_IS_AVAILABLE = False
def fft(tensor):
@ -73,7 +72,6 @@ class CogVideoXAttnProcessor2_0:
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
@torch.compiler.disable()
def __call__(
self,
attn: Attention,
@ -81,6 +79,7 @@ class CogVideoXAttnProcessor2_0:
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
attention_mode: Optional[str] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
@ -112,20 +111,21 @@ class CogVideoXAttnProcessor2_0:
# Apply RoPE if needed
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
#if SAGEATTN_IS_AVAILABLE:
# hidden_states = sageattn(query, key, value, is_causal=False)
#else:
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
if torch.isinf(hidden_states).any():
raise ValueError(f"hidden_states after dot product has inf")
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
if attention_mode == "sageattn":
if SAGEATTN_IS_AVAILABLE:
hidden_states = sageattn(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False)
else:
raise ImportError("sageattn not found")
else:
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
#if torch.isinf(hidden_states).any():
# raise ValueError(f"hidden_states after dot product has inf")
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
@ -193,6 +193,7 @@ class CogVideoXBlock(nn.Module):
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
attention_mode: Optional[str] = None,
):
super().__init__()
@ -224,6 +225,7 @@ class CogVideoXBlock(nn.Module):
)
self.cached_hidden_states = []
self.cached_encoder_hidden_states = []
self.attention_mode = attention_mode
def forward(
self,
@ -235,7 +237,7 @@ class CogVideoXBlock(nn.Module):
fuser=None,
fastercache_counter=0,
fastercache_start_step=15,
fastercache_device="cuda:0"
fastercache_device="cuda:0",
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
@ -271,7 +273,8 @@ class CogVideoXBlock(nn.Module):
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb
image_rotary_emb=image_rotary_emb,
attention_mode=self.attention_mode,
)
if fastercache_counter == fastercache_start_step:
self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)]
@ -386,6 +389,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
use_rotary_positional_embeddings: bool = False,
use_learned_positional_embeddings: bool = False,
patch_bias: bool = True,
attention_mode: Optional[str] = None,
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
@ -471,6 +475,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
self.fastercache_lf_step = 40
self.fastercache_hf_step = 30
self.fastercache_device = "cuda"
self.attention_mode = attention_mode
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
@ -667,9 +672,9 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
fastercache_counter = self.fastercache_counter,
fastercache_device = self.fastercache_device
)
has_nan = torch.isnan(hidden_states).any()
if has_nan:
raise ValueError(f"block output hidden_states has nan: {has_nan}")
#has_nan = torch.isnan(hidden_states).any()
#if has_nan:
# raise ValueError(f"block output hidden_states has nan: {has_nan}")
if (controlnet_states is not None) and (i < len(controlnet_states)):
controlnet_states_block = controlnet_states[i]

View File

@ -39,11 +39,9 @@ def fp8_linear_forward(cls, original_dtype, input):
def convert_fp8_linear(module, original_dtype, params_to_keep={}):
setattr(module, "fp8_matmul_enabled", True)
for name, module in module.named_modules():
if not any(keyword in name for keyword in params_to_keep):
if isinstance(module, nn.Linear):
print(name)
original_forward = module.forward
setattr(module, "original_forward", original_forward)
setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input))

View File

@ -60,6 +60,7 @@ class CogVideoLoraSelect:
cog_loras_list.append(cog_lora)
print(cog_loras_list)
return (cog_loras_list,)
#region DownloadAndLoadCogVideoModel
class DownloadAndLoadCogVideoModel:
@classmethod
@ -98,6 +99,7 @@ class DownloadAndLoadCogVideoModel:
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
"lora": ("COGLORA", {"default": None}),
"compile_args":("COMPILEARGS", ),
"attention_mode": (["sdpa", "sageattn"], {"default": "sdpa"}),
"load_device": (["main_device", "offload_device"], {"default": "main_device"}),
}
}
@ -108,9 +110,9 @@ class DownloadAndLoadCogVideoModel:
CATEGORY = "CogVideoWrapper"
DESCRIPTION = "Downloads and loads the selected CogVideo model from Huggingface to 'ComfyUI/models/CogVideo'"
def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled", enable_sequential_cpu_offload=False, pab_config=None, block_edit=None, lora=None, compile_args=None, load_device="main_device"):
check_diffusers_version()
def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled",
enable_sequential_cpu_offload=False, pab_config=None, block_edit=None, lora=None, compile_args=None,
attention_mode="sdpa", load_device="main_device"):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
@ -195,6 +197,8 @@ class DownloadAndLoadCogVideoModel:
transformer = transformer.to(dtype).to(transformer_load_device)
transformer.attention_mode = attention_mode
if block_edit is not None:
transformer = remove_specific_blocks(transformer, block_edit)