partially working temporal tiling for 5B

This commit is contained in:
kijai 2024-08-27 21:47:30 +03:00
parent 9f1c2d7d80
commit a8251b3b93

View File

@ -267,7 +267,9 @@ class CogVideoXPipeline(DiffusionPipeline):
width: int,
num_frames: int,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
start_frame: int = None,
end_frame: int = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
@ -284,14 +286,20 @@ class CogVideoXPipeline(DiffusionPipeline):
use_real=True,
)
freqs_cos = freqs_cos.view(num_frames, grid_height * grid_width, -1)
freqs_sin = freqs_sin.view(num_frames, grid_height * grid_width, -1)
if start_frame is not None:
freqs_cos = freqs_cos[start_frame:end_frame]
freqs_sin = freqs_sin[start_frame:end_frame]
freqs_cos = freqs_cos.view(-1, freqs_cos.shape[-1])
freqs_sin = freqs_sin.view(-1, freqs_sin.shape[-1])
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin
@property
def guidance_scale(self):
return self._guidance_scale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@ -436,12 +444,12 @@ class CogVideoXPipeline(DiffusionPipeline):
print("latents.device", latents.device)
# 6.5. Create rotary embeds if required
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
# # 6.5. Create rotary embeds if required
# image_rotary_emb = (
# self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
# if self.transformer.config.use_rotary_positional_embeddings
# else None
# )
# 7. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
@ -477,6 +485,12 @@ class CogVideoXPipeline(DiffusionPipeline):
#latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
#latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device, input_start_t, input_end_t)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
latents_tile = latents[:, input_start_t:input_end_t,:, :, :]
latent_model_input_tile = torch.cat([latents_tile] * 2) if do_classifier_free_guidance else latents_tile
latent_model_input_tile = self.scheduler.scale_model_input(latent_model_input_tile, t)
@ -496,7 +510,7 @@ class CogVideoXPipeline(DiffusionPipeline):
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents_tile = self.scheduler.step(noise_pred, t, latents_tile, **extra_step_kwargs, return_dict=False)[0]
@ -527,6 +541,11 @@ class CogVideoXPipeline(DiffusionPipeline):
comfy_pbar.update(1)
# ==========================================
else:
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)