mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-09 04:44:22 +08:00
partially working temporal tiling for 5B
This commit is contained in:
parent
9f1c2d7d80
commit
a8251b3b93
@ -267,7 +267,9 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
width: int,
|
||||
num_frames: int,
|
||||
device: torch.device,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
start_frame: int = None,
|
||||
end_frame: int = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
@ -284,14 +286,20 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
use_real=True,
|
||||
)
|
||||
|
||||
freqs_cos = freqs_cos.view(num_frames, grid_height * grid_width, -1)
|
||||
freqs_sin = freqs_sin.view(num_frames, grid_height * grid_width, -1)
|
||||
|
||||
if start_frame is not None:
|
||||
freqs_cos = freqs_cos[start_frame:end_frame]
|
||||
freqs_sin = freqs_sin[start_frame:end_frame]
|
||||
|
||||
freqs_cos = freqs_cos.view(-1, freqs_cos.shape[-1])
|
||||
freqs_sin = freqs_sin.view(-1, freqs_sin.shape[-1])
|
||||
|
||||
freqs_cos = freqs_cos.to(device=device)
|
||||
freqs_sin = freqs_sin.to(device=device)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
@ -436,12 +444,12 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
print("latents.device", latents.device)
|
||||
|
||||
|
||||
# 6.5. Create rotary embeds if required
|
||||
image_rotary_emb = (
|
||||
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
||||
if self.transformer.config.use_rotary_positional_embeddings
|
||||
else None
|
||||
)
|
||||
# # 6.5. Create rotary embeds if required
|
||||
# image_rotary_emb = (
|
||||
# self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
||||
# if self.transformer.config.use_rotary_positional_embeddings
|
||||
# else None
|
||||
# )
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
@ -477,6 +485,12 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
#latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
#latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
image_rotary_emb = (
|
||||
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device, input_start_t, input_end_t)
|
||||
if self.transformer.config.use_rotary_positional_embeddings
|
||||
else None
|
||||
)
|
||||
|
||||
latents_tile = latents[:, input_start_t:input_end_t,:, :, :]
|
||||
latent_model_input_tile = torch.cat([latents_tile] * 2) if do_classifier_free_guidance else latents_tile
|
||||
latent_model_input_tile = self.scheduler.scale_model_input(latent_model_input_tile, t)
|
||||
@ -496,7 +510,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_tile = self.scheduler.step(noise_pred, t, latents_tile, **extra_step_kwargs, return_dict=False)[0]
|
||||
@ -527,6 +541,11 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
comfy_pbar.update(1)
|
||||
# ==========================================
|
||||
else:
|
||||
image_rotary_emb = (
|
||||
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
||||
if self.transformer.config.use_rotary_positional_embeddings
|
||||
else None
|
||||
)
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user