cast controlnet in fp8 if fp8 is enabled

This commit is contained in:
kijai 2024-10-08 20:17:46 +03:00
parent bf0f6888c2
commit 34475221d1
3 changed files with 17 additions and 4 deletions

View File

@ -5,7 +5,7 @@ import torch
from torch import nn
from einops import rearrange
import torch.nn.functional as F
from diffusers.models.transformers.cogvideox_transformer_3d import Transformer2DModelOutput, CogVideoXBlock
from .custom_cogvideox_transformer_3d import Transformer2DModelOutput, CogVideoXBlock
from diffusers.utils import is_torch_version
from diffusers.loaders import PeftAdapterMixin
from diffusers.models.embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps

View File

@ -37,6 +37,7 @@ def fp8_linear_forward(cls, original_dtype, input):
return cls.original_forward(input)
def convert_fp8_linear(module, original_dtype):
setattr(module, "fp8_matmul_enabled", True)
for name, module in module.named_modules():
if isinstance(module, nn.Linear):
original_forward = module.forward

View File

@ -568,9 +568,21 @@ class CogVideoXPipeline(VideoSysPipeline):
if controlnet is not None:
self.controlnet = controlnet["control_model"].to(device)
control_frames = controlnet["control_frames"].to(device).to(self.vae.dtype).contiguous()
if self.transformer.dtype == torch.float8_e4m3fn:
for name, param in self.controlnet.named_parameters():
if "patch_embed" not in name:
param.data = param.data.to(torch.float8_e4m3fn)
else:
self.controlnet.to(self.transformer.dtype)
if getattr(self.transformer, 'fp8_matmul_enabled', False):
from .fp8_optimization import convert_fp8_linear
convert_fp8_linear(self.controlnet, torch.float16)
control_frames = controlnet["control_frames"].to(device).to(self.controlnet.dtype).contiguous()
control_frames = torch.cat([control_frames] * 2) if do_classifier_free_guidance else control_frames
control_weights = controlnet["control_weights"]
print("Controlnet enabled with weights: ", control_weights)
control_start = controlnet["control_start"]
control_end = controlnet["control_end"]
@ -725,9 +737,9 @@ class CogVideoXPipeline(VideoSysPipeline):
return_dict=False,
)[0]
if isinstance(controlnet_states, (tuple, list)):
controlnet_states = [x.to(dtype=self.vae.dtype) for x in controlnet_states]
controlnet_states = [x.to(dtype=self.controlnet.dtype) for x in controlnet_states]
else:
controlnet_states = controlnet_states.to(dtype=self.vae.dtype)
controlnet_states = controlnet_states.to(dtype=self.controlnet.dtype)
# predict noise model_output
noise_pred[:, c, :, :, :] += self.transformer(