possible sdpa kernel fixes and add optional cfg scheduling

This commit is contained in:
kijai 2024-10-27 12:23:01 +02:00
parent e20eb66f93
commit 3613700752
4 changed files with 90 additions and 88 deletions

View File

@ -5,8 +5,6 @@ import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn.attention import sdpa_kernel, SDPBackend
from .layers import (
FeedForward,
PatchEmbed,
@ -22,7 +20,7 @@ from .rope_mixed import (
)
from .temporal_rope import apply_rotary_emb_qk_real
from .utils import (
AttentionPool,
pool_tokens,
modulate,
pad_and_split_xy,
unify_streams,
@ -39,16 +37,81 @@ try:
except ImportError:
SAGEATTN_IS_AVAILABLE = False
backends = []
if torch.cuda.get_device_properties(0).major < 7:
backends.append(SDPBackend.MATH)
backends.append(SDPBackend.EFFICIENT_ATTENTION)
if torch.cuda.get_device_properties(0).major >= 9.0:
backends.append(SDPBackend.EFFICIENT_ATTENTION)
backends.append(SDPBackend.CUDNN_ATTENTION)
else:
backends.append(SDPBackend.EFFICIENT_ATTENTION)
from torch.nn.attention import sdpa_kernel, SDPBackend
backends = []
backends.append(SDPBackend.CUDNN_ATTENTION)
backends.append(SDPBackend.EFFICIENT_ATTENTION)
backends.append(SDPBackend.MATH)
import comfy.model_management as mm
class AttentionPool(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
output_dim: int = None,
device: Optional[torch.device] = None,
):
"""
Args:
spatial_dim (int): Number of tokens in sequence length.
embed_dim (int): Dimensionality of input tokens.
num_heads (int): Number of attention heads.
output_dim (int): Dimensionality of output tokens. Defaults to embed_dim.
"""
super().__init__()
self.num_heads = num_heads
self.to_kv = nn.Linear(embed_dim, 2 * embed_dim, device=device)
self.to_q = nn.Linear(embed_dim, embed_dim, device=device)
self.to_out = nn.Linear(embed_dim, output_dim or embed_dim, device=device)
def forward(self, x, mask):
"""
Args:
x (torch.Tensor): (B, L, D) tensor of input tokens.
mask (torch.Tensor): (B, L) boolean tensor indicating which tokens are not padding.
NOTE: We assume x does not require gradients.
Returns:
x (torch.Tensor): (B, D) tensor of pooled tokens.
"""
D = x.size(2)
# Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L).
attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L).
attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L).
# Average non-padding token features. These will be used as the query.
x_pool = pool_tokens(x, mask, keepdim=True) # (B, 1, D)
# Concat pooled features to input sequence.
x = torch.cat([x_pool, x], dim=1) # (B, L+1, D)
# Compute queries, keys, values. Only the mean token is used to create a query.
kv = self.to_kv(x) # (B, L+1, 2 * D)
q = self.to_q(x[:, 0]) # (B, D)
# Extract heads.
head_dim = D // self.num_heads
kv = kv.unflatten(2, (2, self.num_heads, head_dim)) # (B, 1+L, 2, H, head_dim)
kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim)
k, v = kv.unbind(2) # (B, H, 1+L, head_dim)
q = q.unflatten(1, (self.num_heads, head_dim)) # (B, H, head_dim)
q = q.unsqueeze(2) # (B, H, 1, head_dim)
# Compute attention.
with sdpa_kernel(backends):
x = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, dropout_p=0.0
) # (B, H, 1, head_dim)
# Concatenate heads and run output.
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
x = self.to_out(x)
return x
class AsymmetricAttention(nn.Module):
def __init__(
@ -77,6 +140,7 @@ class AsymmetricAttention(nn.Module):
self.attend_to_padding = attend_to_padding
self.softmax_scale = softmax_scale
self.attention_mode = attention_mode
self.device = device
if dim_x % num_heads != 0:
raise ValueError(
f"dim_x={dim_x} should be divisible by num_heads={num_heads}"
@ -162,7 +226,7 @@ class AsymmetricAttention(nn.Module):
return qkv
def flash_attention(self, qkv, cu_seqlens, max_seqlen_in_batch, total, local_dim):
with torch.autocast("cuda", enabled=False):
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
out: torch.Tensor = flash_attn_varlen_qkvpacked_func(
qkv,
cu_seqlens=cu_seqlens,
@ -174,7 +238,7 @@ class AsymmetricAttention(nn.Module):
def sdpa_attention(self, qkv):
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
with torch.autocast("cuda", enabled=False):
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
with sdpa_kernel(backends):
out = F.scaled_dot_product_attention(
q,
@ -188,7 +252,7 @@ class AsymmetricAttention(nn.Module):
def sage_attention(self, qkv):
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
with torch.autocast("cuda", enabled=False):
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
out = sageattn(
q,
k,
@ -202,7 +266,7 @@ class AsymmetricAttention(nn.Module):
def comfy_attention(self, qkv):
from comfy.ldm.modules.attention import optimized_attention
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
with torch.autocast("cuda", enabled=False):
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
out = optimized_attention(
q,
k,

View File

@ -30,73 +30,6 @@ def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.
return pooled
class AttentionPool(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
output_dim: int = None,
device: Optional[torch.device] = None,
):
"""
Args:
spatial_dim (int): Number of tokens in sequence length.
embed_dim (int): Dimensionality of input tokens.
num_heads (int): Number of attention heads.
output_dim (int): Dimensionality of output tokens. Defaults to embed_dim.
"""
super().__init__()
self.num_heads = num_heads
self.to_kv = nn.Linear(embed_dim, 2 * embed_dim, device=device)
self.to_q = nn.Linear(embed_dim, embed_dim, device=device)
self.to_out = nn.Linear(embed_dim, output_dim or embed_dim, device=device)
def forward(self, x, mask):
"""
Args:
x (torch.Tensor): (B, L, D) tensor of input tokens.
mask (torch.Tensor): (B, L) boolean tensor indicating which tokens are not padding.
NOTE: We assume x does not require gradients.
Returns:
x (torch.Tensor): (B, D) tensor of pooled tokens.
"""
D = x.size(2)
# Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L).
attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L).
attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L).
# Average non-padding token features. These will be used as the query.
x_pool = pool_tokens(x, mask, keepdim=True) # (B, 1, D)
# Concat pooled features to input sequence.
x = torch.cat([x_pool, x], dim=1) # (B, L+1, D)
# Compute queries, keys, values. Only the mean token is used to create a query.
kv = self.to_kv(x) # (B, L+1, 2 * D)
q = self.to_q(x[:, 0]) # (B, D)
# Extract heads.
head_dim = D // self.num_heads
kv = kv.unflatten(2, (2, self.num_heads, head_dim)) # (B, 1+L, 2, H, head_dim)
kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim)
k, v = kv.unbind(2) # (B, H, 1+L, head_dim)
q = q.unflatten(1, (self.num_heads, head_dim)) # (B, H, head_dim)
q = q.unsqueeze(2) # (B, H, 1, head_dim)
# Compute attention.
x = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, dropout_p=0.0
) # (B, H, 1, head_dim)
# Concatenate heads and run output.
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
x = self.to_out(x)
return x
class PadSplitXY(torch.autograd.Function):
"""
Merge heads, pad and extract visual and text tokens,

View File

@ -352,11 +352,10 @@ class T2VSynthMochiModel:
z = z + dsigma * pred
comfy_pbar.update(1)
#cp_rank, cp_size = get_cp_rank_size()
if batch_cfg:
z = z[:B]
#z = z.tensor_split(cp_size, dim=2)[cp_rank] # split along temporal dim
self.dit.to(self.offload_device, non_blocking=True)
self.dit.to(self.offload_device)
samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)
logging.info(f"samples shape: {samples.shape}")

View File

@ -354,6 +354,9 @@ class MochiSampler:
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
#"batch_cfg": ("BOOLEAN", {"default": False, "tooltip": "Enable batched cfg"}),
},
"optional": {
"cfg_schedule": ("FLOAT", {"forceInput": True,}),
}
}
RETURN_TYPES = ("LATENT",)
@ -361,19 +364,22 @@ class MochiSampler:
FUNCTION = "process"
CATEGORY = "MochiWrapper"
def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames):
def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames, cfg_schedule=None):
mm.soft_empty_cache()
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
cfg_schedule = cfg_schedule or [cfg] * steps
logging.info(f"Using cfg schedule: {cfg_schedule}")
args = {
"height": height,
"width": width,
"num_frames": num_frames,
"mochi_args": {
"sigma_schedule": linear_quadratic_schedule(steps, 0.025),
"cfg_schedule": [cfg] * steps,
"cfg_schedule": cfg_schedule,
"num_inference_steps": steps,
"batch_cfg": False,
},