Update pipeline_cogvideox.py

This commit is contained in:
kijai 2024-08-06 04:33:22 +03:00
parent 58a6802ab4
commit 320af57bb3

View File

@ -166,7 +166,7 @@ class CogVideoXPipeline(DiffusionPipeline):
def __init__(
self,
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
#text_encoder: T5EncoderModel,
vae: AutoencoderKLCogVideoX,
transformer: CogVideoXTransformer3DModel,
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
@ -174,7 +174,7 @@ class CogVideoXPipeline(DiffusionPipeline):
super().__init__()
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
tokenizer=tokenizer, vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor_spatial = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
@ -197,7 +197,6 @@ class CogVideoXPipeline(DiffusionPipeline):
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
@ -220,7 +219,7 @@ class CogVideoXPipeline(DiffusionPipeline):
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
#prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method