mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2026-05-02 09:26:44 +08:00
fp8 controlnet fixes
This commit is contained in:
parent
ba694115c6
commit
032a849bc6
@ -570,14 +570,16 @@ class CogVideoXPipeline(VideoSysPipeline):
|
|||||||
self.controlnet = controlnet["control_model"].to(device)
|
self.controlnet = controlnet["control_model"].to(device)
|
||||||
if self.transformer.dtype == torch.float8_e4m3fn:
|
if self.transformer.dtype == torch.float8_e4m3fn:
|
||||||
for name, param in self.controlnet.named_parameters():
|
for name, param in self.controlnet.named_parameters():
|
||||||
if "patch_embed" not in name:
|
if "patch_embed" not in name and param.data.dtype != torch.float8_e4m3fn:
|
||||||
param.data = param.data.to(torch.float8_e4m3fn)
|
param.data = param.data.to(torch.float8_e4m3fn)
|
||||||
else:
|
else:
|
||||||
self.controlnet.to(self.transformer.dtype)
|
self.controlnet.to(self.transformer.dtype)
|
||||||
|
|
||||||
if getattr(self.transformer, 'fp8_matmul_enabled', False):
|
if getattr(self.transformer, 'fp8_matmul_enabled', False):
|
||||||
from .fp8_optimization import convert_fp8_linear
|
from .fp8_optimization import convert_fp8_linear
|
||||||
convert_fp8_linear(self.controlnet, torch.float16)
|
if not hasattr(self.controlnet, 'fp8_matmul_enabled') or not self.controlnet.fp8_matmul_enabled:
|
||||||
|
convert_fp8_linear(self.controlnet, torch.float16)
|
||||||
|
setattr(self.controlnet, "fp8_matmul_enabled", True)
|
||||||
|
|
||||||
control_frames = controlnet["control_frames"].to(device).to(self.controlnet.dtype).contiguous()
|
control_frames = controlnet["control_frames"].to(device).to(self.controlnet.dtype).contiguous()
|
||||||
control_frames = torch.cat([control_frames] * 2) if do_classifier_free_guidance else control_frames
|
control_frames = torch.cat([control_frames] * 2) if do_classifier_free_guidance else control_frames
|
||||||
|
|||||||
@ -522,6 +522,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
|||||||
timestep_cond: Optional[torch.Tensor] = None,
|
timestep_cond: Optional[torch.Tensor] = None,
|
||||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
|
controlnet_states: torch.Tensor = None,
|
||||||
|
controlnet_weights: Optional[Union[float, int, list, torch.FloatTensor]] = 1.0,
|
||||||
):
|
):
|
||||||
# if self.parallel_manager.cp_size > 1:
|
# if self.parallel_manager.cp_size > 1:
|
||||||
# (
|
# (
|
||||||
@ -597,6 +599,15 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
|||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=image_rotary_emb,
|
||||||
timestep=timesteps if enable_pab() else None,
|
timestep=timesteps if enable_pab() else None,
|
||||||
)
|
)
|
||||||
|
if (controlnet_states is not None) and (i < len(controlnet_states)):
|
||||||
|
controlnet_states_block = controlnet_states[i]
|
||||||
|
controlnet_block_weight = 1.0
|
||||||
|
if isinstance(controlnet_weights, (list)) or torch.is_tensor(controlnet_weights):
|
||||||
|
controlnet_block_weight = controlnet_weights[i]
|
||||||
|
elif isinstance(controlnet_weights, (float, int)):
|
||||||
|
controlnet_block_weight = controlnet_weights
|
||||||
|
|
||||||
|
hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight
|
||||||
|
|
||||||
#if self.parallel_manager.sp_size > 1:
|
#if self.parallel_manager.sp_size > 1:
|
||||||
# hidden_states = gather_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
|
# hidden_states = gather_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user