cleanup, possibly support older GPUs

This commit is contained in:
kijai 2024-10-24 14:27:11 +03:00
parent 257c526125
commit d699fae213
3 changed files with 51 additions and 27 deletions

View File

@ -6,7 +6,7 @@ import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn.attention import sdpa_kernel
from torch.nn.attention import sdpa_kernel, SDPBackend
from .context_parallel import all_to_all_collect_tokens, all_to_all_collect_heads, all_gather, get_cp_rank_size, is_cp_active
from .layers import (
@ -45,6 +45,12 @@ except ImportError:
COMPILE_FINAL_LAYER = False #os.environ.get("COMPILE_DIT") == "1"
COMPILE_MMDIT_BLOCK = False #os.environ.get("COMPILE_DIT") == "1"
backends = []
if torch.cuda.get_device_properties(0).major < 7:
backends.append(SDPBackend.MATH)
else:
backends.append(SDPBackend.EFFICIENT_ATTENTION)
class AsymmetricAttention(nn.Module):
def __init__(
@ -180,15 +186,16 @@ class AsymmetricAttention(nn.Module):
def sdpa_attention(self, qkv):
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
with torch.autocast("cuda", enabled=False):
out = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=None,
dropout_p=0.0,
is_causal=False
)
return rearrange(out, 'b h s d -> s (b h d)')
with sdpa_kernel(backends):
out = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=None,
dropout_p=0.0,
is_causal=False
)
return rearrange(out, 'b h s d -> s (b h d)')
def sage_attention(self, qkv):
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
@ -202,6 +209,19 @@ class AsymmetricAttention(nn.Module):
is_causal=False
)
return rearrange(out, 'b h s d -> s (b h d)')
def comfy_attention(self, qkv):
from comfy.ldm.modules.attention import optimized_attention
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
with torch.autocast("cuda", enabled=False):
out = optimized_attention(
q,
k,
v,
heads = self.num_heads,
skip_reshape=True
)
return out.squeeze(0)
@torch.compiler.disable()
def run_attention(
@ -228,6 +248,8 @@ class AsymmetricAttention(nn.Module):
out = self.sdpa_attention(qkv)
elif self.attention_mode == "sage_attn":
out = self.sage_attention(qkv)
elif self.attention_mode == "comfy":
out = self.comfy_attention(qkv)
x, y = pad_and_split_xy(out, valid_token_indices, B, N, L, qkv.dtype)
assert x.size() == (B, N, local_dim)
@ -642,7 +664,7 @@ class AsymmDiTJoint(nn.Module):
# Use EFFICIENT_ATTENTION backend for T5 pooling, since we have a mask.
# Have to call sdpa_kernel outside of a torch.compile region.
with sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION):
with sdpa_kernel(backends):
x, c, y_feat, rope_cos, rope_sin = self.prepare(
x, sigma, y_feat[0], y_mask[0]
)

View File

@ -211,8 +211,7 @@ class T2VSynthMochiModel:
if isinstance(sample[key], torch.Tensor):
sample[key] = sample[key].to(self.device, non_blocking=True)
@torch.inference_mode(mode=True)
def run(self, args, stream_results):
def run(self, args):
torch.manual_seed(args["seed"])
torch.cuda.manual_seed(args["seed"])

View File

@ -61,7 +61,7 @@ class DownloadAndLoadMochiModel:
),
"precision": (["fp8_e4m3fn","fp8_e4m3fn_fast","fp16", "fp32", "bf16"],
{"default": "fp8_e4m3fn", }),
"attention_mode": (["sdpa","flash_attn","sage_attn"],
"attention_mode": (["sdpa","flash_attn","sage_attn", "comfy"],
),
},
}
@ -167,7 +167,7 @@ class MochiTextEncode:
max_tokens = 256
load_device = mm.text_encoder_device()
offload_device = mm.text_encoder_offload_device()
#print(clip.tokenizer.t5xxl)
clip.tokenizer.t5xxl.pad_to_max_length = True
clip.tokenizer.t5xxl.max_length = max_tokens
clip.cond_stage_model.t5xxl.return_attention_masks = True
@ -176,8 +176,10 @@ class MochiTextEncode:
clip.cond_stage_model.to(load_device)
tokens = clip.tokenizer.t5xxl.tokenize_with_weights(prompt, return_word_ids=True)
embeds, _, attention_mask = clip.cond_stage_model.t5xxl.encode_token_weights(tokens)
try:
embeds, _, attention_mask = clip.cond_stage_model.t5xxl.encode_token_weights(tokens)
except:
NotImplementedError("Failed to get attention mask from T5, is your ComfyUI up to date?")
if embeds.shape[1] > 256:
raise ValueError(f"Prompt is too long, max tokens supported is {max_tokens} or less, got {embeds.shape[1]}")
@ -234,7 +236,7 @@ class MochiSampler:
"negative_embeds": negative,
"seed": seed,
}
latents = model.run(args, stream_results=False)
latents = model.run(args)
mm.soft_empty_cache()
@ -265,6 +267,7 @@ class MochiDecode:
tile_overlap_factor_width, auto_tile_size, frame_batch_size):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
intermediate_device = mm.intermediate_device()
samples = samples["samples"]
samples = samples.to(torch.bfloat16).to(device)
@ -347,23 +350,23 @@ class MochiDecode:
vae.to(device)
with torch.autocast(mm.get_autocast_device(device), dtype=torch.bfloat16):
if enable_vae_tiling and frame_batch_size > T:
logging.warning(f"Frame batch size is larger than the number of samples ({T}), disabling tiling")
samples = vae(samples)
logging.warning(f"Frame batch size is larger than the number of samples, setting to {T}")
frame_batch_size = T
frames = decode_tiled(samples)
elif not enable_vae_tiling:
logging.warning("Attempting to decode without tiling, very memory intensive")
samples = vae(samples)
frames = vae(samples)
else:
logging.info("Decoding with tiling")
samples = decode_tiled(samples)
frames = decode_tiled(samples)
vae.to(offload_device)
samples = samples.float()
samples = (samples + 1.0) / 2.0
samples.clamp_(0.0, 1.0)
frames = frames.float()
frames = (frames + 1.0) / 2.0
frames.clamp_(0.0, 1.0)
frames = rearrange(samples, "b c t h w -> (t b) h w c").cpu().float()
#print(frames.shape)
frames = rearrange(frames, "b c t h w -> (t b) h w c").to(intermediate_device)
return (frames,)