cleanup, possibly support older GPUs
This commit is contained in:
parent
257c526125
commit
d699fae213
@ -6,7 +6,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from torch.nn.attention import sdpa_kernel
|
||||
from torch.nn.attention import sdpa_kernel, SDPBackend
|
||||
|
||||
from .context_parallel import all_to_all_collect_tokens, all_to_all_collect_heads, all_gather, get_cp_rank_size, is_cp_active
|
||||
from .layers import (
|
||||
@ -45,6 +45,12 @@ except ImportError:
|
||||
COMPILE_FINAL_LAYER = False #os.environ.get("COMPILE_DIT") == "1"
|
||||
COMPILE_MMDIT_BLOCK = False #os.environ.get("COMPILE_DIT") == "1"
|
||||
|
||||
backends = []
|
||||
if torch.cuda.get_device_properties(0).major < 7:
|
||||
backends.append(SDPBackend.MATH)
|
||||
else:
|
||||
backends.append(SDPBackend.EFFICIENT_ATTENTION)
|
||||
|
||||
|
||||
class AsymmetricAttention(nn.Module):
|
||||
def __init__(
|
||||
@ -180,15 +186,16 @@ class AsymmetricAttention(nn.Module):
|
||||
def sdpa_attention(self, qkv):
|
||||
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
|
||||
with torch.autocast("cuda", enabled=False):
|
||||
out = F.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=False
|
||||
)
|
||||
return rearrange(out, 'b h s d -> s (b h d)')
|
||||
with sdpa_kernel(backends):
|
||||
out = F.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=False
|
||||
)
|
||||
return rearrange(out, 'b h s d -> s (b h d)')
|
||||
|
||||
def sage_attention(self, qkv):
|
||||
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
|
||||
@ -202,6 +209,19 @@ class AsymmetricAttention(nn.Module):
|
||||
is_causal=False
|
||||
)
|
||||
return rearrange(out, 'b h s d -> s (b h d)')
|
||||
|
||||
def comfy_attention(self, qkv):
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
|
||||
with torch.autocast("cuda", enabled=False):
|
||||
out = optimized_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
heads = self.num_heads,
|
||||
skip_reshape=True
|
||||
)
|
||||
return out.squeeze(0)
|
||||
|
||||
@torch.compiler.disable()
|
||||
def run_attention(
|
||||
@ -228,6 +248,8 @@ class AsymmetricAttention(nn.Module):
|
||||
out = self.sdpa_attention(qkv)
|
||||
elif self.attention_mode == "sage_attn":
|
||||
out = self.sage_attention(qkv)
|
||||
elif self.attention_mode == "comfy":
|
||||
out = self.comfy_attention(qkv)
|
||||
|
||||
x, y = pad_and_split_xy(out, valid_token_indices, B, N, L, qkv.dtype)
|
||||
assert x.size() == (B, N, local_dim)
|
||||
@ -642,7 +664,7 @@ class AsymmDiTJoint(nn.Module):
|
||||
|
||||
# Use EFFICIENT_ATTENTION backend for T5 pooling, since we have a mask.
|
||||
# Have to call sdpa_kernel outside of a torch.compile region.
|
||||
with sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION):
|
||||
with sdpa_kernel(backends):
|
||||
x, c, y_feat, rope_cos, rope_sin = self.prepare(
|
||||
x, sigma, y_feat[0], y_mask[0]
|
||||
)
|
||||
|
||||
@ -211,8 +211,7 @@ class T2VSynthMochiModel:
|
||||
if isinstance(sample[key], torch.Tensor):
|
||||
sample[key] = sample[key].to(self.device, non_blocking=True)
|
||||
|
||||
@torch.inference_mode(mode=True)
|
||||
def run(self, args, stream_results):
|
||||
def run(self, args):
|
||||
torch.manual_seed(args["seed"])
|
||||
torch.cuda.manual_seed(args["seed"])
|
||||
|
||||
|
||||
31
nodes.py
31
nodes.py
@ -61,7 +61,7 @@ class DownloadAndLoadMochiModel:
|
||||
),
|
||||
"precision": (["fp8_e4m3fn","fp8_e4m3fn_fast","fp16", "fp32", "bf16"],
|
||||
{"default": "fp8_e4m3fn", }),
|
||||
"attention_mode": (["sdpa","flash_attn","sage_attn"],
|
||||
"attention_mode": (["sdpa","flash_attn","sage_attn", "comfy"],
|
||||
),
|
||||
},
|
||||
}
|
||||
@ -167,7 +167,7 @@ class MochiTextEncode:
|
||||
max_tokens = 256
|
||||
load_device = mm.text_encoder_device()
|
||||
offload_device = mm.text_encoder_offload_device()
|
||||
#print(clip.tokenizer.t5xxl)
|
||||
|
||||
clip.tokenizer.t5xxl.pad_to_max_length = True
|
||||
clip.tokenizer.t5xxl.max_length = max_tokens
|
||||
clip.cond_stage_model.t5xxl.return_attention_masks = True
|
||||
@ -176,8 +176,10 @@ class MochiTextEncode:
|
||||
clip.cond_stage_model.to(load_device)
|
||||
tokens = clip.tokenizer.t5xxl.tokenize_with_weights(prompt, return_word_ids=True)
|
||||
|
||||
embeds, _, attention_mask = clip.cond_stage_model.t5xxl.encode_token_weights(tokens)
|
||||
|
||||
try:
|
||||
embeds, _, attention_mask = clip.cond_stage_model.t5xxl.encode_token_weights(tokens)
|
||||
except:
|
||||
NotImplementedError("Failed to get attention mask from T5, is your ComfyUI up to date?")
|
||||
|
||||
if embeds.shape[1] > 256:
|
||||
raise ValueError(f"Prompt is too long, max tokens supported is {max_tokens} or less, got {embeds.shape[1]}")
|
||||
@ -234,7 +236,7 @@ class MochiSampler:
|
||||
"negative_embeds": negative,
|
||||
"seed": seed,
|
||||
}
|
||||
latents = model.run(args, stream_results=False)
|
||||
latents = model.run(args)
|
||||
|
||||
mm.soft_empty_cache()
|
||||
|
||||
@ -265,6 +267,7 @@ class MochiDecode:
|
||||
tile_overlap_factor_width, auto_tile_size, frame_batch_size):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
intermediate_device = mm.intermediate_device()
|
||||
samples = samples["samples"]
|
||||
samples = samples.to(torch.bfloat16).to(device)
|
||||
|
||||
@ -347,23 +350,23 @@ class MochiDecode:
|
||||
vae.to(device)
|
||||
with torch.autocast(mm.get_autocast_device(device), dtype=torch.bfloat16):
|
||||
if enable_vae_tiling and frame_batch_size > T:
|
||||
logging.warning(f"Frame batch size is larger than the number of samples ({T}), disabling tiling")
|
||||
samples = vae(samples)
|
||||
logging.warning(f"Frame batch size is larger than the number of samples, setting to {T}")
|
||||
frame_batch_size = T
|
||||
frames = decode_tiled(samples)
|
||||
elif not enable_vae_tiling:
|
||||
logging.warning("Attempting to decode without tiling, very memory intensive")
|
||||
samples = vae(samples)
|
||||
frames = vae(samples)
|
||||
else:
|
||||
logging.info("Decoding with tiling")
|
||||
samples = decode_tiled(samples)
|
||||
frames = decode_tiled(samples)
|
||||
|
||||
vae.to(offload_device)
|
||||
|
||||
samples = samples.float()
|
||||
samples = (samples + 1.0) / 2.0
|
||||
samples.clamp_(0.0, 1.0)
|
||||
frames = frames.float()
|
||||
frames = (frames + 1.0) / 2.0
|
||||
frames.clamp_(0.0, 1.0)
|
||||
|
||||
frames = rearrange(samples, "b c t h w -> (t b) h w c").cpu().float()
|
||||
#print(frames.shape)
|
||||
frames = rearrange(frames, "b c t h w -> (t b) h w c").to(intermediate_device)
|
||||
|
||||
return (frames,)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user