mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:44:54 +08:00
[CI/Build] Fix registry tests (#21934)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
5c765aec65
commit
004203e953
@ -170,8 +170,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
min_transformers_version="4.54"),
|
||||
"Ernie4_5_MoeForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-21B-A3B-PT",
|
||||
min_transformers_version="4.54"),
|
||||
"ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501
|
||||
"Exaone4ForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-4.0-32B"), # noqa: E501
|
||||
"ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct",
|
||||
trust_remote_code=True),
|
||||
"Exaone4ForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-4.0-32B",
|
||||
min_transformers_version="4.54"),
|
||||
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501
|
||||
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
|
||||
"FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base",
|
||||
@ -199,8 +201,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True),
|
||||
"HunYuanMoEV1ForCausalLM": _HfExamplesInfo("tencent/Hunyuan-A13B-Instruct",
|
||||
trust_remote_code=True),
|
||||
# TODO: Remove is_available_online once their config.json is fixed
|
||||
"HunYuanDenseV1ForCausalLM":_HfExamplesInfo("tencent/Hunyuan-7B-Instruct-0124",
|
||||
trust_remote_code=True),
|
||||
trust_remote_code=True,
|
||||
is_available_online=False),
|
||||
"HCXVisionForCausalLM": _HfExamplesInfo(
|
||||
"naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B",
|
||||
trust_remote_code=True),
|
||||
@ -275,7 +279,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501
|
||||
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
|
||||
"Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"),
|
||||
"SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct"),
|
||||
"SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct",
|
||||
trust_remote_code=True),
|
||||
"TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B",
|
||||
trust_remote_code=True),
|
||||
"TeleFLMForCausalLM": _HfExamplesInfo("CofeAI/FLM-2-52B-Instruct-2407",
|
||||
@ -449,7 +454,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
max_model_len=4096),
|
||||
"Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B"),
|
||||
"Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ"), # noqa: E501
|
||||
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"),
|
||||
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B",
|
||||
trust_remote_code=True),
|
||||
"SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501
|
||||
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
|
||||
trust_remote_code=True),
|
||||
|
||||
@ -8,7 +8,7 @@ from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
from transformers import MptConfig
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
@ -50,7 +50,7 @@ class MPTAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
config: MptConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
@ -59,15 +59,15 @@ class MPTAttention(nn.Module):
|
||||
self.d_model = config.d_model
|
||||
self.total_num_heads = config.n_heads
|
||||
self.head_dim = self.d_model // self.total_num_heads
|
||||
self.clip_qkv = config.attn_config["clip_qkv"]
|
||||
self.qk_ln = config.attn_config["qk_ln"]
|
||||
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
|
||||
self.clip_qkv = config.attn_config.clip_qkv
|
||||
self.qk_ln = config.attn_config.qk_ln
|
||||
self.alibi_bias_max = config.attn_config.alibi_bias_max
|
||||
if "kv_n_heads" in config.attn_config:
|
||||
self.total_num_kv_heads = config.attn_config['kv_n_heads']
|
||||
self.total_num_kv_heads = config.attn_config.kv_n_heads
|
||||
else:
|
||||
self.total_num_kv_heads = self.total_num_heads
|
||||
assert not config.attn_config["prefix_lm"]
|
||||
assert config.attn_config["alibi"]
|
||||
assert not config.attn_config.prefix_lm
|
||||
assert config.attn_config.alibi
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
self.Wqkv = QKVParallelLinear(
|
||||
@ -144,7 +144,7 @@ class MPTMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
config: MptConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
@ -176,7 +176,7 @@ class MPTBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
config: MptConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
|
||||
@ -37,9 +37,20 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
|
||||
class TeleChat2Model(LlamaModel):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
|
||||
vllm_config.model_config.hf_config.attribute_map = {
|
||||
"num_hidden_layers": "n_layer",
|
||||
"num_attention_heads": "n_head",
|
||||
"intermediate_size": "ffn_hidden_size",
|
||||
"rms_norm_eps": "layer_norm_epsilon"
|
||||
}
|
||||
vllm_config.model_config.hf_config.hidden_act = "silu"
|
||||
|
||||
# 1. Initialize the LlamaModel with bias
|
||||
vllm_config.model_config.hf_config.bias = True
|
||||
vllm_config.model_config.hf_config.mlp_bias = True
|
||||
hf_config.bias = True
|
||||
hf_config.mlp_bias = True
|
||||
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
# 2. Remove the bias from the qkv_proj and gate_up_proj based on config
|
||||
# Telechat2's gate_up_proj and qkv_proj don't have bias
|
||||
|
||||
@ -34,8 +34,8 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DeepseekVLV2Config,
|
||||
KimiVLConfig, MedusaConfig,
|
||||
MllamaConfig, MLPSpeculatorConfig,
|
||||
Nemotron_Nano_VL_Config,
|
||||
NemotronConfig, RWConfig,
|
||||
UltravoxConfig)
|
||||
NemotronConfig, NVLM_D_Config,
|
||||
RWConfig, UltravoxConfig)
|
||||
# yapf: enable
|
||||
from vllm.transformers_utils.configs.mistral import adapt_config_dict
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
@ -81,6 +81,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
|
||||
"medusa": MedusaConfig,
|
||||
"eagle": EAGLEConfig,
|
||||
"nemotron": NemotronConfig,
|
||||
"NVLM_D": NVLM_D_Config,
|
||||
"ultravox": UltravoxConfig,
|
||||
**_CONFIG_REGISTRY_OVERRIDE_HF
|
||||
}
|
||||
|
||||
@ -23,6 +23,7 @@ from vllm.transformers_utils.configs.moonvit import MoonViTConfig
|
||||
from vllm.transformers_utils.configs.nemotron import NemotronConfig
|
||||
from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig
|
||||
from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config
|
||||
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
|
||||
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
||||
|
||||
__all__ = [
|
||||
@ -39,5 +40,6 @@ __all__ = [
|
||||
"NemotronConfig",
|
||||
"NemotronHConfig",
|
||||
"Nemotron_Nano_VL_Config",
|
||||
"NVLM_D_Config",
|
||||
"UltravoxConfig",
|
||||
]
|
||||
|
||||
31
vllm/transformers_utils/configs/nvlm_d.py
Normal file
31
vllm/transformers_utils/configs/nvlm_d.py
Normal file
@ -0,0 +1,31 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from
|
||||
# https://huggingface.co/nvidia/NVLM-D-72B/blob/main/configuration_nvlm_d.py
|
||||
# --------------------------------------------------------
|
||||
# NVLM-D
|
||||
# Copyright (c) 2024 NVIDIA
|
||||
# Licensed under Apache 2.0 License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
from transformers import Qwen2Config
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class NVLM_D_Config(PretrainedConfig):
|
||||
model_type = 'NVLM_D'
|
||||
is_composition = True
|
||||
|
||||
def __init__(self, vision_config=None, llm_config=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Handle vision_config initialization
|
||||
if vision_config is None:
|
||||
vision_config = {}
|
||||
|
||||
# Handle llm_config initialization
|
||||
if llm_config is None:
|
||||
llm_config = {}
|
||||
|
||||
self.vision_config = PretrainedConfig(**vision_config)
|
||||
self.text_config = Qwen2Config(**llm_config)
|
||||
Loading…
x
Reference in New Issue
Block a user