mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-14 23:44:26 +08:00
partially working temporal tiling for 5B
This commit is contained in:
parent
9f1c2d7d80
commit
a8251b3b93
@ -267,7 +267,9 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
width: int,
|
width: int,
|
||||||
num_frames: int,
|
num_frames: int,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
start_frame: int = None,
|
||||||
|
end_frame: int = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||||
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||||
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||||
@ -284,14 +286,20 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
use_real=True,
|
use_real=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
freqs_cos = freqs_cos.view(num_frames, grid_height * grid_width, -1)
|
||||||
|
freqs_sin = freqs_sin.view(num_frames, grid_height * grid_width, -1)
|
||||||
|
|
||||||
|
if start_frame is not None:
|
||||||
|
freqs_cos = freqs_cos[start_frame:end_frame]
|
||||||
|
freqs_sin = freqs_sin[start_frame:end_frame]
|
||||||
|
|
||||||
|
freqs_cos = freqs_cos.view(-1, freqs_cos.shape[-1])
|
||||||
|
freqs_sin = freqs_sin.view(-1, freqs_sin.shape[-1])
|
||||||
|
|
||||||
freqs_cos = freqs_cos.to(device=device)
|
freqs_cos = freqs_cos.to(device=device)
|
||||||
freqs_sin = freqs_sin.to(device=device)
|
freqs_sin = freqs_sin.to(device=device)
|
||||||
return freqs_cos, freqs_sin
|
return freqs_cos, freqs_sin
|
||||||
|
|
||||||
@property
|
|
||||||
def guidance_scale(self):
|
|
||||||
return self._guidance_scale
|
|
||||||
|
|
||||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||||
# corresponds to doing no classifier free guidance.
|
# corresponds to doing no classifier free guidance.
|
||||||
@ -436,12 +444,12 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
print("latents.device", latents.device)
|
print("latents.device", latents.device)
|
||||||
|
|
||||||
|
|
||||||
# 6.5. Create rotary embeds if required
|
# # 6.5. Create rotary embeds if required
|
||||||
image_rotary_emb = (
|
# image_rotary_emb = (
|
||||||
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
# self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
||||||
if self.transformer.config.use_rotary_positional_embeddings
|
# if self.transformer.config.use_rotary_positional_embeddings
|
||||||
else None
|
# else None
|
||||||
)
|
# )
|
||||||
|
|
||||||
# 7. Denoising loop
|
# 7. Denoising loop
|
||||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||||
@ -477,6 +485,12 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
#latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
#latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||||
#latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
#latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||||
|
|
||||||
|
image_rotary_emb = (
|
||||||
|
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device, input_start_t, input_end_t)
|
||||||
|
if self.transformer.config.use_rotary_positional_embeddings
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
latents_tile = latents[:, input_start_t:input_end_t,:, :, :]
|
latents_tile = latents[:, input_start_t:input_end_t,:, :, :]
|
||||||
latent_model_input_tile = torch.cat([latents_tile] * 2) if do_classifier_free_guidance else latents_tile
|
latent_model_input_tile = torch.cat([latents_tile] * 2) if do_classifier_free_guidance else latents_tile
|
||||||
latent_model_input_tile = self.scheduler.scale_model_input(latent_model_input_tile, t)
|
latent_model_input_tile = self.scheduler.scale_model_input(latent_model_input_tile, t)
|
||||||
@ -496,7 +510,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
|
|
||||||
if self.do_classifier_free_guidance:
|
if self.do_classifier_free_guidance:
|
||||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
latents_tile = self.scheduler.step(noise_pred, t, latents_tile, **extra_step_kwargs, return_dict=False)[0]
|
latents_tile = self.scheduler.step(noise_pred, t, latents_tile, **extra_step_kwargs, return_dict=False)[0]
|
||||||
@ -527,6 +541,11 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
comfy_pbar.update(1)
|
comfy_pbar.update(1)
|
||||||
# ==========================================
|
# ==========================================
|
||||||
else:
|
else:
|
||||||
|
image_rotary_emb = (
|
||||||
|
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
||||||
|
if self.transformer.config.use_rotary_positional_embeddings
|
||||||
|
else None
|
||||||
|
)
|
||||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user