mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2026-01-23 12:24:23 +08:00
fp8 controlnet fixes
This commit is contained in:
parent
ba694115c6
commit
032a849bc6
@ -570,14 +570,16 @@ class CogVideoXPipeline(VideoSysPipeline):
|
||||
self.controlnet = controlnet["control_model"].to(device)
|
||||
if self.transformer.dtype == torch.float8_e4m3fn:
|
||||
for name, param in self.controlnet.named_parameters():
|
||||
if "patch_embed" not in name:
|
||||
param.data = param.data.to(torch.float8_e4m3fn)
|
||||
if "patch_embed" not in name and param.data.dtype != torch.float8_e4m3fn:
|
||||
param.data = param.data.to(torch.float8_e4m3fn)
|
||||
else:
|
||||
self.controlnet.to(self.transformer.dtype)
|
||||
|
||||
if getattr(self.transformer, 'fp8_matmul_enabled', False):
|
||||
from .fp8_optimization import convert_fp8_linear
|
||||
convert_fp8_linear(self.controlnet, torch.float16)
|
||||
if not hasattr(self.controlnet, 'fp8_matmul_enabled') or not self.controlnet.fp8_matmul_enabled:
|
||||
convert_fp8_linear(self.controlnet, torch.float16)
|
||||
setattr(self.controlnet, "fp8_matmul_enabled", True)
|
||||
|
||||
control_frames = controlnet["control_frames"].to(device).to(self.controlnet.dtype).contiguous()
|
||||
control_frames = torch.cat([control_frames] * 2) if do_classifier_free_guidance else control_frames
|
||||
|
||||
@ -522,6 +522,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
return_dict: bool = True,
|
||||
controlnet_states: torch.Tensor = None,
|
||||
controlnet_weights: Optional[Union[float, int, list, torch.FloatTensor]] = 1.0,
|
||||
):
|
||||
# if self.parallel_manager.cp_size > 1:
|
||||
# (
|
||||
@ -597,6 +599,15 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
timestep=timesteps if enable_pab() else None,
|
||||
)
|
||||
if (controlnet_states is not None) and (i < len(controlnet_states)):
|
||||
controlnet_states_block = controlnet_states[i]
|
||||
controlnet_block_weight = 1.0
|
||||
if isinstance(controlnet_weights, (list)) or torch.is_tensor(controlnet_weights):
|
||||
controlnet_block_weight = controlnet_weights[i]
|
||||
elif isinstance(controlnet_weights, (float, int)):
|
||||
controlnet_block_weight = controlnet_weights
|
||||
|
||||
hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight
|
||||
|
||||
#if self.parallel_manager.sp_size > 1:
|
||||
# hidden_states = gather_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user