From 032a849bc6c2aa91c3a9d1a8b4eca49c0a22c3e5 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 8 Oct 2024 20:53:43 +0300 Subject: [PATCH] fp8 controlnet fixes --- pipeline_cogvideox.py | 8 +++++--- videosys/cogvideox_transformer_3d.py | 11 +++++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index 842fee8..be97777 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -570,14 +570,16 @@ class CogVideoXPipeline(VideoSysPipeline): self.controlnet = controlnet["control_model"].to(device) if self.transformer.dtype == torch.float8_e4m3fn: for name, param in self.controlnet.named_parameters(): - if "patch_embed" not in name: - param.data = param.data.to(torch.float8_e4m3fn) + if "patch_embed" not in name and param.data.dtype != torch.float8_e4m3fn: + param.data = param.data.to(torch.float8_e4m3fn) else: self.controlnet.to(self.transformer.dtype) if getattr(self.transformer, 'fp8_matmul_enabled', False): from .fp8_optimization import convert_fp8_linear - convert_fp8_linear(self.controlnet, torch.float16) + if not hasattr(self.controlnet, 'fp8_matmul_enabled') or not self.controlnet.fp8_matmul_enabled: + convert_fp8_linear(self.controlnet, torch.float16) + setattr(self.controlnet, "fp8_matmul_enabled", True) control_frames = controlnet["control_frames"].to(device).to(self.controlnet.dtype).contiguous() control_frames = torch.cat([control_frames] * 2) if do_classifier_free_guidance else control_frames diff --git a/videosys/cogvideox_transformer_3d.py b/videosys/cogvideox_transformer_3d.py index 6fe697d..6a482fa 100644 --- a/videosys/cogvideox_transformer_3d.py +++ b/videosys/cogvideox_transformer_3d.py @@ -522,6 +522,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): timestep_cond: Optional[torch.Tensor] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, return_dict: bool = True, + controlnet_states: torch.Tensor = None, + controlnet_weights: Optional[Union[float, int, list, torch.FloatTensor]] = 1.0, ): # if self.parallel_manager.cp_size > 1: # ( @@ -597,6 +599,15 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): image_rotary_emb=image_rotary_emb, timestep=timesteps if enable_pab() else None, ) + if (controlnet_states is not None) and (i < len(controlnet_states)): + controlnet_states_block = controlnet_states[i] + controlnet_block_weight = 1.0 + if isinstance(controlnet_weights, (list)) or torch.is_tensor(controlnet_weights): + controlnet_block_weight = controlnet_weights[i] + elif isinstance(controlnet_weights, (float, int)): + controlnet_block_weight = controlnet_weights + + hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight #if self.parallel_manager.sp_size > 1: # hidden_states = gather_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))