use proper length RoPE embeds with context windowing

This commit is contained in:
kijai 2024-10-05 22:04:51 +03:00
parent 1801c65e97
commit 2be8f694b0
2 changed files with 29 additions and 23 deletions

View File

@ -742,6 +742,12 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
latents_all_list = []
# =====================================================
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, context_frames, device)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
for t_i in range(grid_ts):
if t_i < grid_ts - 1:
ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
@ -751,12 +757,6 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
input_start_t = ofs_t
input_end_t = ofs_t + t_tile_length
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device, input_start_t, input_end_t)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
latents_tile = latents[:, input_start_t:input_end_t,:, :, :]
control_latents_tile = control_latents[:, input_start_t:input_end_t, :, :, :]
@ -832,15 +832,21 @@ class CogVideoX_Fun_Pipeline_Control(VideoSysPipeline):
if do_classifier_free_guidance:
noise_uncond = torch.zeros_like(latent_model_input)
for c in context_queue:
partial_latent_model_input = latent_model_input[:, c, :, :, :]
partial_control_latents = current_control_latents[:, c, :, :, :]
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device, context_frames=c)
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, context_frames, device)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
for c in context_queue:
partial_latent_model_input = latent_model_input[:, c, :, :, :]
partial_control_latents = current_control_latents[:, c, :, :, :]
# image_rotary_emb = (
# self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device, context_frames=c)
# if self.transformer.config.use_rotary_positional_embeddings
# else None
# )
# predict noise model_output
noise_pred[:, c, :, :, :] += self.transformer(
hidden_states=partial_latent_model_input,

View File

@ -903,6 +903,12 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
latents_all_list = []
# =====================================================
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, t_tile_length, device)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
for t_i in range(grid_ts):
if t_i < grid_ts - 1:
ofs_t = max(t_i * t_tile_length - t_tile_overlap * t_i, 0)
@ -912,12 +918,6 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
input_start_t = ofs_t
input_end_t = ofs_t + t_tile_length
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device, input_start_t, input_end_t)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
latents_tile = latents[:, input_start_t:input_end_t,:, :, :]
inpaint_latents_tile = inpaint_latents[:, input_start_t:input_end_t, :, :, :]
@ -989,17 +989,17 @@ class CogVideoX_Fun_Pipeline_Inpaint(VideoSysPipeline):
if do_classifier_free_guidance:
noise_uncond = torch.zeros_like(latent_model_input)
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, context_frames, device)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
for c in context_queue:
partial_latent_model_input = latent_model_input[:, c, :, :, :]
partial_inpaint_latents = inpaint_latents[:, c, :, :, :]
partial_inpaint_latents[:, 0, :, :, :] = inpaint_latents[:, 0, :, :, :]
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device, context_frames=c)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
# predict noise model_output
noise_pred[:, c, :, :, :] += self.transformer(
hidden_states=partial_latent_model_input,