diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index a72f8cc47..2e8ef0687 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -40,7 +40,8 @@ class ChromaParams: out_dim: int hidden_dim: int n_layers: int - + txt_ids_dims: list + vec_in_dim: int diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 2472ab79c..60f2bdae2 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -57,6 +57,35 @@ class MLPEmbedder(nn.Module): def forward(self, x: Tensor) -> Tensor: return self.out_layer(self.silu(self.in_layer(x))) +class YakMLP(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int, dtype=None, device=None, operations=None): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device) + self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device) + self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=True, dtype=dtype, device=device) + self.act_fn = nn.SiLU() + + def forward(self, x: Tensor) -> Tensor: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + +def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dtype=None, device=None, operations=None): + if yak_mlp: + return YakMLP(hidden_size, mlp_hidden_dim, dtype=dtype, device=device, operations=operations) + if mlp_silu_act: + return nn.Sequential( + operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device), + SiLUActivation(), + operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device), + ) + else: + return nn.Sequential( + operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), + nn.GELU(approximate="tanh"), + operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), + ) class RMSNorm(torch.nn.Module): def __init__(self, dim: int, dtype=None, device=None, operations=None): @@ -140,7 +169,7 @@ class SiLUActivation(nn.Module): class DoubleStreamBlock(nn.Module): - def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, dtype=None, device=None, operations=None): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None): super().__init__() mlp_hidden_dim = int(hidden_size * mlp_ratio) @@ -156,18 +185,7 @@ class DoubleStreamBlock(nn.Module): self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - if mlp_silu_act: - self.img_mlp = nn.Sequential( - operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device), - SiLUActivation(), - operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device), - ) - else: - self.img_mlp = nn.Sequential( - operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), - nn.GELU(approximate="tanh"), - operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), - ) + self.img_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations) if self.modulation: self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations) @@ -177,18 +195,7 @@ class DoubleStreamBlock(nn.Module): self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - if mlp_silu_act: - self.txt_mlp = nn.Sequential( - operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device), - SiLUActivation(), - operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device), - ) - else: - self.txt_mlp = nn.Sequential( - operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), - nn.GELU(approximate="tanh"), - operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), - ) + self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations) self.flipped_img_txt = flipped_img_txt @@ -275,6 +282,7 @@ class SingleStreamBlock(nn.Module): modulation=True, mlp_silu_act=False, bias=True, + yak_mlp=False, dtype=None, device=None, operations=None @@ -288,12 +296,17 @@ class SingleStreamBlock(nn.Module): self.mlp_hidden_dim = int(hidden_size * mlp_ratio) self.mlp_hidden_dim_first = self.mlp_hidden_dim + self.yak_mlp = yak_mlp if mlp_silu_act: self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2) self.mlp_act = SiLUActivation() else: self.mlp_act = nn.GELU(approximate="tanh") + if self.yak_mlp: + self.mlp_hidden_dim_first *= 2 + self.mlp_act = nn.SiLU() + # qkv and mlp_in self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device) # proj and mlp_out @@ -325,7 +338,10 @@ class SingleStreamBlock(nn.Module): attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) del q, k, v # compute activation in mlp stream, cat again and run second linear layer - mlp = self.mlp_act(mlp) + if self.yak_mlp: + mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2] + else: + mlp = self.mlp_act(mlp) output = self.linear2(torch.cat((attn, mlp), 2)) x += apply_mod(output, mod.gate, None, modulation_dims) if x.dtype == torch.float16: diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index d5674dea6..f40c2a7a9 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -15,7 +15,8 @@ from .layers import ( MLPEmbedder, SingleStreamBlock, timestep_embedding, - Modulation + Modulation, + RMSNorm ) @dataclass @@ -34,11 +35,14 @@ class FluxParams: patch_size: int qkv_bias: bool guidance_embed: bool + txt_ids_dims: list global_modulation: bool = False mlp_silu_act: bool = False ops_bias: bool = True default_ref_method: str = "offset" ref_index_scale: float = 1.0 + yak_mlp: bool = False + txt_norm: bool = False class Flux(nn.Module): @@ -76,6 +80,11 @@ class Flux(nn.Module): ) self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device) + if params.txt_norm: + self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations) + else: + self.txt_norm = None + self.double_blocks = nn.ModuleList( [ DoubleStreamBlock( @@ -86,6 +95,7 @@ class Flux(nn.Module): modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, proj_bias=params.ops_bias, + yak_mlp=params.yak_mlp, dtype=dtype, device=device, operations=operations ) for _ in range(params.depth) @@ -94,7 +104,7 @@ class Flux(nn.Module): self.single_blocks = nn.ModuleList( [ - SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, yak_mlp=params.yak_mlp, dtype=dtype, device=device, operations=operations) for _ in range(params.depth_single_blocks) ] ) @@ -150,6 +160,8 @@ class Flux(nn.Module): y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) + if self.txt_norm is not None: + txt = self.txt_norm(txt) txt = self.txt_in(txt) vec_orig = vec @@ -332,8 +344,9 @@ class Flux(nn.Module): txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32) - if len(self.params.axes_dim) == 4: # Flux 2 - txt_ids[:, :, 3] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32) + if len(self.params.txt_ids_dims) > 0: + for i in self.params.txt_ids_dims: + txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32) out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) out = out[:, :img_tokens] diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 7afe4a798..7d0517e61 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -208,12 +208,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["theta"] = 2000 dit_config["out_channels"] = 128 dit_config["global_modulation"] = True - dit_config["vec_in_dim"] = None dit_config["mlp_silu_act"] = True dit_config["qkv_bias"] = False dit_config["ops_bias"] = False dit_config["default_ref_method"] = "index" dit_config["ref_index_scale"] = 10.0 + dit_config["txt_ids_dims"] = [3] patch_size = 1 else: dit_config["image_model"] = "flux" @@ -223,6 +223,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["theta"] = 10000 dit_config["out_channels"] = 16 dit_config["qkv_bias"] = True + dit_config["txt_ids_dims"] = [] patch_size = 2 dit_config["in_channels"] = 16 @@ -245,6 +246,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix) if vec_in_key in state_dict_keys: dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1] + else: + dit_config["vec_in_dim"] = None dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') @@ -270,6 +273,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["nerf_embedder_dtype"] = torch.float32 else: dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys + dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys + dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys + if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model + dit_config["txt_ids_dims"] = [1, 2] + return dit_config if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview diff --git a/comfy/sd.py b/comfy/sd.py index 9eeb0c45a..f9e5efab5 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -53,6 +53,7 @@ import comfy.text_encoders.omnigen2 import comfy.text_encoders.qwen_image import comfy.text_encoders.hunyuan_image import comfy.text_encoders.z_image +import comfy.text_encoders.ovis import comfy.model_patcher import comfy.lora @@ -956,6 +957,7 @@ class CLIPType(Enum): QWEN_IMAGE = 18 HUNYUAN_IMAGE = 19 HUNYUAN_VIDEO_15 = 20 + OVIS = 21 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): @@ -987,6 +989,7 @@ class TEModel(Enum): MISTRAL3_24B = 14 MISTRAL3_24B_PRUNED_FLUX2 = 15 QWEN3_4B = 16 + QWEN3_2B = 17 def detect_te_model(sd): @@ -1020,9 +1023,12 @@ def detect_te_model(sd): if weight.shape[0] == 512: return TEModel.QWEN25_7B if "model.layers.0.post_attention_layernorm.weight" in sd: - if 'model.layers.0.self_attn.q_norm.weight' in sd: - return TEModel.QWEN3_4B weight = sd['model.layers.0.post_attention_layernorm.weight'] + if 'model.layers.0.self_attn.q_norm.weight' in sd: + if weight.shape[0] == 2560: + return TEModel.QWEN3_4B + elif weight.shape[0] == 2048: + return TEModel.QWEN3_2B if weight.shape[0] == 5120: if "model.layers.39.post_attention_layernorm.weight" in sd: return TEModel.MISTRAL3_24B @@ -1150,6 +1156,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif te_model == TEModel.QWEN3_4B: clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer + elif te_model == TEModel.QWEN3_2B: + clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer else: # clip_l if clip_type == CLIPType.SD3: diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index cd4b5f76c..0d07ac8c6 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -100,6 +100,28 @@ class Qwen3_4BConfig: rope_scale = None final_norm: bool = True +@dataclass +class Ovis25_2BConfig: + vocab_size: int = 151936 + hidden_size: int = 2048 + intermediate_size: int = 6144 + num_hidden_layers: int = 28 + num_attention_heads: int = 16 + num_key_value_heads: int = 8 + max_position_embeddings: int = 40960 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + rope_scale = None + final_norm: bool = True + @dataclass class Qwen25_7BVLI_Config: vocab_size: int = 152064 @@ -542,6 +564,15 @@ class Qwen3_4B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype +class Ovis25_2B(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Ovis25_2BConfig(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + class Qwen25_7BVLI(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() diff --git a/comfy/text_encoders/ovis.py b/comfy/text_encoders/ovis.py new file mode 100644 index 000000000..81c9bd51c --- /dev/null +++ b/comfy/text_encoders/ovis.py @@ -0,0 +1,69 @@ +from transformers import Qwen2Tokenizer +import comfy.text_encoders.llama +from comfy import sd1_clip +import os +import torch +import numbers + +class Qwen3Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='qwen3_2b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=284, pad_token=151643, tokenizer_data=tokenizer_data) + + +class OvisTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_2b", tokenizer=Qwen3Tokenizer) + self.llama_template = "<|im_start|>user\nDescribe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background: {}<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs): + if llama_template is None: + llama_text = self.llama_template.format(text) + else: + llama_text = llama_template.format(text) + + tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) + return tokens + +class Ovis25_2BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Ovis25_2B, enable_attention_masks=attention_mask, return_attention_masks=False, zero_out_masked=True, model_options=model_options) + + +class OvisTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="qwen3_2b", clip_model=Ovis25_2BModel, model_options=model_options) + + def encode_token_weights(self, token_weight_pairs, template_end=-1): + out, pooled = super().encode_token_weights(token_weight_pairs) + tok_pairs = token_weight_pairs["qwen3_2b"][0] + count_im_start = 0 + if template_end == -1: + for i, v in enumerate(tok_pairs): + elem = v[0] + if not torch.is_tensor(elem): + if isinstance(elem, numbers.Integral): + if elem == 4004 and count_im_start < 1: + template_end = i + count_im_start += 1 + + if out.shape[1] > (template_end + 1): + if tok_pairs[template_end + 1][0] == 25: + template_end += 1 + + out = out[:, template_end:] + return out, pooled, {} + + +def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None): + class OvisTEModel_(OvisTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["scaled_fp8"] = llama_scaled_fp8 + if dtype_llama is not None: + dtype = dtype_llama + if llama_quantization_metadata is not None: + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, dtype=dtype, model_options=model_options) + return OvisTEModel_ diff --git a/nodes.py b/nodes.py index 495dec806..d5e5dc228 100644 --- a/nodes.py +++ b/nodes.py @@ -939,7 +939,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}),