mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-15 07:54:24 +08:00
support input latents for vid2vid
This commit is contained in:
parent
5755a8c7f3
commit
0d69d73c60
104
nodes.py
104
nodes.py
@ -122,21 +122,82 @@ class CogVideoTextEncode:
|
|||||||
embeds = clip.encode_from_tokens(tokens, return_pooled=False, return_dict=False)
|
embeds = clip.encode_from_tokens(tokens, return_pooled=False, return_dict=False)
|
||||||
|
|
||||||
return (embeds, )
|
return (embeds, )
|
||||||
|
|
||||||
class CogVideoSampler:
|
class CogVideoImageEncode:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {
|
return {"required": {
|
||||||
"pipeline": ("COGVIDEOPIPE",),
|
"pipeline": ("COGVIDEOPIPE",),
|
||||||
"positive": ("CONDITIONING", ),
|
"image": ("IMAGE", ),
|
||||||
"negative": ("CONDITIONING", ),
|
},
|
||||||
"height": ("INT", {"default": 480, "min": 128, "max": 2048, "step": 8}),
|
}
|
||||||
"width": ("INT", {"default": 720, "min": 128, "max": 2048, "step": 8}),
|
|
||||||
"num_frames": ("INT", {"default": 48, "min": 8, "max": 100, "step": 8}),
|
RETURN_TYPES = ("LATENT",)
|
||||||
"fps": ("INT", {"default": 8, "min": 1, "max": 100, "step": 1}),
|
RETURN_NAMES = ("samples",)
|
||||||
"steps": ("INT", {"default": 25, "min": 1}),
|
FUNCTION = "encode"
|
||||||
"cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}),
|
CATEGORY = "CogVideoWrapper"
|
||||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
|
||||||
|
def encode(self, pipeline, image):
|
||||||
|
device = mm.get_torch_device()
|
||||||
|
offload_device = mm.unet_offload_device()
|
||||||
|
generator = torch.Generator(device=device).manual_seed(0)
|
||||||
|
vae = pipeline["pipe"].vae
|
||||||
|
vae.to(device)
|
||||||
|
|
||||||
|
image = image * 2.0 - 1.0
|
||||||
|
image = image.to(vae.dtype).to(device)
|
||||||
|
image = image.unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W
|
||||||
|
B, C, T, H, W = image.shape
|
||||||
|
chunk_size = 16
|
||||||
|
latents_list = []
|
||||||
|
# Loop through the temporal dimension in chunks of 16
|
||||||
|
for i in range(0, T, chunk_size):
|
||||||
|
# Get the chunk of 16 frames (or remaining frames if less than 16 are left)
|
||||||
|
end_index = min(i + chunk_size, T)
|
||||||
|
image_chunk = image[:, :, i:end_index, :, :] # Shape: [B, C, chunk_size, H, W]
|
||||||
|
|
||||||
|
# Encode the chunk of images
|
||||||
|
latents = vae.encode(image_chunk)
|
||||||
|
|
||||||
|
sample_mode = "sample"
|
||||||
|
if hasattr(latents, "latent_dist") and sample_mode == "sample":
|
||||||
|
latents = latents.latent_dist.sample(generator)
|
||||||
|
elif hasattr(latents, "latent_dist") and sample_mode == "argmax":
|
||||||
|
latents = latents.latent_dist.mode()
|
||||||
|
elif hasattr(latents, "latents"):
|
||||||
|
latents = latents.latents
|
||||||
|
|
||||||
|
latents = vae.config.scaling_factor * latents
|
||||||
|
latents = latents.permute(0, 2, 1, 3, 4) # B, T_chunk, C, H, W
|
||||||
|
latents_list.append(latents)
|
||||||
|
|
||||||
|
# Concatenate all the chunks along the temporal dimension
|
||||||
|
final_latents = torch.cat(latents_list, dim=1)
|
||||||
|
print("final latents: ", final_latents.shape)
|
||||||
|
|
||||||
|
vae.to(offload_device)
|
||||||
|
|
||||||
|
return ({"samples": final_latents}, )
|
||||||
|
|
||||||
|
class CogVideoSampler:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"pipeline": ("COGVIDEOPIPE",),
|
||||||
|
"positive": ("CONDITIONING", ),
|
||||||
|
"negative": ("CONDITIONING", ),
|
||||||
|
"height": ("INT", {"default": 480, "min": 128, "max": 2048, "step": 8}),
|
||||||
|
"width": ("INT", {"default": 720, "min": 128, "max": 2048, "step": 8}),
|
||||||
|
"num_frames": ("INT", {"default": 48, "min": 8, "max": 100, "step": 8}),
|
||||||
|
"fps": ("INT", {"default": 8, "min": 1, "max": 100, "step": 1}),
|
||||||
|
"steps": ("INT", {"default": 25, "min": 1}),
|
||||||
|
"cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}),
|
||||||
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"samples": ("LATENT", ),
|
||||||
|
"denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -145,7 +206,7 @@ class CogVideoSampler:
|
|||||||
FUNCTION = "process"
|
FUNCTION = "process"
|
||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
def process(self, pipeline, positive, negative, fps, steps, cfg, seed, height, width, num_frames):
|
def process(self, pipeline, positive, negative, fps, steps, cfg, seed, height, width, num_frames, samples=None, denoise_strength=1.0):
|
||||||
mm.soft_empty_cache()
|
mm.soft_empty_cache()
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
@ -162,6 +223,8 @@ class CogVideoSampler:
|
|||||||
num_frames = num_frames,
|
num_frames = num_frames,
|
||||||
fps = fps,
|
fps = fps,
|
||||||
guidance_scale=cfg,
|
guidance_scale=cfg,
|
||||||
|
latents=samples["samples"] if samples is not None else None,
|
||||||
|
denoise_strength=denoise_strength,
|
||||||
prompt_embeds=positive.to(dtype).to(device),
|
prompt_embeds=positive.to(dtype).to(device),
|
||||||
negative_prompt_embeds=negative.to(dtype).to(device),
|
negative_prompt_embeds=negative.to(dtype).to(device),
|
||||||
#negative_prompt_embeds=torch.zeros_like(embeds),
|
#negative_prompt_embeds=torch.zeros_like(embeds),
|
||||||
@ -198,11 +261,15 @@ class CogVideoDecode:
|
|||||||
vae = pipeline["pipe"].vae
|
vae = pipeline["pipe"].vae
|
||||||
vae.to(device)
|
vae.to(device)
|
||||||
|
|
||||||
num_frames = pipeline["num_frames"]
|
if "num_frames" in pipeline:
|
||||||
fps = pipeline["fps"]
|
num_frames = pipeline["num_frames"]
|
||||||
|
fps = pipeline["fps"]
|
||||||
|
|
||||||
|
|
||||||
|
else:
|
||||||
|
num_frames = latents.shape[2]
|
||||||
|
fps = 8
|
||||||
num_seconds = num_frames // fps
|
num_seconds = num_frames // fps
|
||||||
|
|
||||||
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
||||||
latents = 1 / vae.config.scaling_factor * latents
|
latents = 1 / vae.config.scaling_factor * latents
|
||||||
|
|
||||||
@ -217,6 +284,7 @@ class CogVideoDecode:
|
|||||||
vae.to(offload_device)
|
vae.to(offload_device)
|
||||||
|
|
||||||
frames = torch.cat(frames, dim=2)
|
frames = torch.cat(frames, dim=2)
|
||||||
|
print(frames.min(), frames.max())
|
||||||
video = pipeline["pipe"].video_processor.postprocess_video(video=frames, output_type="pt")
|
video = pipeline["pipe"].video_processor.postprocess_video(video=frames, output_type="pt")
|
||||||
print(video.shape)
|
print(video.shape)
|
||||||
video = video[0].permute(0, 2, 3, 1).cpu().float()
|
video = video[0].permute(0, 2, 3, 1).cpu().float()
|
||||||
@ -229,11 +297,13 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"DownloadAndLoadCogVideoModel": DownloadAndLoadCogVideoModel,
|
"DownloadAndLoadCogVideoModel": DownloadAndLoadCogVideoModel,
|
||||||
"CogVideoSampler": CogVideoSampler,
|
"CogVideoSampler": CogVideoSampler,
|
||||||
"CogVideoDecode": CogVideoDecode,
|
"CogVideoDecode": CogVideoDecode,
|
||||||
"CogVideoTextEncode": CogVideoTextEncode
|
"CogVideoTextEncode": CogVideoTextEncode,
|
||||||
|
"CogVideoImageEncode": CogVideoImageEncode
|
||||||
}
|
}
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
|
"DownloadAndLoadCogVideoModel": "(Down)load CogVideo Model",
|
||||||
"CogVideoSampler": "CogVideo Sampler",
|
"CogVideoSampler": "CogVideo Sampler",
|
||||||
"CogVideoDecode": "CogVideo Decode",
|
"CogVideoDecode": "CogVideo Decode",
|
||||||
"CogVideoTextEncode": "CogVideo TextEncode"
|
"CogVideoTextEncode": "CogVideo TextEncode",
|
||||||
|
"CogVideoImageEncode": "CogVideo ImageEncode"
|
||||||
}
|
}
|
||||||
@ -18,7 +18,6 @@ from dataclasses import dataclass
|
|||||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import T5EncoderModel, T5Tokenizer
|
|
||||||
|
|
||||||
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||||
from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
|
from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
|
||||||
@ -165,8 +164,6 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tokenizer: T5Tokenizer,
|
|
||||||
#text_encoder: T5EncoderModel,
|
|
||||||
vae: AutoencoderKLCogVideoX,
|
vae: AutoencoderKLCogVideoX,
|
||||||
transformer: CogVideoXTransformer3DModel,
|
transformer: CogVideoXTransformer3DModel,
|
||||||
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
||||||
@ -174,7 +171,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.register_modules(
|
self.register_modules(
|
||||||
tokenizer=tokenizer, vae=vae, transformer=transformer, scheduler=scheduler
|
vae=vae, transformer=transformer, scheduler=scheduler
|
||||||
)
|
)
|
||||||
self.vae_scale_factor_spatial = (
|
self.vae_scale_factor_spatial = (
|
||||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||||
@ -182,136 +179,11 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
self.vae_scale_factor_temporal = (
|
self.vae_scale_factor_temporal = (
|
||||||
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
|
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
|
||||||
)
|
)
|
||||||
self.tokenizer_max_length = (
|
|
||||||
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 226
|
|
||||||
)
|
|
||||||
|
|
||||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||||
|
|
||||||
def _get_t5_prompt_embeds(
|
|
||||||
self,
|
|
||||||
prompt: Union[str, List[str]] = None,
|
|
||||||
num_videos_per_prompt: int = 1,
|
|
||||||
max_sequence_length: int = 226,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
):
|
|
||||||
device = device or self._execution_device
|
|
||||||
|
|
||||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
|
||||||
batch_size = len(prompt)
|
|
||||||
|
|
||||||
text_inputs = self.tokenizer(
|
|
||||||
prompt,
|
|
||||||
padding="max_length",
|
|
||||||
max_length=max_sequence_length,
|
|
||||||
truncation=True,
|
|
||||||
add_special_tokens=True,
|
|
||||||
return_tensors="pt",
|
|
||||||
)
|
|
||||||
text_input_ids = text_inputs.input_ids
|
|
||||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
|
||||||
|
|
||||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
|
||||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
|
||||||
logger.warning(
|
|
||||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
|
||||||
f" {max_sequence_length} tokens: {removed_text}"
|
|
||||||
)
|
|
||||||
|
|
||||||
#prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
|
||||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
|
||||||
|
|
||||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
|
||||||
_, seq_len, _ = prompt_embeds.shape
|
|
||||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
|
||||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
|
||||||
|
|
||||||
return prompt_embeds
|
|
||||||
|
|
||||||
def encode_prompt(
|
|
||||||
self,
|
|
||||||
prompt: Union[str, List[str]],
|
|
||||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
|
||||||
do_classifier_free_guidance: bool = True,
|
|
||||||
num_videos_per_prompt: int = 1,
|
|
||||||
prompt_embeds: Optional[torch.Tensor] = None,
|
|
||||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
|
||||||
max_sequence_length: int = 226,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
):
|
|
||||||
r"""
|
|
||||||
Encodes the prompt into text encoder hidden states.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt (`str` or `List[str]`, *optional*):
|
|
||||||
prompt to be encoded
|
|
||||||
negative_prompt (`str` or `List[str]`, *optional*):
|
|
||||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
|
||||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
|
||||||
less than `1`).
|
|
||||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to use classifier free guidance or not.
|
|
||||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
|
||||||
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
|
||||||
prompt_embeds (`torch.Tensor`, *optional*):
|
|
||||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
|
||||||
provided, text embeddings will be generated from `prompt` input argument.
|
|
||||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
|
||||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
|
||||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
|
||||||
argument.
|
|
||||||
device: (`torch.device`, *optional*):
|
|
||||||
torch device
|
|
||||||
dtype: (`torch.dtype`, *optional*):
|
|
||||||
torch dtype
|
|
||||||
"""
|
|
||||||
device = device or self._execution_device
|
|
||||||
|
|
||||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
|
||||||
if prompt is not None:
|
|
||||||
batch_size = len(prompt)
|
|
||||||
else:
|
|
||||||
batch_size = prompt_embeds.shape[0]
|
|
||||||
|
|
||||||
if prompt_embeds is None:
|
|
||||||
prompt_embeds = self._get_t5_prompt_embeds(
|
|
||||||
prompt=prompt,
|
|
||||||
num_videos_per_prompt=num_videos_per_prompt,
|
|
||||||
max_sequence_length=max_sequence_length,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
|
||||||
negative_prompt = negative_prompt or ""
|
|
||||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
|
||||||
|
|
||||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
|
||||||
raise TypeError(
|
|
||||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
|
||||||
f" {type(prompt)}."
|
|
||||||
)
|
|
||||||
elif batch_size != len(negative_prompt):
|
|
||||||
raise ValueError(
|
|
||||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
|
||||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
|
||||||
" the batch size of `prompt`."
|
|
||||||
)
|
|
||||||
|
|
||||||
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
|
||||||
prompt=negative_prompt,
|
|
||||||
num_videos_per_prompt=num_videos_per_prompt,
|
|
||||||
max_sequence_length=max_sequence_length,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
return prompt_embeds, negative_prompt_embeds
|
|
||||||
|
|
||||||
def prepare_latents(
|
def prepare_latents(
|
||||||
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, timesteps, denoise_strength, num_inference_steps, latents=None,
|
||||||
):
|
):
|
||||||
shape = (
|
shape = (
|
||||||
batch_size,
|
batch_size,
|
||||||
@ -328,12 +200,27 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
|
|
||||||
if latents is None:
|
if latents is None:
|
||||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||||
|
# scale the initial noise by the standard deviation required by the scheduler
|
||||||
|
|
||||||
else:
|
else:
|
||||||
latents = latents.to(device)
|
latents = latents.to(device)
|
||||||
|
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, denoise_strength, device)
|
||||||
|
latent_timestep = timesteps[:1]
|
||||||
|
|
||||||
|
noise = randn_tensor(shape, generator=generator, device=device, dtype=latents.dtype)
|
||||||
|
frames_needed = noise.shape[1]
|
||||||
|
current_frames = latents.shape[1]
|
||||||
|
|
||||||
|
if frames_needed > current_frames:
|
||||||
|
repeat_factor = frames_needed // current_frames
|
||||||
|
additional_frame = torch.randn((latents.size(0), repeat_factor, latents.size(2), latents.size(3), latents.size(4)), dtype=latents.dtype, device=latents.device)
|
||||||
|
latents = torch.cat((latents, additional_frame), dim=1)
|
||||||
|
elif frames_needed < current_frames:
|
||||||
|
latents = latents[:, :frames_needed, :, :, :]
|
||||||
|
|
||||||
# scale the initial noise by the standard deviation required by the scheduler
|
latents = self.scheduler.add_noise(latents, noise, latent_timestep)
|
||||||
latents = latents * self.scheduler.init_noise_sigma
|
latents = latents * self.scheduler.init_noise_sigma
|
||||||
return latents
|
return latents, timesteps
|
||||||
|
|
||||||
def decode_latents(self, latents: torch.Tensor, num_seconds: int):
|
def decode_latents(self, latents: torch.Tensor, num_seconds: int):
|
||||||
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
||||||
@ -372,10 +259,8 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
||||||
def check_inputs(
|
def check_inputs(
|
||||||
self,
|
self,
|
||||||
prompt,
|
|
||||||
height,
|
height,
|
||||||
width,
|
width,
|
||||||
negative_prompt,
|
|
||||||
callback_on_step_end_tensor_inputs,
|
callback_on_step_end_tensor_inputs,
|
||||||
prompt_embeds=None,
|
prompt_embeds=None,
|
||||||
negative_prompt_embeds=None,
|
negative_prompt_embeds=None,
|
||||||
@ -389,29 +274,6 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||||
)
|
)
|
||||||
if prompt is not None and prompt_embeds is not None:
|
|
||||||
raise ValueError(
|
|
||||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
|
||||||
" only forward one of the two."
|
|
||||||
)
|
|
||||||
elif prompt is None and prompt_embeds is None:
|
|
||||||
raise ValueError(
|
|
||||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
|
||||||
)
|
|
||||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
|
||||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
|
||||||
|
|
||||||
if prompt is not None and negative_prompt_embeds is not None:
|
|
||||||
raise ValueError(
|
|
||||||
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
|
||||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
|
||||||
)
|
|
||||||
|
|
||||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
|
||||||
raise ValueError(
|
|
||||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
|
||||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
|
||||||
)
|
|
||||||
|
|
||||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||||
@ -420,6 +282,16 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||||
f" {negative_prompt_embeds.shape}."
|
f" {negative_prompt_embeds.shape}."
|
||||||
)
|
)
|
||||||
|
def get_timesteps(self, num_inference_steps, strength, device):
|
||||||
|
# get the original timestep using init_timestep
|
||||||
|
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||||
|
|
||||||
|
t_start = max(num_inference_steps - init_timestep, 0)
|
||||||
|
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||||
|
if hasattr(self.scheduler, "set_begin_index"):
|
||||||
|
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||||
|
|
||||||
|
return timesteps.to(device), num_inference_steps - t_start
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def guidance_scale(self):
|
def guidance_scale(self):
|
||||||
@ -444,8 +316,6 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
prompt: Optional[Union[str, List[str]]] = None,
|
|
||||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
|
||||||
height: int = 480,
|
height: int = 480,
|
||||||
width: int = 720,
|
width: int = 720,
|
||||||
num_frames: int = 48,
|
num_frames: int = 48,
|
||||||
@ -453,6 +323,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
num_inference_steps: int = 50,
|
num_inference_steps: int = 50,
|
||||||
timesteps: Optional[List[int]] = None,
|
timesteps: Optional[List[int]] = None,
|
||||||
guidance_scale: float = 6,
|
guidance_scale: float = 6,
|
||||||
|
denoise_strength: float = 1.0,
|
||||||
num_videos_per_prompt: int = 1,
|
num_videos_per_prompt: int = 1,
|
||||||
eta: float = 0.0,
|
eta: float = 0.0,
|
||||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||||
@ -553,10 +424,8 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
|
|
||||||
# 1. Check inputs. Raise error if not correct
|
# 1. Check inputs. Raise error if not correct
|
||||||
self.check_inputs(
|
self.check_inputs(
|
||||||
prompt,
|
|
||||||
height,
|
height,
|
||||||
width,
|
width,
|
||||||
negative_prompt,
|
|
||||||
callback_on_step_end_tensor_inputs,
|
callback_on_step_end_tensor_inputs,
|
||||||
prompt_embeds,
|
prompt_embeds,
|
||||||
negative_prompt_embeds,
|
negative_prompt_embeds,
|
||||||
@ -565,12 +434,8 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
self._interrupt = False
|
self._interrupt = False
|
||||||
|
|
||||||
# 2. Default call parameters
|
# 2. Default call parameters
|
||||||
if prompt is not None and isinstance(prompt, str):
|
|
||||||
batch_size = 1
|
batch_size = prompt_embeds.shape[0]
|
||||||
elif prompt is not None and isinstance(prompt, list):
|
|
||||||
batch_size = len(prompt)
|
|
||||||
else:
|
|
||||||
batch_size = prompt_embeds.shape[0]
|
|
||||||
|
|
||||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||||
@ -587,7 +452,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
# 5. Prepare latents.
|
# 5. Prepare latents.
|
||||||
latent_channels = self.transformer.config.in_channels
|
latent_channels = self.transformer.config.in_channels
|
||||||
num_frames += 1
|
num_frames += 1
|
||||||
latents = self.prepare_latents(
|
latents, timesteps = self.prepare_latents(
|
||||||
batch_size * num_videos_per_prompt,
|
batch_size * num_videos_per_prompt,
|
||||||
latent_channels,
|
latent_channels,
|
||||||
num_frames,
|
num_frames,
|
||||||
@ -596,7 +461,10 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
prompt_embeds.dtype,
|
prompt_embeds.dtype,
|
||||||
device,
|
device,
|
||||||
generator,
|
generator,
|
||||||
latents,
|
timesteps,
|
||||||
|
denoise_strength,
|
||||||
|
num_inference_steps,
|
||||||
|
latents
|
||||||
)
|
)
|
||||||
|
|
||||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user