fix onediff

This commit is contained in:
Jukka Seppänen 2024-09-03 19:26:49 +03:00
parent c8c55256b2
commit 3d536025e0

View File

@ -172,7 +172,7 @@ class CogVideoXPipeline(DiffusionPipeline):
)
noise = randn_tensor(shape, generator=generator, device=device, dtype=self.vae.dtype)
if latents is None:
latents = noise
latents = noise
else:
latents = latents.to(device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, denoise_strength, device)
@ -538,7 +538,7 @@ class CogVideoXPipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents_tile = self.scheduler.step(noise_pred, t, latents_tile.to(self.vae.dtype), **extra_step_kwargs, return_dict=False)[0]
latents_tile = self.scheduler.step(noise_pred, t, latents_tile.to(self.vae.dtype), **extra_step_kwargs, return_dict=False)[0]
latents_all_list.append(latents_tile)
# ==========================================
@ -617,6 +617,7 @@ class CogVideoXPipeline(DiffusionPipeline):
**extra_step_kwargs,
return_dict=False,
)
latents = latents.to(prompt_embeds.dtype)
# start diff diff
if i < len(timesteps) - 1 and self.original_mask is not None:
noise_timestep = timesteps[i + 1]