cleanup, possibly support older GPUs
This commit is contained in:
parent
257c526125
commit
d699fae213
@ -6,7 +6,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from torch.nn.attention import sdpa_kernel
|
from torch.nn.attention import sdpa_kernel, SDPBackend
|
||||||
|
|
||||||
from .context_parallel import all_to_all_collect_tokens, all_to_all_collect_heads, all_gather, get_cp_rank_size, is_cp_active
|
from .context_parallel import all_to_all_collect_tokens, all_to_all_collect_heads, all_gather, get_cp_rank_size, is_cp_active
|
||||||
from .layers import (
|
from .layers import (
|
||||||
@ -45,6 +45,12 @@ except ImportError:
|
|||||||
COMPILE_FINAL_LAYER = False #os.environ.get("COMPILE_DIT") == "1"
|
COMPILE_FINAL_LAYER = False #os.environ.get("COMPILE_DIT") == "1"
|
||||||
COMPILE_MMDIT_BLOCK = False #os.environ.get("COMPILE_DIT") == "1"
|
COMPILE_MMDIT_BLOCK = False #os.environ.get("COMPILE_DIT") == "1"
|
||||||
|
|
||||||
|
backends = []
|
||||||
|
if torch.cuda.get_device_properties(0).major < 7:
|
||||||
|
backends.append(SDPBackend.MATH)
|
||||||
|
else:
|
||||||
|
backends.append(SDPBackend.EFFICIENT_ATTENTION)
|
||||||
|
|
||||||
|
|
||||||
class AsymmetricAttention(nn.Module):
|
class AsymmetricAttention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -180,15 +186,16 @@ class AsymmetricAttention(nn.Module):
|
|||||||
def sdpa_attention(self, qkv):
|
def sdpa_attention(self, qkv):
|
||||||
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
|
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
|
||||||
with torch.autocast("cuda", enabled=False):
|
with torch.autocast("cuda", enabled=False):
|
||||||
out = F.scaled_dot_product_attention(
|
with sdpa_kernel(backends):
|
||||||
q,
|
out = F.scaled_dot_product_attention(
|
||||||
k,
|
q,
|
||||||
v,
|
k,
|
||||||
attn_mask=None,
|
v,
|
||||||
dropout_p=0.0,
|
attn_mask=None,
|
||||||
is_causal=False
|
dropout_p=0.0,
|
||||||
)
|
is_causal=False
|
||||||
return rearrange(out, 'b h s d -> s (b h d)')
|
)
|
||||||
|
return rearrange(out, 'b h s d -> s (b h d)')
|
||||||
|
|
||||||
def sage_attention(self, qkv):
|
def sage_attention(self, qkv):
|
||||||
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
|
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
|
||||||
@ -202,6 +209,19 @@ class AsymmetricAttention(nn.Module):
|
|||||||
is_causal=False
|
is_causal=False
|
||||||
)
|
)
|
||||||
return rearrange(out, 'b h s d -> s (b h d)')
|
return rearrange(out, 'b h s d -> s (b h d)')
|
||||||
|
|
||||||
|
def comfy_attention(self, qkv):
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
|
||||||
|
with torch.autocast("cuda", enabled=False):
|
||||||
|
out = optimized_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
heads = self.num_heads,
|
||||||
|
skip_reshape=True
|
||||||
|
)
|
||||||
|
return out.squeeze(0)
|
||||||
|
|
||||||
@torch.compiler.disable()
|
@torch.compiler.disable()
|
||||||
def run_attention(
|
def run_attention(
|
||||||
@ -228,6 +248,8 @@ class AsymmetricAttention(nn.Module):
|
|||||||
out = self.sdpa_attention(qkv)
|
out = self.sdpa_attention(qkv)
|
||||||
elif self.attention_mode == "sage_attn":
|
elif self.attention_mode == "sage_attn":
|
||||||
out = self.sage_attention(qkv)
|
out = self.sage_attention(qkv)
|
||||||
|
elif self.attention_mode == "comfy":
|
||||||
|
out = self.comfy_attention(qkv)
|
||||||
|
|
||||||
x, y = pad_and_split_xy(out, valid_token_indices, B, N, L, qkv.dtype)
|
x, y = pad_and_split_xy(out, valid_token_indices, B, N, L, qkv.dtype)
|
||||||
assert x.size() == (B, N, local_dim)
|
assert x.size() == (B, N, local_dim)
|
||||||
@ -642,7 +664,7 @@ class AsymmDiTJoint(nn.Module):
|
|||||||
|
|
||||||
# Use EFFICIENT_ATTENTION backend for T5 pooling, since we have a mask.
|
# Use EFFICIENT_ATTENTION backend for T5 pooling, since we have a mask.
|
||||||
# Have to call sdpa_kernel outside of a torch.compile region.
|
# Have to call sdpa_kernel outside of a torch.compile region.
|
||||||
with sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION):
|
with sdpa_kernel(backends):
|
||||||
x, c, y_feat, rope_cos, rope_sin = self.prepare(
|
x, c, y_feat, rope_cos, rope_sin = self.prepare(
|
||||||
x, sigma, y_feat[0], y_mask[0]
|
x, sigma, y_feat[0], y_mask[0]
|
||||||
)
|
)
|
||||||
|
|||||||
@ -211,8 +211,7 @@ class T2VSynthMochiModel:
|
|||||||
if isinstance(sample[key], torch.Tensor):
|
if isinstance(sample[key], torch.Tensor):
|
||||||
sample[key] = sample[key].to(self.device, non_blocking=True)
|
sample[key] = sample[key].to(self.device, non_blocking=True)
|
||||||
|
|
||||||
@torch.inference_mode(mode=True)
|
def run(self, args):
|
||||||
def run(self, args, stream_results):
|
|
||||||
torch.manual_seed(args["seed"])
|
torch.manual_seed(args["seed"])
|
||||||
torch.cuda.manual_seed(args["seed"])
|
torch.cuda.manual_seed(args["seed"])
|
||||||
|
|
||||||
|
|||||||
31
nodes.py
31
nodes.py
@ -61,7 +61,7 @@ class DownloadAndLoadMochiModel:
|
|||||||
),
|
),
|
||||||
"precision": (["fp8_e4m3fn","fp8_e4m3fn_fast","fp16", "fp32", "bf16"],
|
"precision": (["fp8_e4m3fn","fp8_e4m3fn_fast","fp16", "fp32", "bf16"],
|
||||||
{"default": "fp8_e4m3fn", }),
|
{"default": "fp8_e4m3fn", }),
|
||||||
"attention_mode": (["sdpa","flash_attn","sage_attn"],
|
"attention_mode": (["sdpa","flash_attn","sage_attn", "comfy"],
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -167,7 +167,7 @@ class MochiTextEncode:
|
|||||||
max_tokens = 256
|
max_tokens = 256
|
||||||
load_device = mm.text_encoder_device()
|
load_device = mm.text_encoder_device()
|
||||||
offload_device = mm.text_encoder_offload_device()
|
offload_device = mm.text_encoder_offload_device()
|
||||||
#print(clip.tokenizer.t5xxl)
|
|
||||||
clip.tokenizer.t5xxl.pad_to_max_length = True
|
clip.tokenizer.t5xxl.pad_to_max_length = True
|
||||||
clip.tokenizer.t5xxl.max_length = max_tokens
|
clip.tokenizer.t5xxl.max_length = max_tokens
|
||||||
clip.cond_stage_model.t5xxl.return_attention_masks = True
|
clip.cond_stage_model.t5xxl.return_attention_masks = True
|
||||||
@ -176,8 +176,10 @@ class MochiTextEncode:
|
|||||||
clip.cond_stage_model.to(load_device)
|
clip.cond_stage_model.to(load_device)
|
||||||
tokens = clip.tokenizer.t5xxl.tokenize_with_weights(prompt, return_word_ids=True)
|
tokens = clip.tokenizer.t5xxl.tokenize_with_weights(prompt, return_word_ids=True)
|
||||||
|
|
||||||
embeds, _, attention_mask = clip.cond_stage_model.t5xxl.encode_token_weights(tokens)
|
try:
|
||||||
|
embeds, _, attention_mask = clip.cond_stage_model.t5xxl.encode_token_weights(tokens)
|
||||||
|
except:
|
||||||
|
NotImplementedError("Failed to get attention mask from T5, is your ComfyUI up to date?")
|
||||||
|
|
||||||
if embeds.shape[1] > 256:
|
if embeds.shape[1] > 256:
|
||||||
raise ValueError(f"Prompt is too long, max tokens supported is {max_tokens} or less, got {embeds.shape[1]}")
|
raise ValueError(f"Prompt is too long, max tokens supported is {max_tokens} or less, got {embeds.shape[1]}")
|
||||||
@ -234,7 +236,7 @@ class MochiSampler:
|
|||||||
"negative_embeds": negative,
|
"negative_embeds": negative,
|
||||||
"seed": seed,
|
"seed": seed,
|
||||||
}
|
}
|
||||||
latents = model.run(args, stream_results=False)
|
latents = model.run(args)
|
||||||
|
|
||||||
mm.soft_empty_cache()
|
mm.soft_empty_cache()
|
||||||
|
|
||||||
@ -265,6 +267,7 @@ class MochiDecode:
|
|||||||
tile_overlap_factor_width, auto_tile_size, frame_batch_size):
|
tile_overlap_factor_width, auto_tile_size, frame_batch_size):
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
|
intermediate_device = mm.intermediate_device()
|
||||||
samples = samples["samples"]
|
samples = samples["samples"]
|
||||||
samples = samples.to(torch.bfloat16).to(device)
|
samples = samples.to(torch.bfloat16).to(device)
|
||||||
|
|
||||||
@ -347,23 +350,23 @@ class MochiDecode:
|
|||||||
vae.to(device)
|
vae.to(device)
|
||||||
with torch.autocast(mm.get_autocast_device(device), dtype=torch.bfloat16):
|
with torch.autocast(mm.get_autocast_device(device), dtype=torch.bfloat16):
|
||||||
if enable_vae_tiling and frame_batch_size > T:
|
if enable_vae_tiling and frame_batch_size > T:
|
||||||
logging.warning(f"Frame batch size is larger than the number of samples ({T}), disabling tiling")
|
logging.warning(f"Frame batch size is larger than the number of samples, setting to {T}")
|
||||||
samples = vae(samples)
|
frame_batch_size = T
|
||||||
|
frames = decode_tiled(samples)
|
||||||
elif not enable_vae_tiling:
|
elif not enable_vae_tiling:
|
||||||
logging.warning("Attempting to decode without tiling, very memory intensive")
|
logging.warning("Attempting to decode without tiling, very memory intensive")
|
||||||
samples = vae(samples)
|
frames = vae(samples)
|
||||||
else:
|
else:
|
||||||
logging.info("Decoding with tiling")
|
logging.info("Decoding with tiling")
|
||||||
samples = decode_tiled(samples)
|
frames = decode_tiled(samples)
|
||||||
|
|
||||||
vae.to(offload_device)
|
vae.to(offload_device)
|
||||||
|
|
||||||
samples = samples.float()
|
frames = frames.float()
|
||||||
samples = (samples + 1.0) / 2.0
|
frames = (frames + 1.0) / 2.0
|
||||||
samples.clamp_(0.0, 1.0)
|
frames.clamp_(0.0, 1.0)
|
||||||
|
|
||||||
frames = rearrange(samples, "b c t h w -> (t b) h w c").cpu().float()
|
frames = rearrange(frames, "b c t h w -> (t b) h w c").to(intermediate_device)
|
||||||
#print(frames.shape)
|
|
||||||
|
|
||||||
return (frames,)
|
return (frames,)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user