mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-08 20:34:23 +08:00
cast controlnet in fp8 if fp8 is enabled
This commit is contained in:
parent
bf0f6888c2
commit
34475221d1
@ -5,7 +5,7 @@ import torch
|
||||
from torch import nn
|
||||
from einops import rearrange
|
||||
import torch.nn.functional as F
|
||||
from diffusers.models.transformers.cogvideox_transformer_3d import Transformer2DModelOutput, CogVideoXBlock
|
||||
from .custom_cogvideox_transformer_3d import Transformer2DModelOutput, CogVideoXBlock
|
||||
from diffusers.utils import is_torch_version
|
||||
from diffusers.loaders import PeftAdapterMixin
|
||||
from diffusers.models.embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
|
||||
|
||||
@ -37,6 +37,7 @@ def fp8_linear_forward(cls, original_dtype, input):
|
||||
return cls.original_forward(input)
|
||||
|
||||
def convert_fp8_linear(module, original_dtype):
|
||||
setattr(module, "fp8_matmul_enabled", True)
|
||||
for name, module in module.named_modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
original_forward = module.forward
|
||||
|
||||
@ -568,9 +568,21 @@ class CogVideoXPipeline(VideoSysPipeline):
|
||||
|
||||
if controlnet is not None:
|
||||
self.controlnet = controlnet["control_model"].to(device)
|
||||
control_frames = controlnet["control_frames"].to(device).to(self.vae.dtype).contiguous()
|
||||
if self.transformer.dtype == torch.float8_e4m3fn:
|
||||
for name, param in self.controlnet.named_parameters():
|
||||
if "patch_embed" not in name:
|
||||
param.data = param.data.to(torch.float8_e4m3fn)
|
||||
else:
|
||||
self.controlnet.to(self.transformer.dtype)
|
||||
|
||||
if getattr(self.transformer, 'fp8_matmul_enabled', False):
|
||||
from .fp8_optimization import convert_fp8_linear
|
||||
convert_fp8_linear(self.controlnet, torch.float16)
|
||||
|
||||
control_frames = controlnet["control_frames"].to(device).to(self.controlnet.dtype).contiguous()
|
||||
control_frames = torch.cat([control_frames] * 2) if do_classifier_free_guidance else control_frames
|
||||
control_weights = controlnet["control_weights"]
|
||||
print("Controlnet enabled with weights: ", control_weights)
|
||||
control_start = controlnet["control_start"]
|
||||
control_end = controlnet["control_end"]
|
||||
|
||||
@ -725,9 +737,9 @@ class CogVideoXPipeline(VideoSysPipeline):
|
||||
return_dict=False,
|
||||
)[0]
|
||||
if isinstance(controlnet_states, (tuple, list)):
|
||||
controlnet_states = [x.to(dtype=self.vae.dtype) for x in controlnet_states]
|
||||
controlnet_states = [x.to(dtype=self.controlnet.dtype) for x in controlnet_states]
|
||||
else:
|
||||
controlnet_states = controlnet_states.to(dtype=self.vae.dtype)
|
||||
controlnet_states = controlnet_states.to(dtype=self.controlnet.dtype)
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred[:, c, :, :, :] += self.transformer(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user