fp8 controlnet fixes

This commit is contained in:
kijai 2024-10-08 20:53:43 +03:00
parent ba694115c6
commit 032a849bc6
2 changed files with 16 additions and 3 deletions

View File

@ -570,14 +570,16 @@ class CogVideoXPipeline(VideoSysPipeline):
self.controlnet = controlnet["control_model"].to(device)
if self.transformer.dtype == torch.float8_e4m3fn:
for name, param in self.controlnet.named_parameters():
if "patch_embed" not in name:
param.data = param.data.to(torch.float8_e4m3fn)
if "patch_embed" not in name and param.data.dtype != torch.float8_e4m3fn:
param.data = param.data.to(torch.float8_e4m3fn)
else:
self.controlnet.to(self.transformer.dtype)
if getattr(self.transformer, 'fp8_matmul_enabled', False):
from .fp8_optimization import convert_fp8_linear
convert_fp8_linear(self.controlnet, torch.float16)
if not hasattr(self.controlnet, 'fp8_matmul_enabled') or not self.controlnet.fp8_matmul_enabled:
convert_fp8_linear(self.controlnet, torch.float16)
setattr(self.controlnet, "fp8_matmul_enabled", True)
control_frames = controlnet["control_frames"].to(device).to(self.controlnet.dtype).contiguous()
control_frames = torch.cat([control_frames] * 2) if do_classifier_free_guidance else control_frames

View File

@ -522,6 +522,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
timestep_cond: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
return_dict: bool = True,
controlnet_states: torch.Tensor = None,
controlnet_weights: Optional[Union[float, int, list, torch.FloatTensor]] = 1.0,
):
# if self.parallel_manager.cp_size > 1:
# (
@ -597,6 +599,15 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
image_rotary_emb=image_rotary_emb,
timestep=timesteps if enable_pab() else None,
)
if (controlnet_states is not None) and (i < len(controlnet_states)):
controlnet_states_block = controlnet_states[i]
controlnet_block_weight = 1.0
if isinstance(controlnet_weights, (list)) or torch.is_tensor(controlnet_weights):
controlnet_block_weight = controlnet_weights[i]
elif isinstance(controlnet_weights, (float, int)):
controlnet_block_weight = controlnet_weights
hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight
#if self.parallel_manager.sp_size > 1:
# hidden_states = gather_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))