mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-09 22:14:34 +08:00
Hack to make zimage work in fp16. (#11057)
This commit is contained in:
parent
33d6aec3b7
commit
daaceac769
@ -22,6 +22,10 @@ def modulate(x, scale):
|
|||||||
# Core NextDiT Model #
|
# Core NextDiT Model #
|
||||||
#############################################################################
|
#############################################################################
|
||||||
|
|
||||||
|
def clamp_fp16(x):
|
||||||
|
if x.dtype == torch.float16:
|
||||||
|
return torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||||
|
return x
|
||||||
|
|
||||||
class JointAttention(nn.Module):
|
class JointAttention(nn.Module):
|
||||||
"""Multi-head attention module."""
|
"""Multi-head attention module."""
|
||||||
@ -169,7 +173,7 @@ class FeedForward(nn.Module):
|
|||||||
|
|
||||||
# @torch.compile
|
# @torch.compile
|
||||||
def _forward_silu_gating(self, x1, x3):
|
def _forward_silu_gating(self, x1, x3):
|
||||||
return F.silu(x1) * x3
|
return clamp_fp16(F.silu(x1) * x3)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
||||||
@ -273,27 +277,27 @@ class JointTransformerBlock(nn.Module):
|
|||||||
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
|
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
|
||||||
|
|
||||||
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
|
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
|
||||||
self.attention(
|
clamp_fp16(self.attention(
|
||||||
modulate(self.attention_norm1(x), scale_msa),
|
modulate(self.attention_norm1(x), scale_msa),
|
||||||
x_mask,
|
x_mask,
|
||||||
freqs_cis,
|
freqs_cis,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
)
|
))
|
||||||
)
|
)
|
||||||
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
|
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
|
||||||
self.feed_forward(
|
clamp_fp16(self.feed_forward(
|
||||||
modulate(self.ffn_norm1(x), scale_mlp),
|
modulate(self.ffn_norm1(x), scale_mlp),
|
||||||
)
|
))
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert adaln_input is None
|
assert adaln_input is None
|
||||||
x = x + self.attention_norm2(
|
x = x + self.attention_norm2(
|
||||||
self.attention(
|
clamp_fp16(self.attention(
|
||||||
self.attention_norm1(x),
|
self.attention_norm1(x),
|
||||||
x_mask,
|
x_mask,
|
||||||
freqs_cis,
|
freqs_cis,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
)
|
))
|
||||||
)
|
)
|
||||||
x = x + self.ffn_norm2(
|
x = x + self.ffn_norm2(
|
||||||
self.feed_forward(
|
self.feed_forward(
|
||||||
|
|||||||
@ -1027,6 +1027,8 @@ class ZImage(Lumina2):
|
|||||||
|
|
||||||
memory_usage_factor = 1.7
|
memory_usage_factor = 1.7
|
||||||
|
|
||||||
|
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
|
||||||
def clip_target(self, state_dict={}):
|
def clip_target(self, state_dict={}):
|
||||||
pref = self.text_encoder_key_prefix[0]
|
pref = self.text_encoder_key_prefix[0]
|
||||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user