fix fp8 with controlnet

This commit is contained in:
kijai 2024-10-08 18:05:28 +03:00
parent e047e6f07f
commit bf0f6888c2
2 changed files with 6 additions and 6 deletions

View File

@ -1358,7 +1358,7 @@ class CogVideoControlNet:
controlnet = {
"control_model": controlnet,
"control_frames": control_frames,
"control_strength": control_strength,
"control_weights": control_strength,
"control_start": control_start_percent,
"control_end": control_end_percent,
}

View File

@ -570,7 +570,7 @@ class CogVideoXPipeline(VideoSysPipeline):
self.controlnet = controlnet["control_model"].to(device)
control_frames = controlnet["control_frames"].to(device).to(self.vae.dtype).contiguous()
control_frames = torch.cat([control_frames] * 2) if do_classifier_free_guidance else control_frames
control_strength = controlnet["control_strength"]
control_weights = controlnet["control_weights"]
control_start = controlnet["control_start"]
control_end = controlnet["control_end"]
@ -737,7 +737,7 @@ class CogVideoXPipeline(VideoSysPipeline):
image_rotary_emb=image_rotary_emb,
return_dict=False,
controlnet_states=controlnet_states,
controlnet_weights=control_strength,
controlnet_weights=control_weights,
)[0]
counter[:, c, :, :, :] += 1
@ -792,9 +792,9 @@ class CogVideoXPipeline(VideoSysPipeline):
return_dict=False,
)[0]
if isinstance(controlnet_states, (tuple, list)):
controlnet_states = [x.to(dtype=self.transformer.dtype) for x in controlnet_states]
controlnet_states = [x.to(dtype=self.vae.dtype) for x in controlnet_states]
else:
controlnet_states = controlnet_states.to(dtype=self.transformer.dtype)
controlnet_states = controlnet_states.to(dtype=self.vae.dtype)
# predict noise model_output
@ -805,7 +805,7 @@ class CogVideoXPipeline(VideoSysPipeline):
image_rotary_emb=image_rotary_emb,
return_dict=False,
controlnet_states=controlnet_states,
controlnet_weights=control_strength,
controlnet_weights=control_weights,
)[0]
noise_pred = noise_pred.float()