cleanup, possibly support older GPUs

This commit is contained in:
kijai 2024-10-24 14:27:11 +03:00
parent 257c526125
commit d699fae213
3 changed files with 51 additions and 27 deletions

View File

@ -6,7 +6,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from torch.nn.attention import sdpa_kernel from torch.nn.attention import sdpa_kernel, SDPBackend
from .context_parallel import all_to_all_collect_tokens, all_to_all_collect_heads, all_gather, get_cp_rank_size, is_cp_active from .context_parallel import all_to_all_collect_tokens, all_to_all_collect_heads, all_gather, get_cp_rank_size, is_cp_active
from .layers import ( from .layers import (
@ -45,6 +45,12 @@ except ImportError:
COMPILE_FINAL_LAYER = False #os.environ.get("COMPILE_DIT") == "1" COMPILE_FINAL_LAYER = False #os.environ.get("COMPILE_DIT") == "1"
COMPILE_MMDIT_BLOCK = False #os.environ.get("COMPILE_DIT") == "1" COMPILE_MMDIT_BLOCK = False #os.environ.get("COMPILE_DIT") == "1"
backends = []
if torch.cuda.get_device_properties(0).major < 7:
backends.append(SDPBackend.MATH)
else:
backends.append(SDPBackend.EFFICIENT_ATTENTION)
class AsymmetricAttention(nn.Module): class AsymmetricAttention(nn.Module):
def __init__( def __init__(
@ -180,15 +186,16 @@ class AsymmetricAttention(nn.Module):
def sdpa_attention(self, qkv): def sdpa_attention(self, qkv):
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1) q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
with torch.autocast("cuda", enabled=False): with torch.autocast("cuda", enabled=False):
out = F.scaled_dot_product_attention( with sdpa_kernel(backends):
q, out = F.scaled_dot_product_attention(
k, q,
v, k,
attn_mask=None, v,
dropout_p=0.0, attn_mask=None,
is_causal=False dropout_p=0.0,
) is_causal=False
return rearrange(out, 'b h s d -> s (b h d)') )
return rearrange(out, 'b h s d -> s (b h d)')
def sage_attention(self, qkv): def sage_attention(self, qkv):
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1) q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
@ -202,6 +209,19 @@ class AsymmetricAttention(nn.Module):
is_causal=False is_causal=False
) )
return rearrange(out, 'b h s d -> s (b h d)') return rearrange(out, 'b h s d -> s (b h d)')
def comfy_attention(self, qkv):
from comfy.ldm.modules.attention import optimized_attention
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
with torch.autocast("cuda", enabled=False):
out = optimized_attention(
q,
k,
v,
heads = self.num_heads,
skip_reshape=True
)
return out.squeeze(0)
@torch.compiler.disable() @torch.compiler.disable()
def run_attention( def run_attention(
@ -228,6 +248,8 @@ class AsymmetricAttention(nn.Module):
out = self.sdpa_attention(qkv) out = self.sdpa_attention(qkv)
elif self.attention_mode == "sage_attn": elif self.attention_mode == "sage_attn":
out = self.sage_attention(qkv) out = self.sage_attention(qkv)
elif self.attention_mode == "comfy":
out = self.comfy_attention(qkv)
x, y = pad_and_split_xy(out, valid_token_indices, B, N, L, qkv.dtype) x, y = pad_and_split_xy(out, valid_token_indices, B, N, L, qkv.dtype)
assert x.size() == (B, N, local_dim) assert x.size() == (B, N, local_dim)
@ -642,7 +664,7 @@ class AsymmDiTJoint(nn.Module):
# Use EFFICIENT_ATTENTION backend for T5 pooling, since we have a mask. # Use EFFICIENT_ATTENTION backend for T5 pooling, since we have a mask.
# Have to call sdpa_kernel outside of a torch.compile region. # Have to call sdpa_kernel outside of a torch.compile region.
with sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION): with sdpa_kernel(backends):
x, c, y_feat, rope_cos, rope_sin = self.prepare( x, c, y_feat, rope_cos, rope_sin = self.prepare(
x, sigma, y_feat[0], y_mask[0] x, sigma, y_feat[0], y_mask[0]
) )

View File

@ -211,8 +211,7 @@ class T2VSynthMochiModel:
if isinstance(sample[key], torch.Tensor): if isinstance(sample[key], torch.Tensor):
sample[key] = sample[key].to(self.device, non_blocking=True) sample[key] = sample[key].to(self.device, non_blocking=True)
@torch.inference_mode(mode=True) def run(self, args):
def run(self, args, stream_results):
torch.manual_seed(args["seed"]) torch.manual_seed(args["seed"])
torch.cuda.manual_seed(args["seed"]) torch.cuda.manual_seed(args["seed"])

View File

@ -61,7 +61,7 @@ class DownloadAndLoadMochiModel:
), ),
"precision": (["fp8_e4m3fn","fp8_e4m3fn_fast","fp16", "fp32", "bf16"], "precision": (["fp8_e4m3fn","fp8_e4m3fn_fast","fp16", "fp32", "bf16"],
{"default": "fp8_e4m3fn", }), {"default": "fp8_e4m3fn", }),
"attention_mode": (["sdpa","flash_attn","sage_attn"], "attention_mode": (["sdpa","flash_attn","sage_attn", "comfy"],
), ),
}, },
} }
@ -167,7 +167,7 @@ class MochiTextEncode:
max_tokens = 256 max_tokens = 256
load_device = mm.text_encoder_device() load_device = mm.text_encoder_device()
offload_device = mm.text_encoder_offload_device() offload_device = mm.text_encoder_offload_device()
#print(clip.tokenizer.t5xxl)
clip.tokenizer.t5xxl.pad_to_max_length = True clip.tokenizer.t5xxl.pad_to_max_length = True
clip.tokenizer.t5xxl.max_length = max_tokens clip.tokenizer.t5xxl.max_length = max_tokens
clip.cond_stage_model.t5xxl.return_attention_masks = True clip.cond_stage_model.t5xxl.return_attention_masks = True
@ -176,8 +176,10 @@ class MochiTextEncode:
clip.cond_stage_model.to(load_device) clip.cond_stage_model.to(load_device)
tokens = clip.tokenizer.t5xxl.tokenize_with_weights(prompt, return_word_ids=True) tokens = clip.tokenizer.t5xxl.tokenize_with_weights(prompt, return_word_ids=True)
embeds, _, attention_mask = clip.cond_stage_model.t5xxl.encode_token_weights(tokens) try:
embeds, _, attention_mask = clip.cond_stage_model.t5xxl.encode_token_weights(tokens)
except:
NotImplementedError("Failed to get attention mask from T5, is your ComfyUI up to date?")
if embeds.shape[1] > 256: if embeds.shape[1] > 256:
raise ValueError(f"Prompt is too long, max tokens supported is {max_tokens} or less, got {embeds.shape[1]}") raise ValueError(f"Prompt is too long, max tokens supported is {max_tokens} or less, got {embeds.shape[1]}")
@ -234,7 +236,7 @@ class MochiSampler:
"negative_embeds": negative, "negative_embeds": negative,
"seed": seed, "seed": seed,
} }
latents = model.run(args, stream_results=False) latents = model.run(args)
mm.soft_empty_cache() mm.soft_empty_cache()
@ -265,6 +267,7 @@ class MochiDecode:
tile_overlap_factor_width, auto_tile_size, frame_batch_size): tile_overlap_factor_width, auto_tile_size, frame_batch_size):
device = mm.get_torch_device() device = mm.get_torch_device()
offload_device = mm.unet_offload_device() offload_device = mm.unet_offload_device()
intermediate_device = mm.intermediate_device()
samples = samples["samples"] samples = samples["samples"]
samples = samples.to(torch.bfloat16).to(device) samples = samples.to(torch.bfloat16).to(device)
@ -347,23 +350,23 @@ class MochiDecode:
vae.to(device) vae.to(device)
with torch.autocast(mm.get_autocast_device(device), dtype=torch.bfloat16): with torch.autocast(mm.get_autocast_device(device), dtype=torch.bfloat16):
if enable_vae_tiling and frame_batch_size > T: if enable_vae_tiling and frame_batch_size > T:
logging.warning(f"Frame batch size is larger than the number of samples ({T}), disabling tiling") logging.warning(f"Frame batch size is larger than the number of samples, setting to {T}")
samples = vae(samples) frame_batch_size = T
frames = decode_tiled(samples)
elif not enable_vae_tiling: elif not enable_vae_tiling:
logging.warning("Attempting to decode without tiling, very memory intensive") logging.warning("Attempting to decode without tiling, very memory intensive")
samples = vae(samples) frames = vae(samples)
else: else:
logging.info("Decoding with tiling") logging.info("Decoding with tiling")
samples = decode_tiled(samples) frames = decode_tiled(samples)
vae.to(offload_device) vae.to(offload_device)
samples = samples.float() frames = frames.float()
samples = (samples + 1.0) / 2.0 frames = (frames + 1.0) / 2.0
samples.clamp_(0.0, 1.0) frames.clamp_(0.0, 1.0)
frames = rearrange(samples, "b c t h w -> (t b) h w c").cpu().float() frames = rearrange(frames, "b c t h w -> (t b) h w c").to(intermediate_device)
#print(frames.shape)
return (frames,) return (frames,)