From 9e488568b2005156fdb922250e0088549855d977 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 23 Oct 2024 19:23:50 +0300 Subject: [PATCH] Support 5b controlnet --- cogvideo_controlnet.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/cogvideo_controlnet.py b/cogvideo_controlnet.py index 04334e9..b360d0c 100644 --- a/cogvideo_controlnet.py +++ b/cogvideo_controlnet.py @@ -14,6 +14,8 @@ from diffusers.configuration_utils import ConfigMixin, register_to_config class CogVideoXControlnet(ModelMixin, ConfigMixin, PeftAdapterMixin): + _supports_gradient_checkpointing = True + @register_to_config def __init__( self, @@ -42,6 +44,7 @@ class CogVideoXControlnet(ModelMixin, ConfigMixin, PeftAdapterMixin): temporal_interpolation_scale: float = 1.0, use_rotary_positional_embeddings: bool = False, use_learned_positional_embeddings: bool = False, + out_proj_dim = None, ): super().__init__() inner_dim = num_attention_heads * attention_head_dim @@ -109,7 +112,16 @@ class CogVideoXControlnet(ModelMixin, ConfigMixin, PeftAdapterMixin): ] ) + self.out_projectors = None + if out_proj_dim is not None: + self.out_projectors = nn.ModuleList( + [nn.Linear(inner_dim, out_proj_dim) for _ in range(num_layers)] + ) + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + self.gradient_checkpointing = value def compress_time(self, x, num_frames): x = rearrange(x, '(b f) c h w -> b f c h w', f=num_frames) @@ -197,7 +209,11 @@ class CogVideoXControlnet(ModelMixin, ConfigMixin, PeftAdapterMixin): temb=emb, image_rotary_emb=image_rotary_emb, ) - controlnet_hidden_states += (hidden_states,) + + if self.out_projectors is not None: + controlnet_hidden_states += (self.out_projectors[i](hidden_states),) + else: + controlnet_hidden_states += (hidden_states,) if not return_dict: return (controlnet_hidden_states,)