possible sdpa kernel fixes and add optional cfg scheduling
This commit is contained in:
parent
e20eb66f93
commit
3613700752
@ -5,8 +5,6 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from torch.nn.attention import sdpa_kernel, SDPBackend
|
||||
|
||||
from .layers import (
|
||||
FeedForward,
|
||||
PatchEmbed,
|
||||
@ -22,7 +20,7 @@ from .rope_mixed import (
|
||||
)
|
||||
from .temporal_rope import apply_rotary_emb_qk_real
|
||||
from .utils import (
|
||||
AttentionPool,
|
||||
pool_tokens,
|
||||
modulate,
|
||||
pad_and_split_xy,
|
||||
unify_streams,
|
||||
@ -39,16 +37,81 @@ try:
|
||||
except ImportError:
|
||||
SAGEATTN_IS_AVAILABLE = False
|
||||
|
||||
backends = []
|
||||
if torch.cuda.get_device_properties(0).major < 7:
|
||||
backends.append(SDPBackend.MATH)
|
||||
backends.append(SDPBackend.EFFICIENT_ATTENTION)
|
||||
if torch.cuda.get_device_properties(0).major >= 9.0:
|
||||
backends.append(SDPBackend.EFFICIENT_ATTENTION)
|
||||
backends.append(SDPBackend.CUDNN_ATTENTION)
|
||||
else:
|
||||
backends.append(SDPBackend.EFFICIENT_ATTENTION)
|
||||
from torch.nn.attention import sdpa_kernel, SDPBackend
|
||||
|
||||
backends = []
|
||||
backends.append(SDPBackend.CUDNN_ATTENTION)
|
||||
backends.append(SDPBackend.EFFICIENT_ATTENTION)
|
||||
backends.append(SDPBackend.MATH)
|
||||
|
||||
import comfy.model_management as mm
|
||||
|
||||
class AttentionPool(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
output_dim: int = None,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
spatial_dim (int): Number of tokens in sequence length.
|
||||
embed_dim (int): Dimensionality of input tokens.
|
||||
num_heads (int): Number of attention heads.
|
||||
output_dim (int): Dimensionality of output tokens. Defaults to embed_dim.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.to_kv = nn.Linear(embed_dim, 2 * embed_dim, device=device)
|
||||
self.to_q = nn.Linear(embed_dim, embed_dim, device=device)
|
||||
self.to_out = nn.Linear(embed_dim, output_dim or embed_dim, device=device)
|
||||
|
||||
def forward(self, x, mask):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): (B, L, D) tensor of input tokens.
|
||||
mask (torch.Tensor): (B, L) boolean tensor indicating which tokens are not padding.
|
||||
|
||||
NOTE: We assume x does not require gradients.
|
||||
|
||||
Returns:
|
||||
x (torch.Tensor): (B, D) tensor of pooled tokens.
|
||||
"""
|
||||
D = x.size(2)
|
||||
|
||||
# Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L).
|
||||
attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L).
|
||||
attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L).
|
||||
|
||||
# Average non-padding token features. These will be used as the query.
|
||||
x_pool = pool_tokens(x, mask, keepdim=True) # (B, 1, D)
|
||||
|
||||
# Concat pooled features to input sequence.
|
||||
x = torch.cat([x_pool, x], dim=1) # (B, L+1, D)
|
||||
|
||||
# Compute queries, keys, values. Only the mean token is used to create a query.
|
||||
kv = self.to_kv(x) # (B, L+1, 2 * D)
|
||||
q = self.to_q(x[:, 0]) # (B, D)
|
||||
|
||||
# Extract heads.
|
||||
head_dim = D // self.num_heads
|
||||
kv = kv.unflatten(2, (2, self.num_heads, head_dim)) # (B, 1+L, 2, H, head_dim)
|
||||
kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim)
|
||||
k, v = kv.unbind(2) # (B, H, 1+L, head_dim)
|
||||
q = q.unflatten(1, (self.num_heads, head_dim)) # (B, H, head_dim)
|
||||
q = q.unsqueeze(2) # (B, H, 1, head_dim)
|
||||
|
||||
# Compute attention.
|
||||
with sdpa_kernel(backends):
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=0.0
|
||||
) # (B, H, 1, head_dim)
|
||||
|
||||
# Concatenate heads and run output.
|
||||
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
|
||||
x = self.to_out(x)
|
||||
return x
|
||||
|
||||
class AsymmetricAttention(nn.Module):
|
||||
def __init__(
|
||||
@ -77,6 +140,7 @@ class AsymmetricAttention(nn.Module):
|
||||
self.attend_to_padding = attend_to_padding
|
||||
self.softmax_scale = softmax_scale
|
||||
self.attention_mode = attention_mode
|
||||
self.device = device
|
||||
if dim_x % num_heads != 0:
|
||||
raise ValueError(
|
||||
f"dim_x={dim_x} should be divisible by num_heads={num_heads}"
|
||||
@ -162,7 +226,7 @@ class AsymmetricAttention(nn.Module):
|
||||
return qkv
|
||||
|
||||
def flash_attention(self, qkv, cu_seqlens, max_seqlen_in_batch, total, local_dim):
|
||||
with torch.autocast("cuda", enabled=False):
|
||||
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
|
||||
out: torch.Tensor = flash_attn_varlen_qkvpacked_func(
|
||||
qkv,
|
||||
cu_seqlens=cu_seqlens,
|
||||
@ -174,7 +238,7 @@ class AsymmetricAttention(nn.Module):
|
||||
|
||||
def sdpa_attention(self, qkv):
|
||||
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
|
||||
with torch.autocast("cuda", enabled=False):
|
||||
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
|
||||
with sdpa_kernel(backends):
|
||||
out = F.scaled_dot_product_attention(
|
||||
q,
|
||||
@ -188,7 +252,7 @@ class AsymmetricAttention(nn.Module):
|
||||
|
||||
def sage_attention(self, qkv):
|
||||
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
|
||||
with torch.autocast("cuda", enabled=False):
|
||||
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
|
||||
out = sageattn(
|
||||
q,
|
||||
k,
|
||||
@ -202,7 +266,7 @@ class AsymmetricAttention(nn.Module):
|
||||
def comfy_attention(self, qkv):
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
|
||||
with torch.autocast("cuda", enabled=False):
|
||||
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
|
||||
out = optimized_attention(
|
||||
q,
|
||||
k,
|
||||
|
||||
@ -30,73 +30,6 @@ def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.
|
||||
return pooled
|
||||
|
||||
|
||||
class AttentionPool(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
output_dim: int = None,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
spatial_dim (int): Number of tokens in sequence length.
|
||||
embed_dim (int): Dimensionality of input tokens.
|
||||
num_heads (int): Number of attention heads.
|
||||
output_dim (int): Dimensionality of output tokens. Defaults to embed_dim.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.to_kv = nn.Linear(embed_dim, 2 * embed_dim, device=device)
|
||||
self.to_q = nn.Linear(embed_dim, embed_dim, device=device)
|
||||
self.to_out = nn.Linear(embed_dim, output_dim or embed_dim, device=device)
|
||||
|
||||
def forward(self, x, mask):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): (B, L, D) tensor of input tokens.
|
||||
mask (torch.Tensor): (B, L) boolean tensor indicating which tokens are not padding.
|
||||
|
||||
NOTE: We assume x does not require gradients.
|
||||
|
||||
Returns:
|
||||
x (torch.Tensor): (B, D) tensor of pooled tokens.
|
||||
"""
|
||||
D = x.size(2)
|
||||
|
||||
# Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L).
|
||||
attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L).
|
||||
attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L).
|
||||
|
||||
# Average non-padding token features. These will be used as the query.
|
||||
x_pool = pool_tokens(x, mask, keepdim=True) # (B, 1, D)
|
||||
|
||||
# Concat pooled features to input sequence.
|
||||
x = torch.cat([x_pool, x], dim=1) # (B, L+1, D)
|
||||
|
||||
# Compute queries, keys, values. Only the mean token is used to create a query.
|
||||
kv = self.to_kv(x) # (B, L+1, 2 * D)
|
||||
q = self.to_q(x[:, 0]) # (B, D)
|
||||
|
||||
# Extract heads.
|
||||
head_dim = D // self.num_heads
|
||||
kv = kv.unflatten(2, (2, self.num_heads, head_dim)) # (B, 1+L, 2, H, head_dim)
|
||||
kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim)
|
||||
k, v = kv.unbind(2) # (B, H, 1+L, head_dim)
|
||||
q = q.unflatten(1, (self.num_heads, head_dim)) # (B, H, head_dim)
|
||||
q = q.unsqueeze(2) # (B, H, 1, head_dim)
|
||||
|
||||
# Compute attention.
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=0.0
|
||||
) # (B, H, 1, head_dim)
|
||||
|
||||
# Concatenate heads and run output.
|
||||
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
|
||||
x = self.to_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class PadSplitXY(torch.autograd.Function):
|
||||
"""
|
||||
Merge heads, pad and extract visual and text tokens,
|
||||
|
||||
@ -352,11 +352,10 @@ class T2VSynthMochiModel:
|
||||
z = z + dsigma * pred
|
||||
comfy_pbar.update(1)
|
||||
|
||||
#cp_rank, cp_size = get_cp_rank_size()
|
||||
if batch_cfg:
|
||||
z = z[:B]
|
||||
#z = z.tensor_split(cp_size, dim=2)[cp_rank] # split along temporal dim
|
||||
self.dit.to(self.offload_device, non_blocking=True)
|
||||
|
||||
self.dit.to(self.offload_device)
|
||||
|
||||
samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)
|
||||
logging.info(f"samples shape: {samples.shape}")
|
||||
|
||||
10
nodes.py
10
nodes.py
@ -354,6 +354,9 @@ class MochiSampler:
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||
#"batch_cfg": ("BOOLEAN", {"default": False, "tooltip": "Enable batched cfg"}),
|
||||
},
|
||||
"optional": {
|
||||
"cfg_schedule": ("FLOAT", {"forceInput": True,}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
@ -361,19 +364,22 @@ class MochiSampler:
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "MochiWrapper"
|
||||
|
||||
def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames):
|
||||
def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames, cfg_schedule=None):
|
||||
mm.soft_empty_cache()
|
||||
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
|
||||
cfg_schedule = cfg_schedule or [cfg] * steps
|
||||
logging.info(f"Using cfg schedule: {cfg_schedule}")
|
||||
|
||||
args = {
|
||||
"height": height,
|
||||
"width": width,
|
||||
"num_frames": num_frames,
|
||||
"mochi_args": {
|
||||
"sigma_schedule": linear_quadratic_schedule(steps, 0.025),
|
||||
"cfg_schedule": [cfg] * steps,
|
||||
"cfg_schedule": cfg_schedule,
|
||||
"num_inference_steps": steps,
|
||||
"batch_cfg": False,
|
||||
},
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user