mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-10 01:42:16 +08:00
Support embedding models in V1 (#16188)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Max de Bayser <maxdebayser@gmail.com> Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
parent
4959915089
commit
799397ee4f
@ -12,7 +12,10 @@ def parse_args():
|
|||||||
parser = EngineArgs.add_cli_args(parser)
|
parser = EngineArgs.add_cli_args(parser)
|
||||||
# Set example specific arguments
|
# Set example specific arguments
|
||||||
parser.set_defaults(
|
parser.set_defaults(
|
||||||
model="intfloat/e5-mistral-7b-instruct", task="embed", enforce_eager=True
|
model="intfloat/e5-mistral-7b-instruct",
|
||||||
|
task="embed",
|
||||||
|
enforce_eager=True,
|
||||||
|
max_model_len=1024,
|
||||||
)
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|||||||
@ -94,6 +94,7 @@ def run_vlm2vec(query: Query) -> ModelRequestData:
|
|||||||
engine_args = EngineArgs(
|
engine_args = EngineArgs(
|
||||||
model="TIGER-Lab/VLM2Vec-Full",
|
model="TIGER-Lab/VLM2Vec-Full",
|
||||||
task="embed",
|
task="embed",
|
||||||
|
max_model_len=4096,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
mm_processor_kwargs={"num_crops": 4},
|
mm_processor_kwargs={"num_crops": 4},
|
||||||
limit_mm_per_prompt={"image": 1},
|
limit_mm_per_prompt={"image": 1},
|
||||||
|
|||||||
@ -31,7 +31,7 @@ class TestSetting:
|
|||||||
# basic llama model
|
# basic llama model
|
||||||
TestSetting(
|
TestSetting(
|
||||||
model="meta-llama/Llama-3.2-1B-Instruct",
|
model="meta-llama/Llama-3.2-1B-Instruct",
|
||||||
model_args=[],
|
model_args=["--max-model-len", "2048"],
|
||||||
pp_size=2,
|
pp_size=2,
|
||||||
tp_size=2,
|
tp_size=2,
|
||||||
attn_backend="FLASHINFER",
|
attn_backend="FLASHINFER",
|
||||||
@ -41,7 +41,7 @@ class TestSetting:
|
|||||||
# llama model with quantization
|
# llama model with quantization
|
||||||
TestSetting(
|
TestSetting(
|
||||||
model="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
|
model="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
|
||||||
model_args=["--quantization", "gptq"],
|
model_args=["--quantization", "gptq", "--max-model-len", "2048"],
|
||||||
pp_size=1,
|
pp_size=1,
|
||||||
tp_size=1,
|
tp_size=1,
|
||||||
attn_backend="FLASH_ATTN",
|
attn_backend="FLASH_ATTN",
|
||||||
@ -51,7 +51,7 @@ class TestSetting:
|
|||||||
# MoE model
|
# MoE model
|
||||||
TestSetting(
|
TestSetting(
|
||||||
model="ibm/PowerMoE-3b",
|
model="ibm/PowerMoE-3b",
|
||||||
model_args=[],
|
model_args=["--max-model-len", "2048"],
|
||||||
pp_size=1,
|
pp_size=1,
|
||||||
tp_size=2,
|
tp_size=2,
|
||||||
attn_backend="FLASH_ATTN",
|
attn_backend="FLASH_ATTN",
|
||||||
@ -61,23 +61,27 @@ class TestSetting:
|
|||||||
# embedding model
|
# embedding model
|
||||||
TestSetting(
|
TestSetting(
|
||||||
model="BAAI/bge-multilingual-gemma2",
|
model="BAAI/bge-multilingual-gemma2",
|
||||||
model_args=["--task", "embed", "--dtype", "bfloat16"],
|
model_args=[
|
||||||
|
"--task", "embed", "--dtype", "bfloat16", "--max-model-len",
|
||||||
|
"2048"
|
||||||
|
],
|
||||||
pp_size=1,
|
pp_size=1,
|
||||||
tp_size=1,
|
tp_size=1,
|
||||||
attn_backend="FLASH_ATTN",
|
attn_backend="FLASH_ATTN",
|
||||||
method="encode",
|
method="encode",
|
||||||
fullgraph=True,
|
fullgraph=True,
|
||||||
),
|
),
|
||||||
# encoder-based embedding model (BERT)
|
# TODO: bert models are not supported in V1 yet
|
||||||
TestSetting(
|
# # encoder-based embedding model (BERT)
|
||||||
model="BAAI/bge-base-en-v1.5",
|
# TestSetting(
|
||||||
model_args=["--task", "embed"],
|
# model="BAAI/bge-base-en-v1.5",
|
||||||
pp_size=1,
|
# model_args=["--task", "embed"],
|
||||||
tp_size=1,
|
# pp_size=1,
|
||||||
attn_backend="XFORMERS",
|
# tp_size=1,
|
||||||
method="encode",
|
# attn_backend="XFORMERS",
|
||||||
fullgraph=True,
|
# method="encode",
|
||||||
),
|
# fullgraph=True,
|
||||||
|
# ),
|
||||||
# vision language model
|
# vision language model
|
||||||
TestSetting(
|
TestSetting(
|
||||||
model="microsoft/Phi-3.5-vision-instruct",
|
model="microsoft/Phi-3.5-vision-instruct",
|
||||||
|
|||||||
@ -145,6 +145,7 @@ def run_with_both_engines(request, monkeypatch):
|
|||||||
# Automatically runs tests twice, once with V1 and once without
|
# Automatically runs tests twice, once with V1 and once without
|
||||||
use_v1 = request.param
|
use_v1 = request.param
|
||||||
# Tests decorated with `@skip_v1` are only run without v1
|
# Tests decorated with `@skip_v1` are only run without v1
|
||||||
|
skip_v0 = request.node.get_closest_marker("skip_v0")
|
||||||
skip_v1 = request.node.get_closest_marker("skip_v1")
|
skip_v1 = request.node.get_closest_marker("skip_v1")
|
||||||
|
|
||||||
if use_v1:
|
if use_v1:
|
||||||
@ -152,6 +153,8 @@ def run_with_both_engines(request, monkeypatch):
|
|||||||
pytest.skip("Skipping test on vllm V1")
|
pytest.skip("Skipping test on vllm V1")
|
||||||
monkeypatch.setenv('VLLM_USE_V1', '1')
|
monkeypatch.setenv('VLLM_USE_V1', '1')
|
||||||
else:
|
else:
|
||||||
|
if skip_v0:
|
||||||
|
pytest.skip("Skipping test on vllm V0")
|
||||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|||||||
@ -8,6 +8,8 @@ import pytest
|
|||||||
from vllm import LLM, PoolingParams, PoolingRequestOutput
|
from vllm import LLM, PoolingParams, PoolingRequestOutput
|
||||||
from vllm.distributed import cleanup_dist_env_and_memory
|
from vllm.distributed import cleanup_dist_env_and_memory
|
||||||
|
|
||||||
|
from ...models.utils import check_embeddings_close
|
||||||
|
|
||||||
MODEL_NAME = "intfloat/multilingual-e5-small"
|
MODEL_NAME = "intfloat/multilingual-e5-small"
|
||||||
|
|
||||||
PROMPTS = [
|
PROMPTS = [
|
||||||
@ -27,6 +29,14 @@ TOKEN_IDS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def v1(run_with_both_engines):
|
||||||
|
# Simple autouse wrapper to run both engines for each test
|
||||||
|
# This can be promoted up to conftest.py to run for every
|
||||||
|
# test in a package
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def llm():
|
def llm():
|
||||||
# pytest caches the fixture so we use weakref.proxy to
|
# pytest caches the fixture so we use weakref.proxy to
|
||||||
@ -46,9 +56,15 @@ def llm():
|
|||||||
cleanup_dist_env_and_memory()
|
cleanup_dist_env_and_memory()
|
||||||
|
|
||||||
|
|
||||||
def assert_outputs_equal(o1: list[PoolingRequestOutput],
|
def assert_outputs_match(o1: list[PoolingRequestOutput],
|
||||||
o2: list[PoolingRequestOutput]):
|
o2: list[PoolingRequestOutput]):
|
||||||
assert [o.outputs for o in o1] == [o.outputs for o in o2]
|
check_embeddings_close(
|
||||||
|
embeddings_0_lst=[o.outputs.data for o in o1],
|
||||||
|
embeddings_1_lst=[o.outputs.data for o in o2],
|
||||||
|
name_0="hf",
|
||||||
|
name_1="vllm",
|
||||||
|
tol=1e-2,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
@ -63,7 +79,7 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
|
|||||||
|
|
||||||
v2_output = llm.encode({"prompt_token_ids": prompt_token_ids},
|
v2_output = llm.encode({"prompt_token_ids": prompt_token_ids},
|
||||||
pooling_params=pooling_params)
|
pooling_params=pooling_params)
|
||||||
assert_outputs_equal(v1_output, v2_output)
|
assert_outputs_match(v1_output, v2_output)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
@ -80,7 +96,7 @@ def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
|
|||||||
} for p in TOKEN_IDS],
|
} for p in TOKEN_IDS],
|
||||||
pooling_params=pooling_params,
|
pooling_params=pooling_params,
|
||||||
)
|
)
|
||||||
assert_outputs_equal(v1_output, v2_output)
|
assert_outputs_match(v1_output, v2_output)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
|
|||||||
@ -21,6 +21,14 @@ DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' +
|
|||||||
DTYPE = "bfloat16"
|
DTYPE = "bfloat16"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def v1(run_with_both_engines):
|
||||||
|
# Simple autouse wrapper to run both engines for each test
|
||||||
|
# This can be promoted up to conftest.py to run for every
|
||||||
|
# test in a package
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def server():
|
def server():
|
||||||
args = [
|
args = [
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
from tests.models.utils import check_embeddings_close
|
||||||
from vllm.entrypoints.openai.protocol import PoolingResponse
|
from vllm.entrypoints.openai.protocol import PoolingResponse
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
@ -223,8 +224,11 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
|
|||||||
np.frombuffer(base64.b64decode(data.data),
|
np.frombuffer(base64.b64decode(data.data),
|
||||||
dtype="float32").tolist())
|
dtype="float32").tolist())
|
||||||
|
|
||||||
assert responses_float.data[0].data == decoded_responses_base64_data[0]
|
check_embeddings_close(
|
||||||
assert responses_float.data[1].data == decoded_responses_base64_data[1]
|
embeddings_0_lst=[d.data for d in responses_float.data],
|
||||||
|
embeddings_1_lst=decoded_responses_base64_data,
|
||||||
|
name_0="float32",
|
||||||
|
name_1="base64")
|
||||||
|
|
||||||
# Default response is float32 decoded from base64 by OpenAI Client
|
# Default response is float32 decoded from base64 by OpenAI Client
|
||||||
default_response = requests.post(
|
default_response = requests.post(
|
||||||
@ -237,5 +241,8 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
|
|||||||
default_response.raise_for_status()
|
default_response.raise_for_status()
|
||||||
responses_default = PoolingResponse.model_validate(default_response.json())
|
responses_default = PoolingResponse.model_validate(default_response.json())
|
||||||
|
|
||||||
assert responses_float.data[0].data == responses_default.data[0].data
|
check_embeddings_close(
|
||||||
assert responses_float.data[1].data == responses_default.data[1].data
|
embeddings_0_lst=[d.data for d in responses_default.data],
|
||||||
|
embeddings_1_lst=[d.data for d in responses_default.data],
|
||||||
|
name_0="float32",
|
||||||
|
name_1="base64")
|
||||||
|
|||||||
@ -12,6 +12,14 @@ MODEL_NAME = "BAAI/bge-reranker-base"
|
|||||||
DTYPE = "bfloat16"
|
DTYPE = "bfloat16"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def v1(run_with_both_engines):
|
||||||
|
# Simple autouse wrapper to run both engines for each test
|
||||||
|
# This can be promoted up to conftest.py to run for every
|
||||||
|
# test in a package
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def server():
|
def server():
|
||||||
args = ["--enforce-eager", "--max-model-len", "100", "--dtype", DTYPE]
|
args = ["--enforce-eager", "--max-model-len", "100", "--dtype", DTYPE]
|
||||||
|
|||||||
@ -11,6 +11,15 @@ from vllm.entrypoints.openai.protocol import ScoreResponse
|
|||||||
|
|
||||||
from ...utils import RemoteOpenAIServer
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def v1(run_with_both_engines):
|
||||||
|
# Simple autouse wrapper to run both engines for each test
|
||||||
|
# This can be promoted up to conftest.py to run for every
|
||||||
|
# test in a package
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
{
|
{
|
||||||
"name": "BAAI/bge-reranker-v2-m3",
|
"name": "BAAI/bge-reranker-v2-m3",
|
||||||
|
|||||||
@ -6,6 +6,14 @@ from transformers import AutoModelForSequenceClassification
|
|||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
# TODO: enable when float32 is supported by V1
|
||||||
|
# @pytest.fixture(autouse=True)
|
||||||
|
# def v1(run_with_both_engines):
|
||||||
|
# # Simple autouse wrapper to run both engines for each test
|
||||||
|
# # This can be promoted up to conftest.py to run for every
|
||||||
|
# # test in a package
|
||||||
|
# pass
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model",
|
"model",
|
||||||
@ -29,7 +37,7 @@ def test_models(
|
|||||||
# switch to use ROCm CK FA backend
|
# switch to use ROCm CK FA backend
|
||||||
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
|
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
|
||||||
|
|
||||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
with vllm_runner(model, max_model_len=512, dtype=dtype) as vllm_model:
|
||||||
vllm_outputs = vllm_model.classify(example_prompts)
|
vllm_outputs = vllm_model.classify(example_prompts)
|
||||||
|
|
||||||
with hf_runner(model,
|
with hf_runner(model,
|
||||||
|
|||||||
@ -8,6 +8,14 @@ from vllm.platforms import current_platform
|
|||||||
from ...utils import check_embeddings_close
|
from ...utils import check_embeddings_close
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def v1(run_with_both_engines):
|
||||||
|
# Simple autouse wrapper to run both engines for each test
|
||||||
|
# This can be promoted up to conftest.py to run for every
|
||||||
|
# test in a package
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model",
|
"model",
|
||||||
[
|
[
|
||||||
@ -20,15 +28,27 @@ from ...utils import check_embeddings_close
|
|||||||
marks=[pytest.mark.core_model]),
|
marks=[pytest.mark.core_model]),
|
||||||
pytest.param("intfloat/e5-mistral-7b-instruct",
|
pytest.param("intfloat/e5-mistral-7b-instruct",
|
||||||
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
|
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
|
||||||
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"),
|
# the qwen models interfere with each other (see PR
|
||||||
|
# https://github.com/vllm-project/vllm/pull/18720).
|
||||||
|
# To avoid this problem, for now we skip v0 since it will be
|
||||||
|
# deprecated anyway.
|
||||||
|
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base",
|
||||||
|
marks=[pytest.mark.skip_v0]),
|
||||||
# [Encoder-only]
|
# [Encoder-only]
|
||||||
pytest.param("BAAI/bge-base-en-v1.5",
|
pytest.param("BAAI/bge-base-en-v1.5",
|
||||||
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
|
marks=[
|
||||||
pytest.param("sentence-transformers/all-MiniLM-L12-v2"),
|
pytest.mark.core_model, pytest.mark.cpu_model,
|
||||||
pytest.param("intfloat/multilingual-e5-small"),
|
pytest.mark.skip_v1
|
||||||
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"),
|
]),
|
||||||
|
pytest.param("sentence-transformers/all-MiniLM-L12-v2",
|
||||||
|
marks=[pytest.mark.skip_v1]),
|
||||||
|
pytest.param("intfloat/multilingual-e5-small",
|
||||||
|
marks=[pytest.mark.skip_v1]),
|
||||||
|
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
||||||
|
marks=[pytest.mark.skip_v1]),
|
||||||
# [Cross-Encoder]
|
# [Cross-Encoder]
|
||||||
pytest.param("sentence-transformers/stsb-roberta-base-v2"),
|
pytest.param("sentence-transformers/stsb-roberta-base-v2",
|
||||||
|
marks=[pytest.mark.skip_v1]),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_models(
|
def test_models(
|
||||||
@ -62,7 +82,7 @@ def test_models(
|
|||||||
|
|
||||||
with vllm_runner(model,
|
with vllm_runner(model,
|
||||||
task="embed",
|
task="embed",
|
||||||
max_model_len=None,
|
max_model_len=512,
|
||||||
**vllm_extra_kwargs) as vllm_model:
|
**vllm_extra_kwargs) as vllm_model:
|
||||||
vllm_outputs = vllm_model.encode(example_prompts)
|
vllm_outputs = vllm_model.encode(example_prompts)
|
||||||
|
|
||||||
|
|||||||
@ -265,8 +265,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
|||||||
|
|
||||||
_EMBEDDING_EXAMPLE_MODELS = {
|
_EMBEDDING_EXAMPLE_MODELS = {
|
||||||
# [Text-only]
|
# [Text-only]
|
||||||
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
|
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5", v0_only=True),
|
||||||
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"),
|
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2", v0_only=True), # noqa: E501
|
||||||
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
|
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
|
||||||
"GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
|
"GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
@ -279,16 +279,16 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
|||||||
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
|
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
|
||||||
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
|
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
|
||||||
"ModernBertModel": _HfExamplesInfo("Alibaba-NLP/gte-modernbert-base",
|
"ModernBertModel": _HfExamplesInfo("Alibaba-NLP/gte-modernbert-base",
|
||||||
trust_remote_code=True),
|
trust_remote_code=True, v0_only=True),
|
||||||
"NomicBertModel": _HfExamplesInfo("nomic-ai/nomic-embed-text-v2-moe",
|
"NomicBertModel": _HfExamplesInfo("nomic-ai/nomic-embed-text-v2-moe",
|
||||||
trust_remote_code=True),
|
trust_remote_code=True, v0_only=True), # noqa: E501
|
||||||
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),
|
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),
|
||||||
"Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"),
|
"Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"),
|
||||||
"Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B"),
|
"Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B"),
|
||||||
"Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501
|
"Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501
|
||||||
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"), # noqa: E501
|
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2", v0_only=True), # noqa: E501
|
||||||
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), # noqa: E501
|
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1", v0_only=True), # noqa: E501
|
||||||
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"),
|
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small", v0_only=True), # noqa: E501
|
||||||
# [Multimodal]
|
# [Multimodal]
|
||||||
"LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"),
|
"LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"),
|
||||||
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",
|
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",
|
||||||
@ -300,10 +300,10 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
|||||||
|
|
||||||
_CROSS_ENCODER_EXAMPLE_MODELS = {
|
_CROSS_ENCODER_EXAMPLE_MODELS = {
|
||||||
# [Text-only]
|
# [Text-only]
|
||||||
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2"), # noqa: E501
|
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501
|
||||||
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base"), # noqa: E501
|
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501
|
||||||
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3"), # noqa: E501
|
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501
|
||||||
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base"), # noqa: E501
|
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501
|
||||||
}
|
}
|
||||||
|
|
||||||
_MULTIMODAL_EXAMPLE_MODELS = {
|
_MULTIMODAL_EXAMPLE_MODELS = {
|
||||||
|
|||||||
@ -68,6 +68,7 @@ def _run_incremental_decode(tokenizer,
|
|||||||
None,
|
None,
|
||||||
params,
|
params,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
0.0,
|
0.0,
|
||||||
None,
|
None,
|
||||||
cache_salt=None,
|
cache_salt=None,
|
||||||
|
|||||||
@ -43,6 +43,7 @@ def make_request(request_id,
|
|||||||
multi_modal_hashes=mm_hashes,
|
multi_modal_hashes=mm_hashes,
|
||||||
multi_modal_placeholders=mm_positions,
|
multi_modal_placeholders=mm_positions,
|
||||||
sampling_params=SamplingParams(max_tokens=17),
|
sampling_params=SamplingParams(max_tokens=17),
|
||||||
|
pooling_params=None,
|
||||||
eos_token_id=100,
|
eos_token_id=100,
|
||||||
lora_request=None,
|
lora_request=None,
|
||||||
cache_salt=cache_salt,
|
cache_salt=cache_salt,
|
||||||
|
|||||||
@ -39,6 +39,7 @@ def make_request(request_id,
|
|||||||
multi_modal_placeholders=mm_positions,
|
multi_modal_placeholders=mm_positions,
|
||||||
sampling_params=SamplingParams(max_tokens=17,
|
sampling_params=SamplingParams(max_tokens=17,
|
||||||
prompt_logprobs=prompt_logprobs),
|
prompt_logprobs=prompt_logprobs),
|
||||||
|
pooling_params=None,
|
||||||
eos_token_id=100,
|
eos_token_id=100,
|
||||||
lora_request=None,
|
lora_request=None,
|
||||||
cache_salt=cache_salt,
|
cache_salt=cache_salt,
|
||||||
|
|||||||
@ -135,6 +135,7 @@ def create_requests(num_requests: int,
|
|||||||
request_id=f"{i}",
|
request_id=f"{i}",
|
||||||
prompt_token_ids=[i] * num_tokens,
|
prompt_token_ids=[i] * num_tokens,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
|
pooling_params=None,
|
||||||
multi_modal_inputs=mm_inputs,
|
multi_modal_inputs=mm_inputs,
|
||||||
multi_modal_placeholders=mm_position,
|
multi_modal_placeholders=mm_position,
|
||||||
multi_modal_hashes=None,
|
multi_modal_hashes=None,
|
||||||
@ -283,6 +284,7 @@ def test_schedule_partial_requests():
|
|||||||
spec_token_ids=None,
|
spec_token_ids=None,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
|
pooler_output=[],
|
||||||
)
|
)
|
||||||
scheduler.update_from_output(output, model_runner_output)
|
scheduler.update_from_output(output, model_runner_output)
|
||||||
|
|
||||||
@ -333,6 +335,7 @@ def test_no_mm_input_chunking():
|
|||||||
spec_token_ids=None,
|
spec_token_ids=None,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
|
pooler_output=[],
|
||||||
)
|
)
|
||||||
scheduler.update_from_output(output, model_runner_output)
|
scheduler.update_from_output(output, model_runner_output)
|
||||||
|
|
||||||
@ -396,6 +399,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
|
|||||||
spec_token_ids=None,
|
spec_token_ids=None,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
|
pooler_output=[],
|
||||||
)
|
)
|
||||||
scheduler.update_from_output(output, model_runner_output)
|
scheduler.update_from_output(output, model_runner_output)
|
||||||
|
|
||||||
@ -420,6 +424,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
|
|||||||
spec_token_ids=None,
|
spec_token_ids=None,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
|
pooler_output=[],
|
||||||
)
|
)
|
||||||
scheduler.update_from_output(output1, model_runner_output)
|
scheduler.update_from_output(output1, model_runner_output)
|
||||||
output2 = scheduler.schedule()
|
output2 = scheduler.schedule()
|
||||||
@ -473,7 +478,8 @@ def test_stop_via_update_from_output():
|
|||||||
11]], # First request hits EOS, second continues
|
11]], # First request hits EOS, second continues
|
||||||
spec_token_ids=None,
|
spec_token_ids=None,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={})
|
prompt_logprobs_dict={},
|
||||||
|
pooler_output=[])
|
||||||
|
|
||||||
scheduler.update_from_output(scheduler_output, model_output)
|
scheduler.update_from_output(scheduler_output, model_output)
|
||||||
|
|
||||||
@ -523,7 +529,8 @@ def test_stop_via_update_from_output():
|
|||||||
[13, 14]], # First request hits stop token
|
[13, 14]], # First request hits stop token
|
||||||
spec_token_ids=None,
|
spec_token_ids=None,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={})
|
prompt_logprobs_dict={},
|
||||||
|
pooler_output=[])
|
||||||
|
|
||||||
scheduler.update_from_output(scheduler_output, model_output)
|
scheduler.update_from_output(scheduler_output, model_output)
|
||||||
|
|
||||||
@ -572,7 +579,8 @@ def test_stop_via_update_from_output():
|
|||||||
[13]], # First request exceeds max_tokens
|
[13]], # First request exceeds max_tokens
|
||||||
spec_token_ids=None,
|
spec_token_ids=None,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={})
|
prompt_logprobs_dict={},
|
||||||
|
pooler_output=[])
|
||||||
|
|
||||||
scheduler.update_from_output(scheduler_output, model_output)
|
scheduler.update_from_output(scheduler_output, model_output)
|
||||||
|
|
||||||
@ -614,7 +622,8 @@ def test_stop_via_update_from_output():
|
|||||||
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
|
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
|
||||||
spec_token_ids=None,
|
spec_token_ids=None,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={})
|
prompt_logprobs_dict={},
|
||||||
|
pooler_output=[])
|
||||||
|
|
||||||
scheduler.update_from_output(scheduler_output, model_output)
|
scheduler.update_from_output(scheduler_output, model_output)
|
||||||
|
|
||||||
@ -663,6 +672,7 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
|
|||||||
spec_token_ids=None,
|
spec_token_ids=None,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
|
pooler_output=[],
|
||||||
)
|
)
|
||||||
scheduler.update_from_output(scheduler_output0, model_runner_output)
|
scheduler.update_from_output(scheduler_output0, model_runner_output)
|
||||||
|
|
||||||
@ -680,6 +690,7 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
|
|||||||
spec_token_ids=None,
|
spec_token_ids=None,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
|
pooler_output=[],
|
||||||
)
|
)
|
||||||
scheduler.update_from_output(scheduler_output1, model_runner_output)
|
scheduler.update_from_output(scheduler_output1, model_runner_output)
|
||||||
|
|
||||||
@ -730,6 +741,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
|
|||||||
spec_token_ids=spec_tokens,
|
spec_token_ids=spec_tokens,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
|
pooler_output=[],
|
||||||
)
|
)
|
||||||
engine_core_outputs = scheduler.update_from_output(output,
|
engine_core_outputs = scheduler.update_from_output(output,
|
||||||
model_runner_output)
|
model_runner_output)
|
||||||
@ -769,6 +781,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
|
|||||||
spec_token_ids=None,
|
spec_token_ids=None,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
|
pooler_output=[],
|
||||||
)
|
)
|
||||||
engine_core_outputs = scheduler.update_from_output(output,
|
engine_core_outputs = scheduler.update_from_output(output,
|
||||||
model_runner_output)
|
model_runner_output)
|
||||||
@ -896,6 +909,7 @@ def test_kv_connector_basic():
|
|||||||
spec_token_ids=None,
|
spec_token_ids=None,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
|
pooler_output=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ensure ScheduleOutput is correct.
|
# Ensure ScheduleOutput is correct.
|
||||||
@ -941,6 +955,7 @@ def test_kv_connector_basic():
|
|||||||
spec_token_ids=None,
|
spec_token_ids=None,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
|
pooler_output=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
# We should get a local cache hit of NUM_TOKENS_PREFIX and
|
# We should get a local cache hit of NUM_TOKENS_PREFIX and
|
||||||
@ -1007,6 +1022,7 @@ def test_kv_connector_unable_to_allocate():
|
|||||||
spec_token_ids=None,
|
spec_token_ids=None,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
|
pooler_output=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Just one request should be running.
|
# Just one request should be running.
|
||||||
@ -1087,6 +1103,7 @@ def test_kv_connector_handles_preemption():
|
|||||||
spec_token_ids=None,
|
spec_token_ids=None,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
|
pooler_output=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
# All can be scheduled - 1st token.
|
# All can be scheduled - 1st token.
|
||||||
@ -1181,6 +1198,7 @@ def make_output(scheduler: Scheduler):
|
|||||||
spec_token_ids=None,
|
spec_token_ids=None,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
|
pooler_output=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -39,6 +39,7 @@ def make_request() -> EngineCoreRequest:
|
|||||||
mm_hashes=None,
|
mm_hashes=None,
|
||||||
mm_placeholders=None,
|
mm_placeholders=None,
|
||||||
sampling_params=SamplingParams(),
|
sampling_params=SamplingParams(),
|
||||||
|
pooling_params=None,
|
||||||
eos_token_id=None,
|
eos_token_id=None,
|
||||||
arrival_time=time.time(),
|
arrival_time=time.time(),
|
||||||
lora_request=None,
|
lora_request=None,
|
||||||
|
|||||||
@ -53,6 +53,7 @@ def make_request(
|
|||||||
mm_hashes=None,
|
mm_hashes=None,
|
||||||
mm_placeholders=None,
|
mm_placeholders=None,
|
||||||
sampling_params=params,
|
sampling_params=params,
|
||||||
|
pooling_params=None,
|
||||||
eos_token_id=None,
|
eos_token_id=None,
|
||||||
arrival_time=time.time(),
|
arrival_time=time.time(),
|
||||||
lora_request=None,
|
lora_request=None,
|
||||||
|
|||||||
@ -33,6 +33,7 @@ def test_fast_inc_detok_invalid_utf8_err_case():
|
|||||||
None,
|
None,
|
||||||
params,
|
params,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
0.0,
|
0.0,
|
||||||
None,
|
None,
|
||||||
cache_salt=None,
|
cache_salt=None,
|
||||||
|
|||||||
@ -66,7 +66,8 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
|
|||||||
output_kind=request_output_kind,
|
output_kind=request_output_kind,
|
||||||
stop=[],
|
stop=[],
|
||||||
include_stop_str_in_output=False,
|
include_stop_str_in_output=False,
|
||||||
))
|
),
|
||||||
|
pooling_params=None)
|
||||||
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -416,7 +417,8 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
|
|||||||
include_stop_str_in_output=False,
|
include_stop_str_in_output=False,
|
||||||
logprobs=num_sample_logprobs,
|
logprobs=num_sample_logprobs,
|
||||||
prompt_logprobs=num_prompt_logprobs,
|
prompt_logprobs=num_prompt_logprobs,
|
||||||
))
|
),
|
||||||
|
pooling_params=None)
|
||||||
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -582,7 +584,8 @@ def test_stop_token(include_stop_str_in_output: bool,
|
|||||||
logprobs=num_sample_logprobs,
|
logprobs=num_sample_logprobs,
|
||||||
prompt_logprobs=None,
|
prompt_logprobs=None,
|
||||||
ignore_eos=ignore_eos,
|
ignore_eos=ignore_eos,
|
||||||
))
|
),
|
||||||
|
pooling_params=None)
|
||||||
|
|
||||||
# Add request to the detokenizer.
|
# Add request to the detokenizer.
|
||||||
output_processor.add_request(request, prompt_string)
|
output_processor.add_request(request, prompt_string)
|
||||||
@ -678,7 +681,8 @@ def test_stop_string(include_stop_str_in_output: bool,
|
|||||||
include_stop_str_in_output=include_stop_str_in_output,
|
include_stop_str_in_output=include_stop_str_in_output,
|
||||||
logprobs=num_sample_logprobs,
|
logprobs=num_sample_logprobs,
|
||||||
prompt_logprobs=None,
|
prompt_logprobs=None,
|
||||||
))
|
),
|
||||||
|
pooling_params=None)
|
||||||
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -786,6 +790,7 @@ def test_iteration_stats(dummy_test_vectors):
|
|||||||
cache_salt=None,
|
cache_salt=None,
|
||||||
data_parallel_rank=None,
|
data_parallel_rank=None,
|
||||||
sampling_params=SamplingParams(),
|
sampling_params=SamplingParams(),
|
||||||
|
pooling_params=None,
|
||||||
) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -150,6 +150,7 @@ def create_request(
|
|||||||
request_id=f"id-{request_id}",
|
request_id=f"id-{request_id}",
|
||||||
prompt_token_ids=prompt_token_ids,
|
prompt_token_ids=prompt_token_ids,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
|
pooling_params=None,
|
||||||
multi_modal_inputs=None,
|
multi_modal_inputs=None,
|
||||||
multi_modal_placeholders=None,
|
multi_modal_placeholders=None,
|
||||||
multi_modal_hashes=None,
|
multi_modal_hashes=None,
|
||||||
@ -183,6 +184,7 @@ def create_model_runner_output(
|
|||||||
spec_token_ids=None,
|
spec_token_ids=None,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
|
pooler_output=None,
|
||||||
finished_sending=finished_sending,
|
finished_sending=finished_sending,
|
||||||
finished_recving=finished_recving,
|
finished_recving=finished_recving,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import torch
|
|||||||
|
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||||
|
from vllm.v1.pool.metadata import PoolingMetadata
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
|
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
|
||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
@ -46,7 +47,7 @@ def _compare_objs(obj1, obj2):
|
|||||||
for a_i, b_i in zip(a.block_tables, b.block_tables):
|
for a_i, b_i in zip(a.block_tables, b.block_tables):
|
||||||
_compare_objs(a_i, b_i)
|
_compare_objs(a_i, b_i)
|
||||||
is_same = True
|
is_same = True
|
||||||
elif isinstance(a, (BlockTable, SamplingMetadata)):
|
elif isinstance(a, (BlockTable, SamplingMetadata, PoolingMetadata)):
|
||||||
_compare_objs(a, b)
|
_compare_objs(a, b)
|
||||||
is_same = True # if we make it here must be same
|
is_same = True # if we make it here must be same
|
||||||
elif a == b:
|
elif a == b:
|
||||||
@ -201,6 +202,7 @@ def _construct_cached_request_state(req_id_suffix: int):
|
|||||||
req_id=f"req_id_{req_id_suffix}",
|
req_id=f"req_id_{req_id_suffix}",
|
||||||
prompt_token_ids=prompt_token_ids,
|
prompt_token_ids=prompt_token_ids,
|
||||||
sampling_params=_create_sampling_params(),
|
sampling_params=_create_sampling_params(),
|
||||||
|
pooling_params=None,
|
||||||
mm_inputs=[],
|
mm_inputs=[],
|
||||||
mm_positions=[],
|
mm_positions=[],
|
||||||
block_ids=([], ),
|
block_ids=([], ),
|
||||||
|
|||||||
@ -122,6 +122,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
|||||||
mm_hashes=[],
|
mm_hashes=[],
|
||||||
mm_positions=[],
|
mm_positions=[],
|
||||||
sampling_params=SamplingParams(),
|
sampling_params=SamplingParams(),
|
||||||
|
pooling_params=None,
|
||||||
block_ids=([0], ),
|
block_ids=([0], ),
|
||||||
num_computed_tokens=0,
|
num_computed_tokens=0,
|
||||||
lora_request=None,
|
lora_request=None,
|
||||||
|
|||||||
@ -4496,11 +4496,31 @@ class VllmConfig:
|
|||||||
|
|
||||||
if self.compilation_config.full_cuda_graph and \
|
if self.compilation_config.full_cuda_graph and \
|
||||||
not self.model_config.disable_cascade_attn:
|
not self.model_config.disable_cascade_attn:
|
||||||
logger.warning_once(
|
logger.info("full_cuda_graph is not supported with "
|
||||||
"full_cuda_graph is not supported with "
|
"cascade attention. Disabling cascade attention.")
|
||||||
"cascade attention. Disabling cascade attention.")
|
|
||||||
self.model_config.disable_cascade_attn = True
|
self.model_config.disable_cascade_attn = True
|
||||||
|
|
||||||
|
disable_chunked_prefill_reasons: list[str] = []
|
||||||
|
|
||||||
|
if self.model_config and self.model_config.pooler_config:
|
||||||
|
pooling_type = self.model_config.pooler_config.pooling_type
|
||||||
|
if pooling_type is None or pooling_type.lower() != "last":
|
||||||
|
disable_chunked_prefill_reasons.append(
|
||||||
|
"Only \"last\" pooling supports chunked "
|
||||||
|
"prefill and prefix caching; disabling both.")
|
||||||
|
|
||||||
|
if disable_chunked_prefill_reasons:
|
||||||
|
for reason in disable_chunked_prefill_reasons:
|
||||||
|
logger.info(reason)
|
||||||
|
self.scheduler_config.chunked_prefill_enabled = False
|
||||||
|
self.scheduler_config.long_prefill_token_threshold = 0
|
||||||
|
self.scheduler_config.max_num_batched_tokens = max(
|
||||||
|
self.scheduler_config.max_model_len,
|
||||||
|
DEFAULT_MAX_NUM_BATCHED_TOKENS)
|
||||||
|
|
||||||
|
if self.cache_config is not None:
|
||||||
|
self.cache_config.enable_prefix_caching = False
|
||||||
|
|
||||||
if (self.kv_events_config is not None
|
if (self.kv_events_config is not None
|
||||||
and self.kv_events_config.enable_kv_cache_events
|
and self.kv_events_config.enable_kv_cache_events
|
||||||
and not self.cache_config.enable_prefix_caching):
|
and not self.cache_config.enable_prefix_caching):
|
||||||
|
|||||||
@ -1041,7 +1041,7 @@ class EngineArgs:
|
|||||||
|
|
||||||
# Set default arguments for V0 or V1 Engine.
|
# Set default arguments for V0 or V1 Engine.
|
||||||
if use_v1:
|
if use_v1:
|
||||||
self._set_default_args_v1(usage_context)
|
self._set_default_args_v1(usage_context, model_config)
|
||||||
else:
|
else:
|
||||||
self._set_default_args_v0(model_config)
|
self._set_default_args_v0(model_config)
|
||||||
|
|
||||||
@ -1349,13 +1349,7 @@ class EngineArgs:
|
|||||||
recommend_to_remove=False)
|
recommend_to_remove=False)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# No Embedding Models so far.
|
# No Mamba or Encoder-Decoder so far.
|
||||||
if model_config.task not in ["generate"]:
|
|
||||||
_raise_or_fallback(feature_name=f"--task {model_config.task}",
|
|
||||||
recommend_to_remove=False)
|
|
||||||
return False
|
|
||||||
|
|
||||||
# No Encoder-Decoder, not all Mamba so far.
|
|
||||||
if not model_config.is_v1_compatible:
|
if not model_config.is_v1_compatible:
|
||||||
_raise_or_fallback(feature_name=model_config.architectures,
|
_raise_or_fallback(feature_name=model_config.architectures,
|
||||||
recommend_to_remove=False)
|
recommend_to_remove=False)
|
||||||
@ -1523,15 +1517,38 @@ class EngineArgs:
|
|||||||
if self.max_num_seqs is None:
|
if self.max_num_seqs is None:
|
||||||
self.max_num_seqs = 256
|
self.max_num_seqs = 256
|
||||||
|
|
||||||
def _set_default_args_v1(self, usage_context: UsageContext) -> None:
|
def _set_default_args_v1(self, usage_context: UsageContext,
|
||||||
|
model_config: ModelConfig) -> None:
|
||||||
"""Set Default Arguments for V1 Engine."""
|
"""Set Default Arguments for V1 Engine."""
|
||||||
|
|
||||||
# V1 always uses chunked prefills.
|
# V1 always uses chunked prefills and prefix caching
|
||||||
self.enable_chunked_prefill = True
|
# for non-pooling tasks.
|
||||||
|
# For pooling tasks the default is False
|
||||||
|
if model_config.runner_type != "pooling":
|
||||||
|
self.enable_chunked_prefill = True
|
||||||
|
if self.enable_prefix_caching is None:
|
||||||
|
self.enable_prefix_caching = True
|
||||||
|
else:
|
||||||
|
|
||||||
# V1 enables prefix caching by default.
|
pooling_type = model_config.pooler_config.pooling_type
|
||||||
if self.enable_prefix_caching is None:
|
|
||||||
self.enable_prefix_caching = True
|
# TODO: when encoder models are supported we'll have to
|
||||||
|
# check for causal attention here.
|
||||||
|
incremental_prefill_supported = (pooling_type is not None and
|
||||||
|
pooling_type.lower() == "last")
|
||||||
|
|
||||||
|
action = "Enabling" if \
|
||||||
|
incremental_prefill_supported else "Disabling"
|
||||||
|
|
||||||
|
if self.enable_chunked_prefill is None:
|
||||||
|
self.enable_chunked_prefill = incremental_prefill_supported
|
||||||
|
logger.info("(%s) chunked prefill by default", action)
|
||||||
|
if self.enable_prefix_caching is None:
|
||||||
|
self.enable_prefix_caching = incremental_prefill_supported
|
||||||
|
logger.info("(%s) prefix caching by default", action)
|
||||||
|
|
||||||
|
if not self.enable_chunked_prefill:
|
||||||
|
self.max_num_batched_tokens = model_config.max_model_len
|
||||||
|
|
||||||
# V1 should use the new scheduler by default.
|
# V1 should use the new scheduler by default.
|
||||||
# Swap it only if this arg is set to the original V0 default
|
# Swap it only if this arg is set to the original V0 default
|
||||||
|
|||||||
@ -1266,7 +1266,7 @@ class LLM:
|
|||||||
# the tokenizer for models such as
|
# the tokenizer for models such as
|
||||||
# "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
|
# "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
|
||||||
# lists of tokens to the `text` and `text_pair` kwargs
|
# lists of tokens to the `text` and `text_pair` kwargs
|
||||||
tokenizer = self.llm_engine.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
def ensure_str(prompt: SingletonPrompt):
|
def ensure_str(prompt: SingletonPrompt):
|
||||||
if isinstance(prompt, dict):
|
if isinstance(prompt, dict):
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from typing import Final, Literal, Optional, Union, cast
|
|||||||
|
|
||||||
import jinja2
|
import jinja2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from typing_extensions import assert_never
|
from typing_extensions import assert_never
|
||||||
|
|
||||||
@ -39,7 +40,8 @@ def _get_data(
|
|||||||
elif encoding_format == "base64":
|
elif encoding_format == "base64":
|
||||||
# Force to use float32 for base64 encoding
|
# Force to use float32 for base64 encoding
|
||||||
# to match the OpenAI python client behavior
|
# to match the OpenAI python client behavior
|
||||||
pooling_bytes = np.array(output.data, dtype="float32").tobytes()
|
pt_float32 = output.data.to(dtype=torch.float32)
|
||||||
|
pooling_bytes = np.array(pt_float32, dtype="float32").tobytes()
|
||||||
return base64.b64encode(pooling_bytes).decode("utf-8")
|
return base64.b64encode(pooling_bytes).decode("utf-8")
|
||||||
|
|
||||||
assert_never(encoding_format)
|
assert_never(encoding_format)
|
||||||
|
|||||||
@ -10,11 +10,15 @@ import torch.nn.functional as F
|
|||||||
from typing_extensions import assert_never
|
from typing_extensions import assert_never
|
||||||
|
|
||||||
from vllm.config import ModelConfig, PoolerConfig
|
from vllm.config import ModelConfig, PoolerConfig
|
||||||
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
|
from vllm.model_executor.pooling_metadata import ( # noqa: E501
|
||||||
PoolingTensors)
|
PoolingMetadata as V0PoolingMetadata)
|
||||||
|
from vllm.model_executor.pooling_metadata import PoolingTensors
|
||||||
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
|
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
|
||||||
from vllm.transformers_utils.config import (
|
from vllm.transformers_utils.config import (
|
||||||
get_cross_encoder_activation_function)
|
get_cross_encoder_activation_function)
|
||||||
|
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
|
||||||
|
|
||||||
|
PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
|
||||||
|
|
||||||
|
|
||||||
class PoolingType(IntEnum):
|
class PoolingType(IntEnum):
|
||||||
@ -75,15 +79,18 @@ class SimplePooler(nn.Module):
|
|||||||
|
|
||||||
def get_prompt_lens(
|
def get_prompt_lens(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||||
pooling_metadata: PoolingMetadata,
|
pooling_metadata: PoolingMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if isinstance(pooling_metadata, V1PoolingMetadata):
|
||||||
|
return pooling_metadata.prompt_lens
|
||||||
|
assert isinstance(hidden_states, torch.Tensor)
|
||||||
return PoolingTensors.from_pooling_metadata(
|
return PoolingTensors.from_pooling_metadata(
|
||||||
pooling_metadata, hidden_states.device).prompt_lens
|
pooling_metadata, hidden_states.device).prompt_lens
|
||||||
|
|
||||||
def extract_states(
|
def extract_states(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||||
pooling_metadata: PoolingMetadata,
|
pooling_metadata: PoolingMetadata,
|
||||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -93,7 +100,7 @@ class SimplePooler(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||||
pooling_metadata: PoolingMetadata,
|
pooling_metadata: PoolingMetadata,
|
||||||
) -> PoolerOutput:
|
) -> PoolerOutput:
|
||||||
pooled_data = self.extract_states(hidden_states, pooling_metadata)
|
pooled_data = self.extract_states(hidden_states, pooling_metadata)
|
||||||
@ -106,11 +113,19 @@ class CLSPool(SimplePooler):
|
|||||||
|
|
||||||
def extract_states(
|
def extract_states(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||||
pooling_metadata: PoolingMetadata,
|
pooling_metadata: PoolingMetadata,
|
||||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||||
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
||||||
|
|
||||||
|
if isinstance(hidden_states, list):
|
||||||
|
result = []
|
||||||
|
for req_state, prompt_len in zip(hidden_states, prompt_lens):
|
||||||
|
assert prompt_len == req_state.shape[0], \
|
||||||
|
"partial prefill not supported with CLS pooling"
|
||||||
|
result.append(req_state[0])
|
||||||
|
return result
|
||||||
|
|
||||||
first_token_flat_indices = torch.zeros_like(prompt_lens)
|
first_token_flat_indices = torch.zeros_like(prompt_lens)
|
||||||
first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
|
first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
|
||||||
return hidden_states[first_token_flat_indices]
|
return hidden_states[first_token_flat_indices]
|
||||||
@ -120,9 +135,12 @@ class LastPool(SimplePooler):
|
|||||||
|
|
||||||
def extract_states(
|
def extract_states(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||||
pooling_metadata: PoolingMetadata,
|
pooling_metadata: PoolingMetadata,
|
||||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||||
|
if isinstance(hidden_states, list):
|
||||||
|
return [h[-1] for h in hidden_states]
|
||||||
|
|
||||||
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
||||||
|
|
||||||
last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
|
last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
|
||||||
@ -133,11 +151,17 @@ class AllPool(SimplePooler):
|
|||||||
|
|
||||||
def extract_states(
|
def extract_states(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||||
pooling_metadata: PoolingMetadata,
|
pooling_metadata: PoolingMetadata,
|
||||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||||
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
||||||
|
|
||||||
|
if isinstance(hidden_states, list):
|
||||||
|
for req_state, prompt_len in zip(hidden_states, prompt_lens):
|
||||||
|
assert prompt_len == req_state.shape[0], \
|
||||||
|
"partial prefill not supported with ALL pooling"
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
offset = 0
|
offset = 0
|
||||||
pooled_data = list[torch.Tensor]()
|
pooled_data = list[torch.Tensor]()
|
||||||
for prompt_len in prompt_lens:
|
for prompt_len in prompt_lens:
|
||||||
@ -151,11 +175,20 @@ class MeanPool(SimplePooler):
|
|||||||
|
|
||||||
def extract_states(
|
def extract_states(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||||
pooling_metadata: PoolingMetadata,
|
pooling_metadata: PoolingMetadata,
|
||||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||||
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
||||||
|
|
||||||
|
if isinstance(hidden_states, list):
|
||||||
|
result = []
|
||||||
|
for req_state, prompt_len in zip(hidden_states, prompt_lens):
|
||||||
|
assert prompt_len == req_state.shape[0], \
|
||||||
|
"partial prefill not supported with mean pooling"
|
||||||
|
result.append(torch.mean(req_state, dim=0,
|
||||||
|
dtype=torch.float32))
|
||||||
|
return result
|
||||||
|
|
||||||
# Use float32 for torch.cumsum in MeanPool,
|
# Use float32 for torch.cumsum in MeanPool,
|
||||||
# otherwise precision will be lost significantly.
|
# otherwise precision will be lost significantly.
|
||||||
cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)
|
cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)
|
||||||
@ -184,30 +217,53 @@ class StepPool(SimplePooler):
|
|||||||
self.step_tag_id = step_tag_id
|
self.step_tag_id = step_tag_id
|
||||||
self.returned_token_ids = returned_token_ids
|
self.returned_token_ids = returned_token_ids
|
||||||
|
|
||||||
|
def get_prompt_token_ids(
|
||||||
|
self,
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> list[torch.Tensor]:
|
||||||
|
if isinstance(pooling_metadata, V1PoolingMetadata):
|
||||||
|
return [
|
||||||
|
pooling_metadata.prompt_token_ids[i, :num]
|
||||||
|
for i, num in enumerate(pooling_metadata.prompt_lens)
|
||||||
|
]
|
||||||
|
return [
|
||||||
|
torch.tensor(seq_data_i.prompt_token_ids)
|
||||||
|
for seq_data_i in pooling_metadata.seq_data.values()
|
||||||
|
]
|
||||||
|
|
||||||
def extract_states(
|
def extract_states(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||||
pooling_metadata: PoolingMetadata,
|
pooling_metadata: PoolingMetadata,
|
||||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||||
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
||||||
|
prompt_token_ids = self.get_prompt_token_ids(pooling_metadata)
|
||||||
|
|
||||||
|
pooled_data: list[torch.Tensor] = []
|
||||||
|
|
||||||
|
if isinstance(hidden_states, list):
|
||||||
|
for req_state, prompt_len in zip(hidden_states, prompt_lens):
|
||||||
|
assert prompt_len == req_state.shape[0], \
|
||||||
|
"partial prefill not supported with mean pooling"
|
||||||
|
pooled_data = hidden_states
|
||||||
|
else:
|
||||||
|
offset = 0
|
||||||
|
for prompt_len in prompt_lens:
|
||||||
|
pooled_data_i = hidden_states[offset:offset + prompt_len]
|
||||||
|
offset += prompt_len
|
||||||
|
pooled_data.append(pooled_data_i)
|
||||||
|
|
||||||
|
pooled_data = []
|
||||||
returned_token_ids = self.returned_token_ids
|
returned_token_ids = self.returned_token_ids
|
||||||
if returned_token_ids is not None and len(returned_token_ids) > 0:
|
|
||||||
hidden_states = hidden_states[:, returned_token_ids]
|
|
||||||
|
|
||||||
step_tag_id = self.step_tag_id
|
step_tag_id = self.step_tag_id
|
||||||
|
|
||||||
offset = 0
|
for data, token_id in zip(pooled_data, prompt_token_ids):
|
||||||
pooled_data = list[torch.Tensor]()
|
if returned_token_ids is not None and len(returned_token_ids) > 0:
|
||||||
for prompt_len, seq_data_i in zip(prompt_lens,
|
data = data[:, returned_token_ids]
|
||||||
pooling_metadata.seq_data.values()):
|
|
||||||
pooled_data_i = hidden_states[offset:offset + prompt_len]
|
|
||||||
if step_tag_id is not None:
|
|
||||||
token_ids = torch.tensor(seq_data_i.prompt_token_ids)
|
|
||||||
pooled_data_i = pooled_data_i[token_ids == step_tag_id]
|
|
||||||
|
|
||||||
offset += prompt_len
|
if step_tag_id is not None:
|
||||||
pooled_data.append(pooled_data_i)
|
data = data[token_id == step_tag_id]
|
||||||
|
pooled_data.append(data)
|
||||||
|
|
||||||
return pooled_data
|
return pooled_data
|
||||||
|
|
||||||
@ -230,10 +286,17 @@ class PoolerHead(nn.Module):
|
|||||||
else:
|
else:
|
||||||
pooled_data = pooled_data.to(torch.float32)
|
pooled_data = pooled_data.to(torch.float32)
|
||||||
|
|
||||||
dimensions_list = [
|
if isinstance(pooling_metadata, V0PoolingMetadata):
|
||||||
pooling_param.dimensions
|
dimensions_list = [
|
||||||
for _, pooling_param in pooling_metadata.seq_groups
|
pooling_param.dimensions
|
||||||
]
|
for _, pooling_param in pooling_metadata.seq_groups
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
assert isinstance(pooled_data, list)
|
||||||
|
dimensions_list = [
|
||||||
|
pooling_param.dimensions
|
||||||
|
for pooling_param in pooling_metadata.pooling_params
|
||||||
|
]
|
||||||
if any(d is not None for d in dimensions_list):
|
if any(d is not None for d in dimensions_list):
|
||||||
# change the output dimension
|
# change the output dimension
|
||||||
assert len(pooled_data) == len(dimensions_list)
|
assert len(pooled_data) == len(dimensions_list)
|
||||||
@ -325,20 +388,41 @@ class ClassifierPooler(nn.Module):
|
|||||||
raise NotImplementedError(f"task={config.task!r} is not supported"
|
raise NotImplementedError(f"task={config.task!r} is not supported"
|
||||||
" with the classification pooler")
|
" with the classification pooler")
|
||||||
|
|
||||||
|
def get_prompt_lens(
|
||||||
|
self,
|
||||||
|
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if isinstance(pooling_metadata, V1PoolingMetadata):
|
||||||
|
return pooling_metadata.prompt_lens
|
||||||
|
assert isinstance(hidden_states, torch.Tensor)
|
||||||
|
return PoolingTensors.from_pooling_metadata(
|
||||||
|
pooling_metadata, hidden_states.device).prompt_lens
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||||
pooling_metadata: PoolingMetadata,
|
pooling_metadata: PoolingMetadata,
|
||||||
) -> PoolerOutput:
|
) -> PoolerOutput:
|
||||||
"""Pools sentence pair scores from the hidden_states."""
|
"""Pools sentence pair scores from the hidden_states."""
|
||||||
|
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
||||||
|
|
||||||
prompt_lens = PoolingTensors.from_pooling_metadata(
|
pooled_data = list[torch.Tensor]()
|
||||||
pooling_metadata, hidden_states.device).prompt_lens
|
if isinstance(hidden_states, list):
|
||||||
|
for req_state, prompt_len in zip(hidden_states, prompt_lens):
|
||||||
|
assert prompt_len == req_state.shape[0], \
|
||||||
|
"partial prefill not supported with classifier"
|
||||||
|
pooled_data = hidden_states
|
||||||
|
else:
|
||||||
|
offset = 0
|
||||||
|
for prompt_len in prompt_lens:
|
||||||
|
pooled_data_i = hidden_states[offset:offset + prompt_len]
|
||||||
|
offset += prompt_len
|
||||||
|
pooled_data.append(pooled_data_i)
|
||||||
|
|
||||||
offset = 0
|
offset = 0
|
||||||
pooled_data_lst = []
|
pooled_data_lst = []
|
||||||
for prompt_len in prompt_lens:
|
for pooled_data_i in pooled_data:
|
||||||
pooled_data_i = hidden_states[offset:offset + prompt_len]
|
|
||||||
|
|
||||||
if self.pooler is not None:
|
if self.pooler is not None:
|
||||||
final_shape_tensor = self.pooler(pooled_data_i)
|
final_shape_tensor = self.pooler(pooled_data_i)
|
||||||
@ -346,7 +430,6 @@ class ClassifierPooler(nn.Module):
|
|||||||
final_shape_tensor = self.classifier(pooled_data_i)
|
final_shape_tensor = self.classifier(pooled_data_i)
|
||||||
|
|
||||||
pooled_data_lst.append(final_shape_tensor)
|
pooled_data_lst.append(final_shape_tensor)
|
||||||
offset += prompt_len
|
|
||||||
|
|
||||||
pooled_output = torch.stack(pooled_data_lst)
|
pooled_output = torch.stack(pooled_data_lst)
|
||||||
|
|
||||||
|
|||||||
@ -446,8 +446,8 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
|
|||||||
softmax=False)
|
softmax=False)
|
||||||
|
|
||||||
|
|
||||||
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
class BertForSequenceClassification(nn.Module, SupportsV0Only,
|
||||||
SupportsQuant):
|
SupportsCrossEncoding, SupportsQuant):
|
||||||
"""A model that uses Bert to provide embedding functionalities.
|
"""A model that uses Bert to provide embedding functionalities.
|
||||||
|
|
||||||
This class encapsulates the BertModel and provides an interface for
|
This class encapsulates the BertModel and provides an interface for
|
||||||
|
|||||||
@ -21,7 +21,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||||
|
|
||||||
from .interfaces import SupportsCrossEncoding
|
from .interfaces import SupportsCrossEncoding, SupportsV0Only
|
||||||
from .utils import WeightsMapper, maybe_prefix
|
from .utils import WeightsMapper, maybe_prefix
|
||||||
|
|
||||||
|
|
||||||
@ -270,7 +270,8 @@ class ModernBertPooler(nn.Module):
|
|||||||
return pooled_output
|
return pooled_output
|
||||||
|
|
||||||
|
|
||||||
class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
|
||||||
|
SupportsCrossEncoding):
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@ -375,7 +375,12 @@ class Qwen3ForSequenceClassification(nn.Module, SupportsLoRA,
|
|||||||
) -> Optional[PoolerOutput]:
|
) -> Optional[PoolerOutput]:
|
||||||
hidden_states = self._pooler.extract_states(hidden_states,
|
hidden_states = self._pooler.extract_states(hidden_states,
|
||||||
pooling_metadata)
|
pooling_metadata)
|
||||||
logits, _ = self.score(hidden_states)
|
|
||||||
|
if isinstance(hidden_states, list):
|
||||||
|
logits = [self.score(state)[0] for state in hidden_states]
|
||||||
|
else:
|
||||||
|
logits, _ = self.score(hidden_states)
|
||||||
|
|
||||||
pooled_data = self._pooler.head(logits, pooling_metadata)
|
pooled_data = self._pooler.head(logits, pooling_metadata)
|
||||||
pooled_outputs = [
|
pooled_outputs = [
|
||||||
self._pooler.build_output(data.squeeze(-1)) for data in pooled_data
|
self._pooler.build_output(data.squeeze(-1)) for data in pooled_data
|
||||||
|
|||||||
@ -5,6 +5,8 @@ from typing import TYPE_CHECKING, Any, Optional
|
|||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
|
|
||||||
|
from vllm.sampling_params import RequestOutputKind
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
|
|
||||||
@ -23,6 +25,7 @@ class PoolingParams(
|
|||||||
|
|
||||||
dimensions: Optional[int] = None
|
dimensions: Optional[int] = None
|
||||||
additional_data: Optional[Any] = None
|
additional_data: Optional[Any] = None
|
||||||
|
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
|
||||||
|
|
||||||
def clone(self) -> "PoolingParams":
|
def clone(self) -> "PoolingParams":
|
||||||
"""Returns a deep copy of the PoolingParams instance."""
|
"""Returns a deep copy of the PoolingParams instance."""
|
||||||
@ -52,3 +55,7 @@ class PoolingParams(
|
|||||||
return (f"PoolingParams("
|
return (f"PoolingParams("
|
||||||
f"dimensions={self.dimensions}, "
|
f"dimensions={self.dimensions}, "
|
||||||
f"additional_metadata={self.additional_data})")
|
f"additional_metadata={self.additional_data})")
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
assert self.output_kind == RequestOutputKind.FINAL_ONLY,\
|
||||||
|
"For pooling output_kind has to be FINAL_ONLY"
|
||||||
|
|||||||
@ -146,7 +146,8 @@ class KVCacheManager:
|
|||||||
# Prefix caching is disabled or
|
# Prefix caching is disabled or
|
||||||
# When the request requires prompt logprobs, we skip prefix caching.
|
# When the request requires prompt logprobs, we skip prefix caching.
|
||||||
if (not self.enable_caching
|
if (not self.enable_caching
|
||||||
or request.sampling_params.prompt_logprobs is not None):
|
or (request.sampling_params is not None
|
||||||
|
and request.sampling_params.prompt_logprobs is not None)):
|
||||||
return self.create_empty_block_list(), 0
|
return self.create_empty_block_list(), 0
|
||||||
|
|
||||||
# The block hashes for the request may already be computed
|
# The block hashes for the request may already be computed
|
||||||
|
|||||||
@ -14,6 +14,7 @@ if TYPE_CHECKING:
|
|||||||
KVConnectorMetadata)
|
KVConnectorMetadata)
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||||
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.v1.request import Request
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
@ -26,7 +27,8 @@ class NewRequestData:
|
|||||||
mm_inputs: list[MultiModalKwargs]
|
mm_inputs: list[MultiModalKwargs]
|
||||||
mm_hashes: list[str]
|
mm_hashes: list[str]
|
||||||
mm_positions: list[PlaceholderRange]
|
mm_positions: list[PlaceholderRange]
|
||||||
sampling_params: SamplingParams
|
sampling_params: Optional[SamplingParams]
|
||||||
|
pooling_params: Optional[PoolingParams]
|
||||||
block_ids: tuple[list[int], ...]
|
block_ids: tuple[list[int], ...]
|
||||||
num_computed_tokens: int
|
num_computed_tokens: int
|
||||||
lora_request: Optional[LoRARequest]
|
lora_request: Optional[LoRARequest]
|
||||||
@ -44,6 +46,7 @@ class NewRequestData:
|
|||||||
mm_hashes=request.mm_hashes,
|
mm_hashes=request.mm_hashes,
|
||||||
mm_positions=request.mm_positions,
|
mm_positions=request.mm_positions,
|
||||||
sampling_params=request.sampling_params,
|
sampling_params=request.sampling_params,
|
||||||
|
pooling_params=request.pooling_params,
|
||||||
block_ids=block_ids,
|
block_ids=block_ids,
|
||||||
num_computed_tokens=request.num_computed_tokens,
|
num_computed_tokens=request.num_computed_tokens,
|
||||||
lora_request=request.lora_request,
|
lora_request=request.lora_request,
|
||||||
|
|||||||
@ -402,6 +402,15 @@ class Scheduler(SchedulerInterface):
|
|||||||
< num_new_tokens):
|
< num_new_tokens):
|
||||||
num_new_tokens = (
|
num_new_tokens = (
|
||||||
self.scheduler_config.long_prefill_token_threshold)
|
self.scheduler_config.long_prefill_token_threshold)
|
||||||
|
|
||||||
|
# chunked prefill has to be enabled explicitly to allow
|
||||||
|
# pooling requests to be chunked
|
||||||
|
if not self.scheduler_config.chunked_prefill_enabled and \
|
||||||
|
num_new_tokens > token_budget:
|
||||||
|
self.waiting.popleft()
|
||||||
|
skipped_waiting_requests.appendleft(request)
|
||||||
|
continue
|
||||||
|
|
||||||
num_new_tokens = min(num_new_tokens, token_budget)
|
num_new_tokens = min(num_new_tokens, token_budget)
|
||||||
assert num_new_tokens > 0
|
assert num_new_tokens > 0
|
||||||
|
|
||||||
@ -707,6 +716,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
logprobs = model_runner_output.logprobs
|
logprobs = model_runner_output.logprobs
|
||||||
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
|
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
|
||||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||||
|
pooler_outputs = model_runner_output.pooler_output
|
||||||
|
|
||||||
new_running: list[Request] = []
|
new_running: list[Request] = []
|
||||||
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
|
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
|
||||||
@ -724,7 +734,8 @@ class Scheduler(SchedulerInterface):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
req_index = model_runner_output.req_id_to_index[req_id]
|
req_index = model_runner_output.req_id_to_index[req_id]
|
||||||
generated_token_ids = sampled_token_ids[req_index]
|
generated_token_ids = sampled_token_ids[
|
||||||
|
req_index] if sampled_token_ids else []
|
||||||
|
|
||||||
scheduled_spec_token_ids = (
|
scheduled_spec_token_ids = (
|
||||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
|
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
|
||||||
@ -776,8 +787,17 @@ class Scheduler(SchedulerInterface):
|
|||||||
del new_token_ids[num_new:] # Trim new tokens if needed.
|
del new_token_ids[num_new:] # Trim new tokens if needed.
|
||||||
break
|
break
|
||||||
|
|
||||||
|
pooler_output = None
|
||||||
|
if pooler_outputs:
|
||||||
|
pooler_output = pooler_outputs[req_index]
|
||||||
|
stopped = check_stop(request, self.max_model_len,
|
||||||
|
pooler_output)
|
||||||
|
if stopped:
|
||||||
|
kv_transfer_params = self._free_request(request)
|
||||||
|
|
||||||
# Extract sample logprobs if needed.
|
# Extract sample logprobs if needed.
|
||||||
if request.sampling_params.logprobs is not None and logprobs:
|
if request.sampling_params is not None \
|
||||||
|
and request.sampling_params.logprobs is not None and logprobs:
|
||||||
# NOTE: once we support N tokens per step (spec decode),
|
# NOTE: once we support N tokens per step (spec decode),
|
||||||
# the outer lists can be of length > 1.
|
# the outer lists can be of length > 1.
|
||||||
new_logprobs = logprobs.slice(req_index, req_index + 1)
|
new_logprobs = logprobs.slice(req_index, req_index + 1)
|
||||||
@ -802,7 +822,8 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
# Get prompt logprobs for this request.
|
# Get prompt logprobs for this request.
|
||||||
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
|
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
|
||||||
if new_token_ids or kv_transfer_params:
|
if new_token_ids or pooler_output is not None \
|
||||||
|
or kv_transfer_params:
|
||||||
|
|
||||||
# Add EngineCoreOutput for this Request.
|
# Add EngineCoreOutput for this Request.
|
||||||
outputs[request.client_index].append(
|
outputs[request.client_index].append(
|
||||||
@ -812,6 +833,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
finish_reason=request.get_finished_reason(),
|
finish_reason=request.get_finished_reason(),
|
||||||
new_logprobs=new_logprobs,
|
new_logprobs=new_logprobs,
|
||||||
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
|
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
|
||||||
|
pooling_output=pooler_output,
|
||||||
stop_reason=request.stop_reason,
|
stop_reason=request.stop_reason,
|
||||||
events=request.take_events(),
|
events=request.take_events(),
|
||||||
kv_transfer_params=kv_transfer_params,
|
kv_transfer_params=kv_transfer_params,
|
||||||
|
|||||||
@ -1,15 +1,28 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from vllm.v1.request import Request, RequestStatus
|
from vllm.v1.request import Request, RequestStatus
|
||||||
|
|
||||||
|
|
||||||
def check_stop(request: Request, max_model_len: int) -> bool:
|
def check_stop(request: Request,
|
||||||
|
max_model_len: int,
|
||||||
|
pooler_output: Optional[torch.Tensor] = None) -> bool:
|
||||||
if (request.num_tokens >= max_model_len
|
if (request.num_tokens >= max_model_len
|
||||||
or request.num_output_tokens >= request.max_tokens):
|
or request.num_output_tokens >= request.max_tokens):
|
||||||
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
|
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
if request.pooling_params:
|
||||||
|
if pooler_output is not None:
|
||||||
|
request.status = RequestStatus.FINISHED_STOPPED
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
sampling_params = request.sampling_params
|
sampling_params = request.sampling_params
|
||||||
|
assert sampling_params is not None
|
||||||
last_token_id = request.output_token_ids[-1]
|
last_token_id = request.output_token_ids[-1]
|
||||||
if (not sampling_params.ignore_eos
|
if (not sampling_params.ignore_eos
|
||||||
and last_token_id == request.eos_token_id):
|
and last_token_id == request.eos_token_id):
|
||||||
|
|||||||
@ -7,10 +7,12 @@ from collections.abc import Sequence
|
|||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
|
import torch
|
||||||
|
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.multimodal import MultiModalKwargs
|
from vllm.multimodal import MultiModalKwargs
|
||||||
from vllm.multimodal.inputs import PlaceholderRange
|
from vllm.multimodal.inputs import PlaceholderRange
|
||||||
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.v1.metrics.stats import SchedulerStats
|
from vllm.v1.metrics.stats import SchedulerStats
|
||||||
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
|
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
|
||||||
@ -50,7 +52,8 @@ class EngineCoreRequest(
|
|||||||
mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]]
|
mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]]
|
||||||
mm_hashes: Optional[list[str]]
|
mm_hashes: Optional[list[str]]
|
||||||
mm_placeholders: Optional[list[PlaceholderRange]]
|
mm_placeholders: Optional[list[PlaceholderRange]]
|
||||||
sampling_params: SamplingParams
|
sampling_params: Optional[SamplingParams]
|
||||||
|
pooling_params: Optional[PoolingParams]
|
||||||
eos_token_id: Optional[int]
|
eos_token_id: Optional[int]
|
||||||
arrival_time: float
|
arrival_time: float
|
||||||
lora_request: Optional[LoRARequest]
|
lora_request: Optional[LoRARequest]
|
||||||
@ -104,6 +107,8 @@ class EngineCoreOutput(
|
|||||||
new_logprobs: Optional[LogprobsLists] = None
|
new_logprobs: Optional[LogprobsLists] = None
|
||||||
new_prompt_logprobs_tensors: Optional[LogprobsTensors] = None
|
new_prompt_logprobs_tensors: Optional[LogprobsTensors] = None
|
||||||
|
|
||||||
|
pooling_output: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
finish_reason: Optional[FinishReason] = None
|
finish_reason: Optional[FinishReason] = None
|
||||||
stop_reason: Union[int, str, None] = None
|
stop_reason: Union[int, str, None] = None
|
||||||
events: Optional[list[EngineCoreEvent]] = None
|
events: Optional[list[EngineCoreEvent]] = None
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from vllm.inputs.preprocess import InputPreprocessor
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
@ -228,8 +228,7 @@ class AsyncLLM(EngineClient):
|
|||||||
if self.errored:
|
if self.errored:
|
||||||
raise EngineDeadError()
|
raise EngineDeadError()
|
||||||
|
|
||||||
assert isinstance(params, SamplingParams), \
|
is_pooling = isinstance(params, PoolingParams)
|
||||||
"Pooling is not supported in V1"
|
|
||||||
|
|
||||||
# Create a new output collector for the request.
|
# Create a new output collector for the request.
|
||||||
queue = RequestOutputCollector(output_kind=params.output_kind)
|
queue = RequestOutputCollector(output_kind=params.output_kind)
|
||||||
@ -240,7 +239,7 @@ class AsyncLLM(EngineClient):
|
|||||||
tokenization_kwargs, trace_headers, prompt_adapter_request,
|
tokenization_kwargs, trace_headers, prompt_adapter_request,
|
||||||
priority, data_parallel_rank)
|
priority, data_parallel_rank)
|
||||||
|
|
||||||
if params.n == 1:
|
if is_pooling or params.n == 1:
|
||||||
await self._add_request(request, prompt_str, None, 0, queue)
|
await self._add_request(request, prompt_str, None, 0, queue)
|
||||||
return queue
|
return queue
|
||||||
|
|
||||||
@ -443,7 +442,7 @@ class AsyncLLM(EngineClient):
|
|||||||
stat_logger.record(scheduler_stats=scheduler_stats,
|
stat_logger.record(scheduler_stats=scheduler_stats,
|
||||||
iteration_stats=iteration_stats)
|
iteration_stats=iteration_stats)
|
||||||
|
|
||||||
def encode(
|
async def encode(
|
||||||
self,
|
self,
|
||||||
prompt: PromptType,
|
prompt: PromptType,
|
||||||
pooling_params: PoolingParams,
|
pooling_params: PoolingParams,
|
||||||
@ -451,8 +450,75 @@ class AsyncLLM(EngineClient):
|
|||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
priority: int = 0,
|
priority: int = 0,
|
||||||
):
|
) -> AsyncGenerator[PoolingRequestOutput, None]:
|
||||||
raise ValueError("Not Supported on V1 yet.")
|
"""
|
||||||
|
Main function called by the API server to kick off a request
|
||||||
|
* 1) Making an AsyncStream corresponding to the Request.
|
||||||
|
* 2) Processing the Input.
|
||||||
|
* 3) Adding the Request to the EngineCore (separate process).
|
||||||
|
|
||||||
|
A separate output_handler loop runs in a background AsyncIO task,
|
||||||
|
pulling outputs from EngineCore and putting them into the
|
||||||
|
per-request AsyncStream.
|
||||||
|
|
||||||
|
The caller of generate() iterates the returned AsyncGenerator,
|
||||||
|
returning the RequestOutput back to the caller.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# We start the output_handler on the first call to generate() so
|
||||||
|
# we can call __init__ before the event loop, which enables us
|
||||||
|
# to handle startup failure gracefully in the OpenAI server.
|
||||||
|
self._run_output_handler()
|
||||||
|
|
||||||
|
q = await self.add_request(
|
||||||
|
request_id,
|
||||||
|
prompt,
|
||||||
|
pooling_params,
|
||||||
|
lora_request=lora_request,
|
||||||
|
trace_headers=trace_headers,
|
||||||
|
priority=priority,
|
||||||
|
)
|
||||||
|
|
||||||
|
# The output_handler task pushes items into the queue.
|
||||||
|
# This task pulls from the queue and yields to caller.
|
||||||
|
finished = False
|
||||||
|
while not finished:
|
||||||
|
# Note: drain queue without await if possible (avoids
|
||||||
|
# task switching under load which helps performance).
|
||||||
|
out = q.get_nowait() or await q.get()
|
||||||
|
assert isinstance(out, PoolingRequestOutput)
|
||||||
|
# Note: both OutputProcessor and EngineCore handle their
|
||||||
|
# own request cleanup based on finished.
|
||||||
|
finished = out.finished
|
||||||
|
yield out
|
||||||
|
|
||||||
|
# If the request is disconnected by the client, generate()
|
||||||
|
# is cancelled. So, we abort the request if we end up here.
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
await self.abort(request_id)
|
||||||
|
if self.log_requests:
|
||||||
|
logger.info("Request %s aborted.", request_id)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Engine is dead. Do not abort since we shut down.
|
||||||
|
except EngineDeadError:
|
||||||
|
if self.log_requests:
|
||||||
|
logger.info("Request %s failed (engine dead).", request_id)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Request validation error.
|
||||||
|
except ValueError:
|
||||||
|
if self.log_requests:
|
||||||
|
logger.info("Request %s failed (bad request).", request_id)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Unexpected error in the generate() task (possibly recoverable).
|
||||||
|
except Exception as e:
|
||||||
|
await self.abort(request_id)
|
||||||
|
if self.log_requests:
|
||||||
|
logger.info("Request %s failed.", request_id)
|
||||||
|
raise EngineGenerateError() from e
|
||||||
|
|
||||||
async def get_vllm_config(self) -> VllmConfig:
|
async def get_vllm_config(self) -> VllmConfig:
|
||||||
return self.vllm_config
|
return self.vllm_config
|
||||||
|
|||||||
@ -60,7 +60,6 @@ class EngineCore:
|
|||||||
executor_class: type[Executor],
|
executor_class: type[Executor],
|
||||||
log_stats: bool,
|
log_stats: bool,
|
||||||
executor_fail_callback: Optional[Callable] = None):
|
executor_fail_callback: Optional[Callable] = None):
|
||||||
assert vllm_config.model_config.runner_type != "pooling"
|
|
||||||
|
|
||||||
# plugins need to be loaded at the engine/scheduler level too
|
# plugins need to be loaded at the engine/scheduler level too
|
||||||
from vllm.plugins import load_general_plugins
|
from vllm.plugins import load_general_plugins
|
||||||
|
|||||||
@ -50,6 +50,8 @@ class IncrementalDetokenizer:
|
|||||||
request: EngineCoreRequest,
|
request: EngineCoreRequest,
|
||||||
) -> "IncrementalDetokenizer":
|
) -> "IncrementalDetokenizer":
|
||||||
|
|
||||||
|
assert request.sampling_params is not None
|
||||||
|
|
||||||
if tokenizer is None:
|
if tokenizer is None:
|
||||||
# No tokenizer => skipping detokenization.
|
# No tokenizer => skipping detokenization.
|
||||||
return IncrementalDetokenizer()
|
return IncrementalDetokenizer()
|
||||||
@ -70,6 +72,7 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
|
|||||||
|
|
||||||
# Stop strings
|
# Stop strings
|
||||||
params = request.sampling_params
|
params = request.sampling_params
|
||||||
|
assert params is not None
|
||||||
self.stop = stop = params.stop
|
self.stop = stop = params.stop
|
||||||
self.include_stop_str_in_output = params.include_stop_str_in_output
|
self.include_stop_str_in_output = params.include_stop_str_in_output
|
||||||
|
|
||||||
@ -164,6 +167,7 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
|||||||
super().__init__(request)
|
super().__init__(request)
|
||||||
|
|
||||||
sampling_params = request.sampling_params
|
sampling_params = request.sampling_params
|
||||||
|
assert sampling_params is not None
|
||||||
|
|
||||||
self.request_id = request.request_id
|
self.request_id = request.request_id
|
||||||
self.skip_special_tokens = sampling_params.skip_special_tokens
|
self.skip_special_tokens = sampling_params.skip_special_tokens
|
||||||
@ -245,20 +249,20 @@ class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
|||||||
super().__init__(request)
|
super().__init__(request)
|
||||||
|
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
params = request.sampling_params
|
||||||
|
assert params is not None
|
||||||
|
|
||||||
# Metadata for incremental detokenization.
|
# Metadata for incremental detokenization.
|
||||||
self.tokens, self.prefix_offset, self.read_offset = (
|
self.tokens, self.prefix_offset, self.read_offset = (
|
||||||
convert_prompt_ids_to_tokens(
|
convert_prompt_ids_to_tokens(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
prompt_ids=request.prompt_token_ids,
|
prompt_ids=request.prompt_token_ids,
|
||||||
skip_special_tokens=request.sampling_params.
|
skip_special_tokens=params.skip_special_tokens,
|
||||||
skip_special_tokens,
|
|
||||||
))
|
))
|
||||||
|
|
||||||
self.token_ids.extend(request.prompt_token_ids)
|
self.token_ids.extend(request.prompt_token_ids)
|
||||||
self.prompt_len = len(request.prompt_token_ids)
|
self.prompt_len = len(request.prompt_token_ids)
|
||||||
|
|
||||||
params = request.sampling_params
|
|
||||||
self.skip_special_tokens = params.skip_special_tokens
|
self.skip_special_tokens = params.skip_special_tokens
|
||||||
self.spaces_between_special_tokens = (
|
self.spaces_between_special_tokens = (
|
||||||
params.spaces_between_special_tokens)
|
params.spaces_between_special_tokens)
|
||||||
|
|||||||
@ -15,7 +15,7 @@ from vllm.inputs import PromptType
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
@ -221,7 +221,7 @@ class LLMEngine:
|
|||||||
# Add the request to EngineCore.
|
# Add the request to EngineCore.
|
||||||
self.engine_core.add_request(child_request)
|
self.engine_core.add_request(child_request)
|
||||||
|
|
||||||
def step(self) -> list[RequestOutput]:
|
def step(self) -> Union[list[RequestOutput], list[PoolingRequestOutput]]:
|
||||||
|
|
||||||
if self.should_execute_dummy_batch:
|
if self.should_execute_dummy_batch:
|
||||||
self.should_execute_dummy_batch = False
|
self.should_execute_dummy_batch = False
|
||||||
|
|||||||
@ -38,6 +38,7 @@ class LogprobsProcessor:
|
|||||||
tokenizer: Optional[AnyTokenizer],
|
tokenizer: Optional[AnyTokenizer],
|
||||||
request: EngineCoreRequest,
|
request: EngineCoreRequest,
|
||||||
) -> "LogprobsProcessor":
|
) -> "LogprobsProcessor":
|
||||||
|
assert request.sampling_params is not None
|
||||||
num_logprobs = request.sampling_params.logprobs
|
num_logprobs = request.sampling_params.logprobs
|
||||||
num_prompt_logprobs = request.sampling_params.prompt_logprobs
|
num_prompt_logprobs = request.sampling_params.prompt_logprobs
|
||||||
return cls(
|
return cls(
|
||||||
|
|||||||
@ -4,9 +4,12 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union, cast
|
||||||
|
|
||||||
from vllm.outputs import CompletionOutput, RequestOutput
|
import torch
|
||||||
|
|
||||||
|
from vllm.outputs import (CompletionOutput, PoolingOutput,
|
||||||
|
PoolingRequestOutput, RequestOutput)
|
||||||
from vllm.sampling_params import RequestOutputKind
|
from vllm.sampling_params import RequestOutputKind
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||||
@ -29,20 +32,22 @@ class RequestOutputCollector:
|
|||||||
|
|
||||||
def __init__(self, output_kind: RequestOutputKind):
|
def __init__(self, output_kind: RequestOutputKind):
|
||||||
self.aggregate = output_kind == RequestOutputKind.DELTA
|
self.aggregate = output_kind == RequestOutputKind.DELTA
|
||||||
self.output: Optional[Union[RequestOutput, Exception]] = None
|
self.output: Optional[Union[RequestOutput, PoolingRequestOutput,
|
||||||
|
Exception]] = None
|
||||||
self.ready = asyncio.Event()
|
self.ready = asyncio.Event()
|
||||||
|
|
||||||
def put(self, output: Union[RequestOutput, Exception]) -> None:
|
def put(self, output: Union[RequestOutput, PoolingRequestOutput,
|
||||||
|
Exception]) -> None:
|
||||||
"""Non-blocking put operation."""
|
"""Non-blocking put operation."""
|
||||||
if self.output is None or isinstance(output, Exception):
|
if self.output is None or isinstance(output, Exception):
|
||||||
self.output = output
|
self.output = output
|
||||||
self.ready.set()
|
self.ready.set()
|
||||||
elif isinstance(self.output, RequestOutput):
|
elif isinstance(self.output, (RequestOutput, PoolingRequestOutput)):
|
||||||
# This ensures that request outputs with different request indexes
|
# This ensures that request outputs with different request indexes
|
||||||
# (if n > 1) do not override each other.
|
# (if n > 1) do not override each other.
|
||||||
self.output.add(output, aggregate=self.aggregate)
|
self.output.add(output, aggregate=self.aggregate)
|
||||||
|
|
||||||
async def get(self) -> RequestOutput:
|
async def get(self) -> Union[RequestOutput, PoolingRequestOutput]:
|
||||||
"""Get operation blocks on put event."""
|
"""Get operation blocks on put event."""
|
||||||
while (output := self.output) is None:
|
while (output := self.output) is None:
|
||||||
await self.ready.wait()
|
await self.ready.wait()
|
||||||
@ -52,7 +57,8 @@ class RequestOutputCollector:
|
|||||||
raise output
|
raise output
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def get_nowait(self) -> Optional[RequestOutput]:
|
def get_nowait(
|
||||||
|
self) -> Optional[Union[RequestOutput, PoolingRequestOutput]]:
|
||||||
"""Non-blocking get operation."""
|
"""Non-blocking get operation."""
|
||||||
output = self.output
|
output = self.output
|
||||||
if output is not None:
|
if output is not None:
|
||||||
@ -66,7 +72,7 @@ class RequestOutputCollector:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class OutputProcessorOutput:
|
class OutputProcessorOutput:
|
||||||
|
|
||||||
request_outputs: list[RequestOutput]
|
request_outputs: list[Union[RequestOutput, PoolingRequestOutput]]
|
||||||
reqs_to_abort: list[str]
|
reqs_to_abort: list[str]
|
||||||
|
|
||||||
|
|
||||||
@ -81,8 +87,8 @@ class RequestState:
|
|||||||
output_kind: RequestOutputKind,
|
output_kind: RequestOutputKind,
|
||||||
prompt: Optional[str],
|
prompt: Optional[str],
|
||||||
prompt_token_ids: list[int],
|
prompt_token_ids: list[int],
|
||||||
logprobs_processor: LogprobsProcessor,
|
logprobs_processor: Optional[LogprobsProcessor],
|
||||||
detokenizer: IncrementalDetokenizer,
|
detokenizer: Optional[IncrementalDetokenizer],
|
||||||
max_tokens_param: Optional[int],
|
max_tokens_param: Optional[int],
|
||||||
arrival_time: float,
|
arrival_time: float,
|
||||||
queue: Optional[RequestOutputCollector],
|
queue: Optional[RequestOutputCollector],
|
||||||
@ -116,27 +122,39 @@ class RequestState:
|
|||||||
queue: Optional[RequestOutputCollector],
|
queue: Optional[RequestOutputCollector],
|
||||||
log_stats: bool,
|
log_stats: bool,
|
||||||
) -> "RequestState":
|
) -> "RequestState":
|
||||||
if not request.sampling_params.detokenize:
|
|
||||||
tokenizer = None
|
if sampling_params := request.sampling_params:
|
||||||
|
if not sampling_params.detokenize:
|
||||||
|
tokenizer = None
|
||||||
|
output_kind = sampling_params.output_kind
|
||||||
|
logprobs_processor = LogprobsProcessor.from_new_request(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
request=request,
|
||||||
|
)
|
||||||
|
detokenizer = IncrementalDetokenizer.from_new_request(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
request=request,
|
||||||
|
)
|
||||||
|
max_tokens_param = sampling_params.max_tokens
|
||||||
|
else:
|
||||||
|
logprobs_processor = None
|
||||||
|
detokenizer = None
|
||||||
|
max_tokens_param = None
|
||||||
|
assert request.pooling_params is not None
|
||||||
|
output_kind = request.pooling_params.output_kind
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
request_id=request.request_id,
|
request_id=request.request_id,
|
||||||
parent_req=parent_req,
|
parent_req=parent_req,
|
||||||
request_index=request_index,
|
request_index=request_index,
|
||||||
lora_name=(request.lora_request.name
|
lora_name=(request.lora_request.name
|
||||||
if request.lora_request is not None else None),
|
if request.lora_request is not None else None),
|
||||||
output_kind=request.sampling_params.output_kind,
|
output_kind=output_kind,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
prompt_token_ids=request.prompt_token_ids,
|
prompt_token_ids=request.prompt_token_ids,
|
||||||
logprobs_processor=LogprobsProcessor.from_new_request(
|
logprobs_processor=logprobs_processor,
|
||||||
tokenizer=tokenizer,
|
detokenizer=detokenizer,
|
||||||
request=request,
|
max_tokens_param=max_tokens_param,
|
||||||
),
|
|
||||||
detokenizer=IncrementalDetokenizer.from_new_request(
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
request=request,
|
|
||||||
),
|
|
||||||
max_tokens_param=(request.sampling_params.max_tokens if
|
|
||||||
request.sampling_params is not None else None),
|
|
||||||
arrival_time=request.arrival_time,
|
arrival_time=request.arrival_time,
|
||||||
queue=queue,
|
queue=queue,
|
||||||
log_stats=log_stats,
|
log_stats=log_stats,
|
||||||
@ -145,11 +163,12 @@ class RequestState:
|
|||||||
def make_request_output(
|
def make_request_output(
|
||||||
self,
|
self,
|
||||||
new_token_ids: list[int],
|
new_token_ids: list[int],
|
||||||
|
pooling_output: Optional[torch.Tensor],
|
||||||
finish_reason: Optional[FinishReason],
|
finish_reason: Optional[FinishReason],
|
||||||
stop_reason: Union[int, str, None],
|
stop_reason: Union[int, str, None],
|
||||||
kv_transfer_params: Optional[dict[str, Any]] = None,
|
kv_transfer_params: Optional[dict[str, Any]] = None,
|
||||||
num_cached_tokens: int = 0,
|
num_cached_tokens: int = 0,
|
||||||
) -> Optional[RequestOutput]:
|
) -> Optional[Union[RequestOutput, PoolingRequestOutput]]:
|
||||||
|
|
||||||
finished = finish_reason is not None
|
finished = finish_reason is not None
|
||||||
final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
|
final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
|
||||||
@ -158,15 +177,20 @@ class RequestState:
|
|||||||
# Only the final output is required in FINAL_ONLY mode.
|
# Only the final output is required in FINAL_ONLY mode.
|
||||||
return None
|
return None
|
||||||
|
|
||||||
completion_output = self._new_completion_output(
|
|
||||||
new_token_ids, finish_reason, stop_reason)
|
|
||||||
|
|
||||||
request_id = self.request_id
|
request_id = self.request_id
|
||||||
|
if pooling_output is not None:
|
||||||
|
return self._new_request_output(
|
||||||
|
request_id, [self._new_pooling_output(pooling_output)],
|
||||||
|
finished)
|
||||||
|
|
||||||
|
output = self._new_completion_output(new_token_ids, finish_reason,
|
||||||
|
stop_reason)
|
||||||
|
|
||||||
if self.parent_req is None:
|
if self.parent_req is None:
|
||||||
outputs = [completion_output]
|
outputs = [output]
|
||||||
else:
|
else:
|
||||||
request_id, outputs, finished = self.parent_req.get_outputs(
|
request_id, outputs, finished = self.parent_req.get_outputs(
|
||||||
request_id, completion_output)
|
request_id, output)
|
||||||
if not outputs:
|
if not outputs:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -176,12 +200,21 @@ class RequestState:
|
|||||||
def _new_request_output(
|
def _new_request_output(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
outputs: list[CompletionOutput],
|
outputs: Union[list[CompletionOutput], list[PoolingOutput]],
|
||||||
finished: bool,
|
finished: bool,
|
||||||
kv_transfer_params: Optional[dict[str, Any]] = None,
|
kv_transfer_params: Optional[dict[str, Any]] = None,
|
||||||
num_cached_tokens: int = 0,
|
num_cached_tokens: int = 0,
|
||||||
) -> RequestOutput:
|
) -> Union[RequestOutput, PoolingRequestOutput]:
|
||||||
|
|
||||||
|
if isinstance(outputs[0], PoolingOutput):
|
||||||
|
assert len(outputs) == 1
|
||||||
|
return PoolingRequestOutput(
|
||||||
|
request_id=request_id,
|
||||||
|
outputs=outputs[0],
|
||||||
|
prompt_token_ids=self.prompt_token_ids,
|
||||||
|
finished=finished,
|
||||||
|
)
|
||||||
|
assert self.logprobs_processor is not None
|
||||||
if self.output_kind == RequestOutputKind.DELTA:
|
if self.output_kind == RequestOutputKind.DELTA:
|
||||||
# Side effect: logprobs processor forgets prompt logprobs
|
# Side effect: logprobs processor forgets prompt logprobs
|
||||||
prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
|
prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
|
||||||
@ -193,7 +226,7 @@ class RequestState:
|
|||||||
prompt=self.prompt,
|
prompt=self.prompt,
|
||||||
prompt_token_ids=self.prompt_token_ids,
|
prompt_token_ids=self.prompt_token_ids,
|
||||||
prompt_logprobs=prompt_logprobs,
|
prompt_logprobs=prompt_logprobs,
|
||||||
outputs=outputs,
|
outputs=cast(list[CompletionOutput], outputs),
|
||||||
finished=finished,
|
finished=finished,
|
||||||
kv_transfer_params=kv_transfer_params,
|
kv_transfer_params=kv_transfer_params,
|
||||||
num_cached_tokens=num_cached_tokens,
|
num_cached_tokens=num_cached_tokens,
|
||||||
@ -206,6 +239,8 @@ class RequestState:
|
|||||||
stop_reason: Union[int, str, None],
|
stop_reason: Union[int, str, None],
|
||||||
) -> CompletionOutput:
|
) -> CompletionOutput:
|
||||||
|
|
||||||
|
assert self.detokenizer is not None
|
||||||
|
assert self.logprobs_processor is not None
|
||||||
finished = finish_reason is not None
|
finished = finish_reason is not None
|
||||||
delta = self.output_kind == RequestOutputKind.DELTA
|
delta = self.output_kind == RequestOutputKind.DELTA
|
||||||
|
|
||||||
@ -228,6 +263,13 @@ class RequestState:
|
|||||||
finish_reason=str(finish_reason) if finished else None,
|
finish_reason=str(finish_reason) if finished else None,
|
||||||
stop_reason=stop_reason if finished else None)
|
stop_reason=stop_reason if finished else None)
|
||||||
|
|
||||||
|
def _new_pooling_output(
|
||||||
|
self,
|
||||||
|
pooling_output: torch.Tensor,
|
||||||
|
) -> PoolingOutput:
|
||||||
|
|
||||||
|
return PoolingOutput(data=pooling_output)
|
||||||
|
|
||||||
|
|
||||||
class OutputProcessor:
|
class OutputProcessor:
|
||||||
"""Process EngineCoreOutputs into RequestOutputs."""
|
"""Process EngineCoreOutputs into RequestOutputs."""
|
||||||
@ -326,7 +368,8 @@ class OutputProcessor:
|
|||||||
within the loop below.
|
within the loop below.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
request_outputs: list[RequestOutput] = []
|
request_outputs: Union[list[RequestOutput],
|
||||||
|
list[PoolingRequestOutput]] = []
|
||||||
reqs_to_abort: list[str] = []
|
reqs_to_abort: list[str] = []
|
||||||
for engine_core_output in engine_core_outputs:
|
for engine_core_output in engine_core_outputs:
|
||||||
req_id = engine_core_output.request_id
|
req_id = engine_core_output.request_id
|
||||||
@ -341,25 +384,31 @@ class OutputProcessor:
|
|||||||
iteration_stats)
|
iteration_stats)
|
||||||
|
|
||||||
new_token_ids = engine_core_output.new_token_ids
|
new_token_ids = engine_core_output.new_token_ids
|
||||||
|
pooling_output = engine_core_output.pooling_output
|
||||||
finish_reason = engine_core_output.finish_reason
|
finish_reason = engine_core_output.finish_reason
|
||||||
stop_reason = engine_core_output.stop_reason
|
stop_reason = engine_core_output.stop_reason
|
||||||
kv_transfer_params = engine_core_output.kv_transfer_params
|
kv_transfer_params = engine_core_output.kv_transfer_params
|
||||||
num_cached_tokens = engine_core_output.num_cached_tokens
|
num_cached_tokens = engine_core_output.num_cached_tokens
|
||||||
req_state.is_prefilling = False
|
req_state.is_prefilling = False
|
||||||
|
|
||||||
# 2) Detokenize the token ids into text and perform stop checks.
|
if pooling_output is None:
|
||||||
stop_string = req_state.detokenizer.update(
|
assert req_state.detokenizer is not None
|
||||||
new_token_ids, finish_reason == FinishReason.STOP)
|
assert req_state.logprobs_processor is not None
|
||||||
if stop_string:
|
# 2) Detokenize the token ids into text and perform stop checks.
|
||||||
finish_reason = FinishReason.STOP
|
stop_string = req_state.detokenizer.update(
|
||||||
stop_reason = stop_string
|
new_token_ids, finish_reason == FinishReason.STOP)
|
||||||
|
if stop_string:
|
||||||
|
finish_reason = FinishReason.STOP
|
||||||
|
stop_reason = stop_string
|
||||||
|
|
||||||
# 3) Compute sample and prompt logprobs for request, if required.
|
# 3) Compute sample and prompt logprobs for request,
|
||||||
req_state.logprobs_processor.update_from_output(engine_core_output)
|
# if required.
|
||||||
|
req_state.logprobs_processor.update_from_output(
|
||||||
|
engine_core_output)
|
||||||
|
|
||||||
# 4) Create and handle RequestOutput objects.
|
# 4) Create and handle RequestOutput objects.
|
||||||
if request_output := req_state.make_request_output(
|
if request_output := req_state.make_request_output(
|
||||||
new_token_ids, finish_reason, stop_reason,
|
new_token_ids, pooling_output, finish_reason, stop_reason,
|
||||||
kv_transfer_params, num_cached_tokens):
|
kv_transfer_params, num_cached_tokens):
|
||||||
if req_state.queue is not None:
|
if req_state.queue is not None:
|
||||||
# AsyncLLM: put into queue for handling by generate().
|
# AsyncLLM: put into queue for handling by generate().
|
||||||
|
|||||||
@ -136,8 +136,8 @@ class Processor:
|
|||||||
Should raise ValueError if unsupported for API Server.
|
Should raise ValueError if unsupported for API Server.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not isinstance(params, SamplingParams):
|
if isinstance(params, PoolingParams):
|
||||||
raise ValueError("V1 does not yet support Pooling models.")
|
return
|
||||||
|
|
||||||
self._validate_logprobs(params)
|
self._validate_logprobs(params)
|
||||||
self._validate_sampling_params(params, lora_request)
|
self._validate_sampling_params(params, lora_request)
|
||||||
@ -263,18 +263,22 @@ class Processor:
|
|||||||
if encoder_inputs is not None:
|
if encoder_inputs is not None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
assert isinstance(params, SamplingParams)
|
sampling_params = None
|
||||||
# TODO: can we avoid cloning here in multiproc case?
|
pooling_params = None
|
||||||
sampling_params = params.clone()
|
if isinstance(params, SamplingParams):
|
||||||
# If unset max tokens, then generate up to the max_model_len.
|
# TODO: can we avoid cloning here in multiproc case?
|
||||||
if sampling_params.max_tokens is None:
|
sampling_params = params.clone()
|
||||||
sampling_params.max_tokens = (
|
# If unset max tokens, then generate up to the max_model_len.
|
||||||
self.model_config.max_model_len -
|
if sampling_params.max_tokens is None:
|
||||||
len(decoder_inputs["prompt_token_ids"]))
|
sampling_params.max_tokens = (
|
||||||
sampling_params.update_from_generation_config(
|
self.model_config.max_model_len -
|
||||||
self.generation_config_fields, eos_token_id)
|
len(decoder_inputs["prompt_token_ids"]))
|
||||||
sampling_params.update_from_tokenizer(
|
sampling_params.update_from_generation_config(
|
||||||
self.tokenizer.get_lora_tokenizer(lora_request))
|
self.generation_config_fields, eos_token_id)
|
||||||
|
sampling_params.update_from_tokenizer(
|
||||||
|
self.tokenizer.get_lora_tokenizer(lora_request))
|
||||||
|
else:
|
||||||
|
pooling_params = params.clone()
|
||||||
|
|
||||||
# Multimodal related.
|
# Multimodal related.
|
||||||
sorted_mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] = None
|
sorted_mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] = None
|
||||||
@ -331,6 +335,7 @@ class Processor:
|
|||||||
mm_hashes=sorted_mm_hashes,
|
mm_hashes=sorted_mm_hashes,
|
||||||
mm_placeholders=sorted_mm_positions,
|
mm_placeholders=sorted_mm_positions,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
|
pooling_params=pooling_params,
|
||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
arrival_time=arrival_time,
|
arrival_time=arrival_time,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
|
|||||||
@ -481,8 +481,9 @@ class PrometheusStatLogger(StatLoggerBase):
|
|||||||
finished_request.num_prompt_tokens)
|
finished_request.num_prompt_tokens)
|
||||||
self.histogram_num_generation_tokens_request.observe(
|
self.histogram_num_generation_tokens_request.observe(
|
||||||
finished_request.num_generation_tokens)
|
finished_request.num_generation_tokens)
|
||||||
self.histogram_max_tokens_request.observe(
|
if finished_request.max_tokens_param:
|
||||||
finished_request.max_tokens_param)
|
self.histogram_max_tokens_request.observe(
|
||||||
|
finished_request.max_tokens_param)
|
||||||
|
|
||||||
if self.gauge_lora_info is not None:
|
if self.gauge_lora_info is not None:
|
||||||
running_lora_adapters = \
|
running_lora_adapters = \
|
||||||
|
|||||||
@ -106,7 +106,6 @@ class IterationStats:
|
|||||||
|
|
||||||
self.num_generation_tokens += num_new_generation_tokens
|
self.num_generation_tokens += num_new_generation_tokens
|
||||||
if is_prefilling:
|
if is_prefilling:
|
||||||
assert num_new_generation_tokens > 0
|
|
||||||
self.num_prompt_tokens += prompt_len
|
self.num_prompt_tokens += prompt_len
|
||||||
|
|
||||||
first_token_latency = self._time_since(req_stats.arrival_time)
|
first_token_latency = self._time_since(req_stats.arrival_time)
|
||||||
|
|||||||
@ -101,6 +101,9 @@ class ModelRunnerOutput:
|
|||||||
# [prompt_len]
|
# [prompt_len]
|
||||||
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]]
|
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]]
|
||||||
|
|
||||||
|
# [num_reqs, hidden_size]
|
||||||
|
pooler_output: list[Optional[torch.Tensor]]
|
||||||
|
|
||||||
# [req_ids]
|
# [req_ids]
|
||||||
finished_sending: Optional[set[str]] = None
|
finished_sending: Optional[set[str]] = None
|
||||||
finished_recving: Optional[set[str]] = None
|
finished_recving: Optional[set[str]] = None
|
||||||
@ -112,5 +115,6 @@ EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
|
|||||||
spec_token_ids=None,
|
spec_token_ids=None,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
|
pooler_output=[],
|
||||||
finished_sending=None,
|
finished_sending=None,
|
||||||
finished_recving=None)
|
finished_recving=None)
|
||||||
|
|||||||
0
vllm/v1/pool/__init__.py
Normal file
0
vllm/v1/pool/__init__.py
Normal file
16
vllm/v1/pool/metadata.py
Normal file
16
vllm/v1/pool/metadata.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.pooling_params import PoolingParams
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PoolingMetadata:
|
||||||
|
"""Tensors for pooling."""
|
||||||
|
|
||||||
|
prompt_lens: torch.Tensor
|
||||||
|
prompt_token_ids: Optional[torch.Tensor]
|
||||||
|
pooling_params: list[PoolingParams]
|
||||||
@ -5,6 +5,7 @@ import enum
|
|||||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
|
||||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||||
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.utils import is_list_of
|
from vllm.utils import is_list_of
|
||||||
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
|
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
|
||||||
@ -25,7 +26,8 @@ class Request:
|
|||||||
multi_modal_inputs: Optional[list[MultiModalKwargs]],
|
multi_modal_inputs: Optional[list[MultiModalKwargs]],
|
||||||
multi_modal_hashes: Optional[list[str]],
|
multi_modal_hashes: Optional[list[str]],
|
||||||
multi_modal_placeholders: Optional[list[PlaceholderRange]],
|
multi_modal_placeholders: Optional[list[PlaceholderRange]],
|
||||||
sampling_params: SamplingParams,
|
sampling_params: Optional[SamplingParams],
|
||||||
|
pooling_params: Optional[PoolingParams],
|
||||||
eos_token_id: Optional[int],
|
eos_token_id: Optional[int],
|
||||||
client_index: int = 0,
|
client_index: int = 0,
|
||||||
lora_request: Optional["LoRARequest"] = None,
|
lora_request: Optional["LoRARequest"] = None,
|
||||||
@ -35,18 +37,35 @@ class Request:
|
|||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
self.client_index = client_index
|
self.client_index = client_index
|
||||||
self.sampling_params = sampling_params
|
self.sampling_params = sampling_params
|
||||||
|
self.pooling_params = pooling_params
|
||||||
# Because of LoRA, the eos token id can be different for each request.
|
# Because of LoRA, the eos token id can be different for each request.
|
||||||
self.eos_token_id = eos_token_id
|
self.eos_token_id = eos_token_id
|
||||||
self.lora_request = lora_request
|
self.lora_request = lora_request
|
||||||
self.structured_output_request = structured_output_request
|
self.structured_output_request = structured_output_request
|
||||||
|
|
||||||
self.status = (RequestStatus.WAITING_FOR_FSM
|
self.status = RequestStatus.WAITING
|
||||||
if sampling_params.guided_decoding is not None else
|
if sampling_params and sampling_params.guided_decoding is not None:
|
||||||
RequestStatus.WAITING)
|
self.status = RequestStatus.WAITING_FOR_FSM
|
||||||
self.events: list[EngineCoreEvent] = []
|
self.events: list[EngineCoreEvent] = []
|
||||||
self.stop_reason: Union[int, str, None] = None
|
self.stop_reason: Union[int, str, None] = None
|
||||||
assert sampling_params.max_tokens is not None
|
|
||||||
self.max_tokens = sampling_params.max_tokens
|
# P/D: Connector-specific KV transfer parameters.
|
||||||
|
self.kv_transfer_params: Optional[dict[str, Any]] = None
|
||||||
|
|
||||||
|
if pooling_params is not None:
|
||||||
|
self.max_tokens = 1
|
||||||
|
elif sampling_params is not None:
|
||||||
|
assert sampling_params.max_tokens is not None
|
||||||
|
self.max_tokens = sampling_params.max_tokens
|
||||||
|
if sampling_params.guided_decoding is not None:
|
||||||
|
self.status = RequestStatus.WAITING_FOR_FSM
|
||||||
|
|
||||||
|
if sampling_params.extra_args is not None:
|
||||||
|
self.kv_transfer_params = \
|
||||||
|
sampling_params.extra_args.get("kv_transfer_params")
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"sampling_params and pooling_params can't both be unset")
|
||||||
|
|
||||||
self.prompt_token_ids = prompt_token_ids
|
self.prompt_token_ids = prompt_token_ids
|
||||||
self.num_prompt_tokens = len(self.prompt_token_ids)
|
self.num_prompt_tokens = len(self.prompt_token_ids)
|
||||||
@ -63,11 +82,6 @@ class Request:
|
|||||||
self.num_encoder_inputs = len(self.mm_inputs)
|
self.num_encoder_inputs = len(self.mm_inputs)
|
||||||
self.has_encoder_inputs = self.num_encoder_inputs > 0
|
self.has_encoder_inputs = self.num_encoder_inputs > 0
|
||||||
|
|
||||||
# P/D: Connector-specific KV transfer parameters.
|
|
||||||
kv_params = (None if sampling_params.extra_args is None else
|
|
||||||
sampling_params.extra_args.get("kv_transfer_params"))
|
|
||||||
self.kv_transfer_params: Optional[dict[str, Any]] = kv_params
|
|
||||||
|
|
||||||
# Sanity check
|
# Sanity check
|
||||||
assert len(self.mm_inputs) == len(self.mm_positions)
|
assert len(self.mm_inputs) == len(self.mm_positions)
|
||||||
if self.mm_hashes:
|
if self.mm_hashes:
|
||||||
@ -98,10 +112,12 @@ class Request:
|
|||||||
multi_modal_hashes=request.mm_hashes,
|
multi_modal_hashes=request.mm_hashes,
|
||||||
multi_modal_placeholders=request.mm_placeholders,
|
multi_modal_placeholders=request.mm_placeholders,
|
||||||
sampling_params=request.sampling_params,
|
sampling_params=request.sampling_params,
|
||||||
|
pooling_params=request.pooling_params,
|
||||||
eos_token_id=request.eos_token_id,
|
eos_token_id=request.eos_token_id,
|
||||||
lora_request=request.lora_request,
|
lora_request=request.lora_request,
|
||||||
structured_output_request=StructuredOutputRequest(
|
structured_output_request=StructuredOutputRequest(
|
||||||
sampling_params=request.sampling_params),
|
sampling_params=request.sampling_params) \
|
||||||
|
if request.sampling_params else None,
|
||||||
cache_salt=request.cache_salt,
|
cache_salt=request.cache_salt,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -141,7 +157,8 @@ class Request:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def use_structured_output(self) -> bool:
|
def use_structured_output(self) -> bool:
|
||||||
return self.sampling_params.guided_decoding is not None
|
return self.sampling_params is not None and \
|
||||||
|
self.sampling_params.guided_decoding is not None
|
||||||
|
|
||||||
def record_event(
|
def record_event(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -62,13 +62,15 @@ class StructuredOutputManager:
|
|||||||
return
|
return
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
assert request.sampling_params.guided_decoding is not None
|
assert request.sampling_params is not None and \
|
||||||
|
request.sampling_params.guided_decoding is not None
|
||||||
|
|
||||||
# Initialize the backend the first time it is needed.
|
# Initialize the backend the first time it is needed.
|
||||||
#
|
#
|
||||||
# NOTE: We only support a single backend. We do NOT support different
|
# NOTE: We only support a single backend. We do NOT support different
|
||||||
# backends on a per-request basis in V1 (for now, anyway...).
|
# backends on a per-request basis in V1 (for now, anyway...).
|
||||||
if self.backend is None:
|
if self.backend is None:
|
||||||
|
assert request.sampling_params is not None
|
||||||
backend = request.sampling_params.guided_decoding.backend
|
backend = request.sampling_params.guided_decoding.backend
|
||||||
vocab_size = self.vllm_config.model_config.get_vocab_size()
|
vocab_size = self.vllm_config.model_config.get_vocab_size()
|
||||||
if backend == "xgrammar":
|
if backend == "xgrammar":
|
||||||
|
|||||||
@ -10,9 +10,11 @@ import torch
|
|||||||
|
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||||
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.sampling_params import SamplingParams, SamplingType
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
from vllm.utils import swap_dict_values
|
from vllm.utils import swap_dict_values
|
||||||
from vllm.v1.outputs import LogprobsTensors
|
from vllm.v1.outputs import LogprobsTensors
|
||||||
|
from vllm.v1.pool.metadata import PoolingMetadata
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.utils import copy_slice
|
from vllm.v1.utils import copy_slice
|
||||||
from vllm.v1.worker.block_table import MultiGroupBlockTable
|
from vllm.v1.worker.block_table import MultiGroupBlockTable
|
||||||
@ -27,7 +29,8 @@ class CachedRequestState:
|
|||||||
prompt_token_ids: list[int]
|
prompt_token_ids: list[int]
|
||||||
mm_inputs: list[MultiModalKwargs]
|
mm_inputs: list[MultiModalKwargs]
|
||||||
mm_positions: list[PlaceholderRange]
|
mm_positions: list[PlaceholderRange]
|
||||||
sampling_params: SamplingParams
|
sampling_params: Optional[SamplingParams]
|
||||||
|
pooling_params: Optional[PoolingParams]
|
||||||
generator: Optional[torch.Generator]
|
generator: Optional[torch.Generator]
|
||||||
|
|
||||||
block_ids: tuple[list[int], ...]
|
block_ids: tuple[list[int], ...]
|
||||||
@ -226,6 +229,8 @@ class InputBatch:
|
|||||||
# This is updated each time the batch constituents change.
|
# This is updated each time the batch constituents change.
|
||||||
self.sampling_metadata = self._make_sampling_metadata()
|
self.sampling_metadata = self._make_sampling_metadata()
|
||||||
|
|
||||||
|
self.pooling_params: dict[str, PoolingParams] = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def req_ids(self) -> list[str]:
|
def req_ids(self) -> list[str]:
|
||||||
# None elements should only be present transiently
|
# None elements should only be present transiently
|
||||||
@ -269,77 +274,83 @@ class InputBatch:
|
|||||||
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
|
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
|
||||||
self.block_table.add_row(request.block_ids, req_index)
|
self.block_table.add_row(request.block_ids, req_index)
|
||||||
|
|
||||||
sampling_params = request.sampling_params
|
if sampling_params := request.sampling_params:
|
||||||
if sampling_params.sampling_type == SamplingType.GREEDY:
|
if sampling_params.sampling_type == SamplingType.GREEDY:
|
||||||
# Avoid later division by zero.
|
# Avoid later division by zero.
|
||||||
self.temperature_cpu[req_index] = -1.0
|
self.temperature_cpu[req_index] = -1.0
|
||||||
self.greedy_reqs.add(req_id)
|
self.greedy_reqs.add(req_id)
|
||||||
else:
|
else:
|
||||||
self.temperature_cpu[req_index] = sampling_params.temperature
|
self.temperature_cpu[req_index] = sampling_params.temperature
|
||||||
self.random_reqs.add(req_id)
|
self.random_reqs.add(req_id)
|
||||||
|
|
||||||
self.top_p_cpu[req_index] = sampling_params.top_p
|
self.top_p_cpu[req_index] = sampling_params.top_p
|
||||||
if sampling_params.top_p < 1:
|
if sampling_params.top_p < 1:
|
||||||
self.top_p_reqs.add(req_id)
|
self.top_p_reqs.add(req_id)
|
||||||
top_k = sampling_params.top_k
|
top_k = sampling_params.top_k
|
||||||
if 0 < top_k < self.vocab_size:
|
if 0 < top_k < self.vocab_size:
|
||||||
self.top_k_reqs.add(req_id)
|
self.top_k_reqs.add(req_id)
|
||||||
else:
|
else:
|
||||||
top_k = self.vocab_size
|
top_k = self.vocab_size
|
||||||
self.top_k_cpu[req_index] = top_k
|
self.top_k_cpu[req_index] = top_k
|
||||||
self.min_p_cpu[req_index] = sampling_params.min_p
|
self.min_p_cpu[req_index] = sampling_params.min_p
|
||||||
self.frequency_penalties_cpu[
|
self.frequency_penalties_cpu[
|
||||||
req_index] = sampling_params.frequency_penalty
|
req_index] = sampling_params.frequency_penalty
|
||||||
if sampling_params.min_p > _SAMPLING_EPS:
|
if sampling_params.min_p > _SAMPLING_EPS:
|
||||||
self.min_p_reqs.add(req_id)
|
self.min_p_reqs.add(req_id)
|
||||||
if sampling_params.frequency_penalty != 0.0:
|
if sampling_params.frequency_penalty != 0.0:
|
||||||
self.frequency_penalties_reqs.add(req_id)
|
self.frequency_penalties_reqs.add(req_id)
|
||||||
self.presence_penalties_cpu[
|
self.presence_penalties_cpu[
|
||||||
req_index] = sampling_params.presence_penalty
|
req_index] = sampling_params.presence_penalty
|
||||||
if sampling_params.presence_penalty != 0.0:
|
if sampling_params.presence_penalty != 0.0:
|
||||||
self.presence_penalties_reqs.add(req_id)
|
self.presence_penalties_reqs.add(req_id)
|
||||||
self.repetition_penalties_cpu[
|
self.repetition_penalties_cpu[
|
||||||
req_index] = sampling_params.repetition_penalty
|
req_index] = sampling_params.repetition_penalty
|
||||||
if sampling_params.repetition_penalty != 1.0:
|
if sampling_params.repetition_penalty != 1.0:
|
||||||
self.repetition_penalties_reqs.add(req_id)
|
self.repetition_penalties_reqs.add(req_id)
|
||||||
if sampling_params.min_tokens:
|
if sampling_params.min_tokens:
|
||||||
self.min_tokens[req_index] = (sampling_params.min_tokens,
|
self.min_tokens[req_index] = (
|
||||||
sampling_params.all_stop_token_ids)
|
sampling_params.min_tokens,
|
||||||
|
sampling_params.all_stop_token_ids)
|
||||||
|
|
||||||
# NOTE(woosuk): self.generators should not include the requests that
|
# NOTE(woosuk): self.generators should not include the requests that
|
||||||
# do not have their own generator.
|
# do not have their own generator.
|
||||||
if request.generator is not None:
|
if request.generator is not None:
|
||||||
self.generators[req_index] = request.generator
|
self.generators[req_index] = request.generator
|
||||||
|
|
||||||
if sampling_params.logprobs is not None:
|
if sampling_params.logprobs is not None:
|
||||||
self.num_logprobs[req_id] = sampling_params.logprobs
|
self.num_logprobs[req_id] = sampling_params.logprobs
|
||||||
if sampling_params.prompt_logprobs is not None:
|
if sampling_params.prompt_logprobs is not None:
|
||||||
self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs
|
self.num_prompt_logprobs[
|
||||||
if sampling_params.logit_bias is not None:
|
req_id] = sampling_params.prompt_logprobs
|
||||||
self.logit_bias[req_index] = sampling_params.logit_bias
|
if sampling_params.logit_bias is not None:
|
||||||
|
self.logit_bias[req_index] = sampling_params.logit_bias
|
||||||
|
|
||||||
if sampling_params.allowed_token_ids:
|
if sampling_params.allowed_token_ids:
|
||||||
self.has_allowed_token_ids.add(req_id)
|
self.has_allowed_token_ids.add(req_id)
|
||||||
if self.allowed_token_ids_mask_cpu_tensor is None:
|
if self.allowed_token_ids_mask_cpu_tensor is None:
|
||||||
# Lazy allocation for this tensor, which can be large.
|
# Lazy allocation for this tensor, which can be large.
|
||||||
|
# False means we don't fill with -inf.
|
||||||
|
self.allowed_token_ids_mask = torch.zeros(
|
||||||
|
self.max_num_reqs,
|
||||||
|
self.vocab_size,
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=self.device)
|
||||||
|
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
|
||||||
|
self.max_num_reqs,
|
||||||
|
self.vocab_size,
|
||||||
|
dtype=torch.bool,
|
||||||
|
device="cpu")
|
||||||
|
self.allowed_token_ids_mask_cpu_tensor[req_index] = True
|
||||||
# False means we don't fill with -inf.
|
# False means we don't fill with -inf.
|
||||||
self.allowed_token_ids_mask = torch.zeros(self.max_num_reqs,
|
self.allowed_token_ids_mask_cpu_tensor[req_index][
|
||||||
self.vocab_size,
|
sampling_params.allowed_token_ids] = False
|
||||||
dtype=torch.bool,
|
|
||||||
device=self.device)
|
|
||||||
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
|
|
||||||
self.max_num_reqs,
|
|
||||||
self.vocab_size,
|
|
||||||
dtype=torch.bool,
|
|
||||||
device="cpu")
|
|
||||||
self.allowed_token_ids_mask_cpu_tensor[req_index] = True
|
|
||||||
# False means we don't fill with -inf.
|
|
||||||
self.allowed_token_ids_mask_cpu_tensor[req_index][
|
|
||||||
sampling_params.allowed_token_ids] = False
|
|
||||||
|
|
||||||
if sampling_params.bad_words_token_ids:
|
if sampling_params.bad_words_token_ids:
|
||||||
self.bad_words_token_ids[
|
self.bad_words_token_ids[
|
||||||
req_index] = sampling_params.bad_words_token_ids
|
req_index] = sampling_params.bad_words_token_ids
|
||||||
|
else:
|
||||||
|
assert request.pooling_params is not None
|
||||||
|
self.pooling_params[req_id] = request.pooling_params
|
||||||
|
|
||||||
# Add request lora ID
|
# Add request lora ID
|
||||||
if request.lora_request:
|
if request.lora_request:
|
||||||
@ -392,6 +403,7 @@ class InputBatch:
|
|||||||
# False means we don't fill with -inf.
|
# False means we don't fill with -inf.
|
||||||
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
|
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
|
||||||
self.bad_words_token_ids.pop(req_index, None)
|
self.bad_words_token_ids.pop(req_index, None)
|
||||||
|
self.pooling_params.pop(req_id, None)
|
||||||
return req_index
|
return req_index
|
||||||
|
|
||||||
def swap_states(self, i1: int, i2: int) -> None:
|
def swap_states(self, i1: int, i2: int) -> None:
|
||||||
@ -602,6 +614,25 @@ class InputBatch:
|
|||||||
bad_words_token_ids=self.bad_words_token_ids,
|
bad_words_token_ids=self.bad_words_token_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pooling_metadata(self) -> PoolingMetadata:
|
||||||
|
if len(self.pooling_params) == 0:
|
||||||
|
pooling_params = []
|
||||||
|
else:
|
||||||
|
# Note, for now this assumes that all request in the batch
|
||||||
|
# are either sampling or pooling requests
|
||||||
|
assert len(self.req_ids) == len(self.pooling_params)
|
||||||
|
pooling_params = [
|
||||||
|
self.pooling_params[req_id] for req_id in self.req_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
return PoolingMetadata(
|
||||||
|
prompt_lens=torch.from_numpy(
|
||||||
|
self.num_prompt_tokens[:self.num_reqs]).to(self.device),
|
||||||
|
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
|
||||||
|
pooling_params=pooling_params,
|
||||||
|
)
|
||||||
|
|
||||||
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
||||||
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
|
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
|
||||||
prompt_token_ids_cpu_tensor = torch.empty(
|
prompt_token_ids_cpu_tensor = torch.empty(
|
||||||
|
|||||||
@ -36,6 +36,7 @@ from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
|
|||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||||
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.sampling_params import SamplingType
|
from vllm.sampling_params import SamplingType
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||||
@ -51,6 +52,7 @@ from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
|
|||||||
SlidingWindowSpec)
|
SlidingWindowSpec)
|
||||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
|
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
|
||||||
ModelRunnerOutput)
|
ModelRunnerOutput)
|
||||||
|
from vllm.v1.pool.metadata import PoolingMetadata
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.sample.rejection_sampler import RejectionSampler
|
from vllm.v1.sample.rejection_sampler import RejectionSampler
|
||||||
from vllm.v1.sample.sampler import Sampler
|
from vllm.v1.sample.sampler import Sampler
|
||||||
@ -119,6 +121,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
cache_config.cache_dtype]
|
cache_config.cache_dtype]
|
||||||
|
|
||||||
self.is_multimodal_model = model_config.is_multimodal_model
|
self.is_multimodal_model = model_config.is_multimodal_model
|
||||||
|
self.is_pooling_model = model_config.pooler_config is not None
|
||||||
self.max_model_len = model_config.max_model_len
|
self.max_model_len = model_config.max_model_len
|
||||||
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
||||||
self.max_num_reqs = scheduler_config.max_num_seqs
|
self.max_num_reqs = scheduler_config.max_num_seqs
|
||||||
@ -394,7 +397,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
for new_req_data in scheduler_output.scheduled_new_reqs:
|
for new_req_data in scheduler_output.scheduled_new_reqs:
|
||||||
req_id = new_req_data.req_id
|
req_id = new_req_data.req_id
|
||||||
sampling_params = new_req_data.sampling_params
|
sampling_params = new_req_data.sampling_params
|
||||||
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
pooling_params = new_req_data.pooling_params
|
||||||
|
if sampling_params and \
|
||||||
|
sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
||||||
generator = torch.Generator(device=self.device)
|
generator = torch.Generator(device=self.device)
|
||||||
generator.manual_seed(sampling_params.seed)
|
generator.manual_seed(sampling_params.seed)
|
||||||
else:
|
else:
|
||||||
@ -406,6 +411,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
mm_inputs=new_req_data.mm_inputs,
|
mm_inputs=new_req_data.mm_inputs,
|
||||||
mm_positions=new_req_data.mm_positions,
|
mm_positions=new_req_data.mm_positions,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
|
pooling_params=pooling_params,
|
||||||
generator=generator,
|
generator=generator,
|
||||||
block_ids=new_req_data.block_ids,
|
block_ids=new_req_data.block_ids,
|
||||||
num_computed_tokens=new_req_data.num_computed_tokens,
|
num_computed_tokens=new_req_data.num_computed_tokens,
|
||||||
@ -563,7 +569,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
) -> tuple[dict[str, Any], bool, torch.Tensor,
|
) -> tuple[dict[str, Any], bool, torch.Tensor,
|
||||||
Optional[SpecDecodeMetadata]]:
|
Optional[SpecDecodeMetadata], np.ndarray]:
|
||||||
"""
|
"""
|
||||||
:return: tuple[
|
:return: tuple[
|
||||||
attn_metadata: layer-to-attention_metadata mapping,
|
attn_metadata: layer-to-attention_metadata mapping,
|
||||||
@ -750,7 +756,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.set_active_loras(self.input_batch, num_scheduled_tokens)
|
self.set_active_loras(self.input_batch, num_scheduled_tokens)
|
||||||
|
|
||||||
return (attn_metadata, attention_cuda_graphs, logits_indices,
|
return (attn_metadata, attention_cuda_graphs, logits_indices,
|
||||||
spec_decode_metadata)
|
spec_decode_metadata, num_scheduled_tokens)
|
||||||
|
|
||||||
def _compute_cascade_attn_prefix_len(
|
def _compute_cascade_attn_prefix_len(
|
||||||
self,
|
self,
|
||||||
@ -1197,6 +1203,51 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
|
return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
|
||||||
|
|
||||||
|
def _pool(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
num_scheduled_tokens: int,
|
||||||
|
num_scheduled_tokens_np: np.ndarray,
|
||||||
|
finished_sending: Optional[set[str]],
|
||||||
|
finished_recving: Optional[set[str]],
|
||||||
|
) -> ModelRunnerOutput:
|
||||||
|
assert self.input_batch.num_reqs ==\
|
||||||
|
len(self.input_batch.pooling_params), \
|
||||||
|
"Either all or none of the requests in" \
|
||||||
|
" a batch must be pooling request"
|
||||||
|
|
||||||
|
extracted_hidden_states = list(
|
||||||
|
torch.split(hidden_states[:num_scheduled_tokens],
|
||||||
|
num_scheduled_tokens_np.tolist()))
|
||||||
|
|
||||||
|
pooling_metadata = self.input_batch.pooling_metadata
|
||||||
|
|
||||||
|
raw_pooler_output = self.model.pooler(
|
||||||
|
hidden_states=extracted_hidden_states,
|
||||||
|
pooling_metadata=pooling_metadata)
|
||||||
|
|
||||||
|
pooler_output: list[Optional[torch.Tensor]] = []
|
||||||
|
seq_lens = self.seq_lens[:self.input_batch.num_reqs]
|
||||||
|
for raw_output, seq_len, prompt_len in zip(
|
||||||
|
raw_pooler_output, seq_lens, pooling_metadata.prompt_lens):
|
||||||
|
|
||||||
|
if seq_len == prompt_len:
|
||||||
|
pooler_output.append(raw_output.data.cpu())
|
||||||
|
else:
|
||||||
|
pooler_output.append(None)
|
||||||
|
|
||||||
|
return ModelRunnerOutput(
|
||||||
|
req_ids=self.input_batch.req_ids,
|
||||||
|
req_id_to_index=self.input_batch.req_id_to_index,
|
||||||
|
sampled_token_ids=[],
|
||||||
|
spec_token_ids=None,
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={},
|
||||||
|
pooler_output=pooler_output,
|
||||||
|
finished_sending=finished_sending,
|
||||||
|
finished_recving=finished_recving,
|
||||||
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
@ -1214,7 +1265,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
# Prepare the decoder inputs.
|
# Prepare the decoder inputs.
|
||||||
(attn_metadata, attention_cuda_graphs, logits_indices,
|
(attn_metadata, attention_cuda_graphs, logits_indices,
|
||||||
spec_decode_metadata) = (self._prepare_inputs(scheduler_output))
|
spec_decode_metadata,
|
||||||
|
num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output))
|
||||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
if (self.use_cuda_graph
|
if (self.use_cuda_graph
|
||||||
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
|
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
|
||||||
@ -1284,7 +1336,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# compiled with full CUDA graphs, we have to skip them entirely.
|
# compiled with full CUDA graphs, we have to skip them entirely.
|
||||||
skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
|
skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
|
||||||
|
|
||||||
# Run the decoder.
|
# Run the model.
|
||||||
# Use persistent buffers for CUDA graphs.
|
# Use persistent buffers for CUDA graphs.
|
||||||
with set_forward_context(
|
with set_forward_context(
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
@ -1326,6 +1378,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
all_gather_group=get_tp_group())
|
all_gather_group=get_tp_group())
|
||||||
logits = None
|
logits = None
|
||||||
else:
|
else:
|
||||||
|
if self.input_batch.pooling_params:
|
||||||
|
return self._pool(hidden_states, num_scheduled_tokens,
|
||||||
|
num_scheduled_tokens_np, finished_sending,
|
||||||
|
finished_recving)
|
||||||
|
|
||||||
sample_hidden_states = hidden_states[logits_indices]
|
sample_hidden_states = hidden_states[logits_indices]
|
||||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||||
if broadcast_pp_output:
|
if broadcast_pp_output:
|
||||||
@ -1541,6 +1598,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
spec_token_ids=spec_token_ids,
|
spec_token_ids=spec_token_ids,
|
||||||
logprobs=logprobs_lists,
|
logprobs=logprobs_lists,
|
||||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||||
|
pooler_output=[],
|
||||||
finished_sending=finished_sending,
|
finished_sending=finished_sending,
|
||||||
finished_recving=finished_recving,
|
finished_recving=finished_recving,
|
||||||
)
|
)
|
||||||
@ -1802,7 +1860,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self,
|
self,
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
capture_attn_cudagraph: bool = False,
|
capture_attn_cudagraph: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
# Padding for DP
|
# Padding for DP
|
||||||
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
|
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
|
||||||
@ -1899,7 +1957,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.drafter.dummy_run(num_tokens)
|
self.drafter.dummy_run(num_tokens)
|
||||||
|
|
||||||
logit_indices = np.cumsum(num_scheduled_tokens) - 1
|
logit_indices = np.cumsum(num_scheduled_tokens) - 1
|
||||||
return hidden_states[logit_indices]
|
return hidden_states, hidden_states[logit_indices]
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def _dummy_sampler_run(
|
def _dummy_sampler_run(
|
||||||
@ -1978,6 +2036,48 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
)
|
)
|
||||||
return sampler_output
|
return sampler_output
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def _dummy_pooler_run(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
num_tokens = hidden_states.shape[0]
|
||||||
|
max_num_reqs = self.scheduler_config.max_num_seqs
|
||||||
|
num_reqs = min(num_tokens, max_num_reqs)
|
||||||
|
min_tokens_per_req = num_tokens // num_reqs
|
||||||
|
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
|
||||||
|
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
|
||||||
|
assert sum(num_scheduled_tokens_list) == num_tokens
|
||||||
|
assert len(num_scheduled_tokens_list) == num_reqs
|
||||||
|
|
||||||
|
hidden_states_list = list(
|
||||||
|
torch.split(hidden_states, num_scheduled_tokens_list))
|
||||||
|
|
||||||
|
req_num_tokens = num_tokens // num_reqs
|
||||||
|
|
||||||
|
dummy_metadata = PoolingMetadata(
|
||||||
|
prompt_lens=torch.tensor([h.shape[0] for h in hidden_states_list],
|
||||||
|
device=self.device),
|
||||||
|
prompt_token_ids=torch.zeros((num_reqs, req_num_tokens),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device),
|
||||||
|
pooling_params=[PoolingParams()] * num_reqs)
|
||||||
|
|
||||||
|
try:
|
||||||
|
pooler_output = self.model.pooler(hidden_states=hidden_states_list,
|
||||||
|
pooling_metadata=dummy_metadata)
|
||||||
|
except RuntimeError as e:
|
||||||
|
if 'out of memory' in str(e):
|
||||||
|
raise RuntimeError(
|
||||||
|
"CUDA out of memory occurred when warming up pooler with "
|
||||||
|
f"{num_reqs} dummy requests. Please try lowering "
|
||||||
|
"`max_num_seqs` or `gpu_memory_utilization` when "
|
||||||
|
"initializing the engine.") from e
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
return pooler_output
|
||||||
|
|
||||||
def profile_run(self) -> None:
|
def profile_run(self) -> None:
|
||||||
# Profile with multimodal encoder & encoder cache.
|
# Profile with multimodal encoder & encoder cache.
|
||||||
# TODO: handle encoder-decoder models once we support them.
|
# TODO: handle encoder-decoder models once we support them.
|
||||||
@ -2048,13 +2148,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# Cache the dummy encoder outputs.
|
# Cache the dummy encoder outputs.
|
||||||
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
|
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
|
||||||
|
|
||||||
hidden_states = self._dummy_run(self.max_num_tokens)
|
hidden_states, last_hidden_states \
|
||||||
|
= self._dummy_run(self.max_num_tokens)
|
||||||
if get_pp_group().is_last_rank:
|
if get_pp_group().is_last_rank:
|
||||||
sampler_output = self._dummy_sampler_run(hidden_states)
|
if self.is_pooling_model:
|
||||||
|
output = self._dummy_pooler_run(hidden_states)
|
||||||
|
else:
|
||||||
|
output = self._dummy_sampler_run(last_hidden_states)
|
||||||
else:
|
else:
|
||||||
sampler_output = None
|
output = None
|
||||||
self._sync_device()
|
self._sync_device()
|
||||||
del hidden_states, sampler_output
|
del hidden_states, output
|
||||||
self.encoder_cache.clear()
|
self.encoder_cache.clear()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
|||||||
@ -273,9 +273,14 @@ class Worker(WorkerBase):
|
|||||||
if get_pp_group().is_last_rank:
|
if get_pp_group().is_last_rank:
|
||||||
max_num_reqs = min(self.scheduler_config.max_num_seqs,
|
max_num_reqs = min(self.scheduler_config.max_num_seqs,
|
||||||
self.scheduler_config.max_num_batched_tokens)
|
self.scheduler_config.max_num_batched_tokens)
|
||||||
self.model_runner._dummy_sampler_run(
|
|
||||||
hidden_states=self.model_runner._dummy_run(
|
hidden_states, last_hidden_states = \
|
||||||
num_tokens=max_num_reqs))
|
self.model_runner._dummy_run(num_tokens=max_num_reqs)
|
||||||
|
if self.model_runner.is_pooling_model:
|
||||||
|
self.model_runner._dummy_pooler_run(hidden_states)
|
||||||
|
else:
|
||||||
|
self.model_runner._dummy_sampler_run(
|
||||||
|
hidden_states=last_hidden_states)
|
||||||
|
|
||||||
# Reset the seed to ensure that the random state is not affected by
|
# Reset the seed to ensure that the random state is not affected by
|
||||||
# the model initialization and profiling.
|
# the model initialization and profiling.
|
||||||
|
|||||||
@ -231,6 +231,7 @@ class InputBatch:
|
|||||||
self.block_table.add_row(request.block_ids, req_index)
|
self.block_table.add_row(request.block_ids, req_index)
|
||||||
|
|
||||||
sampling_params = request.sampling_params
|
sampling_params = request.sampling_params
|
||||||
|
assert sampling_params is not None, "pooling requests not supported yet"
|
||||||
if sampling_params.sampling_type == SamplingType.GREEDY:
|
if sampling_params.sampling_type == SamplingType.GREEDY:
|
||||||
# Avoid later division by zero.
|
# Avoid later division by zero.
|
||||||
self.temperature_cpu[req_index] = -1.0
|
self.temperature_cpu[req_index] = -1.0
|
||||||
|
|||||||
@ -386,6 +386,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
req_ids_to_add: list[str] = []
|
req_ids_to_add: list[str] = []
|
||||||
# Add new requests to the cached states.
|
# Add new requests to the cached states.
|
||||||
for new_req_data in scheduler_output.scheduled_new_reqs:
|
for new_req_data in scheduler_output.scheduled_new_reqs:
|
||||||
|
assert new_req_data.sampling_params is not None,\
|
||||||
|
"Pooling is not supported in TPU yet"
|
||||||
req_id = new_req_data.req_id
|
req_id = new_req_data.req_id
|
||||||
sampling_params = new_req_data.sampling_params
|
sampling_params = new_req_data.sampling_params
|
||||||
|
|
||||||
@ -395,6 +397,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
mm_inputs=new_req_data.mm_inputs,
|
mm_inputs=new_req_data.mm_inputs,
|
||||||
mm_positions=new_req_data.mm_positions,
|
mm_positions=new_req_data.mm_positions,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
|
pooling_params=None,
|
||||||
generator=None,
|
generator=None,
|
||||||
block_ids=new_req_data.block_ids,
|
block_ids=new_req_data.block_ids,
|
||||||
num_computed_tokens=new_req_data.num_computed_tokens,
|
num_computed_tokens=new_req_data.num_computed_tokens,
|
||||||
@ -956,6 +959,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
spec_token_ids=None,
|
spec_token_ids=None,
|
||||||
logprobs=logprobs_lists,
|
logprobs=logprobs_lists,
|
||||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||||
|
pooler_output=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check there are no new graphs compiled - all the graphs should be
|
# Check there are no new graphs compiled - all the graphs should be
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user