mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-10 05:14:22 +08:00
fix fp8 with controlnet
This commit is contained in:
parent
e047e6f07f
commit
bf0f6888c2
2
nodes.py
2
nodes.py
@ -1358,7 +1358,7 @@ class CogVideoControlNet:
|
|||||||
controlnet = {
|
controlnet = {
|
||||||
"control_model": controlnet,
|
"control_model": controlnet,
|
||||||
"control_frames": control_frames,
|
"control_frames": control_frames,
|
||||||
"control_strength": control_strength,
|
"control_weights": control_strength,
|
||||||
"control_start": control_start_percent,
|
"control_start": control_start_percent,
|
||||||
"control_end": control_end_percent,
|
"control_end": control_end_percent,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -570,7 +570,7 @@ class CogVideoXPipeline(VideoSysPipeline):
|
|||||||
self.controlnet = controlnet["control_model"].to(device)
|
self.controlnet = controlnet["control_model"].to(device)
|
||||||
control_frames = controlnet["control_frames"].to(device).to(self.vae.dtype).contiguous()
|
control_frames = controlnet["control_frames"].to(device).to(self.vae.dtype).contiguous()
|
||||||
control_frames = torch.cat([control_frames] * 2) if do_classifier_free_guidance else control_frames
|
control_frames = torch.cat([control_frames] * 2) if do_classifier_free_guidance else control_frames
|
||||||
control_strength = controlnet["control_strength"]
|
control_weights = controlnet["control_weights"]
|
||||||
control_start = controlnet["control_start"]
|
control_start = controlnet["control_start"]
|
||||||
control_end = controlnet["control_end"]
|
control_end = controlnet["control_end"]
|
||||||
|
|
||||||
@ -737,7 +737,7 @@ class CogVideoXPipeline(VideoSysPipeline):
|
|||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=image_rotary_emb,
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
controlnet_states=controlnet_states,
|
controlnet_states=controlnet_states,
|
||||||
controlnet_weights=control_strength,
|
controlnet_weights=control_weights,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
counter[:, c, :, :, :] += 1
|
counter[:, c, :, :, :] += 1
|
||||||
@ -792,9 +792,9 @@ class CogVideoXPipeline(VideoSysPipeline):
|
|||||||
return_dict=False,
|
return_dict=False,
|
||||||
)[0]
|
)[0]
|
||||||
if isinstance(controlnet_states, (tuple, list)):
|
if isinstance(controlnet_states, (tuple, list)):
|
||||||
controlnet_states = [x.to(dtype=self.transformer.dtype) for x in controlnet_states]
|
controlnet_states = [x.to(dtype=self.vae.dtype) for x in controlnet_states]
|
||||||
else:
|
else:
|
||||||
controlnet_states = controlnet_states.to(dtype=self.transformer.dtype)
|
controlnet_states = controlnet_states.to(dtype=self.vae.dtype)
|
||||||
|
|
||||||
|
|
||||||
# predict noise model_output
|
# predict noise model_output
|
||||||
@ -805,7 +805,7 @@ class CogVideoXPipeline(VideoSysPipeline):
|
|||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=image_rotary_emb,
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
controlnet_states=controlnet_states,
|
controlnet_states=controlnet_states,
|
||||||
controlnet_weights=control_strength,
|
controlnet_weights=control_weights,
|
||||||
)[0]
|
)[0]
|
||||||
noise_pred = noise_pred.float()
|
noise_pred = noise_pred.float()
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user