mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 03:15:00 +08:00
[CI] Enable test_initialization to run on V1 (#16736)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
1645b60196
commit
0ddf88e16e
@ -472,10 +472,7 @@ steps:
|
||||
- pytest -v -s models/test_registry.py
|
||||
- pytest -v -s models/test_utils.py
|
||||
- pytest -v -s models/test_vision.py
|
||||
# V1 Test: https://github.com/vllm-project/vllm/issues/14531
|
||||
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'
|
||||
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'llama4'
|
||||
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'plamo2'
|
||||
- pytest -v -s models/test_initialization.py
|
||||
|
||||
- label: Language Models Test (Standard)
|
||||
mirror_hardwares: [amdexperimental]
|
||||
|
||||
@ -55,9 +55,18 @@ class _HfExamplesInfo:
|
||||
trust_remote_code: bool = False
|
||||
"""The ``trust_remote_code`` level required to load the model."""
|
||||
|
||||
v0_only: bool = False
|
||||
"""The model is only available with the vLLM V0 engine."""
|
||||
|
||||
hf_overrides: dict[str, Any] = field(default_factory=dict)
|
||||
"""The ``hf_overrides`` required to load the model."""
|
||||
|
||||
max_model_len: Optional[int] = None
|
||||
"""
|
||||
The maximum model length to use for this model. Some models default to a
|
||||
length that is too large to fit into memory in CI.
|
||||
"""
|
||||
|
||||
def check_transformers_version(
|
||||
self,
|
||||
*,
|
||||
@ -215,10 +224,11 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"OrionForCausalLM": _HfExamplesInfo("OrionStarAI/Orion-14B-Chat",
|
||||
trust_remote_code=True),
|
||||
"PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"),
|
||||
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"),
|
||||
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2", v0_only=True),
|
||||
"Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"),
|
||||
"Phi3SmallForCausalLM": _HfExamplesInfo("microsoft/Phi-3-small-8k-instruct",
|
||||
trust_remote_code=True),
|
||||
trust_remote_code=True,
|
||||
v0_only=True),
|
||||
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
|
||||
trust_remote_code=True),
|
||||
"Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b",
|
||||
@ -234,7 +244,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
is_available_online=False),
|
||||
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b", # noqa: E501
|
||||
is_available_online=False),
|
||||
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
|
||||
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t",
|
||||
v0_only=True),
|
||||
"Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"),
|
||||
"SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct"),
|
||||
"TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B",
|
||||
@ -303,7 +314,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"),
|
||||
"AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereForAI/aya-vision-8b"), # noqa: E501
|
||||
"Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b", # noqa: E501
|
||||
extras={"6b": "Salesforce/blip2-opt-6.7b"}), # noqa: E501
|
||||
extras={"6b": "Salesforce/blip2-opt-6.7b"}, # noqa: E501
|
||||
v0_only=True),
|
||||
"ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501
|
||||
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501
|
||||
extras={"fork": "Isotr0py/deepseek-vl2-tiny"}, # noqa: E501
|
||||
@ -328,9 +340,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
{"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}), # noqa: E501
|
||||
"KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501
|
||||
extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501
|
||||
trust_remote_code=True),
|
||||
trust_remote_code=True,
|
||||
v0_only=True),
|
||||
"Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
|
||||
min_transformers_version="4.51"),
|
||||
min_transformers_version="4.51",
|
||||
max_model_len=10240),
|
||||
"LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf",
|
||||
extras={"mistral": "mistral-community/pixtral-12b", # noqa: E501
|
||||
"mistral-fp8": "nm-testing/pixtral-12b-FP8-dynamic"}), # noqa: E501
|
||||
@ -349,7 +363,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
extras={"2.6": "openbmb/MiniCPM-V-2_6"}, # noqa: E501
|
||||
trust_remote_code=True),
|
||||
"MiniMaxVL01ForConditionalGeneration": _HfExamplesInfo("MiniMaxAI/MiniMax-VL-01", # noqa: E501
|
||||
trust_remote_code=True),
|
||||
trust_remote_code=True,
|
||||
v0_only=True),
|
||||
"Mistral3ForConditionalGeneration": _HfExamplesInfo("mistralai/Mistral-Small-3.1-24B-Instruct-2503", # noqa: E501
|
||||
extras={"fp8": "nm-testing/Mistral-Small-3.1-24B-Instruct-2503-FP8-dynamic"}), # noqa: E501
|
||||
"MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924",
|
||||
@ -372,7 +387,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
"Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct",
|
||||
trust_remote_code=True),
|
||||
"PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501
|
||||
tokenizer_mode="mistral"),
|
||||
tokenizer_mode="mistral",
|
||||
v0_only=True),
|
||||
"QwenVLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen-VL",
|
||||
extras={"chat": "Qwen/Qwen-VL-Chat"}, # noqa: E501
|
||||
trust_remote_code=True,
|
||||
|
||||
@ -15,12 +15,12 @@ from .registry import HF_EXAMPLE_MODELS
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
|
||||
def test_can_initialize(model_arch):
|
||||
def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
|
||||
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
|
||||
# Avoid OOM
|
||||
# Avoid OOM and reduce initialization time by only using 1 layer
|
||||
def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
|
||||
hf_config.update(model_info.hf_overrides)
|
||||
|
||||
@ -34,6 +34,12 @@ def test_can_initialize(model_arch):
|
||||
"num_local_experts": 2,
|
||||
})
|
||||
|
||||
if hasattr(hf_config, "vision_config"):
|
||||
hf_config.vision_config.update({
|
||||
"num_layers": 1,
|
||||
"num_hidden_layers": 1,
|
||||
})
|
||||
|
||||
return hf_config
|
||||
|
||||
# Avoid calling model.forward()
|
||||
@ -46,7 +52,7 @@ def test_can_initialize(model_arch):
|
||||
scheduler_kv_cache_config = get_kv_cache_config(
|
||||
vllm_config,
|
||||
kv_cache_specs[0],
|
||||
20 * GiB_bytes,
|
||||
10 * GiB_bytes,
|
||||
)
|
||||
|
||||
# gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config
|
||||
@ -55,7 +61,9 @@ def test_can_initialize(model_arch):
|
||||
with (patch.object(V0LLMEngine, "_initialize_kv_caches",
|
||||
_initialize_kv_caches_v0),
|
||||
patch.object(V1EngineCore, "_initialize_kv_caches",
|
||||
_initialize_kv_caches_v1)):
|
||||
_initialize_kv_caches_v1), monkeypatch.context() as m):
|
||||
if model_info.v0_only:
|
||||
m.setenv("VLLM_USE_V1", "0")
|
||||
LLM(
|
||||
model_info.default,
|
||||
tokenizer=model_info.tokenizer,
|
||||
@ -65,6 +73,7 @@ def test_can_initialize(model_arch):
|
||||
"num_speculative_tokens": 1,
|
||||
} if model_info.speculative_model else None,
|
||||
trust_remote_code=model_info.trust_remote_code,
|
||||
max_model_len=model_info.max_model_len,
|
||||
load_format="dummy",
|
||||
hf_overrides=hf_overrides,
|
||||
)
|
||||
|
||||
@ -28,7 +28,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
@ -182,25 +182,20 @@ class Grok1Attention(nn.Module):
|
||||
quant_config=quant_config,
|
||||
logits_soft_cap=attn_logits_soft_cap,
|
||||
prefix=f"{prefix}.attn")
|
||||
self.attn_multiplier = getattr(self.config, "attn_output_multiplier",
|
||||
1.0) if self.config else 1.0
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
|
||||
# Apply attention output multiplier if specified in config
|
||||
attn_multiplier = getattr(self.config, "attn_output_multiplier",
|
||||
None) if self.config else None
|
||||
if attn_multiplier is not None:
|
||||
output = output * attn_multiplier
|
||||
output *= self.attn_multiplier
|
||||
return output
|
||||
|
||||
|
||||
@ -261,8 +256,6 @@ class Grok1DecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@ -276,8 +269,6 @@ class Grok1DecoderLayer(nn.Module):
|
||||
hidden_states = self.attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
# Post attention normalization
|
||||
@ -341,8 +332,6 @@ class Grok1Model(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: list[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
@ -359,9 +348,7 @@ class Grok1Model(nn.Module):
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(positions, hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata, residual)
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
@ -529,13 +516,10 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: list[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -2794,14 +2794,17 @@ def cprofile(save_file: Optional[str] = None, enabled: bool = True):
|
||||
|
||||
# Only relevant for models using ALiBi (e.g, MPT)
|
||||
def check_use_alibi(model_config: ModelConfig) -> bool:
|
||||
return (getattr(model_config.hf_text_config, "alibi", False) # Falcon
|
||||
cfg = model_config.hf_text_config
|
||||
return (getattr(cfg, "alibi", False) # Falcon
|
||||
or ("BloomForCausalLM" in getattr(model_config.hf_config,
|
||||
"architectures", [])) # Bloom
|
||||
or getattr(model_config.hf_text_config, "position_encoding_type",
|
||||
"") == "alibi" # codellm_1b_alibi
|
||||
or
|
||||
(hasattr(model_config.hf_text_config, "attn_config") # MPT
|
||||
and model_config.hf_text_config.attn_config.get("alibi", False)))
|
||||
or getattr(cfg, "position_encoding_type", "") ==
|
||||
"alibi" # codellm_1b_alibi
|
||||
or (hasattr(cfg, "attn_config") # MPT
|
||||
and ((isinstance(cfg.attn_config, dict)
|
||||
and cfg.attn_config.get("alibi", False)) or
|
||||
(not isinstance(cfg.attn_config, dict)
|
||||
and getattr(cfg.attn_config, "alibi", False)))))
|
||||
|
||||
|
||||
def sha256(input) -> int:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user