diff --git a/mochi_preview/dit/joint_model/layers.py b/mochi_preview/dit/joint_model/layers.py index aa40a67..b41d66c 100644 --- a/mochi_preview/dit/joint_model/layers.py +++ b/mochi_preview/dit/joint_model/layers.py @@ -62,28 +62,6 @@ class TimestepEmbedder(nn.Module): return t_emb -class PooledCaptionEmbedder(nn.Module): - def __init__( - self, - caption_feature_dim: int, - hidden_size: int, - *, - bias: bool = True, - device: Optional[torch.device] = None, - ): - super().__init__() - self.caption_feature_dim = caption_feature_dim - self.hidden_size = hidden_size - self.mlp = nn.Sequential( - nn.Linear(caption_feature_dim, hidden_size, bias=bias, device=device), - nn.SiLU(), - nn.Linear(hidden_size, hidden_size, bias=bias, device=device), - ) - - def forward(self, x): - return self.mlp(x) - - class FeedForward(nn.Module): def __init__( self, diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index a1c6fb7..c5c99cd 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -286,41 +286,40 @@ class T2VSynthMochiModel: sample_null["y_mask"], **latent_dims ) - def model_fn(*, z, sigma, cfg_scale): - self.dit.to(self.device) - if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul: - autocast_dtype = torch.float16 - else: - autocast_dtype = torch.bfloat16 - - nonlocal sample, sample_null - with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype): - if cfg_scale > 1.0: - out_cond = self.dit(z, sigma, **sample) - out_uncond = self.dit(z, sigma, **sample_null) - else: - out_cond = self.dit(z, sigma, **sample) - return out_cond + self.dit.to(self.device) + if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul: + autocast_dtype = torch.float16 + else: + autocast_dtype = torch.bfloat16 + def model_fn(*, z, sigma, cfg_scale): + nonlocal sample, sample_null + if cfg_scale > 1.0: + out_cond = self.dit(z, sigma, **sample) + out_uncond = self.dit(z, sigma, **sample_null) + else: + out_cond = self.dit(z, sigma, **sample) + return out_cond return out_uncond + cfg_scale * (out_cond - out_uncond) comfy_pbar = ProgressBar(sample_steps) - for i in tqdm(range(0, sample_steps), desc="Processing Samples", total=sample_steps): - sigma = sigma_schedule[i] - dsigma = sigma - sigma_schedule[i + 1] + with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype): + for i in tqdm(range(0, sample_steps), desc="Processing Samples", total=sample_steps): + sigma = sigma_schedule[i] + dsigma = sigma - sigma_schedule[i + 1] - # `pred` estimates `z_0 - eps`. - pred = model_fn( - z=z, - sigma=torch.full([B], sigma, device=z.device), - cfg_scale=cfg_schedule[i], - ) - pred = pred.to(z) - z = z + dsigma * pred - if callback is not None: - callback(i, z.detach()[0].permute(1,0,2,3), None, sample_steps) - else: - comfy_pbar.update(1) + # `pred` estimates `z_0 - eps`. + pred = model_fn( + z=z, + sigma=torch.full([B], sigma, device=z.device), + cfg_scale=cfg_schedule[i], + ) + pred = pred.to(z) + z = z + dsigma * pred + if callback is not None: + callback(i, z.detach()[0].permute(1,0,2,3), None, sample_steps) + else: + comfy_pbar.update(1) self.dit.to(self.offload_device) logging.info(f"samples shape: {z.shape}")