diff --git a/cogvideo_controlnet.py b/cogvideo_controlnet.py index b7f7399..04334e9 100644 --- a/cogvideo_controlnet.py +++ b/cogvideo_controlnet.py @@ -5,7 +5,7 @@ import torch from torch import nn from einops import rearrange import torch.nn.functional as F -from diffusers.models.transformers.cogvideox_transformer_3d import Transformer2DModelOutput, CogVideoXBlock +from .custom_cogvideox_transformer_3d import Transformer2DModelOutput, CogVideoXBlock from diffusers.utils import is_torch_version from diffusers.loaders import PeftAdapterMixin from diffusers.models.embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps diff --git a/fp8_optimization.py b/fp8_optimization.py index 1e7c42b..b01ac91 100644 --- a/fp8_optimization.py +++ b/fp8_optimization.py @@ -37,6 +37,7 @@ def fp8_linear_forward(cls, original_dtype, input): return cls.original_forward(input) def convert_fp8_linear(module, original_dtype): + setattr(module, "fp8_matmul_enabled", True) for name, module in module.named_modules(): if isinstance(module, nn.Linear): original_forward = module.forward diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index 1c98d04..842fee8 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -568,9 +568,21 @@ class CogVideoXPipeline(VideoSysPipeline): if controlnet is not None: self.controlnet = controlnet["control_model"].to(device) - control_frames = controlnet["control_frames"].to(device).to(self.vae.dtype).contiguous() + if self.transformer.dtype == torch.float8_e4m3fn: + for name, param in self.controlnet.named_parameters(): + if "patch_embed" not in name: + param.data = param.data.to(torch.float8_e4m3fn) + else: + self.controlnet.to(self.transformer.dtype) + + if getattr(self.transformer, 'fp8_matmul_enabled', False): + from .fp8_optimization import convert_fp8_linear + convert_fp8_linear(self.controlnet, torch.float16) + + control_frames = controlnet["control_frames"].to(device).to(self.controlnet.dtype).contiguous() control_frames = torch.cat([control_frames] * 2) if do_classifier_free_guidance else control_frames control_weights = controlnet["control_weights"] + print("Controlnet enabled with weights: ", control_weights) control_start = controlnet["control_start"] control_end = controlnet["control_end"] @@ -725,9 +737,9 @@ class CogVideoXPipeline(VideoSysPipeline): return_dict=False, )[0] if isinstance(controlnet_states, (tuple, list)): - controlnet_states = [x.to(dtype=self.vae.dtype) for x in controlnet_states] + controlnet_states = [x.to(dtype=self.controlnet.dtype) for x in controlnet_states] else: - controlnet_states = controlnet_states.to(dtype=self.vae.dtype) + controlnet_states = controlnet_states.to(dtype=self.controlnet.dtype) # predict noise model_output noise_pred[:, c, :, :, :] += self.transformer(