Revert "works without flash_attn (thanks @juxtapoz!)"

This reverts commit a1b1f86aa3b6780a4981157b8e2e37b0a1017568.
This commit is contained in:
Jukka Seppänen 2024-10-24 02:23:46 +03:00
parent a1b1f86aa3
commit 1ba3ac8e25
3 changed files with 34 additions and 64 deletions

View File

@ -36,11 +36,6 @@ try:
FLASH_ATTN_IS_AVAILABLE = True
except ImportError:
FLASH_ATTN_IS_AVAILABLE = False
try:
from sageattention import sageattn
SAGEATTN_IS_AVAILABLE = True
except ImportError:
SAGEATTN_IS_AVAILABLE = False
COMPILE_FINAL_LAYER = False #os.environ.get("COMPILE_DIT") == "1"
COMPILE_MMDIT_BLOCK = False #os.environ.get("COMPILE_DIT") == "1"
@ -60,8 +55,9 @@ class AsymmetricAttention(nn.Module):
attend_to_padding: bool = False,
softmax_scale: Optional[float] = None,
device: Optional[torch.device] = None,
attention_mode: str = "sdpa",
clip_feat_dim: Optional[int] = None,
pooled_caption_mlp_bias: bool = True,
use_transformer_engine: bool = False,
):
super().__init__()
self.dim_x = dim_x
@ -72,7 +68,6 @@ class AsymmetricAttention(nn.Module):
self.update_y = update_y
self.attend_to_padding = attend_to_padding
self.softmax_scale = softmax_scale
self.attention_mode = attention_mode
if dim_x % num_heads != 0:
raise ValueError(
f"dim_x={dim_x} should be divisible by num_heads={num_heads}"
@ -165,43 +160,6 @@ class AsymmetricAttention(nn.Module):
)
return qkv
def flash_attention(self, qkv, cu_seqlens, max_seqlen_in_batch, total, local_dim):
with torch.autocast("cuda", enabled=False):
out: torch.Tensor = flash_attn_varlen_qkvpacked_func(
qkv,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen_in_batch,
dropout_p=0.0,
softmax_scale=self.softmax_scale,
) # (total, local_heads, head_dim)
return out.view(total, local_dim)
def sdpa_attention(self, qkv):
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
with torch.autocast("cuda", enabled=False):
out = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=None,
dropout_p=0.0,
is_causal=False
)
return rearrange(out, 'b h s d -> s (b h d)')
def sage_attention(self, qkv):
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
with torch.autocast("cuda", enabled=False):
out = sageattn(
q,
k,
v,
attn_mask=None,
dropout_p=0.0,
is_causal=False
)
return rearrange(out, 'b h s d -> s (b h d)')
@torch.compiler.disable()
def run_attention(
@ -222,13 +180,34 @@ class AsymmetricAttention(nn.Module):
local_dim = local_heads * self.head_dim
total = qkv.size(0)
if self.attention_mode == "flash_attn":
out = self.flash_attention(qkv, cu_seqlens, max_seqlen_in_batch, total, local_dim)
elif self.attention_mode == "sdpa":
out = self.sdpa_attention(qkv)
elif self.attention_mode == "sage_attn":
out = self.sage_attention(qkv)
if FLASH_ATTN_IS_AVAILABLE:
with torch.autocast("cuda", enabled=False):
out: torch.Tensor = flash_attn_varlen_qkvpacked_func(
qkv,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen_in_batch,
dropout_p=0.0,
softmax_scale=self.softmax_scale,
) # (total, local_heads, head_dim)
out = out.view(total, local_dim)
else:
raise NotImplementedError("Flash attention is currently required.")
print("qkv: ",qkv.shape, qkv.dtype, qkv.device)
expected_size = 2 * 44520 * 3 * 24 * 128
actual_size = qkv.numel()
print(f"Expected size: {expected_size}, Actual size: {actual_size}")
q, k, v = qkv.reshape(B, N, 3, local_heads, self.head_dim).permute(2, 0, 3, 1, 4)
with torch.autocast("cuda", enabled=False):
out = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=None,
dropout_p=0.0,
is_causal=False
)
out = out.transpose(1, 2).reshape(B, -1, local_heads * self.head_dim)
x, y = pad_and_split_xy(out, valid_token_indices, B, N, L, qkv.dtype)
assert x.size() == (B, N, local_dim)
assert y.size() == (B, L, local_dim)
@ -305,20 +284,18 @@ class AsymmetricJointBlock(nn.Module):
mlp_ratio_y: float = 4.0, # Ratio of hidden size to d_model for MLP for text tokens.
update_y: bool = True, # Whether to update text tokens in this block.
device: Optional[torch.device] = None,
attention_mode: str = "sdpa",
**block_kwargs,
):
super().__init__()
self.update_y = update_y
self.hidden_size_x = hidden_size_x
self.hidden_size_y = hidden_size_y
self.attention_mode = attention_mode
self.mod_x = nn.Linear(hidden_size_x, 4 * hidden_size_x, device=device)
if self.update_y:
self.mod_y = nn.Linear(hidden_size_x, 4 * hidden_size_y, device=device)
else:
self.mod_y = nn.Linear(hidden_size_x, hidden_size_y, device=device)
# Self-attention:
self.attn = AsymmetricAttention(
hidden_size_x,
@ -326,7 +303,6 @@ class AsymmetricJointBlock(nn.Module):
num_heads=num_heads,
update_y=update_y,
device=device,
attention_mode=attention_mode,
**block_kwargs,
)
@ -473,7 +449,6 @@ class AsymmDiTJoint(nn.Module):
use_extended_posenc: bool = False,
rope_theta: float = 10000.0,
device: Optional[torch.device] = None,
attention_mode: str = "sdpa",
**block_kwargs,
):
super().__init__()
@ -541,7 +516,6 @@ class AsymmDiTJoint(nn.Module):
mlp_ratio_y=mlp_ratio_y,
update_y=update_y,
device=device,
attention_mode=attention_mode,
**block_kwargs,
)

View File

@ -119,7 +119,6 @@ class T2VSynthMochiModel:
dit_checkpoint_path: str,
weight_dtype: torch.dtype = torch.float8_e4m3fn,
fp8_fastmode: bool = False,
attention_mode: str = "sdpa"
):
super().__init__()
self.device = device
@ -145,7 +144,6 @@ class T2VSynthMochiModel:
t5_feat_dim=4096,
t5_token_length=256,
rope_theta=10000.0,
attention_mode=attention_mode,
)
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}

View File

@ -60,8 +60,7 @@ class DownloadAndLoadMochiModel:
{"tooltip": "Downloads from 'https://huggingface.co/Kijai/Mochi_preview_comfy' to 'models/vae/mochi'", },
),
"precision": (["fp8_e4m3fn","fp8_e4m3fn_fast","fp16", "fp32", "bf16"],
{"default": "fp8_e4m3fn", }),
"attention_mode": (["sdpa","flash_attn","sage_attn"],
{"default": "fp8_e4m3fn", }
),
},
}
@ -72,7 +71,7 @@ class DownloadAndLoadMochiModel:
CATEGORY = "MochiWrapper"
DESCRIPTION = "Downloads and loads the selected Mochi model from Huggingface"
def loadmodel(self, model, vae, precision, attention_mode):
def loadmodel(self, model, vae, precision):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
@ -116,7 +115,6 @@ class DownloadAndLoadMochiModel:
dit_checkpoint_path=model_path,
weight_dtype=dtype,
fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,
attention_mode=attention_mode
)
with (init_empty_weights() if is_accelerate_available else nullcontext()):
vae = Decoder(