sageattn 2.0.0 options

This commit is contained in:
kijai 2024-11-27 01:16:22 +02:00
parent 1ade29084e
commit 2b211b9d1b
3 changed files with 113 additions and 97 deletions

View File

@ -46,9 +46,40 @@ except:
from comfy.ldm.modules.attention import optimized_attention
@torch.compiler.disable()
def sageattn_func(query, key, value, attn_mask=None, dropout_p=0.0,is_causal=False):
return sageattn(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p,is_causal=is_causal)
def set_attention_func(attention_mode, heads):
if attention_mode == "sdpa" or attention_mode == "fused_sdpa":
def func(q, k, v, is_causal=False, attn_mask=None):
return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=is_causal)
return func
elif attention_mode == "comfy":
def func(q, k, v, is_causal=False, attn_mask=None):
return optimized_attention(q, k, v, mask=attn_mask, heads=heads, skip_reshape=True)
return func
elif attention_mode == "sageattn" or attention_mode == "fused_sageattn":
@torch.compiler.disable()
def func(q, k, v, is_causal=False, attn_mask=None):
return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask)
return func
elif attention_mode == "sageattn_qk_int8_pv_fp16_cuda":
from sageattention import sageattn_qk_int8_pv_fp16_cuda
@torch.compiler.disable()
def func(q, k, v, is_causal=False, attn_mask=None):
return sageattn_qk_int8_pv_fp16_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32")
return func
elif attention_mode == "sageattn_qk_int8_pv_fp16_triton":
from sageattention import sageattn_qk_int8_pv_fp16_triton
@torch.compiler.disable()
def func(q, k, v, is_causal=False, attn_mask=None):
return sageattn_qk_int8_pv_fp16_triton(q, k, v, is_causal=is_causal, attn_mask=attn_mask)
return func
elif attention_mode == "sageattn_qk_int8_pv_fp8_cuda":
from sageattention import sageattn_qk_int8_pv_fp8_cuda
@torch.compiler.disable()
def func(q, k, v, is_causal=False, attn_mask=None):
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32")
return func
def fft(tensor):
tensor_fft = torch.fft.fft2(tensor)
@ -67,16 +98,18 @@ def fft(tensor):
return low_freq_fft, high_freq_fft
#region Attention
class CogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""
def __init__(self):
def __init__(self, attn_func, attention_mode: Optional[str] = None):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.attention_mode = attention_mode
self.attn_func = attn_func
def __call__(
self,
attn: Attention,
@ -84,7 +117,6 @@ class CogVideoXAttnProcessor2_0:
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
attention_mode: Optional[str] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
@ -101,7 +133,7 @@ class CogVideoXAttnProcessor2_0:
if attn.to_q.weight.dtype == torch.float16 or attn.to_q.weight.dtype == torch.bfloat16:
hidden_states = hidden_states.to(attn.to_q.weight.dtype)
if attention_mode != "fused_sdpa" or attention_mode != "fused_sageattn":
if not "fused" in self.attention_mode:
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
@ -128,16 +160,10 @@ class CogVideoXAttnProcessor2_0:
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
if attention_mode == "sageattn" or attention_mode == "fused_sageattn":
hidden_states = sageattn_func(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False)
hidden_states = self.attn_func(query, key, value, attn_mask=attention_mask, is_causal=False)
if self.attention_mode != "comfy":
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
elif attention_mode == "sdpa" or attention_mode == "fused_sdpa":
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
elif attention_mode == "comfy":
hidden_states = optimized_attention(query, key, value, mask=attention_mask, heads=attn.heads, skip_reshape=True)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
@ -203,13 +229,15 @@ class CogVideoXBlock(nn.Module):
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
attention_mode: Optional[str] = "sdpa",
):
super().__init__()
# 1. Self Attention
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
attn_func = set_attention_func(attention_mode, num_attention_heads)
self.attn1 = Attention(
query_dim=dim,
dim_head=attention_head_dim,
@ -218,7 +246,7 @@ class CogVideoXBlock(nn.Module):
eps=1e-6,
bias=attention_bias,
out_bias=attention_out_bias,
processor=CogVideoXAttnProcessor2_0(),
processor=CogVideoXAttnProcessor2_0(attn_func, attention_mode=attention_mode),
)
# 2. Feed Forward
@ -247,7 +275,6 @@ class CogVideoXBlock(nn.Module):
fastercache_counter=0,
fastercache_start_step=15,
fastercache_device="cuda:0",
attention_mode="sdpa",
) -> torch.Tensor:
#print("hidden_states in block: ", hidden_states.shape) #1.5: torch.Size([2, 3200, 3072]) 10.: torch.Size([2, 6400, 3072])
text_seq_length = encoder_hidden_states.size(1)
@ -286,7 +313,6 @@ class CogVideoXBlock(nn.Module):
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
attention_mode=attention_mode,
)
if fastercache_counter == fastercache_start_step:
self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)]
@ -298,8 +324,7 @@ class CogVideoXBlock(nn.Module):
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
attention_mode=attention_mode,
image_rotary_emb=image_rotary_emb
)
hidden_states = hidden_states + gate_msa * attn_hidden_states
@ -408,6 +433,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
use_rotary_positional_embeddings: bool = False,
use_learned_positional_embeddings: bool = False,
patch_bias: bool = True,
attention_mode: Optional[str] = "sdpa",
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
@ -461,6 +487,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
attention_mode=attention_mode,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
@ -496,73 +523,12 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
self.fastercache_hf_step = 30
self.fastercache_device = "cuda"
self.fastercache_num_blocks_to_cache = len(self.transformer_blocks)
self.attention_mode = "sdpa"
self.attention_mode = attention_mode
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
#region forward
def forward(
self,
hidden_states: torch.Tensor,
@ -624,8 +590,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
fastercache_counter = self.fastercache_counter,
fastercache_start_step = self.fastercache_start_step,
fastercache_device = self.fastercache_device,
attention_mode = self.attention_mode
fastercache_device = self.fastercache_device
)
if (controlnet_states is not None) and (i < len(controlnet_states)):
@ -695,8 +660,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache,
fastercache_counter = self.fastercache_counter,
fastercache_start_step = self.fastercache_start_step,
fastercache_device = self.fastercache_device,
attention_mode = self.attention_mode
fastercache_device = self.fastercache_device
)
#has_nan = torch.isnan(hidden_states).any()
#if has_nan:

View File

@ -124,7 +124,19 @@ class DownloadAndLoadCogVideoModel:
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
"lora": ("COGLORA", {"default": None}),
"compile_args":("COMPILEARGS", ),
"attention_mode": (["sdpa", "sageattn", "fused_sdpa", "fused_sageattn", "comfy"], {"default": "sdpa"}),
"attention_mode": ([
"sdpa",
"fused_sdpa",
"sageattn",
"fused_sageattn",
"sageattn_qk_int8_pv_fp8_cuda",
"sageattn_qk_int8_pv_fp16_cuda",
"sageattn_qk_int8_pv_fp16_triton",
"fused_sageattn_qk_int8_pv_fp8_cuda",
"fused_sageattn_qk_int8_pv_fp16_cuda",
"fused_sageattn_qk_int8_pv_fp16_triton",
"comfy"
], {"default": "sdpa"}),
"load_device": (["main_device", "offload_device"], {"default": "main_device"}),
}
}
@ -139,11 +151,18 @@ class DownloadAndLoadCogVideoModel:
enable_sequential_cpu_offload=False, block_edit=None, lora=None, compile_args=None,
attention_mode="sdpa", load_device="main_device"):
transformer = None
if "sage" in attention_mode:
try:
from sageattention import sageattn
except Exception as e:
raise ValueError(f"Can't import SageAttention: {str(e)}")
if "qk_int8" in attention_mode:
try:
from sageattention import sageattn_qk_int8_pv_fp16_cuda
except Exception as e:
raise ValueError(f"Can't import SageAttention 2.0.0: {str(e)}")
if precision == "fp16" and "1.5" in model:
raise ValueError("1.5 models do not currently work in fp16")
@ -218,7 +237,7 @@ class DownloadAndLoadCogVideoModel:
local_dir_use_symlinks=False,
)
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder)
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder, attention_mode=attention_mode)
transformer = transformer.to(dtype).to(transformer_load_device)
if "1.5" in model:
@ -291,7 +310,6 @@ class DownloadAndLoadCogVideoModel:
for module in pipe.transformer.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
pipe.transformer.attention_mode = attention_mode
if compile_args is not None:
pipe.transformer.to(memory_format=torch.channels_last)
@ -515,7 +533,7 @@ class DownloadAndLoadCogVideoGGUFModel:
else:
transformer_config["in_channels"] = 16
transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
transformer = CogVideoXTransformer3DModel.from_config(transformer_config, attention_mode=attention_mode)
cast_dtype = vae_dtype
params_to_keep = {"patch_embed", "pos_embedding", "time_embedding"}
if "2b" in model:
@ -696,7 +714,7 @@ class CogVideoXModelLoader:
transformer_config["sample_width"] = 300
with init_empty_weights():
transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
transformer = CogVideoXTransformer3DModel.from_config(transformer_config, attention_mode=attention_mode)
#load weights
#params_to_keep = {}

View File

@ -471,7 +471,25 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
# 5.5.
if image_cond_latents is not None:
if image_cond_latents.shape[1] == 2:
if image_cond_latents.shape[1] == 3:
logger.info("More than one image conditioning frame received, interpolating")
padding_shape = (
batch_size,
(latents.shape[1] - 3),
self.vae_latent_channels,
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
)
latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae_dtype)
middle_frame = image_cond_latents[:, 2, :, :, :]
image_cond_latents = torch.cat([image_cond_latents[:, 0, :, :, :].unsqueeze(1), latent_padding, image_cond_latents[:, -1, :, :, :].unsqueeze(1)], dim=1)
middle_frame_idx = image_cond_latents.shape[1] // 2
image_cond_latents = image_cond_latents[:, middle_frame_idx, :, :, :] = middle_frame
if self.transformer.config.patch_size_t is not None:
first_frame = image_cond_latents[:, : image_cond_latents.size(1) % self.transformer.config.patch_size_t, ...]
image_cond_latents = torch.cat([first_frame, image_cond_latents], dim=1)
elif image_cond_latents.shape[1] == 2:
logger.info("More than one image conditioning frame received, interpolating")
padding_shape = (
batch_size,
@ -593,9 +611,25 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
counter = torch.zeros_like(latent_model_input)
noise_pred = torch.zeros_like(latent_model_input)
current_step_percentage = i / num_inference_steps
if image_cond_latents is not None:
latent_image_input = torch.cat([image_cond_latents] * 2) if do_classifier_free_guidance else image_cond_latents
latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2)
if not image_cond_start_percent <= current_step_percentage <= image_cond_end_percent:
latent_image_input = torch.zeros_like(latent_model_input)
else:
latent_image_input = torch.cat([image_cond_latents] * 2) if do_classifier_free_guidance else image_cond_latents
if fun_mask is not None: #for fun img2vid and interpolation
fun_inpaint_mask = torch.cat([fun_mask] * 2) if do_classifier_free_guidance else fun_mask
masks_input = torch.cat([fun_inpaint_mask, latent_image_input], dim=2)
latent_model_input = torch.cat([latent_model_input, masks_input], dim=2)
else:
latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2)
else: # for Fun inpaint vid2vid
if fun_mask is not None:
fun_inpaint_mask = torch.cat([fun_mask] * 2) if do_classifier_free_guidance else fun_mask
fun_inpaint_masked_video_latents = torch.cat([fun_masked_video_latents] * 2) if do_classifier_free_guidance else fun_masked_video_latents
fun_inpaint_latents = torch.cat([fun_inpaint_mask, fun_inpaint_masked_video_latents], dim=2).to(latents.dtype)
latent_model_input = torch.cat([latent_model_input, fun_inpaint_latents], dim=2)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])