From bf0f6888c2e89616f031bef1464ba5449fb43980 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 8 Oct 2024 18:05:28 +0300 Subject: [PATCH] fix fp8 with controlnet --- nodes.py | 2 +- pipeline_cogvideox.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/nodes.py b/nodes.py index 0d06eea..82c593b 100644 --- a/nodes.py +++ b/nodes.py @@ -1358,7 +1358,7 @@ class CogVideoControlNet: controlnet = { "control_model": controlnet, "control_frames": control_frames, - "control_strength": control_strength, + "control_weights": control_strength, "control_start": control_start_percent, "control_end": control_end_percent, } diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index 3ec28a2..1c98d04 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -570,7 +570,7 @@ class CogVideoXPipeline(VideoSysPipeline): self.controlnet = controlnet["control_model"].to(device) control_frames = controlnet["control_frames"].to(device).to(self.vae.dtype).contiguous() control_frames = torch.cat([control_frames] * 2) if do_classifier_free_guidance else control_frames - control_strength = controlnet["control_strength"] + control_weights = controlnet["control_weights"] control_start = controlnet["control_start"] control_end = controlnet["control_end"] @@ -737,7 +737,7 @@ class CogVideoXPipeline(VideoSysPipeline): image_rotary_emb=image_rotary_emb, return_dict=False, controlnet_states=controlnet_states, - controlnet_weights=control_strength, + controlnet_weights=control_weights, )[0] counter[:, c, :, :, :] += 1 @@ -792,9 +792,9 @@ class CogVideoXPipeline(VideoSysPipeline): return_dict=False, )[0] if isinstance(controlnet_states, (tuple, list)): - controlnet_states = [x.to(dtype=self.transformer.dtype) for x in controlnet_states] + controlnet_states = [x.to(dtype=self.vae.dtype) for x in controlnet_states] else: - controlnet_states = controlnet_states.to(dtype=self.transformer.dtype) + controlnet_states = controlnet_states.to(dtype=self.vae.dtype) # predict noise model_output @@ -805,7 +805,7 @@ class CogVideoXPipeline(VideoSysPipeline): image_rotary_emb=image_rotary_emb, return_dict=False, controlnet_states=controlnet_states, - controlnet_weights=control_strength, + controlnet_weights=control_weights, )[0] noise_pred = noise_pred.float()