import torch from torch import nn from .model import JointTransformerBlock class ZImageControlTransformerBlock(JointTransformerBlock): def __init__( self, layer_id: int, dim: int, n_heads: int, n_kv_heads: int, multiple_of: int, ffn_dim_multiplier: float, norm_eps: float, qk_norm: bool, modulation=True, block_id=0, operation_settings=None, ): super().__init__(layer_id, dim, n_heads, n_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, qk_norm, modulation, z_image_modulation=True, operation_settings=operation_settings) self.block_id = block_id if block_id == 0: self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) def forward(self, c, x, **kwargs): if self.block_id == 0: c = self.before_proj(c) + x c = super().forward(c, **kwargs) c_skip = self.after_proj(c) return c_skip, c class ZImage_Control(torch.nn.Module): def __init__( self, dim: int = 3840, n_heads: int = 30, n_kv_heads: int = 30, multiple_of: int = 256, ffn_dim_multiplier: float = (8.0 / 3.0), norm_eps: float = 1e-5, qk_norm: bool = True, dtype=None, device=None, operations=None, **kwargs ): super().__init__() operation_settings = {"operations": operations, "device": device, "dtype": dtype} self.additional_in_dim = 0 self.control_in_dim = 16 n_refiner_layers = 2 self.n_control_layers = 6 self.control_layers = nn.ModuleList( [ ZImageControlTransformerBlock( i, dim, n_heads, n_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, qk_norm, block_id=i, operation_settings=operation_settings, ) for i in range(self.n_control_layers) ] ) all_x_embedder = {} patch_size = 2 f_patch_size = 1 x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True, device=device, dtype=dtype) all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder self.control_all_x_embedder = nn.ModuleDict(all_x_embedder) self.control_noise_refiner = nn.ModuleList( [ JointTransformerBlock( layer_id, dim, n_heads, n_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, qk_norm, modulation=True, z_image_modulation=True, operation_settings=operation_settings, ) for layer_id in range(n_refiner_layers) ] ) def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input): patch_size = 2 f_patch_size = 1 pH = pW = patch_size B, C, H, W = control_context.shape control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)) x_attn_mask = None for layer in self.control_noise_refiner: control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input) return control_context def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input): return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)