mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-08 21:44:33 +08:00
Properly load the newbie diffusion model. (#11172)
There is still one of the text encoders missing and I didn't actually test it.
This commit is contained in:
parent
329480da5a
commit
56fa7dbe38
@ -377,6 +377,7 @@ class NextDiT(nn.Module):
|
|||||||
z_image_modulation=False,
|
z_image_modulation=False,
|
||||||
time_scale=1.0,
|
time_scale=1.0,
|
||||||
pad_tokens_multiple=None,
|
pad_tokens_multiple=None,
|
||||||
|
clip_text_dim=None,
|
||||||
image_model=None,
|
image_model=None,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
@ -447,6 +448,31 @@ class NextDiT(nn.Module):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.clip_text_pooled_proj = None
|
||||||
|
|
||||||
|
if clip_text_dim is not None:
|
||||||
|
self.clip_text_dim = clip_text_dim
|
||||||
|
self.clip_text_pooled_proj = nn.Sequential(
|
||||||
|
operation_settings.get("operations").RMSNorm(clip_text_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
||||||
|
operation_settings.get("operations").Linear(
|
||||||
|
clip_text_dim,
|
||||||
|
clip_text_dim,
|
||||||
|
bias=True,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.time_text_embed = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operation_settings.get("operations").Linear(
|
||||||
|
min(dim, 1024) + clip_text_dim,
|
||||||
|
min(dim, 1024),
|
||||||
|
bias=True,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
JointTransformerBlock(
|
JointTransformerBlock(
|
||||||
@ -585,6 +611,15 @@ class NextDiT(nn.Module):
|
|||||||
|
|
||||||
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
||||||
|
|
||||||
|
if self.clip_text_pooled_proj is not None:
|
||||||
|
pooled = kwargs.get("clip_text_pooled", None)
|
||||||
|
if pooled is not None:
|
||||||
|
pooled = self.clip_text_pooled_proj(pooled)
|
||||||
|
else:
|
||||||
|
pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
|
||||||
|
|
||||||
patches = transformer_options.get("patches", {})
|
patches = transformer_options.get("patches", {})
|
||||||
x_is_tensor = isinstance(x, torch.Tensor)
|
x_is_tensor = isinstance(x, torch.Tensor)
|
||||||
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
|
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
|
||||||
|
|||||||
@ -1110,6 +1110,10 @@ class Lumina2(BaseModel):
|
|||||||
if 'num_tokens' not in out:
|
if 'num_tokens' not in out:
|
||||||
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
|
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
|
||||||
|
|
||||||
|
clip_text_pooled = kwargs["pooled_output"] # Newbie
|
||||||
|
if clip_text_pooled is not None:
|
||||||
|
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class WAN21(BaseModel):
|
class WAN21(BaseModel):
|
||||||
|
|||||||
@ -423,6 +423,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["axes_lens"] = [300, 512, 512]
|
dit_config["axes_lens"] = [300, 512, 512]
|
||||||
dit_config["rope_theta"] = 10000.0
|
dit_config["rope_theta"] = 10000.0
|
||||||
dit_config["ffn_dim_multiplier"] = 4.0
|
dit_config["ffn_dim_multiplier"] = 4.0
|
||||||
|
ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
|
||||||
|
if ctd_weight is not None:
|
||||||
|
dit_config["clip_text_dim"] = ctd_weight.shape[0]
|
||||||
elif dit_config["dim"] == 3840: # Z image
|
elif dit_config["dim"] == 3840: # Z image
|
||||||
dit_config["n_heads"] = 30
|
dit_config["n_heads"] = 30
|
||||||
dit_config["n_kv_heads"] = 30
|
dit_config["n_kv_heads"] = 30
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user