[Model] Support Qwen2 embeddings and use tags to select model tests (#10184)

This commit is contained in:
Cyrus Leung 2024-11-15 12:23:09 +08:00 committed by GitHub
parent 2885ba0e24
commit b40cf6402e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 252 additions and 178 deletions

View File

@ -27,9 +27,9 @@ function cpu_tests() {
decord einops librosa peft Pillow sentence-transformers soundfile \ decord einops librosa peft Pillow sentence-transformers soundfile \
transformers_stream_generator matplotlib datamodel_code_generator transformers_stream_generator matplotlib datamodel_code_generator
pip install torchvision --index-url https://download.pytorch.org/whl/cpu pip install torchvision --index-url https://download.pytorch.org/whl/cpu
pytest -v -s tests/models/embedding/language pytest -v -s tests/models/decoder_only/language -m cpu_model
pytest -v -s tests/models/encoder_decoder/language pytest -v -s tests/models/embedding/language -m cpu_model
pytest -v -s tests/models/decoder_only/language/test_models.py pytest -v -s tests/models/encoder_decoder/language -m cpu_model
pytest -v -s tests/models/decoder_only/audio_language -m cpu_model pytest -v -s tests/models/decoder_only/audio_language -m cpu_model
pytest -v -s tests/models/decoder_only/vision_language -m cpu_model" pytest -v -s tests/models/decoder_only/vision_language -m cpu_model"

View File

@ -38,9 +38,9 @@ function cpu_tests() {
decord einops librosa peft Pillow sentence-transformers soundfile \ decord einops librosa peft Pillow sentence-transformers soundfile \
transformers_stream_generator matplotlib datamodel_code_generator transformers_stream_generator matplotlib datamodel_code_generator
pip install torchvision --index-url https://download.pytorch.org/whl/cpu pip install torchvision --index-url https://download.pytorch.org/whl/cpu
pytest -v -s tests/models/embedding/language pytest -v -s tests/models/decoder_only/language -m cpu_model
pytest -v -s tests/models/encoder_decoder/language pytest -v -s tests/models/embedding/language -m cpu_model
pytest -v -s tests/models/decoder_only/language/test_models.py pytest -v -s tests/models/encoder_decoder/language -m cpu_model
pytest -v -s tests/models/decoder_only/audio_language -m cpu_model pytest -v -s tests/models/decoder_only/audio_language -m cpu_model
pytest -v -s tests/models/decoder_only/vision_language -m cpu_model" pytest -v -s tests/models/decoder_only/vision_language -m cpu_model"

View File

@ -323,62 +323,60 @@ steps:
- pytest -v -s models/test_registry.py - pytest -v -s models/test_registry.py
- pytest -v -s models/test_initialization.py - pytest -v -s models/test_initialization.py
- label: Decoder-only Language Models Test (Standard) # 18min - label: Language Models Test (Standard) # 42min
#mirror_hardwares: [amd] #mirror_hardwares: [amd]
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/models/decoder_only/language - tests/models/decoder_only/language
- tests/models/embedding/language
- tests/models/encoder_decoder/language
commands: commands:
- pytest -v -s models/decoder_only/language -m core_model - pytest -v -s models/decoder_only/language -m 'core_model or quant_model'
- pytest -v -s models/decoder_only/language -m quant_model - pytest -v -s models/embedding/language -m core_model
- pytest -v -s models/embedding/vision_language -m core_model
- label: Decoder-only Language Models Test (Extended) # 46min - label: Language Models Test (Extended) # 50min
nightly: true nightly: true
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/models/decoder_only/language - tests/models/decoder_only/language
- tests/models/embedding/language
- tests/models/encoder_decoder/language
commands: commands:
- pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model' - pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
- pytest -v -s models/embedding/language -m 'not core_model'
- pytest -v -s models/embedding/vision_language -m 'not core_model'
- label: Decoder-only Multi-Modal Models Test (Standard) # 22min - label: Multi-Modal Models Test (Standard) # 26min
#mirror_hardwares: [amd] #mirror_hardwares: [amd]
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/models/decoder_only/audio_language - tests/models/decoder_only/audio_language
- tests/models/decoder_only/vision_language - tests/models/decoder_only/vision_language
- tests/models/embedding/vision_language
- tests/models/encoder_decoder/vision_language
commands: commands:
- pytest -v -s models/decoder_only/audio_language -m core_model - pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model'
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m core_model - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model'
# No tests under this group for now - pytest -v -s models/encoder_decoder/language -m core_model
# - pytest -v -s models/decoder_only/audio_language -m quant_model - pytest -v -s models/encoder_decoder/vision_language -m core_model
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m quant_model
- label: Decoder-only Multi-Modal Models Test (Extended) # 1h10m - label: Multi-Modal Models Test (Extended) # 1h15m
nightly: true nightly: true
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/models/decoder_only/audio_language - tests/models/decoder_only/audio_language
- tests/models/decoder_only/vision_language - tests/models/decoder_only/vision_language
- tests/models/embedding/vision_language
- tests/models/encoder_decoder/vision_language
commands: commands:
- pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model' - pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model'
# HACK - run phi3v tests separately to sidestep this transformers bug # HACK - run phi3v tests separately to sidestep this transformers bug
# https://github.com/huggingface/transformers/issues/34307 # https://github.com/huggingface/transformers/issues/34307
- pytest -v -s models/decoder_only/vision_language/test_phi3v.py - pytest -v -s models/decoder_only/vision_language/test_phi3v.py
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model' - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model'
- pytest -v -s models/encoder_decoder/language -m 'not core_model'
- label: Other Models Test # 20min - pytest -v -s models/encoder_decoder/vision_language -m 'not core_model'
#mirror_hardwares: [amd]
source_file_dependencies:
- vllm/
- tests/models/embedding/language
- tests/models/embedding/vision_language
- tests/models/encoder_decoder/language
- tests/models/encoder_decoder/vision_language
commands:
- pytest -v -s models/embedding/language
- pytest -v -s models/embedding/vision_language
- pytest -v -s models/encoder_decoder/language
- pytest -v -s models/encoder_decoder/vision_language
# This test is used only in PR development phase to test individual models and should never run on main # This test is used only in PR development phase to test individual models and should never run on main
- label: Custom Models Test - label: Custom Models Test

View File

@ -330,11 +330,16 @@ Text Embedding
- :code:`BAAI/bge-multilingual-gemma2`, etc. - :code:`BAAI/bge-multilingual-gemma2`, etc.
- -
- ✅︎ - ✅︎
* - :code:`MistralModel` * - :code:`LlamaModel`, :code:`LlamaForCausalLM`, :code:`MistralModel`, etc.
- Mistral-based - Llama-based
- :code:`intfloat/e5-mistral-7b-instruct`, etc. - :code:`intfloat/e5-mistral-7b-instruct`, etc.
- ✅︎ - ✅︎
- ✅︎ - ✅︎
* - :code:`Qwen2Model`, :code:`Qwen2ForCausalLM`
- Qwen2-based
- :code:`ssmits/Qwen2-7B-Instruct-embed-base`, :code:`Alibaba-NLP/gte-Qwen2-1.5B-instruct`, etc.
- ✅︎
- ✅︎
.. important:: .. important::
Some model architectures support both generation and embedding tasks. Some model architectures support both generation and embedding tasks.
@ -355,7 +360,7 @@ Reward Modeling
* - :code:`Qwen2ForRewardModel` * - :code:`Qwen2ForRewardModel`
- Qwen2-based - Qwen2-based
- :code:`Qwen/Qwen2.5-Math-RM-72B`, etc. - :code:`Qwen/Qwen2.5-Math-RM-72B`, etc.
- - ✅︎
- ✅︎ - ✅︎
.. note:: .. note::
@ -376,7 +381,7 @@ Classification
* - :code:`Qwen2ForSequenceClassification` * - :code:`Qwen2ForSequenceClassification`
- Qwen2-based - Qwen2-based
- :code:`jason9693/Qwen2.5-1.5B-apeach`, etc. - :code:`jason9693/Qwen2.5-1.5B-apeach`, etc.
- - ✅︎
- ✅︎ - ✅︎
.. note:: .. note::

View File

@ -33,6 +33,10 @@ def test_models(
with vllm_runner(model, dtype=dtype) as vllm_model: with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
for i in range(len(example_prompts)): for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i] hf_output_ids, hf_output_str = hf_outputs[i]
@ -293,17 +297,3 @@ def test_jamba_distributed_produces_identical_generation(
name_0="vllm_tp_1", name_0="vllm_tp_1",
name_1="vllm_tp_2", name_1="vllm_tp_2",
) )
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_model_print(
vllm_runner,
model: str,
dtype: str,
) -> None:
with vllm_runner(model, dtype=dtype) as vllm_model:
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)

View File

@ -51,6 +51,10 @@ def test_models(
with vllm_runner(model, dtype=dtype) as vllm_model: with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
for i in range(len(example_prompts)): for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i] hf_output_ids, hf_output_str = hf_outputs[i]
@ -279,17 +283,3 @@ def test_state_cleanup(
except ValueError: except ValueError:
pytest.fail("Mamba inner state wasn't cleaned up between states, " pytest.fail("Mamba inner state wasn't cleaned up between states, "
"could be related to finished_requests_ids") "could be related to finished_requests_ids")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_model_print(
vllm_runner,
model: str,
dtype: str,
) -> None:
with vllm_runner(model, dtype=dtype) as vllm_model:
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)

View File

@ -4,37 +4,52 @@ Run `pytest tests/models/test_models.py`.
""" """
import pytest import pytest
from vllm.platforms import current_platform
from ...utils import check_logprobs_close from ...utils import check_logprobs_close
MODELS = [
"facebook/opt-125m", # opt
"openai-community/gpt2", # gpt2
# "Milos/slovak-gpt-j-405M", # gptj
# "bigcode/tiny_starcoder_py", # gpt_bigcode
# "EleutherAI/pythia-70m", # gpt_neox
"bigscience/bloom-560m", # bloom - testing alibi slopes
"microsoft/phi-2", # phi
# "stabilityai/stablelm-3b-4e1t", # stablelm
# "bigcode/starcoder2-3b", # starcoder2
"google/gemma-1.1-2b-it", # gemma
"Qwen/Qwen2.5-0.5B-Instruct", # qwen2
"meta-llama/Llama-3.2-1B-Instruct", # llama
]
if not current_platform.is_cpu(): @pytest.mark.parametrize(
MODELS += [ "model",
# fused_moe which not supported on CPU [
"openbmb/MiniCPM3-4B", pytest.param(
] "bigscience/bloom-560m", # bloom - testing alibi slopes
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
target_dtype = "half" ),
pytest.param(
"openai-community/gpt2", # gpt2
@pytest.mark.core_model marks=[pytest.mark.core_model, pytest.mark.cpu_model],
@pytest.mark.parametrize("model", MODELS) ),
@pytest.mark.parametrize("dtype", [target_dtype]) pytest.param("Milos/slovak-gpt-j-405M"), # gptj
pytest.param("bigcode/tiny_starcoder_py"), # gpt_bigcode
pytest.param("EleutherAI/pythia-70m"), # gpt_neox
pytest.param(
"google/gemma-1.1-2b-it", # gemma
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
pytest.param(
"meta-llama/Llama-3.2-1B-Instruct", # llama
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
pytest.param(
"openbmb/MiniCPM3-4B",
# fused_moe not supported on CPU
marks=[pytest.mark.core_model],
),
pytest.param(
"facebook/opt-125m", # opt
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
pytest.param(
"microsoft/phi-2", # phi
marks=[pytest.mark.core_model],
),
pytest.param(
"Qwen/Qwen2.5-0.5B-Instruct", # qwen2
marks=[pytest.mark.core_model],
),
pytest.param("stabilityai/stablelm-3b-4e1t"), # stablelm
pytest.param("bigcode/starcoder2-3b"), # starcoder2
])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
def test_models( def test_models(

View File

@ -9,10 +9,14 @@ import pytest
import torch import torch
from transformers import AutoModelForSequenceClassification from transformers import AutoModelForSequenceClassification
CLASSIFICATION_MODELS = ["jason9693/Qwen2.5-1.5B-apeach"]
@pytest.mark.parametrize(
@pytest.mark.parametrize("model", CLASSIFICATION_MODELS) "model",
[
pytest.param("jason9693/Qwen2.5-1.5B-apeach",
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
],
)
@pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("dtype", ["float"])
def test_classification_models( def test_classification_models(
hf_runner, hf_runner,
@ -23,31 +27,19 @@ def test_classification_models(
) -> None: ) -> None:
with vllm_runner(model, dtype=dtype) as vllm_model: with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.classify(example_prompts) vllm_outputs = vllm_model.classify(example_prompts)
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
with hf_runner(model, with hf_runner(model,
dtype=dtype, dtype=dtype,
auto_cls=AutoModelForSequenceClassification) as hf_model: auto_cls=AutoModelForSequenceClassification) as hf_model:
hf_outputs = hf_model.classify(example_prompts) hf_outputs = hf_model.classify(example_prompts)
print(hf_outputs, vllm_outputs)
# check logits difference # check logits difference
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
hf_output = torch.tensor(hf_output) hf_output = torch.tensor(hf_output)
vllm_output = torch.tensor(vllm_output) vllm_output = torch.tensor(vllm_output)
assert torch.allclose(hf_output, vllm_output, 1e-3) assert torch.allclose(hf_output, vllm_output, 1e-3)
@pytest.mark.parametrize("model", CLASSIFICATION_MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_classification_model_print(
vllm_runner,
model: str,
dtype: str,
) -> None:
with vllm_runner(model, dtype=dtype) as vllm_model:
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)

View File

@ -4,25 +4,25 @@ Run `pytest tests/models/embedding/language/test_embedding.py`.
""" """
import pytest import pytest
from vllm.utils import current_platform
from ..utils import check_embeddings_close from ..utils import check_embeddings_close
# Model, Guard
MODELS = [
"intfloat/e5-mistral-7b-instruct",
"BAAI/bge-base-en-v1.5",
"BAAI/bge-multilingual-gemma2",
"intfloat/multilingual-e5-large",
]
ENCODER_ONLY = [ @pytest.mark.parametrize(
"BAAI/bge-base-en-v1.5", "model",
"intfloat/multilingual-e5-large", [
] # [Encoder-only]
pytest.param("BAAI/bge-base-en-v1.5",
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
@pytest.mark.parametrize("model", MODELS) pytest.param("intfloat/multilingual-e5-large"),
# [Encoder-decoder]
pytest.param("intfloat/e5-mistral-7b-instruct",
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
pytest.param("BAAI/bge-multilingual-gemma2",
marks=[pytest.mark.core_model]),
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"),
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"),
],
)
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
def test_models( def test_models(
hf_runner, hf_runner,
@ -31,9 +31,6 @@ def test_models(
model, model,
dtype: str, dtype: str,
) -> None: ) -> None:
if model not in ENCODER_ONLY and current_platform.is_cpu():
pytest.skip("Skip large embedding models test on CPU.")
# The example_prompts has ending "\n", for example: # The example_prompts has ending "\n", for example:
# "Write a short story about a robot that dreams for the first time.\n" # "Write a short story about a robot that dreams for the first time.\n"
# sentence_transformers will strip the input texts, see: # sentence_transformers will strip the input texts, see:
@ -46,8 +43,13 @@ def test_models(
is_sentence_transformer=True) as hf_model: is_sentence_transformer=True) as hf_model:
hf_outputs = hf_model.encode(example_prompts) hf_outputs = hf_model.encode(example_prompts)
with vllm_runner(model, dtype=dtype, max_model_len=None) as vllm_model: with vllm_runner(model, task="embedding", dtype=dtype,
max_model_len=None) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts) vllm_outputs = vllm_model.encode(example_prompts)
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
check_embeddings_close( check_embeddings_close(
embeddings_0_lst=hf_outputs, embeddings_0_lst=hf_outputs,

View File

@ -88,6 +88,7 @@ def _run_test(
@pytest.mark.skipif(transformers.__version__.startswith("4.46"), @pytest.mark.skipif(transformers.__version__.startswith("4.46"),
reason="Model broken with changes in transformers 4.46") reason="Model broken with changes in transformers 4.46")
@pytest.mark.core_model
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
def test_models_text( def test_models_text(
@ -112,6 +113,7 @@ def test_models_text(
@large_gpu_test(min_gb=48) @large_gpu_test(min_gb=48)
@pytest.mark.core_model
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
def test_models_image( def test_models_image(

View File

@ -74,6 +74,7 @@ def _run_test(
) )
@pytest.mark.core_model
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
def test_models_text( def test_models_text(
@ -98,6 +99,7 @@ def test_models_text(
@large_gpu_test(min_gb=48) @large_gpu_test(min_gb=48)
@pytest.mark.core_model
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
def test_models_image( def test_models_image(

View File

@ -14,8 +14,6 @@ from ....conftest import (DecoderPromptType, ExplicitEncoderDecoderPrompt,
from ....utils import multi_gpu_test from ....utils import multi_gpu_test
from ...utils import check_logprobs_close from ...utils import check_logprobs_close
MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"]
def vllm_to_hf_output( def vllm_to_hf_output(
vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]], vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]],
@ -170,7 +168,14 @@ def run_test(
) )
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize(
"model",
[
pytest.param("facebook/bart-base",
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
pytest.param("facebook/bart-large-cnn"),
],
)
@pytest.mark.parametrize("dtype", ["float", "bfloat16"]) @pytest.mark.parametrize("dtype", ["float", "bfloat16"])
@pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])

View File

@ -233,6 +233,7 @@ def clear_cache():
@large_gpu_test(min_gb=48) @large_gpu_test(min_gb=48)
@pytest.mark.core_model
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"sizes", "sizes",
@ -278,6 +279,7 @@ def test_models_single_leading_image(hf_runner, vllm_runner, image_assets,
@large_gpu_test(min_gb=48) @large_gpu_test(min_gb=48)
@pytest.mark.core_model
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("max_tokens", [128])
@ -326,6 +328,7 @@ def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
@large_gpu_test(min_gb=48) @large_gpu_test(min_gb=48)
@pytest.mark.core_model
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("max_tokens", [128])

View File

@ -129,9 +129,13 @@ _EMBEDDING_EXAMPLE_MODELS = {
# [Text-only] # [Text-only]
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"), "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"),
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"), "MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),
"Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"), "Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"),
"Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501 "Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"), # noqa: E501
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-large"),
# [Multimodal] # [Multimodal]
"LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"), "LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"),
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full", "Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",

View File

@ -77,8 +77,8 @@ def test_registry_is_pp(model_arch, is_pp, init_cuda):
def test_hf_registry_coverage(): def test_hf_registry_coverage():
untested_archs = (HF_EXAMPLE_MODELS.get_supported_archs() - untested_archs = (ModelRegistry.get_supported_archs() -
set(ModelRegistry.get_supported_archs())) HF_EXAMPLE_MODELS.get_supported_archs())
assert not untested_archs, ( assert not untested_archs, (
"Please add the following architectures to " "Please add the following architectures to "

View File

@ -37,6 +37,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
@ -44,8 +45,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
@ -247,6 +249,18 @@ class Qwen2Model(nn.Module):
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
# TODO (@robertgshaw2): see if this can be moved out
if (cache_config.sliding_window is not None
and hasattr(config, "max_window_layers")):
raise ValueError("Sliding window for some but all layers is not "
"supported. This model uses sliding window "
"but `max_window_layers` = {} is less than "
"`num_hidden_layers` = {}. Please open an issue "
"to discuss this feature.".format(
config.max_window_layers,
config.num_hidden_layers,
))
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
@ -405,20 +419,9 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
# TODO (@robertgshaw2): see if this can be moved out pooler_config = vllm_config.model_config.pooler_config
if (cache_config.sliding_window is not None
and hasattr(config, "max_window_layers")):
raise ValueError("Sliding window for some but all layers is not "
"supported. This model uses sliding window "
"but `max_window_layers` = {} is less than "
"`num_hidden_layers` = {}. Please open an issue "
"to discuss this feature.".format(
config.max_window_layers,
config.num_hidden_layers,
))
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
@ -438,6 +441,15 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler() self.sampler = get_sampler()
# The same model class supports both language generation and embedding
# because the architecture name is the same
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
@ -475,6 +487,13 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
@ -482,3 +501,70 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if self.config.tie_word_embeddings else None), if self.config.tie_word_embeddings else None),
) )
loader.load_weights(weights) loader.load_weights(weights)
class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
pooler_config = vllm_config.model_config.pooler_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = Qwen2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.MEAN,
normalize=True,
softmax=False)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
return self.model(input_ids, positions, kv_caches, attn_metadata,
intermediate_tensors)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self,
ignore_unexpected_prefixes=["lm_head."])
loader.load_weights(weights)

View File

@ -17,10 +17,11 @@ from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsLoRA, SupportsPP
from .utils import AutoWeightsLoader, maybe_prefix from .utils import AutoWeightsLoader, maybe_prefix
class Qwen2ForSequenceClassification(nn.Module): class Qwen2ForSequenceClassification(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
@ -46,21 +47,9 @@ class Qwen2ForSequenceClassification(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
# TODO (@robertgshaw2): see if this can be moved out
if (cache_config.sliding_window is not None
and hasattr(config, "max_window_layers")):
raise ValueError("Sliding window for some but all layers is not "
"supported. This model uses sliding window "
"but `max_window_layers` = {} is less than "
"`num_hidden_layers` = {}. Please open an issue "
"to discuss this feature.".format(
config.max_window_layers,
config.num_hidden_layers,
))
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config

View File

@ -16,7 +16,7 @@ from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .qwen2 import Qwen2Model from .qwen2 import Qwen2Model
from .utils import AutoWeightsLoader, maybe_prefix from .utils import AutoWeightsLoader, maybe_prefix
@ -32,7 +32,7 @@ class ReLU(nn.Module):
return self.activation(input) return self.activation(input)
class Qwen2ForRewardModel(nn.Module, SupportsPP): class Qwen2ForRewardModel(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
@ -58,21 +58,9 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
# TODO (@robertgshaw2): see if this can be moved out
if (cache_config.sliding_window is not None
and hasattr(config, "max_window_layers")):
raise ValueError("Sliding window for some but all layers is not "
"supported. This model uses sliding window "
"but `max_window_layers` = {} is less than "
"`num_hidden_layers` = {}. Please open an issue "
"to discuss this feature.".format(
config.max_window_layers,
config.num_hidden_layers,
))
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config

View File

@ -11,7 +11,8 @@ import tempfile
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import lru_cache from functools import lru_cache
from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union from typing import (AbstractSet, Callable, Dict, List, Optional, Tuple, Type,
TypeVar, Union)
import cloudpickle import cloudpickle
import torch.nn as nn import torch.nn as nn
@ -110,6 +111,8 @@ _EMBEDDING_MODELS = {
}, },
"MistralModel": ("llama", "LlamaEmbeddingModel"), "MistralModel": ("llama", "LlamaEmbeddingModel"),
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
"Qwen2ForSequenceClassification": ("qwen2_cls", "Qwen2ForSequenceClassification"), # noqa: E501 "Qwen2ForSequenceClassification": ("qwen2_cls", "Qwen2ForSequenceClassification"), # noqa: E501
# [Multimodal] # [Multimodal]
@ -301,8 +304,8 @@ class _ModelRegistry:
# Keyed by model_arch # Keyed by model_arch
models: Dict[str, _BaseRegisteredModel] = field(default_factory=dict) models: Dict[str, _BaseRegisteredModel] = field(default_factory=dict)
def get_supported_archs(self) -> List[str]: def get_supported_archs(self) -> AbstractSet[str]:
return list(self.models.keys()) return self.models.keys()
def register_model( def register_model(
self, self,