diff --git a/mochi_preview/dit/joint_model/asymm_models_joint.py b/mochi_preview/dit/joint_model/asymm_models_joint.py index b0f89cc..4150dc8 100644 --- a/mochi_preview/dit/joint_model/asymm_models_joint.py +++ b/mochi_preview/dit/joint_model/asymm_models_joint.py @@ -6,7 +6,7 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from torch.nn.attention import sdpa_kernel +from torch.nn.attention import sdpa_kernel, SDPBackend from .context_parallel import all_to_all_collect_tokens, all_to_all_collect_heads, all_gather, get_cp_rank_size, is_cp_active from .layers import ( @@ -45,6 +45,12 @@ except ImportError: COMPILE_FINAL_LAYER = False #os.environ.get("COMPILE_DIT") == "1" COMPILE_MMDIT_BLOCK = False #os.environ.get("COMPILE_DIT") == "1" +backends = [] +if torch.cuda.get_device_properties(0).major < 7: + backends.append(SDPBackend.MATH) +else: + backends.append(SDPBackend.EFFICIENT_ATTENTION) + class AsymmetricAttention(nn.Module): def __init__( @@ -180,15 +186,16 @@ class AsymmetricAttention(nn.Module): def sdpa_attention(self, qkv): q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1) with torch.autocast("cuda", enabled=False): - out = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - dropout_p=0.0, - is_causal=False - ) - return rearrange(out, 'b h s d -> s (b h d)') + with sdpa_kernel(backends): + out = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + dropout_p=0.0, + is_causal=False + ) + return rearrange(out, 'b h s d -> s (b h d)') def sage_attention(self, qkv): q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1) @@ -202,6 +209,19 @@ class AsymmetricAttention(nn.Module): is_causal=False ) return rearrange(out, 'b h s d -> s (b h d)') + + def comfy_attention(self, qkv): + from comfy.ldm.modules.attention import optimized_attention + q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1) + with torch.autocast("cuda", enabled=False): + out = optimized_attention( + q, + k, + v, + heads = self.num_heads, + skip_reshape=True + ) + return out.squeeze(0) @torch.compiler.disable() def run_attention( @@ -228,6 +248,8 @@ class AsymmetricAttention(nn.Module): out = self.sdpa_attention(qkv) elif self.attention_mode == "sage_attn": out = self.sage_attention(qkv) + elif self.attention_mode == "comfy": + out = self.comfy_attention(qkv) x, y = pad_and_split_xy(out, valid_token_indices, B, N, L, qkv.dtype) assert x.size() == (B, N, local_dim) @@ -642,7 +664,7 @@ class AsymmDiTJoint(nn.Module): # Use EFFICIENT_ATTENTION backend for T5 pooling, since we have a mask. # Have to call sdpa_kernel outside of a torch.compile region. - with sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION): + with sdpa_kernel(backends): x, c, y_feat, rope_cos, rope_sin = self.prepare( x, sigma, y_feat[0], y_mask[0] ) diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index de6ea8d..116665f 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -211,8 +211,7 @@ class T2VSynthMochiModel: if isinstance(sample[key], torch.Tensor): sample[key] = sample[key].to(self.device, non_blocking=True) - @torch.inference_mode(mode=True) - def run(self, args, stream_results): + def run(self, args): torch.manual_seed(args["seed"]) torch.cuda.manual_seed(args["seed"]) diff --git a/nodes.py b/nodes.py index 43aa1b2..7d19db2 100644 --- a/nodes.py +++ b/nodes.py @@ -61,7 +61,7 @@ class DownloadAndLoadMochiModel: ), "precision": (["fp8_e4m3fn","fp8_e4m3fn_fast","fp16", "fp32", "bf16"], {"default": "fp8_e4m3fn", }), - "attention_mode": (["sdpa","flash_attn","sage_attn"], + "attention_mode": (["sdpa","flash_attn","sage_attn", "comfy"], ), }, } @@ -167,7 +167,7 @@ class MochiTextEncode: max_tokens = 256 load_device = mm.text_encoder_device() offload_device = mm.text_encoder_offload_device() - #print(clip.tokenizer.t5xxl) + clip.tokenizer.t5xxl.pad_to_max_length = True clip.tokenizer.t5xxl.max_length = max_tokens clip.cond_stage_model.t5xxl.return_attention_masks = True @@ -176,8 +176,10 @@ class MochiTextEncode: clip.cond_stage_model.to(load_device) tokens = clip.tokenizer.t5xxl.tokenize_with_weights(prompt, return_word_ids=True) - embeds, _, attention_mask = clip.cond_stage_model.t5xxl.encode_token_weights(tokens) - + try: + embeds, _, attention_mask = clip.cond_stage_model.t5xxl.encode_token_weights(tokens) + except: + NotImplementedError("Failed to get attention mask from T5, is your ComfyUI up to date?") if embeds.shape[1] > 256: raise ValueError(f"Prompt is too long, max tokens supported is {max_tokens} or less, got {embeds.shape[1]}") @@ -234,7 +236,7 @@ class MochiSampler: "negative_embeds": negative, "seed": seed, } - latents = model.run(args, stream_results=False) + latents = model.run(args) mm.soft_empty_cache() @@ -265,6 +267,7 @@ class MochiDecode: tile_overlap_factor_width, auto_tile_size, frame_batch_size): device = mm.get_torch_device() offload_device = mm.unet_offload_device() + intermediate_device = mm.intermediate_device() samples = samples["samples"] samples = samples.to(torch.bfloat16).to(device) @@ -347,23 +350,23 @@ class MochiDecode: vae.to(device) with torch.autocast(mm.get_autocast_device(device), dtype=torch.bfloat16): if enable_vae_tiling and frame_batch_size > T: - logging.warning(f"Frame batch size is larger than the number of samples ({T}), disabling tiling") - samples = vae(samples) + logging.warning(f"Frame batch size is larger than the number of samples, setting to {T}") + frame_batch_size = T + frames = decode_tiled(samples) elif not enable_vae_tiling: logging.warning("Attempting to decode without tiling, very memory intensive") - samples = vae(samples) + frames = vae(samples) else: logging.info("Decoding with tiling") - samples = decode_tiled(samples) + frames = decode_tiled(samples) vae.to(offload_device) - samples = samples.float() - samples = (samples + 1.0) / 2.0 - samples.clamp_(0.0, 1.0) + frames = frames.float() + frames = (frames + 1.0) / 2.0 + frames.clamp_(0.0, 1.0) - frames = rearrange(samples, "b c t h w -> (t b) h w c").cpu().float() - #print(frames.shape) + frames = rearrange(frames, "b c t h w -> (t b) h w c").to(intermediate_device) return (frames,)