Support 5b controlnet

This commit is contained in:
kijai 2024-10-23 19:23:50 +03:00
parent ccddf0b271
commit 9e488568b2

View File

@ -14,6 +14,8 @@ from diffusers.configuration_utils import ConfigMixin, register_to_config
class CogVideoXControlnet(ModelMixin, ConfigMixin, PeftAdapterMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
@ -42,6 +44,7 @@ class CogVideoXControlnet(ModelMixin, ConfigMixin, PeftAdapterMixin):
temporal_interpolation_scale: float = 1.0,
use_rotary_positional_embeddings: bool = False,
use_learned_positional_embeddings: bool = False,
out_proj_dim = None,
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
@ -109,7 +112,16 @@ class CogVideoXControlnet(ModelMixin, ConfigMixin, PeftAdapterMixin):
]
)
self.out_projectors = None
if out_proj_dim is not None:
self.out_projectors = nn.ModuleList(
[nn.Linear(inner_dim, out_proj_dim) for _ in range(num_layers)]
)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def compress_time(self, x, num_frames):
x = rearrange(x, '(b f) c h w -> b f c h w', f=num_frames)
@ -197,7 +209,11 @@ class CogVideoXControlnet(ModelMixin, ConfigMixin, PeftAdapterMixin):
temb=emb,
image_rotary_emb=image_rotary_emb,
)
controlnet_hidden_states += (hidden_states,)
if self.out_projectors is not None:
controlnet_hidden_states += (self.out_projectors[i](hidden_states),)
else:
controlnet_hidden_states += (hidden_states,)
if not return_dict:
return (controlnet_hidden_states,)