diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index c0efc455a..efbaecc70 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -178,8 +178,7 @@ class WanAttentionBlock(nn.Module): window_size=(-1, -1), qk_norm=True, cross_attn_norm=False, - eps=1e-6, operation_settings={}, - block_idx=None): + eps=1e-6, operation_settings={}): super().__init__() self.dim = dim self.ffn_dim = ffn_dim @@ -188,7 +187,6 @@ class WanAttentionBlock(nn.Module): self.qk_norm = qk_norm self.cross_attn_norm = cross_attn_norm self.eps = eps - self.block_idx = block_idx # layers self.norm1 = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) @@ -248,7 +246,7 @@ class WanAttentionBlock(nn.Module): if "cross_attn" in patches: for p in patches["cross_attn"]: - x = x + p({"x": x, "q": q, "k": k, "block_idx": self.block_idx, "transformer_options": transformer_options}) + x = x + p({"x": x, "q": q, "k": k, "transformer_options": transformer_options}) y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x))) x = torch.addcmul(x, y, repeat_e(e[5], x)) @@ -271,7 +269,6 @@ class VaceWanAttentionBlock(WanAttentionBlock): ): super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings) self.block_id = block_id - self.block_idx = None if block_id == 0: self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) @@ -496,7 +493,7 @@ class WanModel(torch.nn.Module): cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' self.blocks = nn.ModuleList([ wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads, - window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings, block_idx=i) + window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings) for i in range(num_layers) ]) @@ -579,6 +576,7 @@ class WanModel(torch.nn.Module): patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) for i, block in enumerate(self.blocks): + transformer_options["block_idx"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -775,6 +773,7 @@ class VaceWanModel(WanModel): patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) for i, block in enumerate(self.blocks): + transformer_options["block_idx"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {}