Revert "works without flash_attn (thanks @juxtapoz!)"
This reverts commit a1b1f86aa3b6780a4981157b8e2e37b0a1017568.
This commit is contained in:
parent
a1b1f86aa3
commit
1ba3ac8e25
@ -36,11 +36,6 @@ try:
|
|||||||
FLASH_ATTN_IS_AVAILABLE = True
|
FLASH_ATTN_IS_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
FLASH_ATTN_IS_AVAILABLE = False
|
FLASH_ATTN_IS_AVAILABLE = False
|
||||||
try:
|
|
||||||
from sageattention import sageattn
|
|
||||||
SAGEATTN_IS_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
SAGEATTN_IS_AVAILABLE = False
|
|
||||||
|
|
||||||
COMPILE_FINAL_LAYER = False #os.environ.get("COMPILE_DIT") == "1"
|
COMPILE_FINAL_LAYER = False #os.environ.get("COMPILE_DIT") == "1"
|
||||||
COMPILE_MMDIT_BLOCK = False #os.environ.get("COMPILE_DIT") == "1"
|
COMPILE_MMDIT_BLOCK = False #os.environ.get("COMPILE_DIT") == "1"
|
||||||
@ -60,8 +55,9 @@ class AsymmetricAttention(nn.Module):
|
|||||||
attend_to_padding: bool = False,
|
attend_to_padding: bool = False,
|
||||||
softmax_scale: Optional[float] = None,
|
softmax_scale: Optional[float] = None,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
attention_mode: str = "sdpa",
|
clip_feat_dim: Optional[int] = None,
|
||||||
|
pooled_caption_mlp_bias: bool = True,
|
||||||
|
use_transformer_engine: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim_x = dim_x
|
self.dim_x = dim_x
|
||||||
@ -72,7 +68,6 @@ class AsymmetricAttention(nn.Module):
|
|||||||
self.update_y = update_y
|
self.update_y = update_y
|
||||||
self.attend_to_padding = attend_to_padding
|
self.attend_to_padding = attend_to_padding
|
||||||
self.softmax_scale = softmax_scale
|
self.softmax_scale = softmax_scale
|
||||||
self.attention_mode = attention_mode
|
|
||||||
if dim_x % num_heads != 0:
|
if dim_x % num_heads != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"dim_x={dim_x} should be divisible by num_heads={num_heads}"
|
f"dim_x={dim_x} should be divisible by num_heads={num_heads}"
|
||||||
@ -165,43 +160,6 @@ class AsymmetricAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return qkv
|
return qkv
|
||||||
|
|
||||||
def flash_attention(self, qkv, cu_seqlens, max_seqlen_in_batch, total, local_dim):
|
|
||||||
with torch.autocast("cuda", enabled=False):
|
|
||||||
out: torch.Tensor = flash_attn_varlen_qkvpacked_func(
|
|
||||||
qkv,
|
|
||||||
cu_seqlens=cu_seqlens,
|
|
||||||
max_seqlen=max_seqlen_in_batch,
|
|
||||||
dropout_p=0.0,
|
|
||||||
softmax_scale=self.softmax_scale,
|
|
||||||
) # (total, local_heads, head_dim)
|
|
||||||
return out.view(total, local_dim)
|
|
||||||
|
|
||||||
def sdpa_attention(self, qkv):
|
|
||||||
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
|
|
||||||
with torch.autocast("cuda", enabled=False):
|
|
||||||
out = F.scaled_dot_product_attention(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
attn_mask=None,
|
|
||||||
dropout_p=0.0,
|
|
||||||
is_causal=False
|
|
||||||
)
|
|
||||||
return rearrange(out, 'b h s d -> s (b h d)')
|
|
||||||
|
|
||||||
def sage_attention(self, qkv):
|
|
||||||
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
|
|
||||||
with torch.autocast("cuda", enabled=False):
|
|
||||||
out = sageattn(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
attn_mask=None,
|
|
||||||
dropout_p=0.0,
|
|
||||||
is_causal=False
|
|
||||||
)
|
|
||||||
return rearrange(out, 'b h s d -> s (b h d)')
|
|
||||||
|
|
||||||
@torch.compiler.disable()
|
@torch.compiler.disable()
|
||||||
def run_attention(
|
def run_attention(
|
||||||
@ -222,13 +180,34 @@ class AsymmetricAttention(nn.Module):
|
|||||||
local_dim = local_heads * self.head_dim
|
local_dim = local_heads * self.head_dim
|
||||||
total = qkv.size(0)
|
total = qkv.size(0)
|
||||||
|
|
||||||
if self.attention_mode == "flash_attn":
|
if FLASH_ATTN_IS_AVAILABLE:
|
||||||
out = self.flash_attention(qkv, cu_seqlens, max_seqlen_in_batch, total, local_dim)
|
with torch.autocast("cuda", enabled=False):
|
||||||
elif self.attention_mode == "sdpa":
|
out: torch.Tensor = flash_attn_varlen_qkvpacked_func(
|
||||||
out = self.sdpa_attention(qkv)
|
qkv,
|
||||||
elif self.attention_mode == "sage_attn":
|
cu_seqlens=cu_seqlens,
|
||||||
out = self.sage_attention(qkv)
|
max_seqlen=max_seqlen_in_batch,
|
||||||
|
dropout_p=0.0,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
|
) # (total, local_heads, head_dim)
|
||||||
|
out = out.view(total, local_dim)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Flash attention is currently required.")
|
||||||
|
print("qkv: ",qkv.shape, qkv.dtype, qkv.device)
|
||||||
|
expected_size = 2 * 44520 * 3 * 24 * 128
|
||||||
|
actual_size = qkv.numel()
|
||||||
|
print(f"Expected size: {expected_size}, Actual size: {actual_size}")
|
||||||
|
q, k, v = qkv.reshape(B, N, 3, local_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||||
|
with torch.autocast("cuda", enabled=False):
|
||||||
|
out = F.scaled_dot_product_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
attn_mask=None,
|
||||||
|
dropout_p=0.0,
|
||||||
|
is_causal=False
|
||||||
|
)
|
||||||
|
out = out.transpose(1, 2).reshape(B, -1, local_heads * self.head_dim)
|
||||||
|
|
||||||
x, y = pad_and_split_xy(out, valid_token_indices, B, N, L, qkv.dtype)
|
x, y = pad_and_split_xy(out, valid_token_indices, B, N, L, qkv.dtype)
|
||||||
assert x.size() == (B, N, local_dim)
|
assert x.size() == (B, N, local_dim)
|
||||||
assert y.size() == (B, L, local_dim)
|
assert y.size() == (B, L, local_dim)
|
||||||
@ -305,20 +284,18 @@ class AsymmetricJointBlock(nn.Module):
|
|||||||
mlp_ratio_y: float = 4.0, # Ratio of hidden size to d_model for MLP for text tokens.
|
mlp_ratio_y: float = 4.0, # Ratio of hidden size to d_model for MLP for text tokens.
|
||||||
update_y: bool = True, # Whether to update text tokens in this block.
|
update_y: bool = True, # Whether to update text tokens in this block.
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
attention_mode: str = "sdpa",
|
|
||||||
**block_kwargs,
|
**block_kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.update_y = update_y
|
self.update_y = update_y
|
||||||
self.hidden_size_x = hidden_size_x
|
self.hidden_size_x = hidden_size_x
|
||||||
self.hidden_size_y = hidden_size_y
|
self.hidden_size_y = hidden_size_y
|
||||||
self.attention_mode = attention_mode
|
|
||||||
self.mod_x = nn.Linear(hidden_size_x, 4 * hidden_size_x, device=device)
|
self.mod_x = nn.Linear(hidden_size_x, 4 * hidden_size_x, device=device)
|
||||||
if self.update_y:
|
if self.update_y:
|
||||||
self.mod_y = nn.Linear(hidden_size_x, 4 * hidden_size_y, device=device)
|
self.mod_y = nn.Linear(hidden_size_x, 4 * hidden_size_y, device=device)
|
||||||
else:
|
else:
|
||||||
self.mod_y = nn.Linear(hidden_size_x, hidden_size_y, device=device)
|
self.mod_y = nn.Linear(hidden_size_x, hidden_size_y, device=device)
|
||||||
|
|
||||||
# Self-attention:
|
# Self-attention:
|
||||||
self.attn = AsymmetricAttention(
|
self.attn = AsymmetricAttention(
|
||||||
hidden_size_x,
|
hidden_size_x,
|
||||||
@ -326,7 +303,6 @@ class AsymmetricJointBlock(nn.Module):
|
|||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
update_y=update_y,
|
update_y=update_y,
|
||||||
device=device,
|
device=device,
|
||||||
attention_mode=attention_mode,
|
|
||||||
**block_kwargs,
|
**block_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -473,7 +449,6 @@ class AsymmDiTJoint(nn.Module):
|
|||||||
use_extended_posenc: bool = False,
|
use_extended_posenc: bool = False,
|
||||||
rope_theta: float = 10000.0,
|
rope_theta: float = 10000.0,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
attention_mode: str = "sdpa",
|
|
||||||
**block_kwargs,
|
**block_kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -541,7 +516,6 @@ class AsymmDiTJoint(nn.Module):
|
|||||||
mlp_ratio_y=mlp_ratio_y,
|
mlp_ratio_y=mlp_ratio_y,
|
||||||
update_y=update_y,
|
update_y=update_y,
|
||||||
device=device,
|
device=device,
|
||||||
attention_mode=attention_mode,
|
|
||||||
**block_kwargs,
|
**block_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -119,7 +119,6 @@ class T2VSynthMochiModel:
|
|||||||
dit_checkpoint_path: str,
|
dit_checkpoint_path: str,
|
||||||
weight_dtype: torch.dtype = torch.float8_e4m3fn,
|
weight_dtype: torch.dtype = torch.float8_e4m3fn,
|
||||||
fp8_fastmode: bool = False,
|
fp8_fastmode: bool = False,
|
||||||
attention_mode: str = "sdpa"
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = device
|
self.device = device
|
||||||
@ -145,7 +144,6 @@ class T2VSynthMochiModel:
|
|||||||
t5_feat_dim=4096,
|
t5_feat_dim=4096,
|
||||||
t5_token_length=256,
|
t5_token_length=256,
|
||||||
rope_theta=10000.0,
|
rope_theta=10000.0,
|
||||||
attention_mode=attention_mode,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
|
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
|
||||||
|
|||||||
6
nodes.py
6
nodes.py
@ -60,8 +60,7 @@ class DownloadAndLoadMochiModel:
|
|||||||
{"tooltip": "Downloads from 'https://huggingface.co/Kijai/Mochi_preview_comfy' to 'models/vae/mochi'", },
|
{"tooltip": "Downloads from 'https://huggingface.co/Kijai/Mochi_preview_comfy' to 'models/vae/mochi'", },
|
||||||
),
|
),
|
||||||
"precision": (["fp8_e4m3fn","fp8_e4m3fn_fast","fp16", "fp32", "bf16"],
|
"precision": (["fp8_e4m3fn","fp8_e4m3fn_fast","fp16", "fp32", "bf16"],
|
||||||
{"default": "fp8_e4m3fn", }),
|
{"default": "fp8_e4m3fn", }
|
||||||
"attention_mode": (["sdpa","flash_attn","sage_attn"],
|
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -72,7 +71,7 @@ class DownloadAndLoadMochiModel:
|
|||||||
CATEGORY = "MochiWrapper"
|
CATEGORY = "MochiWrapper"
|
||||||
DESCRIPTION = "Downloads and loads the selected Mochi model from Huggingface"
|
DESCRIPTION = "Downloads and loads the selected Mochi model from Huggingface"
|
||||||
|
|
||||||
def loadmodel(self, model, vae, precision, attention_mode):
|
def loadmodel(self, model, vae, precision):
|
||||||
|
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
@ -116,7 +115,6 @@ class DownloadAndLoadMochiModel:
|
|||||||
dit_checkpoint_path=model_path,
|
dit_checkpoint_path=model_path,
|
||||||
weight_dtype=dtype,
|
weight_dtype=dtype,
|
||||||
fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,
|
fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,
|
||||||
attention_mode=attention_mode
|
|
||||||
)
|
)
|
||||||
with (init_empty_weights() if is_accelerate_available else nullcontext()):
|
with (init_empty_weights() if is_accelerate_available else nullcontext()):
|
||||||
vae = Decoder(
|
vae = Decoder(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user