mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-09 22:14:34 +08:00
Move block_idx to transformer_options
This commit is contained in:
parent
25063f25cc
commit
6bfce54652
@ -178,8 +178,7 @@ class WanAttentionBlock(nn.Module):
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=False,
|
||||
eps=1e-6, operation_settings={},
|
||||
block_idx=None):
|
||||
eps=1e-6, operation_settings={}):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.ffn_dim = ffn_dim
|
||||
@ -188,7 +187,6 @@ class WanAttentionBlock(nn.Module):
|
||||
self.qk_norm = qk_norm
|
||||
self.cross_attn_norm = cross_attn_norm
|
||||
self.eps = eps
|
||||
self.block_idx = block_idx
|
||||
|
||||
# layers
|
||||
self.norm1 = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
@ -248,7 +246,7 @@ class WanAttentionBlock(nn.Module):
|
||||
|
||||
if "cross_attn" in patches:
|
||||
for p in patches["cross_attn"]:
|
||||
x = x + p({"x": x, "q": q, "k": k, "block_idx": self.block_idx, "transformer_options": transformer_options})
|
||||
x = x + p({"x": x, "q": q, "k": k, "transformer_options": transformer_options})
|
||||
|
||||
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
||||
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
||||
@ -271,7 +269,6 @@ class VaceWanAttentionBlock(WanAttentionBlock):
|
||||
):
|
||||
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
|
||||
self.block_id = block_id
|
||||
self.block_idx = None
|
||||
if block_id == 0:
|
||||
self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
@ -496,7 +493,7 @@ class WanModel(torch.nn.Module):
|
||||
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
|
||||
self.blocks = nn.ModuleList([
|
||||
wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads,
|
||||
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings, block_idx=i)
|
||||
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
|
||||
for i in range(num_layers)
|
||||
])
|
||||
|
||||
@ -579,6 +576,7 @@ class WanModel(torch.nn.Module):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.blocks):
|
||||
transformer_options["block_idx"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
@ -775,6 +773,7 @@ class VaceWanModel(WanModel):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.blocks):
|
||||
transformer_options["block_idx"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user