Update pipeline_cogvideox.py

This commit is contained in:
kijai 2024-08-06 04:33:22 +03:00
parent 58a6802ab4
commit 320af57bb3

View File

@ -166,7 +166,7 @@ class CogVideoXPipeline(DiffusionPipeline):
def __init__( def __init__(
self, self,
tokenizer: T5Tokenizer, tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel, #text_encoder: T5EncoderModel,
vae: AutoencoderKLCogVideoX, vae: AutoencoderKLCogVideoX,
transformer: CogVideoXTransformer3DModel, transformer: CogVideoXTransformer3DModel,
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
@ -174,7 +174,7 @@ class CogVideoXPipeline(DiffusionPipeline):
super().__init__() super().__init__()
self.register_modules( self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler tokenizer=tokenizer, vae=vae, transformer=transformer, scheduler=scheduler
) )
self.vae_scale_factor_spatial = ( self.vae_scale_factor_spatial = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
@ -197,7 +197,6 @@ class CogVideoXPipeline(DiffusionPipeline):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
): ):
device = device or self._execution_device device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) batch_size = len(prompt)
@ -220,7 +219,7 @@ class CogVideoXPipeline(DiffusionPipeline):
f" {max_sequence_length} tokens: {removed_text}" f" {max_sequence_length} tokens: {removed_text}"
) )
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] #prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method