cleanup, fix untiled spatial vae decode

This commit is contained in:
kijai 2024-10-30 21:12:34 +02:00
parent 3dce06b28b
commit f0f939b20b
2 changed files with 28 additions and 70 deletions

View File

@ -234,7 +234,6 @@ class T2VSynthMochiModel:
height = args["height"] height = args["height"]
width = args["width"] width = args["width"]
batch_cfg = args["mochi_args"]["batch_cfg"]
sample_steps = args["mochi_args"]["num_inference_steps"] sample_steps = args["mochi_args"]["num_inference_steps"]
cfg_schedule = args["mochi_args"].get("cfg_schedule") cfg_schedule = args["mochi_args"].get("cfg_schedule")
assert ( assert (
@ -247,14 +246,6 @@ class T2VSynthMochiModel:
), f"sigma_schedule must have length {sample_steps + 1}, got {len(sigma_schedule)}" ), f"sigma_schedule must have length {sample_steps + 1}, got {len(sigma_schedule)}"
assert (num_frames - 1) % 6 == 0, f"t - 1 must be divisible by 6, got {num_frames - 1}" assert (num_frames - 1) % 6 == 0, f"t - 1 must be divisible by 6, got {num_frames - 1}"
# if batch_cfg:
# sample_batched = self.get_conditioning(
# [prompt] + [neg_prompt], zero_last_n_prompts=B if neg_prompt == "" else 0
# )
# else:
# sample = self.get_conditioning([prompt], zero_last_n_prompts=0)
# sample_null = self.get_conditioning([neg_prompt] * B, zero_last_n_prompts=B if neg_prompt == "" else 0)
# create z # create z
spatial_downsample = 8 spatial_downsample = 8
temporal_downsample = 6 temporal_downsample = 6
@ -273,45 +264,21 @@ class T2VSynthMochiModel:
dtype=torch.float32, dtype=torch.float32,
) )
if batch_cfg: #WIP sample = {
pos_embeds = args["positive_embeds"]["embeds"].to(self.device) "y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)],
neg_embeds = args["negative_embeds"]["embeds"].to(self.device) "y_feat": [args["positive_embeds"]["embeds"].to(self.device)]
pos_attention_mask = args["positive_embeds"]["attention_mask"].to(self.device) }
neg_attention_mask = args["negative_embeds"]["attention_mask"].to(self.device) sample_null = {
print(neg_embeds.shape) "y_mask": [args["negative_embeds"]["attention_mask"].to(self.device)],
y_feat = torch.cat((pos_embeds, neg_embeds)) "y_feat": [args["negative_embeds"]["embeds"].to(self.device)]
y_mask = torch.cat((pos_attention_mask, neg_attention_mask)) }
zero_last_n_prompts = B# if neg_prompt == "" else 0
y_feat[-zero_last_n_prompts:] = 0
y_mask[-zero_last_n_prompts:] = False
sample_batched = { sample["packed_indices"] = self.get_packed_indices(
"y_mask": [y_mask], sample["y_mask"], **latent_dims
"y_feat": [y_feat] )
} sample_null["packed_indices"] = self.get_packed_indices(
sample_batched["packed_indices"] = self.get_packed_indices( sample_null["y_mask"], **latent_dims
sample_batched["y_mask"], **latent_dims )
)
z = repeat(z, "b ... -> (repeat b) ...", repeat=2)
print("sample_batched y_mask",sample_batched["y_mask"])
print("y_mask type",type(sample_batched["y_mask"])) #<class 'list'>"
print("ymask 0 shape",sample_batched["y_mask"][0].shape)#torch.Size([2, 256])
else:
sample = {
"y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)],
"y_feat": [args["positive_embeds"]["embeds"].to(self.device)]
}
sample_null = {
"y_mask": [args["negative_embeds"]["attention_mask"].to(self.device)],
"y_feat": [args["negative_embeds"]["embeds"].to(self.device)]
}
sample["packed_indices"] = self.get_packed_indices(
sample["y_mask"], **latent_dims
)
sample_null["packed_indices"] = self.get_packed_indices(
sample_null["y_mask"], **latent_dims
)
def model_fn(*, z, sigma, cfg_scale): def model_fn(*, z, sigma, cfg_scale):
self.dit.to(self.device) self.dit.to(self.device)
@ -319,19 +286,15 @@ class T2VSynthMochiModel:
autocast_dtype = torch.float16 autocast_dtype = torch.float16
else: else:
autocast_dtype = torch.bfloat16 autocast_dtype = torch.bfloat16
if batch_cfg:
with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype): nonlocal sample, sample_null
out = self.dit(z, sigma, **sample_batched) with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype):
out_cond, out_uncond = torch.chunk(out, chunks=2, dim=0) if cfg_scale > 1.0:
else: out_cond = self.dit(z, sigma, **sample)
nonlocal sample, sample_null out_uncond = self.dit(z, sigma, **sample_null)
with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype): else:
if cfg_scale > 1.0: out_cond = self.dit(z, sigma, **sample)
out_cond = self.dit(z, sigma, **sample) return out_cond
out_uncond = self.dit(z, sigma, **sample_null)
else:
out_cond = self.dit(z, sigma, **sample)
return out_cond
return out_uncond + cfg_scale * (out_cond - out_uncond) return out_uncond + cfg_scale * (out_cond - out_uncond)
@ -343,18 +306,13 @@ class T2VSynthMochiModel:
# `pred` estimates `z_0 - eps`. # `pred` estimates `z_0 - eps`.
pred = model_fn( pred = model_fn(
z=z, z=z,
sigma=torch.full( sigma=torch.full([B], sigma, device=z.device),
[B] if not batch_cfg else [B * 2], sigma, device=z.device
),
cfg_scale=cfg_schedule[i], cfg_scale=cfg_schedule[i],
) )
pred = pred.to(z) pred = pred.to(z)
z = z + dsigma * pred z = z + dsigma * pred
comfy_pbar.update(1) comfy_pbar.update(1)
if batch_cfg:
z = z[:B]
self.dit.to(self.offload_device) self.dit.to(self.offload_device)
samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std) samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)
logging.info(f"samples shape: {samples.shape}") logging.info(f"samples shape: {samples.shape}")

View File

@ -328,16 +328,16 @@ class MochiTextEncode:
except: except:
NotImplementedError("Failed to get attention mask from T5, is your ComfyUI up to date?") NotImplementedError("Failed to get attention mask from T5, is your ComfyUI up to date?")
except: except:
clip.cond_stage_model.to(load_device) clip.cond_stage_model.to(offload_device)
tokens = clip.tokenizer.tokenize_with_weights(prompt, return_word_ids=True) tokens = clip.tokenizer.tokenize_with_weights(prompt, return_word_ids=True)
embeds, _, attention_mask = clip.cond_stage_model.encode_token_weights(tokens) embeds, _, attention_mask = clip.cond_stage_model.encode_token_weights(tokens)
if embeds.shape[1] > 256: if embeds.shape[1] > 256:
raise ValueError(f"Prompt is too long, max tokens supported is {max_tokens} or less, got {embeds.shape[1]}") raise ValueError(f"Prompt is too long, max tokens supported is {max_tokens} or less, got {embeds.shape[1]}")
embeds *= strength embeds *= strength
if force_offload: if force_offload:
clip.cond_stage_model.to(offload_device) clip.cond_stage_model.to(offload_device)
mm.soft_empty_cache()
t5_embeds = { t5_embeds = {
"embeds": embeds, "embeds": embeds,
@ -409,7 +409,6 @@ class MochiSampler:
"sigma_schedule": sigma_schedule, "sigma_schedule": sigma_schedule,
"cfg_schedule": cfg_schedule, "cfg_schedule": cfg_schedule,
"num_inference_steps": steps, "num_inference_steps": steps,
"batch_cfg": False,
}, },
"positive_embeds": positive, "positive_embeds": positive,
"negative_embeds": negative, "negative_embeds": negative,
@ -597,10 +596,11 @@ class MochiDecodeSpatialTiling:
decoded_list[-1][:, :, -1:, :, :] = blended_frames decoded_list[-1][:, :, -1:, :, :] = blended_frames
decoded_list.append(frames) decoded_list.append(frames)
frames = torch.cat(decoded_list, dim=2)
else: else:
logging.info("Decoding without tiling...") logging.info("Decoding without tiling...")
frames = vae(samples) frames = vae(samples)
frames = torch.cat(decoded_list, dim=2)
vae.to(offload_device) vae.to(offload_device)
frames = frames.float() frames = frames.float()