Revert "works without flash_attn (thanks @juxtapoz!)"
This reverts commit a1b1f86aa3b6780a4981157b8e2e37b0a1017568.
This commit is contained in:
parent
a1b1f86aa3
commit
1ba3ac8e25
@ -36,11 +36,6 @@ try:
|
||||
FLASH_ATTN_IS_AVAILABLE = True
|
||||
except ImportError:
|
||||
FLASH_ATTN_IS_AVAILABLE = False
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
SAGEATTN_IS_AVAILABLE = True
|
||||
except ImportError:
|
||||
SAGEATTN_IS_AVAILABLE = False
|
||||
|
||||
COMPILE_FINAL_LAYER = False #os.environ.get("COMPILE_DIT") == "1"
|
||||
COMPILE_MMDIT_BLOCK = False #os.environ.get("COMPILE_DIT") == "1"
|
||||
@ -60,8 +55,9 @@ class AsymmetricAttention(nn.Module):
|
||||
attend_to_padding: bool = False,
|
||||
softmax_scale: Optional[float] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
attention_mode: str = "sdpa",
|
||||
|
||||
clip_feat_dim: Optional[int] = None,
|
||||
pooled_caption_mlp_bias: bool = True,
|
||||
use_transformer_engine: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim_x = dim_x
|
||||
@ -72,7 +68,6 @@ class AsymmetricAttention(nn.Module):
|
||||
self.update_y = update_y
|
||||
self.attend_to_padding = attend_to_padding
|
||||
self.softmax_scale = softmax_scale
|
||||
self.attention_mode = attention_mode
|
||||
if dim_x % num_heads != 0:
|
||||
raise ValueError(
|
||||
f"dim_x={dim_x} should be divisible by num_heads={num_heads}"
|
||||
@ -165,43 +160,6 @@ class AsymmetricAttention(nn.Module):
|
||||
)
|
||||
|
||||
return qkv
|
||||
|
||||
def flash_attention(self, qkv, cu_seqlens, max_seqlen_in_batch, total, local_dim):
|
||||
with torch.autocast("cuda", enabled=False):
|
||||
out: torch.Tensor = flash_attn_varlen_qkvpacked_func(
|
||||
qkv,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen_in_batch,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=self.softmax_scale,
|
||||
) # (total, local_heads, head_dim)
|
||||
return out.view(total, local_dim)
|
||||
|
||||
def sdpa_attention(self, qkv):
|
||||
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
|
||||
with torch.autocast("cuda", enabled=False):
|
||||
out = F.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=False
|
||||
)
|
||||
return rearrange(out, 'b h s d -> s (b h d)')
|
||||
|
||||
def sage_attention(self, qkv):
|
||||
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
|
||||
with torch.autocast("cuda", enabled=False):
|
||||
out = sageattn(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=False
|
||||
)
|
||||
return rearrange(out, 'b h s d -> s (b h d)')
|
||||
|
||||
@torch.compiler.disable()
|
||||
def run_attention(
|
||||
@ -222,13 +180,34 @@ class AsymmetricAttention(nn.Module):
|
||||
local_dim = local_heads * self.head_dim
|
||||
total = qkv.size(0)
|
||||
|
||||
if self.attention_mode == "flash_attn":
|
||||
out = self.flash_attention(qkv, cu_seqlens, max_seqlen_in_batch, total, local_dim)
|
||||
elif self.attention_mode == "sdpa":
|
||||
out = self.sdpa_attention(qkv)
|
||||
elif self.attention_mode == "sage_attn":
|
||||
out = self.sage_attention(qkv)
|
||||
|
||||
if FLASH_ATTN_IS_AVAILABLE:
|
||||
with torch.autocast("cuda", enabled=False):
|
||||
out: torch.Tensor = flash_attn_varlen_qkvpacked_func(
|
||||
qkv,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen_in_batch,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=self.softmax_scale,
|
||||
) # (total, local_heads, head_dim)
|
||||
out = out.view(total, local_dim)
|
||||
else:
|
||||
raise NotImplementedError("Flash attention is currently required.")
|
||||
print("qkv: ",qkv.shape, qkv.dtype, qkv.device)
|
||||
expected_size = 2 * 44520 * 3 * 24 * 128
|
||||
actual_size = qkv.numel()
|
||||
print(f"Expected size: {expected_size}, Actual size: {actual_size}")
|
||||
q, k, v = qkv.reshape(B, N, 3, local_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
with torch.autocast("cuda", enabled=False):
|
||||
out = F.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=False
|
||||
)
|
||||
out = out.transpose(1, 2).reshape(B, -1, local_heads * self.head_dim)
|
||||
|
||||
x, y = pad_and_split_xy(out, valid_token_indices, B, N, L, qkv.dtype)
|
||||
assert x.size() == (B, N, local_dim)
|
||||
assert y.size() == (B, L, local_dim)
|
||||
@ -305,20 +284,18 @@ class AsymmetricJointBlock(nn.Module):
|
||||
mlp_ratio_y: float = 4.0, # Ratio of hidden size to d_model for MLP for text tokens.
|
||||
update_y: bool = True, # Whether to update text tokens in this block.
|
||||
device: Optional[torch.device] = None,
|
||||
attention_mode: str = "sdpa",
|
||||
**block_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.update_y = update_y
|
||||
self.hidden_size_x = hidden_size_x
|
||||
self.hidden_size_y = hidden_size_y
|
||||
self.attention_mode = attention_mode
|
||||
self.mod_x = nn.Linear(hidden_size_x, 4 * hidden_size_x, device=device)
|
||||
if self.update_y:
|
||||
self.mod_y = nn.Linear(hidden_size_x, 4 * hidden_size_y, device=device)
|
||||
else:
|
||||
self.mod_y = nn.Linear(hidden_size_x, hidden_size_y, device=device)
|
||||
|
||||
|
||||
# Self-attention:
|
||||
self.attn = AsymmetricAttention(
|
||||
hidden_size_x,
|
||||
@ -326,7 +303,6 @@ class AsymmetricJointBlock(nn.Module):
|
||||
num_heads=num_heads,
|
||||
update_y=update_y,
|
||||
device=device,
|
||||
attention_mode=attention_mode,
|
||||
**block_kwargs,
|
||||
)
|
||||
|
||||
@ -473,7 +449,6 @@ class AsymmDiTJoint(nn.Module):
|
||||
use_extended_posenc: bool = False,
|
||||
rope_theta: float = 10000.0,
|
||||
device: Optional[torch.device] = None,
|
||||
attention_mode: str = "sdpa",
|
||||
**block_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
@ -541,7 +516,6 @@ class AsymmDiTJoint(nn.Module):
|
||||
mlp_ratio_y=mlp_ratio_y,
|
||||
update_y=update_y,
|
||||
device=device,
|
||||
attention_mode=attention_mode,
|
||||
**block_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@ -119,7 +119,6 @@ class T2VSynthMochiModel:
|
||||
dit_checkpoint_path: str,
|
||||
weight_dtype: torch.dtype = torch.float8_e4m3fn,
|
||||
fp8_fastmode: bool = False,
|
||||
attention_mode: str = "sdpa"
|
||||
):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
@ -145,7 +144,6 @@ class T2VSynthMochiModel:
|
||||
t5_feat_dim=4096,
|
||||
t5_token_length=256,
|
||||
rope_theta=10000.0,
|
||||
attention_mode=attention_mode,
|
||||
)
|
||||
|
||||
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
|
||||
|
||||
6
nodes.py
6
nodes.py
@ -60,8 +60,7 @@ class DownloadAndLoadMochiModel:
|
||||
{"tooltip": "Downloads from 'https://huggingface.co/Kijai/Mochi_preview_comfy' to 'models/vae/mochi'", },
|
||||
),
|
||||
"precision": (["fp8_e4m3fn","fp8_e4m3fn_fast","fp16", "fp32", "bf16"],
|
||||
{"default": "fp8_e4m3fn", }),
|
||||
"attention_mode": (["sdpa","flash_attn","sage_attn"],
|
||||
{"default": "fp8_e4m3fn", }
|
||||
),
|
||||
},
|
||||
}
|
||||
@ -72,7 +71,7 @@ class DownloadAndLoadMochiModel:
|
||||
CATEGORY = "MochiWrapper"
|
||||
DESCRIPTION = "Downloads and loads the selected Mochi model from Huggingface"
|
||||
|
||||
def loadmodel(self, model, vae, precision, attention_mode):
|
||||
def loadmodel(self, model, vae, precision):
|
||||
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
@ -116,7 +115,6 @@ class DownloadAndLoadMochiModel:
|
||||
dit_checkpoint_path=model_path,
|
||||
weight_dtype=dtype,
|
||||
fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,
|
||||
attention_mode=attention_mode
|
||||
)
|
||||
with (init_empty_weights() if is_accelerate_available else nullcontext()):
|
||||
vae = Decoder(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user