mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-14 00:14:31 +08:00
Move block_idx to transformer_options
This commit is contained in:
parent
25063f25cc
commit
6bfce54652
@ -178,8 +178,7 @@ class WanAttentionBlock(nn.Module):
|
|||||||
window_size=(-1, -1),
|
window_size=(-1, -1),
|
||||||
qk_norm=True,
|
qk_norm=True,
|
||||||
cross_attn_norm=False,
|
cross_attn_norm=False,
|
||||||
eps=1e-6, operation_settings={},
|
eps=1e-6, operation_settings={}):
|
||||||
block_idx=None):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.ffn_dim = ffn_dim
|
self.ffn_dim = ffn_dim
|
||||||
@ -188,7 +187,6 @@ class WanAttentionBlock(nn.Module):
|
|||||||
self.qk_norm = qk_norm
|
self.qk_norm = qk_norm
|
||||||
self.cross_attn_norm = cross_attn_norm
|
self.cross_attn_norm = cross_attn_norm
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.block_idx = block_idx
|
|
||||||
|
|
||||||
# layers
|
# layers
|
||||||
self.norm1 = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
self.norm1 = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
@ -248,7 +246,7 @@ class WanAttentionBlock(nn.Module):
|
|||||||
|
|
||||||
if "cross_attn" in patches:
|
if "cross_attn" in patches:
|
||||||
for p in patches["cross_attn"]:
|
for p in patches["cross_attn"]:
|
||||||
x = x + p({"x": x, "q": q, "k": k, "block_idx": self.block_idx, "transformer_options": transformer_options})
|
x = x + p({"x": x, "q": q, "k": k, "transformer_options": transformer_options})
|
||||||
|
|
||||||
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
||||||
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
||||||
@ -271,7 +269,6 @@ class VaceWanAttentionBlock(WanAttentionBlock):
|
|||||||
):
|
):
|
||||||
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
|
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
|
||||||
self.block_id = block_id
|
self.block_id = block_id
|
||||||
self.block_idx = None
|
|
||||||
if block_id == 0:
|
if block_id == 0:
|
||||||
self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
@ -496,7 +493,7 @@ class WanModel(torch.nn.Module):
|
|||||||
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
|
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
|
||||||
self.blocks = nn.ModuleList([
|
self.blocks = nn.ModuleList([
|
||||||
wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads,
|
wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads,
|
||||||
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings, block_idx=i)
|
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
|
||||||
for i in range(num_layers)
|
for i in range(num_layers)
|
||||||
])
|
])
|
||||||
|
|
||||||
@ -579,6 +576,7 @@ class WanModel(torch.nn.Module):
|
|||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
for i, block in enumerate(self.blocks):
|
for i, block in enumerate(self.blocks):
|
||||||
|
transformer_options["block_idx"] = i
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
@ -775,6 +773,7 @@ class VaceWanModel(WanModel):
|
|||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
for i, block in enumerate(self.blocks):
|
for i, block in enumerate(self.blocks):
|
||||||
|
transformer_options["block_idx"] = i
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user