cleanup, fix untiled spatial vae decode
This commit is contained in:
parent
3dce06b28b
commit
f0f939b20b
@ -234,7 +234,6 @@ class T2VSynthMochiModel:
|
||||
height = args["height"]
|
||||
width = args["width"]
|
||||
|
||||
batch_cfg = args["mochi_args"]["batch_cfg"]
|
||||
sample_steps = args["mochi_args"]["num_inference_steps"]
|
||||
cfg_schedule = args["mochi_args"].get("cfg_schedule")
|
||||
assert (
|
||||
@ -247,14 +246,6 @@ class T2VSynthMochiModel:
|
||||
), f"sigma_schedule must have length {sample_steps + 1}, got {len(sigma_schedule)}"
|
||||
assert (num_frames - 1) % 6 == 0, f"t - 1 must be divisible by 6, got {num_frames - 1}"
|
||||
|
||||
# if batch_cfg:
|
||||
# sample_batched = self.get_conditioning(
|
||||
# [prompt] + [neg_prompt], zero_last_n_prompts=B if neg_prompt == "" else 0
|
||||
# )
|
||||
# else:
|
||||
# sample = self.get_conditioning([prompt], zero_last_n_prompts=0)
|
||||
# sample_null = self.get_conditioning([neg_prompt] * B, zero_last_n_prompts=B if neg_prompt == "" else 0)
|
||||
|
||||
# create z
|
||||
spatial_downsample = 8
|
||||
temporal_downsample = 6
|
||||
@ -273,45 +264,21 @@ class T2VSynthMochiModel:
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
if batch_cfg: #WIP
|
||||
pos_embeds = args["positive_embeds"]["embeds"].to(self.device)
|
||||
neg_embeds = args["negative_embeds"]["embeds"].to(self.device)
|
||||
pos_attention_mask = args["positive_embeds"]["attention_mask"].to(self.device)
|
||||
neg_attention_mask = args["negative_embeds"]["attention_mask"].to(self.device)
|
||||
print(neg_embeds.shape)
|
||||
y_feat = torch.cat((pos_embeds, neg_embeds))
|
||||
y_mask = torch.cat((pos_attention_mask, neg_attention_mask))
|
||||
zero_last_n_prompts = B# if neg_prompt == "" else 0
|
||||
y_feat[-zero_last_n_prompts:] = 0
|
||||
y_mask[-zero_last_n_prompts:] = False
|
||||
|
||||
sample_batched = {
|
||||
"y_mask": [y_mask],
|
||||
"y_feat": [y_feat]
|
||||
}
|
||||
sample_batched["packed_indices"] = self.get_packed_indices(
|
||||
sample_batched["y_mask"], **latent_dims
|
||||
)
|
||||
z = repeat(z, "b ... -> (repeat b) ...", repeat=2)
|
||||
print("sample_batched y_mask",sample_batched["y_mask"])
|
||||
print("y_mask type",type(sample_batched["y_mask"])) #<class 'list'>"
|
||||
print("ymask 0 shape",sample_batched["y_mask"][0].shape)#torch.Size([2, 256])
|
||||
else:
|
||||
sample = {
|
||||
"y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)],
|
||||
"y_feat": [args["positive_embeds"]["embeds"].to(self.device)]
|
||||
}
|
||||
sample_null = {
|
||||
"y_mask": [args["negative_embeds"]["attention_mask"].to(self.device)],
|
||||
"y_feat": [args["negative_embeds"]["embeds"].to(self.device)]
|
||||
}
|
||||
sample = {
|
||||
"y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)],
|
||||
"y_feat": [args["positive_embeds"]["embeds"].to(self.device)]
|
||||
}
|
||||
sample_null = {
|
||||
"y_mask": [args["negative_embeds"]["attention_mask"].to(self.device)],
|
||||
"y_feat": [args["negative_embeds"]["embeds"].to(self.device)]
|
||||
}
|
||||
|
||||
sample["packed_indices"] = self.get_packed_indices(
|
||||
sample["y_mask"], **latent_dims
|
||||
)
|
||||
sample_null["packed_indices"] = self.get_packed_indices(
|
||||
sample_null["y_mask"], **latent_dims
|
||||
)
|
||||
sample["packed_indices"] = self.get_packed_indices(
|
||||
sample["y_mask"], **latent_dims
|
||||
)
|
||||
sample_null["packed_indices"] = self.get_packed_indices(
|
||||
sample_null["y_mask"], **latent_dims
|
||||
)
|
||||
|
||||
def model_fn(*, z, sigma, cfg_scale):
|
||||
self.dit.to(self.device)
|
||||
@ -319,19 +286,15 @@ class T2VSynthMochiModel:
|
||||
autocast_dtype = torch.float16
|
||||
else:
|
||||
autocast_dtype = torch.bfloat16
|
||||
if batch_cfg:
|
||||
with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype):
|
||||
out = self.dit(z, sigma, **sample_batched)
|
||||
out_cond, out_uncond = torch.chunk(out, chunks=2, dim=0)
|
||||
else:
|
||||
nonlocal sample, sample_null
|
||||
with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype):
|
||||
if cfg_scale > 1.0:
|
||||
out_cond = self.dit(z, sigma, **sample)
|
||||
out_uncond = self.dit(z, sigma, **sample_null)
|
||||
else:
|
||||
out_cond = self.dit(z, sigma, **sample)
|
||||
return out_cond
|
||||
|
||||
nonlocal sample, sample_null
|
||||
with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype):
|
||||
if cfg_scale > 1.0:
|
||||
out_cond = self.dit(z, sigma, **sample)
|
||||
out_uncond = self.dit(z, sigma, **sample_null)
|
||||
else:
|
||||
out_cond = self.dit(z, sigma, **sample)
|
||||
return out_cond
|
||||
|
||||
return out_uncond + cfg_scale * (out_cond - out_uncond)
|
||||
|
||||
@ -343,17 +306,12 @@ class T2VSynthMochiModel:
|
||||
# `pred` estimates `z_0 - eps`.
|
||||
pred = model_fn(
|
||||
z=z,
|
||||
sigma=torch.full(
|
||||
[B] if not batch_cfg else [B * 2], sigma, device=z.device
|
||||
),
|
||||
sigma=torch.full([B], sigma, device=z.device),
|
||||
cfg_scale=cfg_schedule[i],
|
||||
)
|
||||
pred = pred.to(z)
|
||||
z = z + dsigma * pred
|
||||
comfy_pbar.update(1)
|
||||
|
||||
if batch_cfg:
|
||||
z = z[:B]
|
||||
|
||||
self.dit.to(self.offload_device)
|
||||
samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)
|
||||
|
||||
8
nodes.py
8
nodes.py
@ -328,16 +328,16 @@ class MochiTextEncode:
|
||||
except:
|
||||
NotImplementedError("Failed to get attention mask from T5, is your ComfyUI up to date?")
|
||||
except:
|
||||
clip.cond_stage_model.to(load_device)
|
||||
clip.cond_stage_model.to(offload_device)
|
||||
tokens = clip.tokenizer.tokenize_with_weights(prompt, return_word_ids=True)
|
||||
embeds, _, attention_mask = clip.cond_stage_model.encode_token_weights(tokens)
|
||||
|
||||
|
||||
if embeds.shape[1] > 256:
|
||||
raise ValueError(f"Prompt is too long, max tokens supported is {max_tokens} or less, got {embeds.shape[1]}")
|
||||
embeds *= strength
|
||||
if force_offload:
|
||||
clip.cond_stage_model.to(offload_device)
|
||||
mm.soft_empty_cache()
|
||||
|
||||
t5_embeds = {
|
||||
"embeds": embeds,
|
||||
@ -409,7 +409,6 @@ class MochiSampler:
|
||||
"sigma_schedule": sigma_schedule,
|
||||
"cfg_schedule": cfg_schedule,
|
||||
"num_inference_steps": steps,
|
||||
"batch_cfg": False,
|
||||
},
|
||||
"positive_embeds": positive,
|
||||
"negative_embeds": negative,
|
||||
@ -597,10 +596,11 @@ class MochiDecodeSpatialTiling:
|
||||
decoded_list[-1][:, :, -1:, :, :] = blended_frames
|
||||
|
||||
decoded_list.append(frames)
|
||||
frames = torch.cat(decoded_list, dim=2)
|
||||
else:
|
||||
logging.info("Decoding without tiling...")
|
||||
frames = vae(samples)
|
||||
frames = torch.cat(decoded_list, dim=2)
|
||||
|
||||
vae.to(offload_device)
|
||||
|
||||
frames = frames.float()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user