Move block_idx to transformer_options

This commit is contained in:
kijai 2025-11-03 20:53:06 +02:00
parent 25063f25cc
commit 6bfce54652

View File

@ -178,8 +178,7 @@ class WanAttentionBlock(nn.Module):
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6, operation_settings={},
block_idx=None):
eps=1e-6, operation_settings={}):
super().__init__()
self.dim = dim
self.ffn_dim = ffn_dim
@ -188,7 +187,6 @@ class WanAttentionBlock(nn.Module):
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
self.block_idx = block_idx
# layers
self.norm1 = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
@ -248,7 +246,7 @@ class WanAttentionBlock(nn.Module):
if "cross_attn" in patches:
for p in patches["cross_attn"]:
x = x + p({"x": x, "q": q, "k": k, "block_idx": self.block_idx, "transformer_options": transformer_options})
x = x + p({"x": x, "q": q, "k": k, "transformer_options": transformer_options})
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
x = torch.addcmul(x, y, repeat_e(e[5], x))
@ -271,7 +269,6 @@ class VaceWanAttentionBlock(WanAttentionBlock):
):
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
self.block_id = block_id
self.block_idx = None
if block_id == 0:
self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
@ -496,7 +493,7 @@ class WanModel(torch.nn.Module):
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
self.blocks = nn.ModuleList([
wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads,
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings, block_idx=i)
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
for i in range(num_layers)
])
@ -579,6 +576,7 @@ class WanModel(torch.nn.Module):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.blocks):
transformer_options["block_idx"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@ -775,6 +773,7 @@ class VaceWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.blocks):
transformer_options["block_idx"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}