mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2026-05-02 14:20:13 +08:00
fix fp8 for I2V
This commit is contained in:
parent
f298ac84b5
commit
5a6bcfd612
20
nodes.py
20
nodes.py
@ -57,8 +57,8 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
mm.soft_empty_cache()
|
mm.soft_empty_cache()
|
||||||
|
|
||||||
if "I2V" in model and fp8_transformer != "disabled":
|
#if "I2V" in model and fp8_transformer != "disabled":
|
||||||
raise NotImplementedError("fp8_transformer is not implemented yet for I2V -model")
|
# raise NotImplementedError("fp8_transformer is not implemented yet for I2V -model")
|
||||||
|
|
||||||
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
||||||
|
|
||||||
@ -99,7 +99,15 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
if name != "pos_embedding":
|
if name != "pos_embedding":
|
||||||
param.data = param.data.to(torch.float8_e4m3fn)
|
param.data = param.data.to(torch.float8_e4m3fn)
|
||||||
else:
|
else:
|
||||||
transformer.to(torch.float8_e4m3fn)
|
for name, param in transformer.named_parameters():
|
||||||
|
|
||||||
|
if "patch_embed" not in name:
|
||||||
|
param.data = param.data.to(torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
else:
|
||||||
|
print(name)
|
||||||
|
print(param.data.dtype)
|
||||||
|
#transformer.to(torch.float8_e4m3fn)
|
||||||
|
|
||||||
if fp8_transformer == "fastmode":
|
if fp8_transformer == "fastmode":
|
||||||
from .fp8_optimization import convert_fp8_linear
|
from .fp8_optimization import convert_fp8_linear
|
||||||
@ -238,6 +246,9 @@ class CogVideoTextEncode:
|
|||||||
return {"required": {
|
return {"required": {
|
||||||
"clip": ("CLIP",),
|
"clip": ("CLIP",),
|
||||||
"prompt": ("STRING", {"default": "", "multiline": True} ),
|
"prompt": ("STRING", {"default": "", "multiline": True} ),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -246,7 +257,7 @@ class CogVideoTextEncode:
|
|||||||
FUNCTION = "process"
|
FUNCTION = "process"
|
||||||
CATEGORY = "CogVideoWrapper"
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
def process(self, clip, prompt):
|
def process(self, clip, prompt, strength=1.0):
|
||||||
load_device = mm.text_encoder_device()
|
load_device = mm.text_encoder_device()
|
||||||
offload_device = mm.text_encoder_offload_device()
|
offload_device = mm.text_encoder_offload_device()
|
||||||
clip.tokenizer.t5xxl.pad_to_max_length = True
|
clip.tokenizer.t5xxl.pad_to_max_length = True
|
||||||
@ -255,6 +266,7 @@ class CogVideoTextEncode:
|
|||||||
tokens = clip.tokenize(prompt, return_word_ids=True)
|
tokens = clip.tokenize(prompt, return_word_ids=True)
|
||||||
|
|
||||||
embeds = clip.encode_from_tokens(tokens, return_pooled=False, return_dict=False)
|
embeds = clip.encode_from_tokens(tokens, return_pooled=False, return_dict=False)
|
||||||
|
embeds *= strength
|
||||||
clip.cond_stage_model.to(offload_device)
|
clip.cond_stage_model.to(offload_device)
|
||||||
|
|
||||||
return (embeds, )
|
return (embeds, )
|
||||||
|
|||||||
@ -412,7 +412,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
|
|
||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||||
prompt_embeds = prompt_embeds.to(self.transformer.dtype)
|
prompt_embeds = prompt_embeds.to(self.vae.dtype)
|
||||||
|
|
||||||
# 4. Prepare timesteps
|
# 4. Prepare timesteps
|
||||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||||
@ -442,14 +442,11 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
num_inference_steps,
|
num_inference_steps,
|
||||||
latents
|
latents
|
||||||
)
|
)
|
||||||
latents = latents.to(self.transformer.dtype)
|
latents = latents.to(self.vae.dtype)
|
||||||
print("latents", latents.shape)
|
print("latents", latents.shape)
|
||||||
|
|
||||||
# 5.5.
|
# 5.5.
|
||||||
if image_cond_latents is not None:
|
if image_cond_latents is not None:
|
||||||
print("image_cond_latents", image_cond_latents.shape)
|
|
||||||
#image_cond_latents = torch.cat(image_cond_latents, dim=0).to(self.transformer.dtype)#.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
|
|
||||||
|
|
||||||
padding_shape = (
|
padding_shape = (
|
||||||
batch_size,
|
batch_size,
|
||||||
(latents.shape[1] - 1),
|
(latents.shape[1] - 1),
|
||||||
@ -457,10 +454,8 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
height // self.vae_scale_factor_spatial,
|
height // self.vae_scale_factor_spatial,
|
||||||
width // self.vae_scale_factor_spatial,
|
width // self.vae_scale_factor_spatial,
|
||||||
)
|
)
|
||||||
print("padding_shape", padding_shape)
|
|
||||||
latent_padding = torch.zeros(padding_shape, device=device, dtype=self.transformer.dtype)
|
latent_padding = torch.zeros(padding_shape, device=device, dtype=self.transformer.dtype)
|
||||||
image_cond_latents = torch.cat([image_cond_latents, latent_padding], dim=1)
|
image_cond_latents = torch.cat([image_cond_latents, latent_padding], dim=1)
|
||||||
print("image_cond_latents", image_cond_latents.shape)
|
|
||||||
|
|
||||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||||
@ -602,11 +597,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
|||||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||||
|
|
||||||
if image_cond_latents is not None:
|
if image_cond_latents is not None:
|
||||||
|
|
||||||
|
|
||||||
latent_image_input = torch.cat([image_cond_latents] * 2) if do_classifier_free_guidance else image_cond_latents
|
latent_image_input = torch.cat([image_cond_latents] * 2) if do_classifier_free_guidance else image_cond_latents
|
||||||
print("latent_model_input",latent_model_input.shape)
|
|
||||||
print("image_cond_latents",image_cond_latents.shape)
|
|
||||||
latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2)
|
latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2)
|
||||||
|
|
||||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user