Compare commits

..

1 Commits

2 changed files with 2 additions and 5 deletions

View File

@ -67,9 +67,8 @@ class CogVideoXPatchEmbed(nn.Module):
post_time_compression_frames,
self.spatial_interpolation_scale,
self.temporal_interpolation_scale,
output_type="pt",
)
pos_embedding = pos_embedding.flatten(0, 1)
pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
joint_pos_embedding = torch.zeros(
1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
)
@ -174,8 +173,6 @@ def get_3d_rotary_pos_embed(
grid_t = np.arange(temporal_size, dtype=np.float32)
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
elif grid_type == "slice":
if max_size is None:
raise ValueError("`max_size` must be provided when `grid_type` is 'slice'")
max_h, max_w = max_size
grid_size_h, grid_size_w = grid_size
grid_h = np.arange(max_h, dtype=np.float32)

View File

@ -1,5 +1,5 @@
huggingface_hub
diffusers>=0.33.1
diffusers>=0.31.0
accelerate>=0.33.0
einops
peft