mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-09 21:04:23 +08:00
Update pipeline_cogvideox.py
This commit is contained in:
parent
58a6802ab4
commit
320af57bb3
@ -166,7 +166,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tokenizer: T5Tokenizer,
|
tokenizer: T5Tokenizer,
|
||||||
text_encoder: T5EncoderModel,
|
#text_encoder: T5EncoderModel,
|
||||||
vae: AutoencoderKLCogVideoX,
|
vae: AutoencoderKLCogVideoX,
|
||||||
transformer: CogVideoXTransformer3DModel,
|
transformer: CogVideoXTransformer3DModel,
|
||||||
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
||||||
@ -174,7 +174,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.register_modules(
|
self.register_modules(
|
||||||
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
tokenizer=tokenizer, vae=vae, transformer=transformer, scheduler=scheduler
|
||||||
)
|
)
|
||||||
self.vae_scale_factor_spatial = (
|
self.vae_scale_factor_spatial = (
|
||||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||||
@ -197,7 +197,6 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
):
|
):
|
||||||
device = device or self._execution_device
|
device = device or self._execution_device
|
||||||
dtype = dtype or self.text_encoder.dtype
|
|
||||||
|
|
||||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||||
batch_size = len(prompt)
|
batch_size = len(prompt)
|
||||||
@ -220,7 +219,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
f" {max_sequence_length} tokens: {removed_text}"
|
f" {max_sequence_length} tokens: {removed_text}"
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
#prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
||||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||||
|
|
||||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user