cleanup, fix untiled spatial vae decode

This commit is contained in:
kijai 2024-10-30 21:12:34 +02:00
parent 3dce06b28b
commit f0f939b20b
2 changed files with 28 additions and 70 deletions

View File

@ -234,7 +234,6 @@ class T2VSynthMochiModel:
height = args["height"]
width = args["width"]
batch_cfg = args["mochi_args"]["batch_cfg"]
sample_steps = args["mochi_args"]["num_inference_steps"]
cfg_schedule = args["mochi_args"].get("cfg_schedule")
assert (
@ -247,14 +246,6 @@ class T2VSynthMochiModel:
), f"sigma_schedule must have length {sample_steps + 1}, got {len(sigma_schedule)}"
assert (num_frames - 1) % 6 == 0, f"t - 1 must be divisible by 6, got {num_frames - 1}"
# if batch_cfg:
# sample_batched = self.get_conditioning(
# [prompt] + [neg_prompt], zero_last_n_prompts=B if neg_prompt == "" else 0
# )
# else:
# sample = self.get_conditioning([prompt], zero_last_n_prompts=0)
# sample_null = self.get_conditioning([neg_prompt] * B, zero_last_n_prompts=B if neg_prompt == "" else 0)
# create z
spatial_downsample = 8
temporal_downsample = 6
@ -273,45 +264,21 @@ class T2VSynthMochiModel:
dtype=torch.float32,
)
if batch_cfg: #WIP
pos_embeds = args["positive_embeds"]["embeds"].to(self.device)
neg_embeds = args["negative_embeds"]["embeds"].to(self.device)
pos_attention_mask = args["positive_embeds"]["attention_mask"].to(self.device)
neg_attention_mask = args["negative_embeds"]["attention_mask"].to(self.device)
print(neg_embeds.shape)
y_feat = torch.cat((pos_embeds, neg_embeds))
y_mask = torch.cat((pos_attention_mask, neg_attention_mask))
zero_last_n_prompts = B# if neg_prompt == "" else 0
y_feat[-zero_last_n_prompts:] = 0
y_mask[-zero_last_n_prompts:] = False
sample_batched = {
"y_mask": [y_mask],
"y_feat": [y_feat]
}
sample_batched["packed_indices"] = self.get_packed_indices(
sample_batched["y_mask"], **latent_dims
)
z = repeat(z, "b ... -> (repeat b) ...", repeat=2)
print("sample_batched y_mask",sample_batched["y_mask"])
print("y_mask type",type(sample_batched["y_mask"])) #<class 'list'>"
print("ymask 0 shape",sample_batched["y_mask"][0].shape)#torch.Size([2, 256])
else:
sample = {
"y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)],
"y_feat": [args["positive_embeds"]["embeds"].to(self.device)]
}
sample_null = {
"y_mask": [args["negative_embeds"]["attention_mask"].to(self.device)],
"y_feat": [args["negative_embeds"]["embeds"].to(self.device)]
}
sample = {
"y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)],
"y_feat": [args["positive_embeds"]["embeds"].to(self.device)]
}
sample_null = {
"y_mask": [args["negative_embeds"]["attention_mask"].to(self.device)],
"y_feat": [args["negative_embeds"]["embeds"].to(self.device)]
}
sample["packed_indices"] = self.get_packed_indices(
sample["y_mask"], **latent_dims
)
sample_null["packed_indices"] = self.get_packed_indices(
sample_null["y_mask"], **latent_dims
)
sample["packed_indices"] = self.get_packed_indices(
sample["y_mask"], **latent_dims
)
sample_null["packed_indices"] = self.get_packed_indices(
sample_null["y_mask"], **latent_dims
)
def model_fn(*, z, sigma, cfg_scale):
self.dit.to(self.device)
@ -319,19 +286,15 @@ class T2VSynthMochiModel:
autocast_dtype = torch.float16
else:
autocast_dtype = torch.bfloat16
if batch_cfg:
with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype):
out = self.dit(z, sigma, **sample_batched)
out_cond, out_uncond = torch.chunk(out, chunks=2, dim=0)
else:
nonlocal sample, sample_null
with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype):
if cfg_scale > 1.0:
out_cond = self.dit(z, sigma, **sample)
out_uncond = self.dit(z, sigma, **sample_null)
else:
out_cond = self.dit(z, sigma, **sample)
return out_cond
nonlocal sample, sample_null
with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype):
if cfg_scale > 1.0:
out_cond = self.dit(z, sigma, **sample)
out_uncond = self.dit(z, sigma, **sample_null)
else:
out_cond = self.dit(z, sigma, **sample)
return out_cond
return out_uncond + cfg_scale * (out_cond - out_uncond)
@ -343,17 +306,12 @@ class T2VSynthMochiModel:
# `pred` estimates `z_0 - eps`.
pred = model_fn(
z=z,
sigma=torch.full(
[B] if not batch_cfg else [B * 2], sigma, device=z.device
),
sigma=torch.full([B], sigma, device=z.device),
cfg_scale=cfg_schedule[i],
)
pred = pred.to(z)
z = z + dsigma * pred
comfy_pbar.update(1)
if batch_cfg:
z = z[:B]
self.dit.to(self.offload_device)
samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)

View File

@ -328,16 +328,16 @@ class MochiTextEncode:
except:
NotImplementedError("Failed to get attention mask from T5, is your ComfyUI up to date?")
except:
clip.cond_stage_model.to(load_device)
clip.cond_stage_model.to(offload_device)
tokens = clip.tokenizer.tokenize_with_weights(prompt, return_word_ids=True)
embeds, _, attention_mask = clip.cond_stage_model.encode_token_weights(tokens)
if embeds.shape[1] > 256:
raise ValueError(f"Prompt is too long, max tokens supported is {max_tokens} or less, got {embeds.shape[1]}")
embeds *= strength
if force_offload:
clip.cond_stage_model.to(offload_device)
mm.soft_empty_cache()
t5_embeds = {
"embeds": embeds,
@ -409,7 +409,6 @@ class MochiSampler:
"sigma_schedule": sigma_schedule,
"cfg_schedule": cfg_schedule,
"num_inference_steps": steps,
"batch_cfg": False,
},
"positive_embeds": positive,
"negative_embeds": negative,
@ -597,10 +596,11 @@ class MochiDecodeSpatialTiling:
decoded_list[-1][:, :, -1:, :, :] = blended_frames
decoded_list.append(frames)
frames = torch.cat(decoded_list, dim=2)
else:
logging.info("Decoding without tiling...")
frames = vae(samples)
frames = torch.cat(decoded_list, dim=2)
vae.to(offload_device)
frames = frames.float()