mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-09 04:44:22 +08:00
support transformer at fp8
This commit is contained in:
parent
46fe99f102
commit
65e9e995e8
@ -818,7 +818,7 @@
|
|||||||
"widgets_values": {
|
"widgets_values": {
|
||||||
"frame_rate": 8,
|
"frame_rate": 8,
|
||||||
"loop_count": 0,
|
"loop_count": 0,
|
||||||
"filename_prefix": "AnimateDiff",
|
"filename_prefix": "CogVideoX_vid2vid",
|
||||||
"format": "video/nvenc_h264-mp4",
|
"format": "video/nvenc_h264-mp4",
|
||||||
"pix_fmt": "yuv420p",
|
"pix_fmt": "yuv420p",
|
||||||
"bitrate": 10,
|
"bitrate": 10,
|
||||||
|
|||||||
54
nodes.py
54
nodes.py
@ -25,15 +25,11 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
|
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"precision": (
|
"precision": (["fp16", "fp32", "bf16"],
|
||||||
[
|
{"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"}
|
||||||
"fp16",
|
|
||||||
"fp32",
|
|
||||||
"bf16",
|
|
||||||
],
|
|
||||||
{"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"},
|
|
||||||
),
|
),
|
||||||
},
|
"fp8_transformer": ("BOOLEAN", {"default": False, "tooltip": "cast the transformer to torch.float8_e4m3fn"}),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("COGVIDEOPIPE",)
|
RETURN_TYPES = ("COGVIDEOPIPE",)
|
||||||
@ -41,12 +37,16 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
FUNCTION = "loadmodel"
|
FUNCTION = "loadmodel"
|
||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
def loadmodel(self, model, precision):
|
def loadmodel(self, model, precision, fp8_transformer):
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
mm.soft_empty_cache()
|
mm.soft_empty_cache()
|
||||||
|
|
||||||
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
||||||
|
if fp8_transformer:
|
||||||
|
transformer_dtype = torch.float8_e4m3fn
|
||||||
|
else:
|
||||||
|
transformer_dtype = dtype
|
||||||
|
|
||||||
if "2b" in model:
|
if "2b" in model:
|
||||||
base_path = os.path.join(folder_paths.models_dir, "CogVideo", "CogVideo2B")
|
base_path = os.path.join(folder_paths.models_dir, "CogVideo", "CogVideo2B")
|
||||||
@ -63,7 +63,7 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
local_dir=base_path,
|
local_dir=base_path,
|
||||||
local_dir_use_symlinks=False,
|
local_dir_use_symlinks=False,
|
||||||
)
|
)
|
||||||
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder="transformer").to(dtype).to(offload_device)
|
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder="transformer").to(transformer_dtype).to(offload_device)
|
||||||
vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
|
vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
|
||||||
scheduler = CogVideoXDDIMScheduler.from_pretrained(base_path, subfolder="scheduler")
|
scheduler = CogVideoXDDIMScheduler.from_pretrained(base_path, subfolder="scheduler")
|
||||||
|
|
||||||
@ -247,22 +247,22 @@ class CogVideoSampler:
|
|||||||
pipe.scheduler = CogVideoXDDIMScheduler.from_pretrained(base_path, subfolder="scheduler")
|
pipe.scheduler = CogVideoXDDIMScheduler.from_pretrained(base_path, subfolder="scheduler")
|
||||||
elif scheduler == "DPM":
|
elif scheduler == "DPM":
|
||||||
pipe.scheduler = CogVideoXDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
|
pipe.scheduler = CogVideoXDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
|
||||||
|
with torch.autocast(mm.get_autocast_device(device)):
|
||||||
latents = pipeline["pipe"](
|
latents = pipeline["pipe"](
|
||||||
num_inference_steps=steps,
|
num_inference_steps=steps,
|
||||||
height = height,
|
height = height,
|
||||||
width = width,
|
width = width,
|
||||||
num_frames = num_frames,
|
num_frames = num_frames,
|
||||||
t_tile_length = t_tile_length,
|
t_tile_length = t_tile_length,
|
||||||
t_tile_overlap = t_tile_overlap,
|
t_tile_overlap = t_tile_overlap,
|
||||||
guidance_scale=cfg,
|
guidance_scale=cfg,
|
||||||
latents=samples["samples"] if samples is not None else None,
|
latents=samples["samples"] if samples is not None else None,
|
||||||
denoise_strength=denoise_strength,
|
denoise_strength=denoise_strength,
|
||||||
prompt_embeds=positive.to(dtype).to(device),
|
prompt_embeds=positive.to(dtype).to(device),
|
||||||
negative_prompt_embeds=negative.to(dtype).to(device),
|
negative_prompt_embeds=negative.to(dtype).to(device),
|
||||||
generator=generator,
|
generator=generator,
|
||||||
device=device
|
device=device
|
||||||
)
|
)
|
||||||
pipe.transformer.to(offload_device)
|
pipe.transformer.to(offload_device)
|
||||||
mm.soft_empty_cache()
|
mm.soft_empty_cache()
|
||||||
print(latents.shape)
|
print(latents.shape)
|
||||||
@ -297,7 +297,7 @@ class CogVideoDecode:
|
|||||||
tile_overlap_factor_height=1 / 12,
|
tile_overlap_factor_height=1 / 12,
|
||||||
tile_overlap_factor_width=1 / 12,
|
tile_overlap_factor_width=1 / 12,
|
||||||
)
|
)
|
||||||
|
latents = latents.to(vae.dtype)
|
||||||
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
||||||
latents = 1 / vae.config.scaling_factor * latents
|
latents = 1 / vae.config.scaling_factor * latents
|
||||||
|
|
||||||
|
|||||||
@ -170,13 +170,13 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if latents is None:
|
if latents is None:
|
||||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=self.vae.dtype)
|
||||||
else:
|
else:
|
||||||
latents = latents.to(device)
|
latents = latents.to(device)
|
||||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, denoise_strength, device)
|
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, denoise_strength, device)
|
||||||
latent_timestep = timesteps[:1]
|
latent_timestep = timesteps[:1]
|
||||||
|
|
||||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=latents.dtype)
|
noise = randn_tensor(shape, generator=generator, device=device, dtype=self.vae.dtype)
|
||||||
frames_needed = noise.shape[1]
|
frames_needed = noise.shape[1]
|
||||||
current_frames = latents.shape[1]
|
current_frames = latents.shape[1]
|
||||||
|
|
||||||
@ -400,6 +400,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
|
|
||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||||
|
prompt_embeds = prompt_embeds.to(self.transformer.dtype)
|
||||||
|
|
||||||
# 4. Prepare timesteps
|
# 4. Prepare timesteps
|
||||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||||
@ -425,6 +426,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
num_inference_steps,
|
num_inference_steps,
|
||||||
latents
|
latents
|
||||||
)
|
)
|
||||||
|
latents = latents.to(self.transformer.dtype)
|
||||||
|
|
||||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||||
@ -551,18 +553,15 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
latents, old_pred_original_sample = self.scheduler.step(
|
||||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
noise_pred,
|
||||||
else:
|
old_pred_original_sample,
|
||||||
latents, old_pred_original_sample = self.scheduler.step(
|
t,
|
||||||
noise_pred,
|
timesteps[i - 1] if i > 0 else None,
|
||||||
old_pred_original_sample,
|
latents.to(self.vae.dtype),
|
||||||
t,
|
**extra_step_kwargs,
|
||||||
timesteps[i - 1] if i > 0 else None,
|
return_dict=False,
|
||||||
latents,
|
)
|
||||||
**extra_step_kwargs,
|
|
||||||
return_dict=False,
|
|
||||||
)
|
|
||||||
latents = latents.to(prompt_embeds.dtype)
|
latents = latents.to(prompt_embeds.dtype)
|
||||||
|
|
||||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user