possible sdpa kernel fixes and add optional cfg scheduling
This commit is contained in:
parent
e20eb66f93
commit
3613700752
@ -5,8 +5,6 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from torch.nn.attention import sdpa_kernel, SDPBackend
|
|
||||||
|
|
||||||
from .layers import (
|
from .layers import (
|
||||||
FeedForward,
|
FeedForward,
|
||||||
PatchEmbed,
|
PatchEmbed,
|
||||||
@ -22,7 +20,7 @@ from .rope_mixed import (
|
|||||||
)
|
)
|
||||||
from .temporal_rope import apply_rotary_emb_qk_real
|
from .temporal_rope import apply_rotary_emb_qk_real
|
||||||
from .utils import (
|
from .utils import (
|
||||||
AttentionPool,
|
pool_tokens,
|
||||||
modulate,
|
modulate,
|
||||||
pad_and_split_xy,
|
pad_and_split_xy,
|
||||||
unify_streams,
|
unify_streams,
|
||||||
@ -39,16 +37,81 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
SAGEATTN_IS_AVAILABLE = False
|
SAGEATTN_IS_AVAILABLE = False
|
||||||
|
|
||||||
backends = []
|
from torch.nn.attention import sdpa_kernel, SDPBackend
|
||||||
if torch.cuda.get_device_properties(0).major < 7:
|
|
||||||
backends.append(SDPBackend.MATH)
|
|
||||||
backends.append(SDPBackend.EFFICIENT_ATTENTION)
|
|
||||||
if torch.cuda.get_device_properties(0).major >= 9.0:
|
|
||||||
backends.append(SDPBackend.EFFICIENT_ATTENTION)
|
|
||||||
backends.append(SDPBackend.CUDNN_ATTENTION)
|
|
||||||
else:
|
|
||||||
backends.append(SDPBackend.EFFICIENT_ATTENTION)
|
|
||||||
|
|
||||||
|
backends = []
|
||||||
|
backends.append(SDPBackend.CUDNN_ATTENTION)
|
||||||
|
backends.append(SDPBackend.EFFICIENT_ATTENTION)
|
||||||
|
backends.append(SDPBackend.MATH)
|
||||||
|
|
||||||
|
import comfy.model_management as mm
|
||||||
|
|
||||||
|
class AttentionPool(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
output_dim: int = None,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
spatial_dim (int): Number of tokens in sequence length.
|
||||||
|
embed_dim (int): Dimensionality of input tokens.
|
||||||
|
num_heads (int): Number of attention heads.
|
||||||
|
output_dim (int): Dimensionality of output tokens. Defaults to embed_dim.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.to_kv = nn.Linear(embed_dim, 2 * embed_dim, device=device)
|
||||||
|
self.to_q = nn.Linear(embed_dim, embed_dim, device=device)
|
||||||
|
self.to_out = nn.Linear(embed_dim, output_dim or embed_dim, device=device)
|
||||||
|
|
||||||
|
def forward(self, x, mask):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): (B, L, D) tensor of input tokens.
|
||||||
|
mask (torch.Tensor): (B, L) boolean tensor indicating which tokens are not padding.
|
||||||
|
|
||||||
|
NOTE: We assume x does not require gradients.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x (torch.Tensor): (B, D) tensor of pooled tokens.
|
||||||
|
"""
|
||||||
|
D = x.size(2)
|
||||||
|
|
||||||
|
# Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L).
|
||||||
|
attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L).
|
||||||
|
attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L).
|
||||||
|
|
||||||
|
# Average non-padding token features. These will be used as the query.
|
||||||
|
x_pool = pool_tokens(x, mask, keepdim=True) # (B, 1, D)
|
||||||
|
|
||||||
|
# Concat pooled features to input sequence.
|
||||||
|
x = torch.cat([x_pool, x], dim=1) # (B, L+1, D)
|
||||||
|
|
||||||
|
# Compute queries, keys, values. Only the mean token is used to create a query.
|
||||||
|
kv = self.to_kv(x) # (B, L+1, 2 * D)
|
||||||
|
q = self.to_q(x[:, 0]) # (B, D)
|
||||||
|
|
||||||
|
# Extract heads.
|
||||||
|
head_dim = D // self.num_heads
|
||||||
|
kv = kv.unflatten(2, (2, self.num_heads, head_dim)) # (B, 1+L, 2, H, head_dim)
|
||||||
|
kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim)
|
||||||
|
k, v = kv.unbind(2) # (B, H, 1+L, head_dim)
|
||||||
|
q = q.unflatten(1, (self.num_heads, head_dim)) # (B, H, head_dim)
|
||||||
|
q = q.unsqueeze(2) # (B, H, 1, head_dim)
|
||||||
|
|
||||||
|
# Compute attention.
|
||||||
|
with sdpa_kernel(backends):
|
||||||
|
x = F.scaled_dot_product_attention(
|
||||||
|
q, k, v, attn_mask=attn_mask, dropout_p=0.0
|
||||||
|
) # (B, H, 1, head_dim)
|
||||||
|
|
||||||
|
# Concatenate heads and run output.
|
||||||
|
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
|
||||||
|
x = self.to_out(x)
|
||||||
|
return x
|
||||||
|
|
||||||
class AsymmetricAttention(nn.Module):
|
class AsymmetricAttention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -77,6 +140,7 @@ class AsymmetricAttention(nn.Module):
|
|||||||
self.attend_to_padding = attend_to_padding
|
self.attend_to_padding = attend_to_padding
|
||||||
self.softmax_scale = softmax_scale
|
self.softmax_scale = softmax_scale
|
||||||
self.attention_mode = attention_mode
|
self.attention_mode = attention_mode
|
||||||
|
self.device = device
|
||||||
if dim_x % num_heads != 0:
|
if dim_x % num_heads != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"dim_x={dim_x} should be divisible by num_heads={num_heads}"
|
f"dim_x={dim_x} should be divisible by num_heads={num_heads}"
|
||||||
@ -162,7 +226,7 @@ class AsymmetricAttention(nn.Module):
|
|||||||
return qkv
|
return qkv
|
||||||
|
|
||||||
def flash_attention(self, qkv, cu_seqlens, max_seqlen_in_batch, total, local_dim):
|
def flash_attention(self, qkv, cu_seqlens, max_seqlen_in_batch, total, local_dim):
|
||||||
with torch.autocast("cuda", enabled=False):
|
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
|
||||||
out: torch.Tensor = flash_attn_varlen_qkvpacked_func(
|
out: torch.Tensor = flash_attn_varlen_qkvpacked_func(
|
||||||
qkv,
|
qkv,
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
@ -174,7 +238,7 @@ class AsymmetricAttention(nn.Module):
|
|||||||
|
|
||||||
def sdpa_attention(self, qkv):
|
def sdpa_attention(self, qkv):
|
||||||
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
|
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
|
||||||
with torch.autocast("cuda", enabled=False):
|
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
|
||||||
with sdpa_kernel(backends):
|
with sdpa_kernel(backends):
|
||||||
out = F.scaled_dot_product_attention(
|
out = F.scaled_dot_product_attention(
|
||||||
q,
|
q,
|
||||||
@ -188,7 +252,7 @@ class AsymmetricAttention(nn.Module):
|
|||||||
|
|
||||||
def sage_attention(self, qkv):
|
def sage_attention(self, qkv):
|
||||||
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
|
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
|
||||||
with torch.autocast("cuda", enabled=False):
|
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
|
||||||
out = sageattn(
|
out = sageattn(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
@ -202,7 +266,7 @@ class AsymmetricAttention(nn.Module):
|
|||||||
def comfy_attention(self, qkv):
|
def comfy_attention(self, qkv):
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
|
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
|
||||||
with torch.autocast("cuda", enabled=False):
|
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
|
||||||
out = optimized_attention(
|
out = optimized_attention(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
|
|||||||
@ -30,73 +30,6 @@ def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.
|
|||||||
return pooled
|
return pooled
|
||||||
|
|
||||||
|
|
||||||
class AttentionPool(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
embed_dim: int,
|
|
||||||
num_heads: int,
|
|
||||||
output_dim: int = None,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
spatial_dim (int): Number of tokens in sequence length.
|
|
||||||
embed_dim (int): Dimensionality of input tokens.
|
|
||||||
num_heads (int): Number of attention heads.
|
|
||||||
output_dim (int): Dimensionality of output tokens. Defaults to embed_dim.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.to_kv = nn.Linear(embed_dim, 2 * embed_dim, device=device)
|
|
||||||
self.to_q = nn.Linear(embed_dim, embed_dim, device=device)
|
|
||||||
self.to_out = nn.Linear(embed_dim, output_dim or embed_dim, device=device)
|
|
||||||
|
|
||||||
def forward(self, x, mask):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x (torch.Tensor): (B, L, D) tensor of input tokens.
|
|
||||||
mask (torch.Tensor): (B, L) boolean tensor indicating which tokens are not padding.
|
|
||||||
|
|
||||||
NOTE: We assume x does not require gradients.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
x (torch.Tensor): (B, D) tensor of pooled tokens.
|
|
||||||
"""
|
|
||||||
D = x.size(2)
|
|
||||||
|
|
||||||
# Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L).
|
|
||||||
attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L).
|
|
||||||
attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L).
|
|
||||||
|
|
||||||
# Average non-padding token features. These will be used as the query.
|
|
||||||
x_pool = pool_tokens(x, mask, keepdim=True) # (B, 1, D)
|
|
||||||
|
|
||||||
# Concat pooled features to input sequence.
|
|
||||||
x = torch.cat([x_pool, x], dim=1) # (B, L+1, D)
|
|
||||||
|
|
||||||
# Compute queries, keys, values. Only the mean token is used to create a query.
|
|
||||||
kv = self.to_kv(x) # (B, L+1, 2 * D)
|
|
||||||
q = self.to_q(x[:, 0]) # (B, D)
|
|
||||||
|
|
||||||
# Extract heads.
|
|
||||||
head_dim = D // self.num_heads
|
|
||||||
kv = kv.unflatten(2, (2, self.num_heads, head_dim)) # (B, 1+L, 2, H, head_dim)
|
|
||||||
kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim)
|
|
||||||
k, v = kv.unbind(2) # (B, H, 1+L, head_dim)
|
|
||||||
q = q.unflatten(1, (self.num_heads, head_dim)) # (B, H, head_dim)
|
|
||||||
q = q.unsqueeze(2) # (B, H, 1, head_dim)
|
|
||||||
|
|
||||||
# Compute attention.
|
|
||||||
x = F.scaled_dot_product_attention(
|
|
||||||
q, k, v, attn_mask=attn_mask, dropout_p=0.0
|
|
||||||
) # (B, H, 1, head_dim)
|
|
||||||
|
|
||||||
# Concatenate heads and run output.
|
|
||||||
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
|
|
||||||
x = self.to_out(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class PadSplitXY(torch.autograd.Function):
|
class PadSplitXY(torch.autograd.Function):
|
||||||
"""
|
"""
|
||||||
Merge heads, pad and extract visual and text tokens,
|
Merge heads, pad and extract visual and text tokens,
|
||||||
|
|||||||
@ -352,11 +352,10 @@ class T2VSynthMochiModel:
|
|||||||
z = z + dsigma * pred
|
z = z + dsigma * pred
|
||||||
comfy_pbar.update(1)
|
comfy_pbar.update(1)
|
||||||
|
|
||||||
#cp_rank, cp_size = get_cp_rank_size()
|
|
||||||
if batch_cfg:
|
if batch_cfg:
|
||||||
z = z[:B]
|
z = z[:B]
|
||||||
#z = z.tensor_split(cp_size, dim=2)[cp_rank] # split along temporal dim
|
|
||||||
self.dit.to(self.offload_device, non_blocking=True)
|
self.dit.to(self.offload_device)
|
||||||
|
|
||||||
samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)
|
samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)
|
||||||
logging.info(f"samples shape: {samples.shape}")
|
logging.info(f"samples shape: {samples.shape}")
|
||||||
|
|||||||
10
nodes.py
10
nodes.py
@ -354,6 +354,9 @@ class MochiSampler:
|
|||||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||||
#"batch_cfg": ("BOOLEAN", {"default": False, "tooltip": "Enable batched cfg"}),
|
#"batch_cfg": ("BOOLEAN", {"default": False, "tooltip": "Enable batched cfg"}),
|
||||||
},
|
},
|
||||||
|
"optional": {
|
||||||
|
"cfg_schedule": ("FLOAT", {"forceInput": True,}),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
RETURN_TYPES = ("LATENT",)
|
||||||
@ -361,19 +364,22 @@ class MochiSampler:
|
|||||||
FUNCTION = "process"
|
FUNCTION = "process"
|
||||||
CATEGORY = "MochiWrapper"
|
CATEGORY = "MochiWrapper"
|
||||||
|
|
||||||
def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames):
|
def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames, cfg_schedule=None):
|
||||||
mm.soft_empty_cache()
|
mm.soft_empty_cache()
|
||||||
|
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
|
|
||||||
|
cfg_schedule = cfg_schedule or [cfg] * steps
|
||||||
|
logging.info(f"Using cfg schedule: {cfg_schedule}")
|
||||||
|
|
||||||
args = {
|
args = {
|
||||||
"height": height,
|
"height": height,
|
||||||
"width": width,
|
"width": width,
|
||||||
"num_frames": num_frames,
|
"num_frames": num_frames,
|
||||||
"mochi_args": {
|
"mochi_args": {
|
||||||
"sigma_schedule": linear_quadratic_schedule(steps, 0.025),
|
"sigma_schedule": linear_quadratic_schedule(steps, 0.025),
|
||||||
"cfg_schedule": [cfg] * steps,
|
"cfg_schedule": cfg_schedule,
|
||||||
"num_inference_steps": steps,
|
"num_inference_steps": steps,
|
||||||
"batch_cfg": False,
|
"batch_cfg": False,
|
||||||
},
|
},
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user