diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index 18d67e6..476ab12 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -234,7 +234,6 @@ class T2VSynthMochiModel: height = args["height"] width = args["width"] - batch_cfg = args["mochi_args"]["batch_cfg"] sample_steps = args["mochi_args"]["num_inference_steps"] cfg_schedule = args["mochi_args"].get("cfg_schedule") assert ( @@ -247,14 +246,6 @@ class T2VSynthMochiModel: ), f"sigma_schedule must have length {sample_steps + 1}, got {len(sigma_schedule)}" assert (num_frames - 1) % 6 == 0, f"t - 1 must be divisible by 6, got {num_frames - 1}" - # if batch_cfg: - # sample_batched = self.get_conditioning( - # [prompt] + [neg_prompt], zero_last_n_prompts=B if neg_prompt == "" else 0 - # ) - # else: - # sample = self.get_conditioning([prompt], zero_last_n_prompts=0) - # sample_null = self.get_conditioning([neg_prompt] * B, zero_last_n_prompts=B if neg_prompt == "" else 0) - # create z spatial_downsample = 8 temporal_downsample = 6 @@ -273,45 +264,21 @@ class T2VSynthMochiModel: dtype=torch.float32, ) - if batch_cfg: #WIP - pos_embeds = args["positive_embeds"]["embeds"].to(self.device) - neg_embeds = args["negative_embeds"]["embeds"].to(self.device) - pos_attention_mask = args["positive_embeds"]["attention_mask"].to(self.device) - neg_attention_mask = args["negative_embeds"]["attention_mask"].to(self.device) - print(neg_embeds.shape) - y_feat = torch.cat((pos_embeds, neg_embeds)) - y_mask = torch.cat((pos_attention_mask, neg_attention_mask)) - zero_last_n_prompts = B# if neg_prompt == "" else 0 - y_feat[-zero_last_n_prompts:] = 0 - y_mask[-zero_last_n_prompts:] = False - - sample_batched = { - "y_mask": [y_mask], - "y_feat": [y_feat] - } - sample_batched["packed_indices"] = self.get_packed_indices( - sample_batched["y_mask"], **latent_dims - ) - z = repeat(z, "b ... -> (repeat b) ...", repeat=2) - print("sample_batched y_mask",sample_batched["y_mask"]) - print("y_mask type",type(sample_batched["y_mask"])) #" - print("ymask 0 shape",sample_batched["y_mask"][0].shape)#torch.Size([2, 256]) - else: - sample = { - "y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)], - "y_feat": [args["positive_embeds"]["embeds"].to(self.device)] - } - sample_null = { - "y_mask": [args["negative_embeds"]["attention_mask"].to(self.device)], - "y_feat": [args["negative_embeds"]["embeds"].to(self.device)] - } + sample = { + "y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)], + "y_feat": [args["positive_embeds"]["embeds"].to(self.device)] + } + sample_null = { + "y_mask": [args["negative_embeds"]["attention_mask"].to(self.device)], + "y_feat": [args["negative_embeds"]["embeds"].to(self.device)] + } - sample["packed_indices"] = self.get_packed_indices( - sample["y_mask"], **latent_dims - ) - sample_null["packed_indices"] = self.get_packed_indices( - sample_null["y_mask"], **latent_dims - ) + sample["packed_indices"] = self.get_packed_indices( + sample["y_mask"], **latent_dims + ) + sample_null["packed_indices"] = self.get_packed_indices( + sample_null["y_mask"], **latent_dims + ) def model_fn(*, z, sigma, cfg_scale): self.dit.to(self.device) @@ -319,19 +286,15 @@ class T2VSynthMochiModel: autocast_dtype = torch.float16 else: autocast_dtype = torch.bfloat16 - if batch_cfg: - with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype): - out = self.dit(z, sigma, **sample_batched) - out_cond, out_uncond = torch.chunk(out, chunks=2, dim=0) - else: - nonlocal sample, sample_null - with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype): - if cfg_scale > 1.0: - out_cond = self.dit(z, sigma, **sample) - out_uncond = self.dit(z, sigma, **sample_null) - else: - out_cond = self.dit(z, sigma, **sample) - return out_cond + + nonlocal sample, sample_null + with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype): + if cfg_scale > 1.0: + out_cond = self.dit(z, sigma, **sample) + out_uncond = self.dit(z, sigma, **sample_null) + else: + out_cond = self.dit(z, sigma, **sample) + return out_cond return out_uncond + cfg_scale * (out_cond - out_uncond) @@ -343,17 +306,12 @@ class T2VSynthMochiModel: # `pred` estimates `z_0 - eps`. pred = model_fn( z=z, - sigma=torch.full( - [B] if not batch_cfg else [B * 2], sigma, device=z.device - ), + sigma=torch.full([B], sigma, device=z.device), cfg_scale=cfg_schedule[i], ) pred = pred.to(z) z = z + dsigma * pred comfy_pbar.update(1) - - if batch_cfg: - z = z[:B] self.dit.to(self.offload_device) samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std) diff --git a/nodes.py b/nodes.py index fa567dc..b58fbab 100644 --- a/nodes.py +++ b/nodes.py @@ -328,16 +328,16 @@ class MochiTextEncode: except: NotImplementedError("Failed to get attention mask from T5, is your ComfyUI up to date?") except: - clip.cond_stage_model.to(load_device) + clip.cond_stage_model.to(offload_device) tokens = clip.tokenizer.tokenize_with_weights(prompt, return_word_ids=True) embeds, _, attention_mask = clip.cond_stage_model.encode_token_weights(tokens) - if embeds.shape[1] > 256: raise ValueError(f"Prompt is too long, max tokens supported is {max_tokens} or less, got {embeds.shape[1]}") embeds *= strength if force_offload: clip.cond_stage_model.to(offload_device) + mm.soft_empty_cache() t5_embeds = { "embeds": embeds, @@ -409,7 +409,6 @@ class MochiSampler: "sigma_schedule": sigma_schedule, "cfg_schedule": cfg_schedule, "num_inference_steps": steps, - "batch_cfg": False, }, "positive_embeds": positive, "negative_embeds": negative, @@ -597,10 +596,11 @@ class MochiDecodeSpatialTiling: decoded_list[-1][:, :, -1:, :, :] = blended_frames decoded_list.append(frames) + frames = torch.cat(decoded_list, dim=2) else: logging.info("Decoding without tiling...") frames = vae(samples) - frames = torch.cat(decoded_list, dim=2) + vae.to(offload_device) frames = frames.float()