mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 03:45:01 +08:00
[Model] Support Qwen2 embeddings and use tags to select model tests (#10184)
This commit is contained in:
parent
2885ba0e24
commit
b40cf6402e
@ -27,9 +27,9 @@ function cpu_tests() {
|
|||||||
decord einops librosa peft Pillow sentence-transformers soundfile \
|
decord einops librosa peft Pillow sentence-transformers soundfile \
|
||||||
transformers_stream_generator matplotlib datamodel_code_generator
|
transformers_stream_generator matplotlib datamodel_code_generator
|
||||||
pip install torchvision --index-url https://download.pytorch.org/whl/cpu
|
pip install torchvision --index-url https://download.pytorch.org/whl/cpu
|
||||||
pytest -v -s tests/models/embedding/language
|
pytest -v -s tests/models/decoder_only/language -m cpu_model
|
||||||
pytest -v -s tests/models/encoder_decoder/language
|
pytest -v -s tests/models/embedding/language -m cpu_model
|
||||||
pytest -v -s tests/models/decoder_only/language/test_models.py
|
pytest -v -s tests/models/encoder_decoder/language -m cpu_model
|
||||||
pytest -v -s tests/models/decoder_only/audio_language -m cpu_model
|
pytest -v -s tests/models/decoder_only/audio_language -m cpu_model
|
||||||
pytest -v -s tests/models/decoder_only/vision_language -m cpu_model"
|
pytest -v -s tests/models/decoder_only/vision_language -m cpu_model"
|
||||||
|
|
||||||
|
|||||||
@ -38,9 +38,9 @@ function cpu_tests() {
|
|||||||
decord einops librosa peft Pillow sentence-transformers soundfile \
|
decord einops librosa peft Pillow sentence-transformers soundfile \
|
||||||
transformers_stream_generator matplotlib datamodel_code_generator
|
transformers_stream_generator matplotlib datamodel_code_generator
|
||||||
pip install torchvision --index-url https://download.pytorch.org/whl/cpu
|
pip install torchvision --index-url https://download.pytorch.org/whl/cpu
|
||||||
pytest -v -s tests/models/embedding/language
|
pytest -v -s tests/models/decoder_only/language -m cpu_model
|
||||||
pytest -v -s tests/models/encoder_decoder/language
|
pytest -v -s tests/models/embedding/language -m cpu_model
|
||||||
pytest -v -s tests/models/decoder_only/language/test_models.py
|
pytest -v -s tests/models/encoder_decoder/language -m cpu_model
|
||||||
pytest -v -s tests/models/decoder_only/audio_language -m cpu_model
|
pytest -v -s tests/models/decoder_only/audio_language -m cpu_model
|
||||||
pytest -v -s tests/models/decoder_only/vision_language -m cpu_model"
|
pytest -v -s tests/models/decoder_only/vision_language -m cpu_model"
|
||||||
|
|
||||||
|
|||||||
@ -323,62 +323,60 @@ steps:
|
|||||||
- pytest -v -s models/test_registry.py
|
- pytest -v -s models/test_registry.py
|
||||||
- pytest -v -s models/test_initialization.py
|
- pytest -v -s models/test_initialization.py
|
||||||
|
|
||||||
- label: Decoder-only Language Models Test (Standard) # 18min
|
- label: Language Models Test (Standard) # 42min
|
||||||
#mirror_hardwares: [amd]
|
#mirror_hardwares: [amd]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
- tests/models/decoder_only/language
|
- tests/models/decoder_only/language
|
||||||
|
- tests/models/embedding/language
|
||||||
|
- tests/models/encoder_decoder/language
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s models/decoder_only/language -m core_model
|
- pytest -v -s models/decoder_only/language -m 'core_model or quant_model'
|
||||||
- pytest -v -s models/decoder_only/language -m quant_model
|
- pytest -v -s models/embedding/language -m core_model
|
||||||
|
- pytest -v -s models/embedding/vision_language -m core_model
|
||||||
|
|
||||||
- label: Decoder-only Language Models Test (Extended) # 46min
|
- label: Language Models Test (Extended) # 50min
|
||||||
nightly: true
|
nightly: true
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
- tests/models/decoder_only/language
|
- tests/models/decoder_only/language
|
||||||
|
- tests/models/embedding/language
|
||||||
|
- tests/models/encoder_decoder/language
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
|
- pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
|
||||||
|
- pytest -v -s models/embedding/language -m 'not core_model'
|
||||||
|
- pytest -v -s models/embedding/vision_language -m 'not core_model'
|
||||||
|
|
||||||
- label: Decoder-only Multi-Modal Models Test (Standard) # 22min
|
- label: Multi-Modal Models Test (Standard) # 26min
|
||||||
#mirror_hardwares: [amd]
|
#mirror_hardwares: [amd]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
- tests/models/decoder_only/audio_language
|
- tests/models/decoder_only/audio_language
|
||||||
- tests/models/decoder_only/vision_language
|
- tests/models/decoder_only/vision_language
|
||||||
|
- tests/models/embedding/vision_language
|
||||||
|
- tests/models/encoder_decoder/vision_language
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s models/decoder_only/audio_language -m core_model
|
- pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model'
|
||||||
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m core_model
|
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model'
|
||||||
# No tests under this group for now
|
- pytest -v -s models/encoder_decoder/language -m core_model
|
||||||
# - pytest -v -s models/decoder_only/audio_language -m quant_model
|
- pytest -v -s models/encoder_decoder/vision_language -m core_model
|
||||||
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m quant_model
|
|
||||||
|
|
||||||
- label: Decoder-only Multi-Modal Models Test (Extended) # 1h10m
|
- label: Multi-Modal Models Test (Extended) # 1h15m
|
||||||
nightly: true
|
nightly: true
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
- tests/models/decoder_only/audio_language
|
- tests/models/decoder_only/audio_language
|
||||||
- tests/models/decoder_only/vision_language
|
- tests/models/decoder_only/vision_language
|
||||||
|
- tests/models/embedding/vision_language
|
||||||
|
- tests/models/encoder_decoder/vision_language
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model'
|
- pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model'
|
||||||
# HACK - run phi3v tests separately to sidestep this transformers bug
|
# HACK - run phi3v tests separately to sidestep this transformers bug
|
||||||
# https://github.com/huggingface/transformers/issues/34307
|
# https://github.com/huggingface/transformers/issues/34307
|
||||||
- pytest -v -s models/decoder_only/vision_language/test_phi3v.py
|
- pytest -v -s models/decoder_only/vision_language/test_phi3v.py
|
||||||
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model'
|
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model'
|
||||||
|
- pytest -v -s models/encoder_decoder/language -m 'not core_model'
|
||||||
- label: Other Models Test # 20min
|
- pytest -v -s models/encoder_decoder/vision_language -m 'not core_model'
|
||||||
#mirror_hardwares: [amd]
|
|
||||||
source_file_dependencies:
|
|
||||||
- vllm/
|
|
||||||
- tests/models/embedding/language
|
|
||||||
- tests/models/embedding/vision_language
|
|
||||||
- tests/models/encoder_decoder/language
|
|
||||||
- tests/models/encoder_decoder/vision_language
|
|
||||||
commands:
|
|
||||||
- pytest -v -s models/embedding/language
|
|
||||||
- pytest -v -s models/embedding/vision_language
|
|
||||||
- pytest -v -s models/encoder_decoder/language
|
|
||||||
- pytest -v -s models/encoder_decoder/vision_language
|
|
||||||
|
|
||||||
# This test is used only in PR development phase to test individual models and should never run on main
|
# This test is used only in PR development phase to test individual models and should never run on main
|
||||||
- label: Custom Models Test
|
- label: Custom Models Test
|
||||||
|
|||||||
@ -330,11 +330,16 @@ Text Embedding
|
|||||||
- :code:`BAAI/bge-multilingual-gemma2`, etc.
|
- :code:`BAAI/bge-multilingual-gemma2`, etc.
|
||||||
-
|
-
|
||||||
- ✅︎
|
- ✅︎
|
||||||
* - :code:`MistralModel`
|
* - :code:`LlamaModel`, :code:`LlamaForCausalLM`, :code:`MistralModel`, etc.
|
||||||
- Mistral-based
|
- Llama-based
|
||||||
- :code:`intfloat/e5-mistral-7b-instruct`, etc.
|
- :code:`intfloat/e5-mistral-7b-instruct`, etc.
|
||||||
- ✅︎
|
- ✅︎
|
||||||
- ✅︎
|
- ✅︎
|
||||||
|
* - :code:`Qwen2Model`, :code:`Qwen2ForCausalLM`
|
||||||
|
- Qwen2-based
|
||||||
|
- :code:`ssmits/Qwen2-7B-Instruct-embed-base`, :code:`Alibaba-NLP/gte-Qwen2-1.5B-instruct`, etc.
|
||||||
|
- ✅︎
|
||||||
|
- ✅︎
|
||||||
|
|
||||||
.. important::
|
.. important::
|
||||||
Some model architectures support both generation and embedding tasks.
|
Some model architectures support both generation and embedding tasks.
|
||||||
@ -355,7 +360,7 @@ Reward Modeling
|
|||||||
* - :code:`Qwen2ForRewardModel`
|
* - :code:`Qwen2ForRewardModel`
|
||||||
- Qwen2-based
|
- Qwen2-based
|
||||||
- :code:`Qwen/Qwen2.5-Math-RM-72B`, etc.
|
- :code:`Qwen/Qwen2.5-Math-RM-72B`, etc.
|
||||||
-
|
- ✅︎
|
||||||
- ✅︎
|
- ✅︎
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
@ -376,7 +381,7 @@ Classification
|
|||||||
* - :code:`Qwen2ForSequenceClassification`
|
* - :code:`Qwen2ForSequenceClassification`
|
||||||
- Qwen2-based
|
- Qwen2-based
|
||||||
- :code:`jason9693/Qwen2.5-1.5B-apeach`, etc.
|
- :code:`jason9693/Qwen2.5-1.5B-apeach`, etc.
|
||||||
-
|
- ✅︎
|
||||||
- ✅︎
|
- ✅︎
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|||||||
@ -33,6 +33,10 @@ def test_models(
|
|||||||
|
|
||||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
# This test is for verifying whether the model's extra_repr
|
||||||
|
# can be printed correctly.
|
||||||
|
print(vllm_model.model.llm_engine.model_executor.driver_worker.
|
||||||
|
model_runner.model)
|
||||||
|
|
||||||
for i in range(len(example_prompts)):
|
for i in range(len(example_prompts)):
|
||||||
hf_output_ids, hf_output_str = hf_outputs[i]
|
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||||
@ -293,17 +297,3 @@ def test_jamba_distributed_produces_identical_generation(
|
|||||||
name_0="vllm_tp_1",
|
name_0="vllm_tp_1",
|
||||||
name_1="vllm_tp_2",
|
name_1="vllm_tp_2",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
|
||||||
@pytest.mark.parametrize("dtype", ["float"])
|
|
||||||
def test_model_print(
|
|
||||||
vllm_runner,
|
|
||||||
model: str,
|
|
||||||
dtype: str,
|
|
||||||
) -> None:
|
|
||||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
|
||||||
# This test is for verifying whether the model's extra_repr
|
|
||||||
# can be printed correctly.
|
|
||||||
print(vllm_model.model.llm_engine.model_executor.driver_worker.
|
|
||||||
model_runner.model)
|
|
||||||
|
|||||||
@ -51,6 +51,10 @@ def test_models(
|
|||||||
|
|
||||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
# This test is for verifying whether the model's extra_repr
|
||||||
|
# can be printed correctly.
|
||||||
|
print(vllm_model.model.llm_engine.model_executor.driver_worker.
|
||||||
|
model_runner.model)
|
||||||
|
|
||||||
for i in range(len(example_prompts)):
|
for i in range(len(example_prompts)):
|
||||||
hf_output_ids, hf_output_str = hf_outputs[i]
|
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||||
@ -279,17 +283,3 @@ def test_state_cleanup(
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
pytest.fail("Mamba inner state wasn't cleaned up between states, "
|
pytest.fail("Mamba inner state wasn't cleaned up between states, "
|
||||||
"could be related to finished_requests_ids")
|
"could be related to finished_requests_ids")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
|
||||||
@pytest.mark.parametrize("dtype", ["float"])
|
|
||||||
def test_model_print(
|
|
||||||
vllm_runner,
|
|
||||||
model: str,
|
|
||||||
dtype: str,
|
|
||||||
) -> None:
|
|
||||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
|
||||||
# This test is for verifying whether the model's extra_repr
|
|
||||||
# can be printed correctly.
|
|
||||||
print(vllm_model.model.llm_engine.model_executor.driver_worker.
|
|
||||||
model_runner.model)
|
|
||||||
|
|||||||
@ -4,37 +4,52 @@ Run `pytest tests/models/test_models.py`.
|
|||||||
"""
|
"""
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
from ...utils import check_logprobs_close
|
from ...utils import check_logprobs_close
|
||||||
|
|
||||||
MODELS = [
|
|
||||||
"facebook/opt-125m", # opt
|
@pytest.mark.parametrize(
|
||||||
"openai-community/gpt2", # gpt2
|
"model",
|
||||||
# "Milos/slovak-gpt-j-405M", # gptj
|
[
|
||||||
# "bigcode/tiny_starcoder_py", # gpt_bigcode
|
pytest.param(
|
||||||
# "EleutherAI/pythia-70m", # gpt_neox
|
|
||||||
"bigscience/bloom-560m", # bloom - testing alibi slopes
|
"bigscience/bloom-560m", # bloom - testing alibi slopes
|
||||||
"microsoft/phi-2", # phi
|
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
|
||||||
# "stabilityai/stablelm-3b-4e1t", # stablelm
|
),
|
||||||
# "bigcode/starcoder2-3b", # starcoder2
|
pytest.param(
|
||||||
|
"openai-community/gpt2", # gpt2
|
||||||
|
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
|
||||||
|
),
|
||||||
|
pytest.param("Milos/slovak-gpt-j-405M"), # gptj
|
||||||
|
pytest.param("bigcode/tiny_starcoder_py"), # gpt_bigcode
|
||||||
|
pytest.param("EleutherAI/pythia-70m"), # gpt_neox
|
||||||
|
pytest.param(
|
||||||
"google/gemma-1.1-2b-it", # gemma
|
"google/gemma-1.1-2b-it", # gemma
|
||||||
"Qwen/Qwen2.5-0.5B-Instruct", # qwen2
|
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
"meta-llama/Llama-3.2-1B-Instruct", # llama
|
"meta-llama/Llama-3.2-1B-Instruct", # llama
|
||||||
]
|
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
|
||||||
|
),
|
||||||
if not current_platform.is_cpu():
|
pytest.param(
|
||||||
MODELS += [
|
|
||||||
# fused_moe which not supported on CPU
|
|
||||||
"openbmb/MiniCPM3-4B",
|
"openbmb/MiniCPM3-4B",
|
||||||
]
|
# fused_moe not supported on CPU
|
||||||
|
marks=[pytest.mark.core_model],
|
||||||
target_dtype = "half"
|
),
|
||||||
|
pytest.param(
|
||||||
|
"facebook/opt-125m", # opt
|
||||||
@pytest.mark.core_model
|
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
),
|
||||||
@pytest.mark.parametrize("dtype", [target_dtype])
|
pytest.param(
|
||||||
|
"microsoft/phi-2", # phi
|
||||||
|
marks=[pytest.mark.core_model],
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"Qwen/Qwen2.5-0.5B-Instruct", # qwen2
|
||||||
|
marks=[pytest.mark.core_model],
|
||||||
|
),
|
||||||
|
pytest.param("stabilityai/stablelm-3b-4e1t"), # stablelm
|
||||||
|
pytest.param("bigcode/starcoder2-3b"), # starcoder2
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
@pytest.mark.parametrize("max_tokens", [32])
|
@pytest.mark.parametrize("max_tokens", [32])
|
||||||
@pytest.mark.parametrize("num_logprobs", [5])
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
def test_models(
|
def test_models(
|
||||||
|
|||||||
@ -9,10 +9,14 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
from transformers import AutoModelForSequenceClassification
|
from transformers import AutoModelForSequenceClassification
|
||||||
|
|
||||||
CLASSIFICATION_MODELS = ["jason9693/Qwen2.5-1.5B-apeach"]
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
@pytest.mark.parametrize("model", CLASSIFICATION_MODELS)
|
"model",
|
||||||
|
[
|
||||||
|
pytest.param("jason9693/Qwen2.5-1.5B-apeach",
|
||||||
|
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
|
||||||
|
],
|
||||||
|
)
|
||||||
@pytest.mark.parametrize("dtype", ["float"])
|
@pytest.mark.parametrize("dtype", ["float"])
|
||||||
def test_classification_models(
|
def test_classification_models(
|
||||||
hf_runner,
|
hf_runner,
|
||||||
@ -23,31 +27,19 @@ def test_classification_models(
|
|||||||
) -> None:
|
) -> None:
|
||||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||||
vllm_outputs = vllm_model.classify(example_prompts)
|
vllm_outputs = vllm_model.classify(example_prompts)
|
||||||
|
# This test is for verifying whether the model's extra_repr
|
||||||
|
# can be printed correctly.
|
||||||
|
print(vllm_model.model.llm_engine.model_executor.driver_worker.
|
||||||
|
model_runner.model)
|
||||||
|
|
||||||
with hf_runner(model,
|
with hf_runner(model,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
auto_cls=AutoModelForSequenceClassification) as hf_model:
|
auto_cls=AutoModelForSequenceClassification) as hf_model:
|
||||||
hf_outputs = hf_model.classify(example_prompts)
|
hf_outputs = hf_model.classify(example_prompts)
|
||||||
|
|
||||||
print(hf_outputs, vllm_outputs)
|
|
||||||
|
|
||||||
# check logits difference
|
# check logits difference
|
||||||
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
|
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
|
||||||
hf_output = torch.tensor(hf_output)
|
hf_output = torch.tensor(hf_output)
|
||||||
vllm_output = torch.tensor(vllm_output)
|
vllm_output = torch.tensor(vllm_output)
|
||||||
|
|
||||||
assert torch.allclose(hf_output, vllm_output, 1e-3)
|
assert torch.allclose(hf_output, vllm_output, 1e-3)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", CLASSIFICATION_MODELS)
|
|
||||||
@pytest.mark.parametrize("dtype", ["float"])
|
|
||||||
def test_classification_model_print(
|
|
||||||
vllm_runner,
|
|
||||||
model: str,
|
|
||||||
dtype: str,
|
|
||||||
) -> None:
|
|
||||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
|
||||||
# This test is for verifying whether the model's extra_repr
|
|
||||||
# can be printed correctly.
|
|
||||||
print(vllm_model.model.llm_engine.model_executor.driver_worker.
|
|
||||||
model_runner.model)
|
|
||||||
|
|||||||
@ -4,25 +4,25 @@ Run `pytest tests/models/embedding/language/test_embedding.py`.
|
|||||||
"""
|
"""
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.utils import current_platform
|
|
||||||
|
|
||||||
from ..utils import check_embeddings_close
|
from ..utils import check_embeddings_close
|
||||||
|
|
||||||
# Model, Guard
|
|
||||||
MODELS = [
|
|
||||||
"intfloat/e5-mistral-7b-instruct",
|
|
||||||
"BAAI/bge-base-en-v1.5",
|
|
||||||
"BAAI/bge-multilingual-gemma2",
|
|
||||||
"intfloat/multilingual-e5-large",
|
|
||||||
]
|
|
||||||
|
|
||||||
ENCODER_ONLY = [
|
@pytest.mark.parametrize(
|
||||||
"BAAI/bge-base-en-v1.5",
|
"model",
|
||||||
"intfloat/multilingual-e5-large",
|
[
|
||||||
]
|
# [Encoder-only]
|
||||||
|
pytest.param("BAAI/bge-base-en-v1.5",
|
||||||
|
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
pytest.param("intfloat/multilingual-e5-large"),
|
||||||
|
# [Encoder-decoder]
|
||||||
|
pytest.param("intfloat/e5-mistral-7b-instruct",
|
||||||
|
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
|
||||||
|
pytest.param("BAAI/bge-multilingual-gemma2",
|
||||||
|
marks=[pytest.mark.core_model]),
|
||||||
|
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"),
|
||||||
|
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"),
|
||||||
|
],
|
||||||
|
)
|
||||||
@pytest.mark.parametrize("dtype", ["half"])
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
def test_models(
|
def test_models(
|
||||||
hf_runner,
|
hf_runner,
|
||||||
@ -31,9 +31,6 @@ def test_models(
|
|||||||
model,
|
model,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
if model not in ENCODER_ONLY and current_platform.is_cpu():
|
|
||||||
pytest.skip("Skip large embedding models test on CPU.")
|
|
||||||
|
|
||||||
# The example_prompts has ending "\n", for example:
|
# The example_prompts has ending "\n", for example:
|
||||||
# "Write a short story about a robot that dreams for the first time.\n"
|
# "Write a short story about a robot that dreams for the first time.\n"
|
||||||
# sentence_transformers will strip the input texts, see:
|
# sentence_transformers will strip the input texts, see:
|
||||||
@ -46,8 +43,13 @@ def test_models(
|
|||||||
is_sentence_transformer=True) as hf_model:
|
is_sentence_transformer=True) as hf_model:
|
||||||
hf_outputs = hf_model.encode(example_prompts)
|
hf_outputs = hf_model.encode(example_prompts)
|
||||||
|
|
||||||
with vllm_runner(model, dtype=dtype, max_model_len=None) as vllm_model:
|
with vllm_runner(model, task="embedding", dtype=dtype,
|
||||||
|
max_model_len=None) as vllm_model:
|
||||||
vllm_outputs = vllm_model.encode(example_prompts)
|
vllm_outputs = vllm_model.encode(example_prompts)
|
||||||
|
# This test is for verifying whether the model's extra_repr
|
||||||
|
# can be printed correctly.
|
||||||
|
print(vllm_model.model.llm_engine.model_executor.driver_worker.
|
||||||
|
model_runner.model)
|
||||||
|
|
||||||
check_embeddings_close(
|
check_embeddings_close(
|
||||||
embeddings_0_lst=hf_outputs,
|
embeddings_0_lst=hf_outputs,
|
||||||
|
|||||||
@ -88,6 +88,7 @@ def _run_test(
|
|||||||
|
|
||||||
@pytest.mark.skipif(transformers.__version__.startswith("4.46"),
|
@pytest.mark.skipif(transformers.__version__.startswith("4.46"),
|
||||||
reason="Model broken with changes in transformers 4.46")
|
reason="Model broken with changes in transformers 4.46")
|
||||||
|
@pytest.mark.core_model
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("dtype", ["half"])
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
def test_models_text(
|
def test_models_text(
|
||||||
@ -112,6 +113,7 @@ def test_models_text(
|
|||||||
|
|
||||||
|
|
||||||
@large_gpu_test(min_gb=48)
|
@large_gpu_test(min_gb=48)
|
||||||
|
@pytest.mark.core_model
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("dtype", ["half"])
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
def test_models_image(
|
def test_models_image(
|
||||||
|
|||||||
@ -74,6 +74,7 @@ def _run_test(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.core_model
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("dtype", ["half"])
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
def test_models_text(
|
def test_models_text(
|
||||||
@ -98,6 +99,7 @@ def test_models_text(
|
|||||||
|
|
||||||
|
|
||||||
@large_gpu_test(min_gb=48)
|
@large_gpu_test(min_gb=48)
|
||||||
|
@pytest.mark.core_model
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("dtype", ["half"])
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
def test_models_image(
|
def test_models_image(
|
||||||
|
|||||||
@ -14,8 +14,6 @@ from ....conftest import (DecoderPromptType, ExplicitEncoderDecoderPrompt,
|
|||||||
from ....utils import multi_gpu_test
|
from ....utils import multi_gpu_test
|
||||||
from ...utils import check_logprobs_close
|
from ...utils import check_logprobs_close
|
||||||
|
|
||||||
MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"]
|
|
||||||
|
|
||||||
|
|
||||||
def vllm_to_hf_output(
|
def vllm_to_hf_output(
|
||||||
vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]],
|
vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]],
|
||||||
@ -170,7 +168,14 @@ def run_test(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
[
|
||||||
|
pytest.param("facebook/bart-base",
|
||||||
|
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
|
||||||
|
pytest.param("facebook/bart-large-cnn"),
|
||||||
|
],
|
||||||
|
)
|
||||||
@pytest.mark.parametrize("dtype", ["float", "bfloat16"])
|
@pytest.mark.parametrize("dtype", ["float", "bfloat16"])
|
||||||
@pytest.mark.parametrize("max_tokens", [64])
|
@pytest.mark.parametrize("max_tokens", [64])
|
||||||
@pytest.mark.parametrize("num_logprobs", [5])
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
|
|||||||
@ -233,6 +233,7 @@ def clear_cache():
|
|||||||
|
|
||||||
|
|
||||||
@large_gpu_test(min_gb=48)
|
@large_gpu_test(min_gb=48)
|
||||||
|
@pytest.mark.core_model
|
||||||
@pytest.mark.parametrize("model", models)
|
@pytest.mark.parametrize("model", models)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"sizes",
|
"sizes",
|
||||||
@ -278,6 +279,7 @@ def test_models_single_leading_image(hf_runner, vllm_runner, image_assets,
|
|||||||
|
|
||||||
|
|
||||||
@large_gpu_test(min_gb=48)
|
@large_gpu_test(min_gb=48)
|
||||||
|
@pytest.mark.core_model
|
||||||
@pytest.mark.parametrize("model", models)
|
@pytest.mark.parametrize("model", models)
|
||||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
@pytest.mark.parametrize("max_tokens", [128])
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
@ -326,6 +328,7 @@ def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
|
|||||||
|
|
||||||
|
|
||||||
@large_gpu_test(min_gb=48)
|
@large_gpu_test(min_gb=48)
|
||||||
|
@pytest.mark.core_model
|
||||||
@pytest.mark.parametrize("model", models)
|
@pytest.mark.parametrize("model", models)
|
||||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
@pytest.mark.parametrize("max_tokens", [128])
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
|
|||||||
@ -129,9 +129,13 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
|||||||
# [Text-only]
|
# [Text-only]
|
||||||
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
|
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
|
||||||
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"),
|
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"),
|
||||||
|
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
|
||||||
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
|
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
|
||||||
|
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),
|
||||||
"Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"),
|
"Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"),
|
||||||
"Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501
|
"Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501
|
||||||
|
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"), # noqa: E501
|
||||||
|
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-large"),
|
||||||
# [Multimodal]
|
# [Multimodal]
|
||||||
"LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"),
|
"LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"),
|
||||||
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",
|
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",
|
||||||
|
|||||||
@ -77,8 +77,8 @@ def test_registry_is_pp(model_arch, is_pp, init_cuda):
|
|||||||
|
|
||||||
|
|
||||||
def test_hf_registry_coverage():
|
def test_hf_registry_coverage():
|
||||||
untested_archs = (HF_EXAMPLE_MODELS.get_supported_archs() -
|
untested_archs = (ModelRegistry.get_supported_archs() -
|
||||||
set(ModelRegistry.get_supported_archs()))
|
HF_EXAMPLE_MODELS.get_supported_archs())
|
||||||
|
|
||||||
assert not untested_archs, (
|
assert not untested_archs, (
|
||||||
"Please add the following architectures to "
|
"Please add the following architectures to "
|
||||||
|
|||||||
@ -37,6 +37,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
|||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
@ -44,8 +45,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
ParallelLMHead, VocabParallelEmbedding)
|
ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import (
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
default_weight_loader, maybe_remap_kv_scale_name)
|
default_weight_loader, maybe_remap_kv_scale_name)
|
||||||
|
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA, SupportsPP
|
from .interfaces import SupportsLoRA, SupportsPP
|
||||||
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
|
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
|
||||||
@ -247,6 +249,18 @@ class Qwen2Model(nn.Module):
|
|||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
|
|
||||||
|
# TODO (@robertgshaw2): see if this can be moved out
|
||||||
|
if (cache_config.sliding_window is not None
|
||||||
|
and hasattr(config, "max_window_layers")):
|
||||||
|
raise ValueError("Sliding window for some but all layers is not "
|
||||||
|
"supported. This model uses sliding window "
|
||||||
|
"but `max_window_layers` = {} is less than "
|
||||||
|
"`num_hidden_layers` = {}. Please open an issue "
|
||||||
|
"to discuss this feature.".format(
|
||||||
|
config.max_window_layers,
|
||||||
|
config.num_hidden_layers,
|
||||||
|
))
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
@ -405,20 +419,9 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
cache_config = vllm_config.cache_config
|
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
lora_config = vllm_config.lora_config
|
lora_config = vllm_config.lora_config
|
||||||
# TODO (@robertgshaw2): see if this can be moved out
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
if (cache_config.sliding_window is not None
|
|
||||||
and hasattr(config, "max_window_layers")):
|
|
||||||
raise ValueError("Sliding window for some but all layers is not "
|
|
||||||
"supported. This model uses sliding window "
|
|
||||||
"but `max_window_layers` = {} is less than "
|
|
||||||
"`num_hidden_layers` = {}. Please open an issue "
|
|
||||||
"to discuss this feature.".format(
|
|
||||||
config.max_window_layers,
|
|
||||||
config.num_hidden_layers,
|
|
||||||
))
|
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
@ -438,6 +441,15 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
|
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
self.sampler = get_sampler()
|
self.sampler = get_sampler()
|
||||||
|
|
||||||
|
# The same model class supports both language generation and embedding
|
||||||
|
# because the architecture name is the same
|
||||||
|
self._pooler = Pooler.from_config_with_defaults(
|
||||||
|
pooler_config,
|
||||||
|
pooling_type=PoolingType.LAST,
|
||||||
|
normalize=True,
|
||||||
|
softmax=False)
|
||||||
|
|
||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
self.model.make_empty_intermediate_tensors)
|
self.model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
@ -475,6 +487,13 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
next_tokens = self.sampler(logits, sampling_metadata)
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
|
def pooler(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> Optional[PoolerOutput]:
|
||||||
|
return self._pooler(hidden_states, pooling_metadata)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
loader = AutoWeightsLoader(
|
loader = AutoWeightsLoader(
|
||||||
self,
|
self,
|
||||||
@ -482,3 +501,70 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
if self.config.tie_word_embeddings else None),
|
if self.config.tie_word_embeddings else None),
|
||||||
)
|
)
|
||||||
loader.load_weights(weights)
|
loader.load_weights(weights)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
],
|
||||||
|
"gate_up_proj": [
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# LoRA specific attributes
|
||||||
|
supported_lora_modules = [
|
||||||
|
"qkv_proj",
|
||||||
|
"o_proj",
|
||||||
|
"gate_up_proj",
|
||||||
|
"down_proj",
|
||||||
|
]
|
||||||
|
embedding_modules = {}
|
||||||
|
embedding_padding_modules = []
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
lora_config = vllm_config.lora_config
|
||||||
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.lora_config = lora_config
|
||||||
|
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.model = Qwen2Model(vllm_config=vllm_config,
|
||||||
|
prefix=maybe_prefix(prefix, "model"))
|
||||||
|
|
||||||
|
self._pooler = Pooler.from_config_with_defaults(
|
||||||
|
pooler_config,
|
||||||
|
pooling_type=PoolingType.MEAN,
|
||||||
|
normalize=True,
|
||||||
|
softmax=False)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return self.model(input_ids, positions, kv_caches, attn_metadata,
|
||||||
|
intermediate_tensors)
|
||||||
|
|
||||||
|
def pooler(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> Optional[PoolerOutput]:
|
||||||
|
return self._pooler(hidden_states, pooling_metadata)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
loader = AutoWeightsLoader(self,
|
||||||
|
ignore_unexpected_prefixes=["lm_head."])
|
||||||
|
loader.load_weights(weights)
|
||||||
|
|||||||
@ -17,10 +17,11 @@ from vllm.model_executor.models.qwen2 import Qwen2Model
|
|||||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||||
|
|
||||||
|
from .interfaces import SupportsLoRA, SupportsPP
|
||||||
from .utils import AutoWeightsLoader, maybe_prefix
|
from .utils import AutoWeightsLoader, maybe_prefix
|
||||||
|
|
||||||
|
|
||||||
class Qwen2ForSequenceClassification(nn.Module):
|
class Qwen2ForSequenceClassification(nn.Module, SupportsLoRA, SupportsPP):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
"q_proj",
|
"q_proj",
|
||||||
@ -46,21 +47,9 @@ class Qwen2ForSequenceClassification(nn.Module):
|
|||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
cache_config = vllm_config.cache_config
|
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
lora_config = vllm_config.lora_config
|
lora_config = vllm_config.lora_config
|
||||||
pooler_config = vllm_config.model_config.pooler_config
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
# TODO (@robertgshaw2): see if this can be moved out
|
|
||||||
if (cache_config.sliding_window is not None
|
|
||||||
and hasattr(config, "max_window_layers")):
|
|
||||||
raise ValueError("Sliding window for some but all layers is not "
|
|
||||||
"supported. This model uses sliding window "
|
|
||||||
"but `max_window_layers` = {} is less than "
|
|
||||||
"`num_hidden_layers` = {}. Please open an issue "
|
|
||||||
"to discuss this feature.".format(
|
|
||||||
config.max_window_layers,
|
|
||||||
config.num_hidden_layers,
|
|
||||||
))
|
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
|
|||||||
@ -16,7 +16,7 @@ from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
|||||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||||
|
|
||||||
from .interfaces import SupportsPP
|
from .interfaces import SupportsLoRA, SupportsPP
|
||||||
from .qwen2 import Qwen2Model
|
from .qwen2 import Qwen2Model
|
||||||
from .utils import AutoWeightsLoader, maybe_prefix
|
from .utils import AutoWeightsLoader, maybe_prefix
|
||||||
|
|
||||||
@ -32,7 +32,7 @@ class ReLU(nn.Module):
|
|||||||
return self.activation(input)
|
return self.activation(input)
|
||||||
|
|
||||||
|
|
||||||
class Qwen2ForRewardModel(nn.Module, SupportsPP):
|
class Qwen2ForRewardModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
"q_proj",
|
"q_proj",
|
||||||
@ -58,21 +58,9 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP):
|
|||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
cache_config = vllm_config.cache_config
|
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
lora_config = vllm_config.lora_config
|
lora_config = vllm_config.lora_config
|
||||||
pooler_config = vllm_config.model_config.pooler_config
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
# TODO (@robertgshaw2): see if this can be moved out
|
|
||||||
if (cache_config.sliding_window is not None
|
|
||||||
and hasattr(config, "max_window_layers")):
|
|
||||||
raise ValueError("Sliding window for some but all layers is not "
|
|
||||||
"supported. This model uses sliding window "
|
|
||||||
"but `max_window_layers` = {} is less than "
|
|
||||||
"`num_hidden_layers` = {}. Please open an issue "
|
|
||||||
"to discuss this feature.".format(
|
|
||||||
config.max_window_layers,
|
|
||||||
config.num_hidden_layers,
|
|
||||||
))
|
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
|
|||||||
@ -11,7 +11,8 @@ import tempfile
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
from typing import (AbstractSet, Callable, Dict, List, Optional, Tuple, Type,
|
||||||
|
TypeVar, Union)
|
||||||
|
|
||||||
import cloudpickle
|
import cloudpickle
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -110,6 +111,8 @@ _EMBEDDING_MODELS = {
|
|||||||
},
|
},
|
||||||
"MistralModel": ("llama", "LlamaEmbeddingModel"),
|
"MistralModel": ("llama", "LlamaEmbeddingModel"),
|
||||||
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
|
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
|
||||||
|
"Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
|
||||||
|
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
||||||
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
|
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
|
||||||
"Qwen2ForSequenceClassification": ("qwen2_cls", "Qwen2ForSequenceClassification"), # noqa: E501
|
"Qwen2ForSequenceClassification": ("qwen2_cls", "Qwen2ForSequenceClassification"), # noqa: E501
|
||||||
# [Multimodal]
|
# [Multimodal]
|
||||||
@ -301,8 +304,8 @@ class _ModelRegistry:
|
|||||||
# Keyed by model_arch
|
# Keyed by model_arch
|
||||||
models: Dict[str, _BaseRegisteredModel] = field(default_factory=dict)
|
models: Dict[str, _BaseRegisteredModel] = field(default_factory=dict)
|
||||||
|
|
||||||
def get_supported_archs(self) -> List[str]:
|
def get_supported_archs(self) -> AbstractSet[str]:
|
||||||
return list(self.models.keys())
|
return self.models.keys()
|
||||||
|
|
||||||
def register_model(
|
def register_model(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user