mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-08 20:34:23 +08:00
support transformer at fp8
This commit is contained in:
parent
46fe99f102
commit
65e9e995e8
@ -818,7 +818,7 @@
|
||||
"widgets_values": {
|
||||
"frame_rate": 8,
|
||||
"loop_count": 0,
|
||||
"filename_prefix": "AnimateDiff",
|
||||
"filename_prefix": "CogVideoX_vid2vid",
|
||||
"format": "video/nvenc_h264-mp4",
|
||||
"pix_fmt": "yuv420p",
|
||||
"bitrate": 10,
|
||||
|
||||
54
nodes.py
54
nodes.py
@ -25,15 +25,11 @@ class DownloadAndLoadCogVideoModel:
|
||||
|
||||
},
|
||||
"optional": {
|
||||
"precision": (
|
||||
[
|
||||
"fp16",
|
||||
"fp32",
|
||||
"bf16",
|
||||
],
|
||||
{"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"},
|
||||
"precision": (["fp16", "fp32", "bf16"],
|
||||
{"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"}
|
||||
),
|
||||
},
|
||||
"fp8_transformer": ("BOOLEAN", {"default": False, "tooltip": "cast the transformer to torch.float8_e4m3fn"}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("COGVIDEOPIPE",)
|
||||
@ -41,12 +37,16 @@ class DownloadAndLoadCogVideoModel:
|
||||
FUNCTION = "loadmodel"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def loadmodel(self, model, precision):
|
||||
def loadmodel(self, model, precision, fp8_transformer):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
mm.soft_empty_cache()
|
||||
|
||||
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
||||
if fp8_transformer:
|
||||
transformer_dtype = torch.float8_e4m3fn
|
||||
else:
|
||||
transformer_dtype = dtype
|
||||
|
||||
if "2b" in model:
|
||||
base_path = os.path.join(folder_paths.models_dir, "CogVideo", "CogVideo2B")
|
||||
@ -63,7 +63,7 @@ class DownloadAndLoadCogVideoModel:
|
||||
local_dir=base_path,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder="transformer").to(dtype).to(offload_device)
|
||||
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder="transformer").to(transformer_dtype).to(offload_device)
|
||||
vae = AutoencoderKLCogVideoX.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device)
|
||||
scheduler = CogVideoXDDIMScheduler.from_pretrained(base_path, subfolder="scheduler")
|
||||
|
||||
@ -247,22 +247,22 @@ class CogVideoSampler:
|
||||
pipe.scheduler = CogVideoXDDIMScheduler.from_pretrained(base_path, subfolder="scheduler")
|
||||
elif scheduler == "DPM":
|
||||
pipe.scheduler = CogVideoXDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
|
||||
|
||||
latents = pipeline["pipe"](
|
||||
num_inference_steps=steps,
|
||||
height = height,
|
||||
width = width,
|
||||
num_frames = num_frames,
|
||||
t_tile_length = t_tile_length,
|
||||
t_tile_overlap = t_tile_overlap,
|
||||
guidance_scale=cfg,
|
||||
latents=samples["samples"] if samples is not None else None,
|
||||
denoise_strength=denoise_strength,
|
||||
prompt_embeds=positive.to(dtype).to(device),
|
||||
negative_prompt_embeds=negative.to(dtype).to(device),
|
||||
generator=generator,
|
||||
device=device
|
||||
)
|
||||
with torch.autocast(mm.get_autocast_device(device)):
|
||||
latents = pipeline["pipe"](
|
||||
num_inference_steps=steps,
|
||||
height = height,
|
||||
width = width,
|
||||
num_frames = num_frames,
|
||||
t_tile_length = t_tile_length,
|
||||
t_tile_overlap = t_tile_overlap,
|
||||
guidance_scale=cfg,
|
||||
latents=samples["samples"] if samples is not None else None,
|
||||
denoise_strength=denoise_strength,
|
||||
prompt_embeds=positive.to(dtype).to(device),
|
||||
negative_prompt_embeds=negative.to(dtype).to(device),
|
||||
generator=generator,
|
||||
device=device
|
||||
)
|
||||
pipe.transformer.to(offload_device)
|
||||
mm.soft_empty_cache()
|
||||
print(latents.shape)
|
||||
@ -297,7 +297,7 @@ class CogVideoDecode:
|
||||
tile_overlap_factor_height=1 / 12,
|
||||
tile_overlap_factor_width=1 / 12,
|
||||
)
|
||||
|
||||
latents = latents.to(vae.dtype)
|
||||
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
||||
latents = 1 / vae.config.scaling_factor * latents
|
||||
|
||||
|
||||
@ -170,13 +170,13 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=self.vae.dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, denoise_strength, device)
|
||||
latent_timestep = timesteps[:1]
|
||||
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=latents.dtype)
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=self.vae.dtype)
|
||||
frames_needed = noise.shape[1]
|
||||
current_frames = latents.shape[1]
|
||||
|
||||
@ -400,6 +400,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_embeds = prompt_embeds.to(self.transformer.dtype)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
@ -425,6 +426,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
num_inference_steps,
|
||||
latents
|
||||
)
|
||||
latents = latents.to(self.transformer.dtype)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
@ -551,18 +553,15 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
else:
|
||||
latents, old_pred_original_sample = self.scheduler.step(
|
||||
noise_pred,
|
||||
old_pred_original_sample,
|
||||
t,
|
||||
timesteps[i - 1] if i > 0 else None,
|
||||
latents,
|
||||
**extra_step_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
latents, old_pred_original_sample = self.scheduler.step(
|
||||
noise_pred,
|
||||
old_pred_original_sample,
|
||||
t,
|
||||
timesteps[i - 1] if i > 0 else None,
|
||||
latents.to(self.vae.dtype),
|
||||
**extra_step_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user