From 320af57bb3e0058d44b0ba6cebd48c56c953a712 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 6 Aug 2024 04:33:22 +0300 Subject: [PATCH] Update pipeline_cogvideox.py --- pipeline_cogvideox.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index bbadc63..d9fea75 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -166,7 +166,7 @@ class CogVideoXPipeline(DiffusionPipeline): def __init__( self, tokenizer: T5Tokenizer, - text_encoder: T5EncoderModel, + #text_encoder: T5EncoderModel, vae: AutoencoderKLCogVideoX, transformer: CogVideoXTransformer3DModel, scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], @@ -174,7 +174,7 @@ class CogVideoXPipeline(DiffusionPipeline): super().__init__() self.register_modules( - tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + tokenizer=tokenizer, vae=vae, transformer=transformer, scheduler=scheduler ) self.vae_scale_factor_spatial = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 @@ -197,7 +197,6 @@ class CogVideoXPipeline(DiffusionPipeline): dtype: Optional[torch.dtype] = None, ): device = device or self._execution_device - dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) @@ -220,7 +219,7 @@ class CogVideoXPipeline(DiffusionPipeline): f" {max_sequence_length} tokens: {removed_text}" ) - prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + #prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) # duplicate text embeddings for each generation per prompt, using mps friendly method