mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-09 04:44:22 +08:00
Update pipeline_cogvideox.py
This commit is contained in:
parent
58a6802ab4
commit
320af57bb3
@ -166,7 +166,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: T5Tokenizer,
|
||||
text_encoder: T5EncoderModel,
|
||||
#text_encoder: T5EncoderModel,
|
||||
vae: AutoencoderKLCogVideoX,
|
||||
transformer: CogVideoXTransformer3DModel,
|
||||
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
||||
@ -174,7 +174,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
||||
tokenizer=tokenizer, vae=vae, transformer=transformer, scheduler=scheduler
|
||||
)
|
||||
self.vae_scale_factor_spatial = (
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
@ -197,7 +197,6 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
@ -220,7 +219,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
||||
#prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user