mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-01 10:47:03 +08:00
Support embedding models in V1 (#16188)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Max de Bayser <maxdebayser@gmail.com> Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
parent
4959915089
commit
799397ee4f
@ -12,7 +12,10 @@ def parse_args():
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(
|
||||
model="intfloat/e5-mistral-7b-instruct", task="embed", enforce_eager=True
|
||||
model="intfloat/e5-mistral-7b-instruct",
|
||||
task="embed",
|
||||
enforce_eager=True,
|
||||
max_model_len=1024,
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -94,6 +94,7 @@ def run_vlm2vec(query: Query) -> ModelRequestData:
|
||||
engine_args = EngineArgs(
|
||||
model="TIGER-Lab/VLM2Vec-Full",
|
||||
task="embed",
|
||||
max_model_len=4096,
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs={"num_crops": 4},
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
|
||||
@ -31,7 +31,7 @@ class TestSetting:
|
||||
# basic llama model
|
||||
TestSetting(
|
||||
model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
model_args=[],
|
||||
model_args=["--max-model-len", "2048"],
|
||||
pp_size=2,
|
||||
tp_size=2,
|
||||
attn_backend="FLASHINFER",
|
||||
@ -41,7 +41,7 @@ class TestSetting:
|
||||
# llama model with quantization
|
||||
TestSetting(
|
||||
model="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
|
||||
model_args=["--quantization", "gptq"],
|
||||
model_args=["--quantization", "gptq", "--max-model-len", "2048"],
|
||||
pp_size=1,
|
||||
tp_size=1,
|
||||
attn_backend="FLASH_ATTN",
|
||||
@ -51,7 +51,7 @@ class TestSetting:
|
||||
# MoE model
|
||||
TestSetting(
|
||||
model="ibm/PowerMoE-3b",
|
||||
model_args=[],
|
||||
model_args=["--max-model-len", "2048"],
|
||||
pp_size=1,
|
||||
tp_size=2,
|
||||
attn_backend="FLASH_ATTN",
|
||||
@ -61,23 +61,27 @@ class TestSetting:
|
||||
# embedding model
|
||||
TestSetting(
|
||||
model="BAAI/bge-multilingual-gemma2",
|
||||
model_args=["--task", "embed", "--dtype", "bfloat16"],
|
||||
model_args=[
|
||||
"--task", "embed", "--dtype", "bfloat16", "--max-model-len",
|
||||
"2048"
|
||||
],
|
||||
pp_size=1,
|
||||
tp_size=1,
|
||||
attn_backend="FLASH_ATTN",
|
||||
method="encode",
|
||||
fullgraph=True,
|
||||
),
|
||||
# encoder-based embedding model (BERT)
|
||||
TestSetting(
|
||||
model="BAAI/bge-base-en-v1.5",
|
||||
model_args=["--task", "embed"],
|
||||
pp_size=1,
|
||||
tp_size=1,
|
||||
attn_backend="XFORMERS",
|
||||
method="encode",
|
||||
fullgraph=True,
|
||||
),
|
||||
# TODO: bert models are not supported in V1 yet
|
||||
# # encoder-based embedding model (BERT)
|
||||
# TestSetting(
|
||||
# model="BAAI/bge-base-en-v1.5",
|
||||
# model_args=["--task", "embed"],
|
||||
# pp_size=1,
|
||||
# tp_size=1,
|
||||
# attn_backend="XFORMERS",
|
||||
# method="encode",
|
||||
# fullgraph=True,
|
||||
# ),
|
||||
# vision language model
|
||||
TestSetting(
|
||||
model="microsoft/Phi-3.5-vision-instruct",
|
||||
|
||||
@ -145,6 +145,7 @@ def run_with_both_engines(request, monkeypatch):
|
||||
# Automatically runs tests twice, once with V1 and once without
|
||||
use_v1 = request.param
|
||||
# Tests decorated with `@skip_v1` are only run without v1
|
||||
skip_v0 = request.node.get_closest_marker("skip_v0")
|
||||
skip_v1 = request.node.get_closest_marker("skip_v1")
|
||||
|
||||
if use_v1:
|
||||
@ -152,6 +153,8 @@ def run_with_both_engines(request, monkeypatch):
|
||||
pytest.skip("Skipping test on vllm V1")
|
||||
monkeypatch.setenv('VLLM_USE_V1', '1')
|
||||
else:
|
||||
if skip_v0:
|
||||
pytest.skip("Skipping test on vllm V0")
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
|
||||
yield
|
||||
|
||||
@ -8,6 +8,8 @@ import pytest
|
||||
from vllm import LLM, PoolingParams, PoolingRequestOutput
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
|
||||
from ...models.utils import check_embeddings_close
|
||||
|
||||
MODEL_NAME = "intfloat/multilingual-e5-small"
|
||||
|
||||
PROMPTS = [
|
||||
@ -27,6 +29,14 @@ TOKEN_IDS = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines):
|
||||
# Simple autouse wrapper to run both engines for each test
|
||||
# This can be promoted up to conftest.py to run for every
|
||||
# test in a package
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def llm():
|
||||
# pytest caches the fixture so we use weakref.proxy to
|
||||
@ -46,9 +56,15 @@ def llm():
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
def assert_outputs_equal(o1: list[PoolingRequestOutput],
|
||||
def assert_outputs_match(o1: list[PoolingRequestOutput],
|
||||
o2: list[PoolingRequestOutput]):
|
||||
assert [o.outputs for o in o1] == [o.outputs for o in o2]
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=[o.outputs.data for o in o1],
|
||||
embeddings_1_lst=[o.outputs.data for o in o2],
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
tol=1e-2,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@ -63,7 +79,7 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
|
||||
|
||||
v2_output = llm.encode({"prompt_token_ids": prompt_token_ids},
|
||||
pooling_params=pooling_params)
|
||||
assert_outputs_equal(v1_output, v2_output)
|
||||
assert_outputs_match(v1_output, v2_output)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@ -80,7 +96,7 @@ def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
|
||||
} for p in TOKEN_IDS],
|
||||
pooling_params=pooling_params,
|
||||
)
|
||||
assert_outputs_equal(v1_output, v2_output)
|
||||
assert_outputs_match(v1_output, v2_output)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
|
||||
@ -21,6 +21,14 @@ DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' +
|
||||
DTYPE = "bfloat16"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines):
|
||||
# Simple autouse wrapper to run both engines for each test
|
||||
# This can be promoted up to conftest.py to run for every
|
||||
# test in a package
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
|
||||
@ -7,6 +7,7 @@ import numpy as np
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.models.utils import check_embeddings_close
|
||||
from vllm.entrypoints.openai.protocol import PoolingResponse
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
@ -223,8 +224,11 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
|
||||
np.frombuffer(base64.b64decode(data.data),
|
||||
dtype="float32").tolist())
|
||||
|
||||
assert responses_float.data[0].data == decoded_responses_base64_data[0]
|
||||
assert responses_float.data[1].data == decoded_responses_base64_data[1]
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=[d.data for d in responses_float.data],
|
||||
embeddings_1_lst=decoded_responses_base64_data,
|
||||
name_0="float32",
|
||||
name_1="base64")
|
||||
|
||||
# Default response is float32 decoded from base64 by OpenAI Client
|
||||
default_response = requests.post(
|
||||
@ -237,5 +241,8 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
|
||||
default_response.raise_for_status()
|
||||
responses_default = PoolingResponse.model_validate(default_response.json())
|
||||
|
||||
assert responses_float.data[0].data == responses_default.data[0].data
|
||||
assert responses_float.data[1].data == responses_default.data[1].data
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=[d.data for d in responses_default.data],
|
||||
embeddings_1_lst=[d.data for d in responses_default.data],
|
||||
name_0="float32",
|
||||
name_1="base64")
|
||||
|
||||
@ -12,6 +12,14 @@ MODEL_NAME = "BAAI/bge-reranker-base"
|
||||
DTYPE = "bfloat16"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines):
|
||||
# Simple autouse wrapper to run both engines for each test
|
||||
# This can be promoted up to conftest.py to run for every
|
||||
# test in a package
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = ["--enforce-eager", "--max-model-len", "100", "--dtype", DTYPE]
|
||||
|
||||
@ -11,6 +11,15 @@ from vllm.entrypoints.openai.protocol import ScoreResponse
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines):
|
||||
# Simple autouse wrapper to run both engines for each test
|
||||
# This can be promoted up to conftest.py to run for every
|
||||
# test in a package
|
||||
pass
|
||||
|
||||
|
||||
MODELS = [
|
||||
{
|
||||
"name": "BAAI/bge-reranker-v2-m3",
|
||||
|
||||
@ -6,6 +6,14 @@ from transformers import AutoModelForSequenceClassification
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# TODO: enable when float32 is supported by V1
|
||||
# @pytest.fixture(autouse=True)
|
||||
# def v1(run_with_both_engines):
|
||||
# # Simple autouse wrapper to run both engines for each test
|
||||
# # This can be promoted up to conftest.py to run for every
|
||||
# # test in a package
|
||||
# pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
@ -29,7 +37,7 @@ def test_models(
|
||||
# switch to use ROCm CK FA backend
|
||||
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
|
||||
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
with vllm_runner(model, max_model_len=512, dtype=dtype) as vllm_model:
|
||||
vllm_outputs = vllm_model.classify(example_prompts)
|
||||
|
||||
with hf_runner(model,
|
||||
|
||||
@ -8,6 +8,14 @@ from vllm.platforms import current_platform
|
||||
from ...utils import check_embeddings_close
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines):
|
||||
# Simple autouse wrapper to run both engines for each test
|
||||
# This can be promoted up to conftest.py to run for every
|
||||
# test in a package
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
@ -20,15 +28,27 @@ from ...utils import check_embeddings_close
|
||||
marks=[pytest.mark.core_model]),
|
||||
pytest.param("intfloat/e5-mistral-7b-instruct",
|
||||
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
|
||||
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"),
|
||||
# the qwen models interfere with each other (see PR
|
||||
# https://github.com/vllm-project/vllm/pull/18720).
|
||||
# To avoid this problem, for now we skip v0 since it will be
|
||||
# deprecated anyway.
|
||||
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base",
|
||||
marks=[pytest.mark.skip_v0]),
|
||||
# [Encoder-only]
|
||||
pytest.param("BAAI/bge-base-en-v1.5",
|
||||
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
|
||||
pytest.param("sentence-transformers/all-MiniLM-L12-v2"),
|
||||
pytest.param("intfloat/multilingual-e5-small"),
|
||||
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"),
|
||||
marks=[
|
||||
pytest.mark.core_model, pytest.mark.cpu_model,
|
||||
pytest.mark.skip_v1
|
||||
]),
|
||||
pytest.param("sentence-transformers/all-MiniLM-L12-v2",
|
||||
marks=[pytest.mark.skip_v1]),
|
||||
pytest.param("intfloat/multilingual-e5-small",
|
||||
marks=[pytest.mark.skip_v1]),
|
||||
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
||||
marks=[pytest.mark.skip_v1]),
|
||||
# [Cross-Encoder]
|
||||
pytest.param("sentence-transformers/stsb-roberta-base-v2"),
|
||||
pytest.param("sentence-transformers/stsb-roberta-base-v2",
|
||||
marks=[pytest.mark.skip_v1]),
|
||||
],
|
||||
)
|
||||
def test_models(
|
||||
@ -62,7 +82,7 @@ def test_models(
|
||||
|
||||
with vllm_runner(model,
|
||||
task="embed",
|
||||
max_model_len=None,
|
||||
max_model_len=512,
|
||||
**vllm_extra_kwargs) as vllm_model:
|
||||
vllm_outputs = vllm_model.encode(example_prompts)
|
||||
|
||||
|
||||
@ -265,8 +265,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
|
||||
_EMBEDDING_EXAMPLE_MODELS = {
|
||||
# [Text-only]
|
||||
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
|
||||
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"),
|
||||
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5", v0_only=True),
|
||||
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2", v0_only=True), # noqa: E501
|
||||
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
|
||||
"GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
|
||||
trust_remote_code=True),
|
||||
@ -279,16 +279,16 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
||||
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
|
||||
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
|
||||
"ModernBertModel": _HfExamplesInfo("Alibaba-NLP/gte-modernbert-base",
|
||||
trust_remote_code=True),
|
||||
trust_remote_code=True, v0_only=True),
|
||||
"NomicBertModel": _HfExamplesInfo("nomic-ai/nomic-embed-text-v2-moe",
|
||||
trust_remote_code=True),
|
||||
trust_remote_code=True, v0_only=True), # noqa: E501
|
||||
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),
|
||||
"Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"),
|
||||
"Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B"),
|
||||
"Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501
|
||||
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"), # noqa: E501
|
||||
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), # noqa: E501
|
||||
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"),
|
||||
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2", v0_only=True), # noqa: E501
|
||||
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1", v0_only=True), # noqa: E501
|
||||
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small", v0_only=True), # noqa: E501
|
||||
# [Multimodal]
|
||||
"LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"),
|
||||
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",
|
||||
@ -300,10 +300,10 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
||||
|
||||
_CROSS_ENCODER_EXAMPLE_MODELS = {
|
||||
# [Text-only]
|
||||
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2"), # noqa: E501
|
||||
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base"), # noqa: E501
|
||||
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3"), # noqa: E501
|
||||
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base"), # noqa: E501
|
||||
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501
|
||||
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501
|
||||
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501
|
||||
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501
|
||||
}
|
||||
|
||||
_MULTIMODAL_EXAMPLE_MODELS = {
|
||||
|
||||
@ -68,6 +68,7 @@ def _run_incremental_decode(tokenizer,
|
||||
None,
|
||||
params,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
None,
|
||||
cache_salt=None,
|
||||
|
||||
@ -43,6 +43,7 @@ def make_request(request_id,
|
||||
multi_modal_hashes=mm_hashes,
|
||||
multi_modal_placeholders=mm_positions,
|
||||
sampling_params=SamplingParams(max_tokens=17),
|
||||
pooling_params=None,
|
||||
eos_token_id=100,
|
||||
lora_request=None,
|
||||
cache_salt=cache_salt,
|
||||
|
||||
@ -39,6 +39,7 @@ def make_request(request_id,
|
||||
multi_modal_placeholders=mm_positions,
|
||||
sampling_params=SamplingParams(max_tokens=17,
|
||||
prompt_logprobs=prompt_logprobs),
|
||||
pooling_params=None,
|
||||
eos_token_id=100,
|
||||
lora_request=None,
|
||||
cache_salt=cache_salt,
|
||||
|
||||
@ -135,6 +135,7 @@ def create_requests(num_requests: int,
|
||||
request_id=f"{i}",
|
||||
prompt_token_ids=[i] * num_tokens,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=None,
|
||||
multi_modal_inputs=mm_inputs,
|
||||
multi_modal_placeholders=mm_position,
|
||||
multi_modal_hashes=None,
|
||||
@ -283,6 +284,7 @@ def test_schedule_partial_requests():
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
scheduler.update_from_output(output, model_runner_output)
|
||||
|
||||
@ -333,6 +335,7 @@ def test_no_mm_input_chunking():
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
scheduler.update_from_output(output, model_runner_output)
|
||||
|
||||
@ -396,6 +399,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
scheduler.update_from_output(output, model_runner_output)
|
||||
|
||||
@ -420,6 +424,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
scheduler.update_from_output(output1, model_runner_output)
|
||||
output2 = scheduler.schedule()
|
||||
@ -473,7 +478,8 @@ def test_stop_via_update_from_output():
|
||||
11]], # First request hits EOS, second continues
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={})
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_output)
|
||||
|
||||
@ -523,7 +529,8 @@ def test_stop_via_update_from_output():
|
||||
[13, 14]], # First request hits stop token
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={})
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_output)
|
||||
|
||||
@ -572,7 +579,8 @@ def test_stop_via_update_from_output():
|
||||
[13]], # First request exceeds max_tokens
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={})
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_output)
|
||||
|
||||
@ -614,7 +622,8 @@ def test_stop_via_update_from_output():
|
||||
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={})
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_output)
|
||||
|
||||
@ -663,6 +672,7 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output0, model_runner_output)
|
||||
|
||||
@ -680,6 +690,7 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output1, model_runner_output)
|
||||
|
||||
@ -730,6 +741,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
|
||||
spec_token_ids=spec_tokens,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
engine_core_outputs = scheduler.update_from_output(output,
|
||||
model_runner_output)
|
||||
@ -769,6 +781,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
engine_core_outputs = scheduler.update_from_output(output,
|
||||
model_runner_output)
|
||||
@ -896,6 +909,7 @@ def test_kv_connector_basic():
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
|
||||
# Ensure ScheduleOutput is correct.
|
||||
@ -941,6 +955,7 @@ def test_kv_connector_basic():
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
|
||||
# We should get a local cache hit of NUM_TOKENS_PREFIX and
|
||||
@ -1007,6 +1022,7 @@ def test_kv_connector_unable_to_allocate():
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
|
||||
# Just one request should be running.
|
||||
@ -1087,6 +1103,7 @@ def test_kv_connector_handles_preemption():
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
|
||||
# All can be scheduled - 1st token.
|
||||
@ -1181,6 +1198,7 @@ def make_output(scheduler: Scheduler):
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -39,6 +39,7 @@ def make_request() -> EngineCoreRequest:
|
||||
mm_hashes=None,
|
||||
mm_placeholders=None,
|
||||
sampling_params=SamplingParams(),
|
||||
pooling_params=None,
|
||||
eos_token_id=None,
|
||||
arrival_time=time.time(),
|
||||
lora_request=None,
|
||||
|
||||
@ -53,6 +53,7 @@ def make_request(
|
||||
mm_hashes=None,
|
||||
mm_placeholders=None,
|
||||
sampling_params=params,
|
||||
pooling_params=None,
|
||||
eos_token_id=None,
|
||||
arrival_time=time.time(),
|
||||
lora_request=None,
|
||||
|
||||
@ -33,6 +33,7 @@ def test_fast_inc_detok_invalid_utf8_err_case():
|
||||
None,
|
||||
params,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
None,
|
||||
cache_salt=None,
|
||||
|
||||
@ -66,7 +66,8 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
|
||||
output_kind=request_output_kind,
|
||||
stop=[],
|
||||
include_stop_str_in_output=False,
|
||||
))
|
||||
),
|
||||
pooling_params=None)
|
||||
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||
]
|
||||
|
||||
@ -416,7 +417,8 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
|
||||
include_stop_str_in_output=False,
|
||||
logprobs=num_sample_logprobs,
|
||||
prompt_logprobs=num_prompt_logprobs,
|
||||
))
|
||||
),
|
||||
pooling_params=None)
|
||||
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||
]
|
||||
|
||||
@ -582,7 +584,8 @@ def test_stop_token(include_stop_str_in_output: bool,
|
||||
logprobs=num_sample_logprobs,
|
||||
prompt_logprobs=None,
|
||||
ignore_eos=ignore_eos,
|
||||
))
|
||||
),
|
||||
pooling_params=None)
|
||||
|
||||
# Add request to the detokenizer.
|
||||
output_processor.add_request(request, prompt_string)
|
||||
@ -678,7 +681,8 @@ def test_stop_string(include_stop_str_in_output: bool,
|
||||
include_stop_str_in_output=include_stop_str_in_output,
|
||||
logprobs=num_sample_logprobs,
|
||||
prompt_logprobs=None,
|
||||
))
|
||||
),
|
||||
pooling_params=None)
|
||||
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||
]
|
||||
|
||||
@ -786,6 +790,7 @@ def test_iteration_stats(dummy_test_vectors):
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
sampling_params=SamplingParams(),
|
||||
pooling_params=None,
|
||||
) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||
]
|
||||
|
||||
|
||||
@ -150,6 +150,7 @@ def create_request(
|
||||
request_id=f"id-{request_id}",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=None,
|
||||
multi_modal_inputs=None,
|
||||
multi_modal_placeholders=None,
|
||||
multi_modal_hashes=None,
|
||||
@ -183,6 +184,7 @@ def create_model_runner_output(
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=None,
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
)
|
||||
|
||||
@ -10,6 +10,7 @@ import torch
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
@ -46,7 +47,7 @@ def _compare_objs(obj1, obj2):
|
||||
for a_i, b_i in zip(a.block_tables, b.block_tables):
|
||||
_compare_objs(a_i, b_i)
|
||||
is_same = True
|
||||
elif isinstance(a, (BlockTable, SamplingMetadata)):
|
||||
elif isinstance(a, (BlockTable, SamplingMetadata, PoolingMetadata)):
|
||||
_compare_objs(a, b)
|
||||
is_same = True # if we make it here must be same
|
||||
elif a == b:
|
||||
@ -201,6 +202,7 @@ def _construct_cached_request_state(req_id_suffix: int):
|
||||
req_id=f"req_id_{req_id_suffix}",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=_create_sampling_params(),
|
||||
pooling_params=None,
|
||||
mm_inputs=[],
|
||||
mm_positions=[],
|
||||
block_ids=([], ),
|
||||
|
||||
@ -122,6 +122,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
||||
mm_hashes=[],
|
||||
mm_positions=[],
|
||||
sampling_params=SamplingParams(),
|
||||
pooling_params=None,
|
||||
block_ids=([0], ),
|
||||
num_computed_tokens=0,
|
||||
lora_request=None,
|
||||
|
||||
@ -4496,11 +4496,31 @@ class VllmConfig:
|
||||
|
||||
if self.compilation_config.full_cuda_graph and \
|
||||
not self.model_config.disable_cascade_attn:
|
||||
logger.warning_once(
|
||||
"full_cuda_graph is not supported with "
|
||||
"cascade attention. Disabling cascade attention.")
|
||||
logger.info("full_cuda_graph is not supported with "
|
||||
"cascade attention. Disabling cascade attention.")
|
||||
self.model_config.disable_cascade_attn = True
|
||||
|
||||
disable_chunked_prefill_reasons: list[str] = []
|
||||
|
||||
if self.model_config and self.model_config.pooler_config:
|
||||
pooling_type = self.model_config.pooler_config.pooling_type
|
||||
if pooling_type is None or pooling_type.lower() != "last":
|
||||
disable_chunked_prefill_reasons.append(
|
||||
"Only \"last\" pooling supports chunked "
|
||||
"prefill and prefix caching; disabling both.")
|
||||
|
||||
if disable_chunked_prefill_reasons:
|
||||
for reason in disable_chunked_prefill_reasons:
|
||||
logger.info(reason)
|
||||
self.scheduler_config.chunked_prefill_enabled = False
|
||||
self.scheduler_config.long_prefill_token_threshold = 0
|
||||
self.scheduler_config.max_num_batched_tokens = max(
|
||||
self.scheduler_config.max_model_len,
|
||||
DEFAULT_MAX_NUM_BATCHED_TOKENS)
|
||||
|
||||
if self.cache_config is not None:
|
||||
self.cache_config.enable_prefix_caching = False
|
||||
|
||||
if (self.kv_events_config is not None
|
||||
and self.kv_events_config.enable_kv_cache_events
|
||||
and not self.cache_config.enable_prefix_caching):
|
||||
|
||||
@ -1041,7 +1041,7 @@ class EngineArgs:
|
||||
|
||||
# Set default arguments for V0 or V1 Engine.
|
||||
if use_v1:
|
||||
self._set_default_args_v1(usage_context)
|
||||
self._set_default_args_v1(usage_context, model_config)
|
||||
else:
|
||||
self._set_default_args_v0(model_config)
|
||||
|
||||
@ -1349,13 +1349,7 @@ class EngineArgs:
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
# No Embedding Models so far.
|
||||
if model_config.task not in ["generate"]:
|
||||
_raise_or_fallback(feature_name=f"--task {model_config.task}",
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
# No Encoder-Decoder, not all Mamba so far.
|
||||
# No Mamba or Encoder-Decoder so far.
|
||||
if not model_config.is_v1_compatible:
|
||||
_raise_or_fallback(feature_name=model_config.architectures,
|
||||
recommend_to_remove=False)
|
||||
@ -1523,15 +1517,38 @@ class EngineArgs:
|
||||
if self.max_num_seqs is None:
|
||||
self.max_num_seqs = 256
|
||||
|
||||
def _set_default_args_v1(self, usage_context: UsageContext) -> None:
|
||||
def _set_default_args_v1(self, usage_context: UsageContext,
|
||||
model_config: ModelConfig) -> None:
|
||||
"""Set Default Arguments for V1 Engine."""
|
||||
|
||||
# V1 always uses chunked prefills.
|
||||
self.enable_chunked_prefill = True
|
||||
# V1 always uses chunked prefills and prefix caching
|
||||
# for non-pooling tasks.
|
||||
# For pooling tasks the default is False
|
||||
if model_config.runner_type != "pooling":
|
||||
self.enable_chunked_prefill = True
|
||||
if self.enable_prefix_caching is None:
|
||||
self.enable_prefix_caching = True
|
||||
else:
|
||||
|
||||
# V1 enables prefix caching by default.
|
||||
if self.enable_prefix_caching is None:
|
||||
self.enable_prefix_caching = True
|
||||
pooling_type = model_config.pooler_config.pooling_type
|
||||
|
||||
# TODO: when encoder models are supported we'll have to
|
||||
# check for causal attention here.
|
||||
incremental_prefill_supported = (pooling_type is not None and
|
||||
pooling_type.lower() == "last")
|
||||
|
||||
action = "Enabling" if \
|
||||
incremental_prefill_supported else "Disabling"
|
||||
|
||||
if self.enable_chunked_prefill is None:
|
||||
self.enable_chunked_prefill = incremental_prefill_supported
|
||||
logger.info("(%s) chunked prefill by default", action)
|
||||
if self.enable_prefix_caching is None:
|
||||
self.enable_prefix_caching = incremental_prefill_supported
|
||||
logger.info("(%s) prefix caching by default", action)
|
||||
|
||||
if not self.enable_chunked_prefill:
|
||||
self.max_num_batched_tokens = model_config.max_model_len
|
||||
|
||||
# V1 should use the new scheduler by default.
|
||||
# Swap it only if this arg is set to the original V0 default
|
||||
|
||||
@ -1266,7 +1266,7 @@ class LLM:
|
||||
# the tokenizer for models such as
|
||||
# "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
|
||||
# lists of tokens to the `text` and `text_pair` kwargs
|
||||
tokenizer = self.llm_engine.get_tokenizer()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
def ensure_str(prompt: SingletonPrompt):
|
||||
if isinstance(prompt, dict):
|
||||
|
||||
@ -9,6 +9,7 @@ from typing import Final, Literal, Optional, Union, cast
|
||||
|
||||
import jinja2
|
||||
import numpy as np
|
||||
import torch
|
||||
from fastapi import Request
|
||||
from typing_extensions import assert_never
|
||||
|
||||
@ -39,7 +40,8 @@ def _get_data(
|
||||
elif encoding_format == "base64":
|
||||
# Force to use float32 for base64 encoding
|
||||
# to match the OpenAI python client behavior
|
||||
pooling_bytes = np.array(output.data, dtype="float32").tobytes()
|
||||
pt_float32 = output.data.to(dtype=torch.float32)
|
||||
pooling_bytes = np.array(pt_float32, dtype="float32").tobytes()
|
||||
return base64.b64encode(pooling_bytes).decode("utf-8")
|
||||
|
||||
assert_never(encoding_format)
|
||||
|
||||
@ -10,11 +10,15 @@ import torch.nn.functional as F
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import ModelConfig, PoolerConfig
|
||||
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
|
||||
PoolingTensors)
|
||||
from vllm.model_executor.pooling_metadata import ( # noqa: E501
|
||||
PoolingMetadata as V0PoolingMetadata)
|
||||
from vllm.model_executor.pooling_metadata import PoolingTensors
|
||||
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
|
||||
from vllm.transformers_utils.config import (
|
||||
get_cross_encoder_activation_function)
|
||||
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
|
||||
|
||||
PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
|
||||
|
||||
|
||||
class PoolingType(IntEnum):
|
||||
@ -75,15 +79,18 @@ class SimplePooler(nn.Module):
|
||||
|
||||
def get_prompt_lens(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> torch.Tensor:
|
||||
if isinstance(pooling_metadata, V1PoolingMetadata):
|
||||
return pooling_metadata.prompt_lens
|
||||
assert isinstance(hidden_states, torch.Tensor)
|
||||
return PoolingTensors.from_pooling_metadata(
|
||||
pooling_metadata, hidden_states.device).prompt_lens
|
||||
|
||||
def extract_states(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
@ -93,7 +100,7 @@ class SimplePooler(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
pooled_data = self.extract_states(hidden_states, pooling_metadata)
|
||||
@ -106,11 +113,19 @@ class CLSPool(SimplePooler):
|
||||
|
||||
def extract_states(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
||||
|
||||
if isinstance(hidden_states, list):
|
||||
result = []
|
||||
for req_state, prompt_len in zip(hidden_states, prompt_lens):
|
||||
assert prompt_len == req_state.shape[0], \
|
||||
"partial prefill not supported with CLS pooling"
|
||||
result.append(req_state[0])
|
||||
return result
|
||||
|
||||
first_token_flat_indices = torch.zeros_like(prompt_lens)
|
||||
first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
|
||||
return hidden_states[first_token_flat_indices]
|
||||
@ -120,9 +135,12 @@ class LastPool(SimplePooler):
|
||||
|
||||
def extract_states(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
if isinstance(hidden_states, list):
|
||||
return [h[-1] for h in hidden_states]
|
||||
|
||||
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
||||
|
||||
last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
|
||||
@ -133,11 +151,17 @@ class AllPool(SimplePooler):
|
||||
|
||||
def extract_states(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
||||
|
||||
if isinstance(hidden_states, list):
|
||||
for req_state, prompt_len in zip(hidden_states, prompt_lens):
|
||||
assert prompt_len == req_state.shape[0], \
|
||||
"partial prefill not supported with ALL pooling"
|
||||
return hidden_states
|
||||
|
||||
offset = 0
|
||||
pooled_data = list[torch.Tensor]()
|
||||
for prompt_len in prompt_lens:
|
||||
@ -151,11 +175,20 @@ class MeanPool(SimplePooler):
|
||||
|
||||
def extract_states(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
||||
|
||||
if isinstance(hidden_states, list):
|
||||
result = []
|
||||
for req_state, prompt_len in zip(hidden_states, prompt_lens):
|
||||
assert prompt_len == req_state.shape[0], \
|
||||
"partial prefill not supported with mean pooling"
|
||||
result.append(torch.mean(req_state, dim=0,
|
||||
dtype=torch.float32))
|
||||
return result
|
||||
|
||||
# Use float32 for torch.cumsum in MeanPool,
|
||||
# otherwise precision will be lost significantly.
|
||||
cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)
|
||||
@ -184,30 +217,53 @@ class StepPool(SimplePooler):
|
||||
self.step_tag_id = step_tag_id
|
||||
self.returned_token_ids = returned_token_ids
|
||||
|
||||
def get_prompt_token_ids(
|
||||
self,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> list[torch.Tensor]:
|
||||
if isinstance(pooling_metadata, V1PoolingMetadata):
|
||||
return [
|
||||
pooling_metadata.prompt_token_ids[i, :num]
|
||||
for i, num in enumerate(pooling_metadata.prompt_lens)
|
||||
]
|
||||
return [
|
||||
torch.tensor(seq_data_i.prompt_token_ids)
|
||||
for seq_data_i in pooling_metadata.seq_data.values()
|
||||
]
|
||||
|
||||
def extract_states(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
||||
prompt_token_ids = self.get_prompt_token_ids(pooling_metadata)
|
||||
|
||||
pooled_data: list[torch.Tensor] = []
|
||||
|
||||
if isinstance(hidden_states, list):
|
||||
for req_state, prompt_len in zip(hidden_states, prompt_lens):
|
||||
assert prompt_len == req_state.shape[0], \
|
||||
"partial prefill not supported with mean pooling"
|
||||
pooled_data = hidden_states
|
||||
else:
|
||||
offset = 0
|
||||
for prompt_len in prompt_lens:
|
||||
pooled_data_i = hidden_states[offset:offset + prompt_len]
|
||||
offset += prompt_len
|
||||
pooled_data.append(pooled_data_i)
|
||||
|
||||
pooled_data = []
|
||||
returned_token_ids = self.returned_token_ids
|
||||
if returned_token_ids is not None and len(returned_token_ids) > 0:
|
||||
hidden_states = hidden_states[:, returned_token_ids]
|
||||
|
||||
step_tag_id = self.step_tag_id
|
||||
|
||||
offset = 0
|
||||
pooled_data = list[torch.Tensor]()
|
||||
for prompt_len, seq_data_i in zip(prompt_lens,
|
||||
pooling_metadata.seq_data.values()):
|
||||
pooled_data_i = hidden_states[offset:offset + prompt_len]
|
||||
if step_tag_id is not None:
|
||||
token_ids = torch.tensor(seq_data_i.prompt_token_ids)
|
||||
pooled_data_i = pooled_data_i[token_ids == step_tag_id]
|
||||
for data, token_id in zip(pooled_data, prompt_token_ids):
|
||||
if returned_token_ids is not None and len(returned_token_ids) > 0:
|
||||
data = data[:, returned_token_ids]
|
||||
|
||||
offset += prompt_len
|
||||
pooled_data.append(pooled_data_i)
|
||||
if step_tag_id is not None:
|
||||
data = data[token_id == step_tag_id]
|
||||
pooled_data.append(data)
|
||||
|
||||
return pooled_data
|
||||
|
||||
@ -230,10 +286,17 @@ class PoolerHead(nn.Module):
|
||||
else:
|
||||
pooled_data = pooled_data.to(torch.float32)
|
||||
|
||||
dimensions_list = [
|
||||
pooling_param.dimensions
|
||||
for _, pooling_param in pooling_metadata.seq_groups
|
||||
]
|
||||
if isinstance(pooling_metadata, V0PoolingMetadata):
|
||||
dimensions_list = [
|
||||
pooling_param.dimensions
|
||||
for _, pooling_param in pooling_metadata.seq_groups
|
||||
]
|
||||
else:
|
||||
assert isinstance(pooled_data, list)
|
||||
dimensions_list = [
|
||||
pooling_param.dimensions
|
||||
for pooling_param in pooling_metadata.pooling_params
|
||||
]
|
||||
if any(d is not None for d in dimensions_list):
|
||||
# change the output dimension
|
||||
assert len(pooled_data) == len(dimensions_list)
|
||||
@ -325,20 +388,41 @@ class ClassifierPooler(nn.Module):
|
||||
raise NotImplementedError(f"task={config.task!r} is not supported"
|
||||
" with the classification pooler")
|
||||
|
||||
def get_prompt_lens(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> torch.Tensor:
|
||||
if isinstance(pooling_metadata, V1PoolingMetadata):
|
||||
return pooling_metadata.prompt_lens
|
||||
assert isinstance(hidden_states, torch.Tensor)
|
||||
return PoolingTensors.from_pooling_metadata(
|
||||
pooling_metadata, hidden_states.device).prompt_lens
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
"""Pools sentence pair scores from the hidden_states."""
|
||||
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
|
||||
|
||||
prompt_lens = PoolingTensors.from_pooling_metadata(
|
||||
pooling_metadata, hidden_states.device).prompt_lens
|
||||
pooled_data = list[torch.Tensor]()
|
||||
if isinstance(hidden_states, list):
|
||||
for req_state, prompt_len in zip(hidden_states, prompt_lens):
|
||||
assert prompt_len == req_state.shape[0], \
|
||||
"partial prefill not supported with classifier"
|
||||
pooled_data = hidden_states
|
||||
else:
|
||||
offset = 0
|
||||
for prompt_len in prompt_lens:
|
||||
pooled_data_i = hidden_states[offset:offset + prompt_len]
|
||||
offset += prompt_len
|
||||
pooled_data.append(pooled_data_i)
|
||||
|
||||
offset = 0
|
||||
pooled_data_lst = []
|
||||
for prompt_len in prompt_lens:
|
||||
pooled_data_i = hidden_states[offset:offset + prompt_len]
|
||||
for pooled_data_i in pooled_data:
|
||||
|
||||
if self.pooler is not None:
|
||||
final_shape_tensor = self.pooler(pooled_data_i)
|
||||
@ -346,7 +430,6 @@ class ClassifierPooler(nn.Module):
|
||||
final_shape_tensor = self.classifier(pooled_data_i)
|
||||
|
||||
pooled_data_lst.append(final_shape_tensor)
|
||||
offset += prompt_len
|
||||
|
||||
pooled_output = torch.stack(pooled_data_lst)
|
||||
|
||||
|
||||
@ -446,8 +446,8 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
|
||||
softmax=False)
|
||||
|
||||
|
||||
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||
SupportsQuant):
|
||||
class BertForSequenceClassification(nn.Module, SupportsV0Only,
|
||||
SupportsCrossEncoding, SupportsQuant):
|
||||
"""A model that uses Bert to provide embedding functionalities.
|
||||
|
||||
This class encapsulates the BertModel and provides an interface for
|
||||
|
||||
@ -21,7 +21,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
|
||||
from .interfaces import SupportsCrossEncoding
|
||||
from .interfaces import SupportsCrossEncoding, SupportsV0Only
|
||||
from .utils import WeightsMapper, maybe_prefix
|
||||
|
||||
|
||||
@ -270,7 +270,8 @@ class ModernBertPooler(nn.Module):
|
||||
return pooled_output
|
||||
|
||||
|
||||
class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
|
||||
SupportsCrossEncoding):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
@ -375,7 +375,12 @@ class Qwen3ForSequenceClassification(nn.Module, SupportsLoRA,
|
||||
) -> Optional[PoolerOutput]:
|
||||
hidden_states = self._pooler.extract_states(hidden_states,
|
||||
pooling_metadata)
|
||||
logits, _ = self.score(hidden_states)
|
||||
|
||||
if isinstance(hidden_states, list):
|
||||
logits = [self.score(state)[0] for state in hidden_states]
|
||||
else:
|
||||
logits, _ = self.score(hidden_states)
|
||||
|
||||
pooled_data = self._pooler.head(logits, pooling_metadata)
|
||||
pooled_outputs = [
|
||||
self._pooler.build_output(data.squeeze(-1)) for data in pooled_data
|
||||
|
||||
@ -5,6 +5,8 @@ from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import msgspec
|
||||
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
@ -23,6 +25,7 @@ class PoolingParams(
|
||||
|
||||
dimensions: Optional[int] = None
|
||||
additional_data: Optional[Any] = None
|
||||
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
|
||||
|
||||
def clone(self) -> "PoolingParams":
|
||||
"""Returns a deep copy of the PoolingParams instance."""
|
||||
@ -52,3 +55,7 @@ class PoolingParams(
|
||||
return (f"PoolingParams("
|
||||
f"dimensions={self.dimensions}, "
|
||||
f"additional_metadata={self.additional_data})")
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert self.output_kind == RequestOutputKind.FINAL_ONLY,\
|
||||
"For pooling output_kind has to be FINAL_ONLY"
|
||||
|
||||
@ -146,7 +146,8 @@ class KVCacheManager:
|
||||
# Prefix caching is disabled or
|
||||
# When the request requires prompt logprobs, we skip prefix caching.
|
||||
if (not self.enable_caching
|
||||
or request.sampling_params.prompt_logprobs is not None):
|
||||
or (request.sampling_params is not None
|
||||
and request.sampling_params.prompt_logprobs is not None)):
|
||||
return self.create_empty_block_list(), 0
|
||||
|
||||
# The block hashes for the request may already be computed
|
||||
|
||||
@ -14,6 +14,7 @@ if TYPE_CHECKING:
|
||||
KVConnectorMetadata)
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.request import Request
|
||||
|
||||
@ -26,7 +27,8 @@ class NewRequestData:
|
||||
mm_inputs: list[MultiModalKwargs]
|
||||
mm_hashes: list[str]
|
||||
mm_positions: list[PlaceholderRange]
|
||||
sampling_params: SamplingParams
|
||||
sampling_params: Optional[SamplingParams]
|
||||
pooling_params: Optional[PoolingParams]
|
||||
block_ids: tuple[list[int], ...]
|
||||
num_computed_tokens: int
|
||||
lora_request: Optional[LoRARequest]
|
||||
@ -44,6 +46,7 @@ class NewRequestData:
|
||||
mm_hashes=request.mm_hashes,
|
||||
mm_positions=request.mm_positions,
|
||||
sampling_params=request.sampling_params,
|
||||
pooling_params=request.pooling_params,
|
||||
block_ids=block_ids,
|
||||
num_computed_tokens=request.num_computed_tokens,
|
||||
lora_request=request.lora_request,
|
||||
|
||||
@ -402,6 +402,15 @@ class Scheduler(SchedulerInterface):
|
||||
< num_new_tokens):
|
||||
num_new_tokens = (
|
||||
self.scheduler_config.long_prefill_token_threshold)
|
||||
|
||||
# chunked prefill has to be enabled explicitly to allow
|
||||
# pooling requests to be chunked
|
||||
if not self.scheduler_config.chunked_prefill_enabled and \
|
||||
num_new_tokens > token_budget:
|
||||
self.waiting.popleft()
|
||||
skipped_waiting_requests.appendleft(request)
|
||||
continue
|
||||
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
assert num_new_tokens > 0
|
||||
|
||||
@ -707,6 +716,7 @@ class Scheduler(SchedulerInterface):
|
||||
logprobs = model_runner_output.logprobs
|
||||
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
pooler_outputs = model_runner_output.pooler_output
|
||||
|
||||
new_running: list[Request] = []
|
||||
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
|
||||
@ -724,7 +734,8 @@ class Scheduler(SchedulerInterface):
|
||||
continue
|
||||
|
||||
req_index = model_runner_output.req_id_to_index[req_id]
|
||||
generated_token_ids = sampled_token_ids[req_index]
|
||||
generated_token_ids = sampled_token_ids[
|
||||
req_index] if sampled_token_ids else []
|
||||
|
||||
scheduled_spec_token_ids = (
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
|
||||
@ -776,8 +787,17 @@ class Scheduler(SchedulerInterface):
|
||||
del new_token_ids[num_new:] # Trim new tokens if needed.
|
||||
break
|
||||
|
||||
pooler_output = None
|
||||
if pooler_outputs:
|
||||
pooler_output = pooler_outputs[req_index]
|
||||
stopped = check_stop(request, self.max_model_len,
|
||||
pooler_output)
|
||||
if stopped:
|
||||
kv_transfer_params = self._free_request(request)
|
||||
|
||||
# Extract sample logprobs if needed.
|
||||
if request.sampling_params.logprobs is not None and logprobs:
|
||||
if request.sampling_params is not None \
|
||||
and request.sampling_params.logprobs is not None and logprobs:
|
||||
# NOTE: once we support N tokens per step (spec decode),
|
||||
# the outer lists can be of length > 1.
|
||||
new_logprobs = logprobs.slice(req_index, req_index + 1)
|
||||
@ -802,7 +822,8 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
# Get prompt logprobs for this request.
|
||||
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
|
||||
if new_token_ids or kv_transfer_params:
|
||||
if new_token_ids or pooler_output is not None \
|
||||
or kv_transfer_params:
|
||||
|
||||
# Add EngineCoreOutput for this Request.
|
||||
outputs[request.client_index].append(
|
||||
@ -812,6 +833,7 @@ class Scheduler(SchedulerInterface):
|
||||
finish_reason=request.get_finished_reason(),
|
||||
new_logprobs=new_logprobs,
|
||||
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
|
||||
pooling_output=pooler_output,
|
||||
stop_reason=request.stop_reason,
|
||||
events=request.take_events(),
|
||||
kv_transfer_params=kv_transfer_params,
|
||||
|
||||
@ -1,15 +1,28 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
|
||||
def check_stop(request: Request, max_model_len: int) -> bool:
|
||||
def check_stop(request: Request,
|
||||
max_model_len: int,
|
||||
pooler_output: Optional[torch.Tensor] = None) -> bool:
|
||||
if (request.num_tokens >= max_model_len
|
||||
or request.num_output_tokens >= request.max_tokens):
|
||||
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
return True
|
||||
|
||||
if request.pooling_params:
|
||||
if pooler_output is not None:
|
||||
request.status = RequestStatus.FINISHED_STOPPED
|
||||
return True
|
||||
return False
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
assert sampling_params is not None
|
||||
last_token_id = request.output_token_ids[-1]
|
||||
if (not sampling_params.ignore_eos
|
||||
and last_token_id == request.eos_token_id):
|
||||
|
||||
@ -7,10 +7,12 @@ from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
|
||||
@ -50,7 +52,8 @@ class EngineCoreRequest(
|
||||
mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]]
|
||||
mm_hashes: Optional[list[str]]
|
||||
mm_placeholders: Optional[list[PlaceholderRange]]
|
||||
sampling_params: SamplingParams
|
||||
sampling_params: Optional[SamplingParams]
|
||||
pooling_params: Optional[PoolingParams]
|
||||
eos_token_id: Optional[int]
|
||||
arrival_time: float
|
||||
lora_request: Optional[LoRARequest]
|
||||
@ -104,6 +107,8 @@ class EngineCoreOutput(
|
||||
new_logprobs: Optional[LogprobsLists] = None
|
||||
new_prompt_logprobs_tensors: Optional[LogprobsTensors] = None
|
||||
|
||||
pooling_output: Optional[torch.Tensor] = None
|
||||
|
||||
finish_reason: Optional[FinishReason] = None
|
||||
stop_reason: Union[int, str, None] = None
|
||||
events: Optional[list[EngineCoreEvent]] = None
|
||||
|
||||
@ -17,7 +17,7 @@ from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@ -228,8 +228,7 @@ class AsyncLLM(EngineClient):
|
||||
if self.errored:
|
||||
raise EngineDeadError()
|
||||
|
||||
assert isinstance(params, SamplingParams), \
|
||||
"Pooling is not supported in V1"
|
||||
is_pooling = isinstance(params, PoolingParams)
|
||||
|
||||
# Create a new output collector for the request.
|
||||
queue = RequestOutputCollector(output_kind=params.output_kind)
|
||||
@ -240,7 +239,7 @@ class AsyncLLM(EngineClient):
|
||||
tokenization_kwargs, trace_headers, prompt_adapter_request,
|
||||
priority, data_parallel_rank)
|
||||
|
||||
if params.n == 1:
|
||||
if is_pooling or params.n == 1:
|
||||
await self._add_request(request, prompt_str, None, 0, queue)
|
||||
return queue
|
||||
|
||||
@ -443,7 +442,7 @@ class AsyncLLM(EngineClient):
|
||||
stat_logger.record(scheduler_stats=scheduler_stats,
|
||||
iteration_stats=iteration_stats)
|
||||
|
||||
def encode(
|
||||
async def encode(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
pooling_params: PoolingParams,
|
||||
@ -451,8 +450,75 @@ class AsyncLLM(EngineClient):
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
):
|
||||
raise ValueError("Not Supported on V1 yet.")
|
||||
) -> AsyncGenerator[PoolingRequestOutput, None]:
|
||||
"""
|
||||
Main function called by the API server to kick off a request
|
||||
* 1) Making an AsyncStream corresponding to the Request.
|
||||
* 2) Processing the Input.
|
||||
* 3) Adding the Request to the EngineCore (separate process).
|
||||
|
||||
A separate output_handler loop runs in a background AsyncIO task,
|
||||
pulling outputs from EngineCore and putting them into the
|
||||
per-request AsyncStream.
|
||||
|
||||
The caller of generate() iterates the returned AsyncGenerator,
|
||||
returning the RequestOutput back to the caller.
|
||||
"""
|
||||
|
||||
try:
|
||||
# We start the output_handler on the first call to generate() so
|
||||
# we can call __init__ before the event loop, which enables us
|
||||
# to handle startup failure gracefully in the OpenAI server.
|
||||
self._run_output_handler()
|
||||
|
||||
q = await self.add_request(
|
||||
request_id,
|
||||
prompt,
|
||||
pooling_params,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
# The output_handler task pushes items into the queue.
|
||||
# This task pulls from the queue and yields to caller.
|
||||
finished = False
|
||||
while not finished:
|
||||
# Note: drain queue without await if possible (avoids
|
||||
# task switching under load which helps performance).
|
||||
out = q.get_nowait() or await q.get()
|
||||
assert isinstance(out, PoolingRequestOutput)
|
||||
# Note: both OutputProcessor and EngineCore handle their
|
||||
# own request cleanup based on finished.
|
||||
finished = out.finished
|
||||
yield out
|
||||
|
||||
# If the request is disconnected by the client, generate()
|
||||
# is cancelled. So, we abort the request if we end up here.
|
||||
except asyncio.CancelledError:
|
||||
await self.abort(request_id)
|
||||
if self.log_requests:
|
||||
logger.info("Request %s aborted.", request_id)
|
||||
raise
|
||||
|
||||
# Engine is dead. Do not abort since we shut down.
|
||||
except EngineDeadError:
|
||||
if self.log_requests:
|
||||
logger.info("Request %s failed (engine dead).", request_id)
|
||||
raise
|
||||
|
||||
# Request validation error.
|
||||
except ValueError:
|
||||
if self.log_requests:
|
||||
logger.info("Request %s failed (bad request).", request_id)
|
||||
raise
|
||||
|
||||
# Unexpected error in the generate() task (possibly recoverable).
|
||||
except Exception as e:
|
||||
await self.abort(request_id)
|
||||
if self.log_requests:
|
||||
logger.info("Request %s failed.", request_id)
|
||||
raise EngineGenerateError() from e
|
||||
|
||||
async def get_vllm_config(self) -> VllmConfig:
|
||||
return self.vllm_config
|
||||
|
||||
@ -60,7 +60,6 @@ class EngineCore:
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
executor_fail_callback: Optional[Callable] = None):
|
||||
assert vllm_config.model_config.runner_type != "pooling"
|
||||
|
||||
# plugins need to be loaded at the engine/scheduler level too
|
||||
from vllm.plugins import load_general_plugins
|
||||
|
||||
@ -50,6 +50,8 @@ class IncrementalDetokenizer:
|
||||
request: EngineCoreRequest,
|
||||
) -> "IncrementalDetokenizer":
|
||||
|
||||
assert request.sampling_params is not None
|
||||
|
||||
if tokenizer is None:
|
||||
# No tokenizer => skipping detokenization.
|
||||
return IncrementalDetokenizer()
|
||||
@ -70,6 +72,7 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
|
||||
|
||||
# Stop strings
|
||||
params = request.sampling_params
|
||||
assert params is not None
|
||||
self.stop = stop = params.stop
|
||||
self.include_stop_str_in_output = params.include_stop_str_in_output
|
||||
|
||||
@ -164,6 +167,7 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
||||
super().__init__(request)
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
assert sampling_params is not None
|
||||
|
||||
self.request_id = request.request_id
|
||||
self.skip_special_tokens = sampling_params.skip_special_tokens
|
||||
@ -245,20 +249,20 @@ class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
||||
super().__init__(request)
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
params = request.sampling_params
|
||||
assert params is not None
|
||||
|
||||
# Metadata for incremental detokenization.
|
||||
self.tokens, self.prefix_offset, self.read_offset = (
|
||||
convert_prompt_ids_to_tokens(
|
||||
tokenizer=tokenizer,
|
||||
prompt_ids=request.prompt_token_ids,
|
||||
skip_special_tokens=request.sampling_params.
|
||||
skip_special_tokens,
|
||||
skip_special_tokens=params.skip_special_tokens,
|
||||
))
|
||||
|
||||
self.token_ids.extend(request.prompt_token_ids)
|
||||
self.prompt_len = len(request.prompt_token_ids)
|
||||
|
||||
params = request.sampling_params
|
||||
self.skip_special_tokens = params.skip_special_tokens
|
||||
self.spaces_between_special_tokens = (
|
||||
params.spaces_between_special_tokens)
|
||||
|
||||
@ -15,7 +15,7 @@ from vllm.inputs import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@ -221,7 +221,7 @@ class LLMEngine:
|
||||
# Add the request to EngineCore.
|
||||
self.engine_core.add_request(child_request)
|
||||
|
||||
def step(self) -> list[RequestOutput]:
|
||||
def step(self) -> Union[list[RequestOutput], list[PoolingRequestOutput]]:
|
||||
|
||||
if self.should_execute_dummy_batch:
|
||||
self.should_execute_dummy_batch = False
|
||||
|
||||
@ -38,6 +38,7 @@ class LogprobsProcessor:
|
||||
tokenizer: Optional[AnyTokenizer],
|
||||
request: EngineCoreRequest,
|
||||
) -> "LogprobsProcessor":
|
||||
assert request.sampling_params is not None
|
||||
num_logprobs = request.sampling_params.logprobs
|
||||
num_prompt_logprobs = request.sampling_params.prompt_logprobs
|
||||
return cls(
|
||||
|
||||
@ -4,9 +4,12 @@
|
||||
import asyncio
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
import torch
|
||||
|
||||
from vllm.outputs import (CompletionOutput, PoolingOutput,
|
||||
PoolingRequestOutput, RequestOutput)
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
@ -29,20 +32,22 @@ class RequestOutputCollector:
|
||||
|
||||
def __init__(self, output_kind: RequestOutputKind):
|
||||
self.aggregate = output_kind == RequestOutputKind.DELTA
|
||||
self.output: Optional[Union[RequestOutput, Exception]] = None
|
||||
self.output: Optional[Union[RequestOutput, PoolingRequestOutput,
|
||||
Exception]] = None
|
||||
self.ready = asyncio.Event()
|
||||
|
||||
def put(self, output: Union[RequestOutput, Exception]) -> None:
|
||||
def put(self, output: Union[RequestOutput, PoolingRequestOutput,
|
||||
Exception]) -> None:
|
||||
"""Non-blocking put operation."""
|
||||
if self.output is None or isinstance(output, Exception):
|
||||
self.output = output
|
||||
self.ready.set()
|
||||
elif isinstance(self.output, RequestOutput):
|
||||
elif isinstance(self.output, (RequestOutput, PoolingRequestOutput)):
|
||||
# This ensures that request outputs with different request indexes
|
||||
# (if n > 1) do not override each other.
|
||||
self.output.add(output, aggregate=self.aggregate)
|
||||
|
||||
async def get(self) -> RequestOutput:
|
||||
async def get(self) -> Union[RequestOutput, PoolingRequestOutput]:
|
||||
"""Get operation blocks on put event."""
|
||||
while (output := self.output) is None:
|
||||
await self.ready.wait()
|
||||
@ -52,7 +57,8 @@ class RequestOutputCollector:
|
||||
raise output
|
||||
return output
|
||||
|
||||
def get_nowait(self) -> Optional[RequestOutput]:
|
||||
def get_nowait(
|
||||
self) -> Optional[Union[RequestOutput, PoolingRequestOutput]]:
|
||||
"""Non-blocking get operation."""
|
||||
output = self.output
|
||||
if output is not None:
|
||||
@ -66,7 +72,7 @@ class RequestOutputCollector:
|
||||
@dataclass
|
||||
class OutputProcessorOutput:
|
||||
|
||||
request_outputs: list[RequestOutput]
|
||||
request_outputs: list[Union[RequestOutput, PoolingRequestOutput]]
|
||||
reqs_to_abort: list[str]
|
||||
|
||||
|
||||
@ -81,8 +87,8 @@ class RequestState:
|
||||
output_kind: RequestOutputKind,
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: list[int],
|
||||
logprobs_processor: LogprobsProcessor,
|
||||
detokenizer: IncrementalDetokenizer,
|
||||
logprobs_processor: Optional[LogprobsProcessor],
|
||||
detokenizer: Optional[IncrementalDetokenizer],
|
||||
max_tokens_param: Optional[int],
|
||||
arrival_time: float,
|
||||
queue: Optional[RequestOutputCollector],
|
||||
@ -116,27 +122,39 @@ class RequestState:
|
||||
queue: Optional[RequestOutputCollector],
|
||||
log_stats: bool,
|
||||
) -> "RequestState":
|
||||
if not request.sampling_params.detokenize:
|
||||
tokenizer = None
|
||||
|
||||
if sampling_params := request.sampling_params:
|
||||
if not sampling_params.detokenize:
|
||||
tokenizer = None
|
||||
output_kind = sampling_params.output_kind
|
||||
logprobs_processor = LogprobsProcessor.from_new_request(
|
||||
tokenizer=tokenizer,
|
||||
request=request,
|
||||
)
|
||||
detokenizer = IncrementalDetokenizer.from_new_request(
|
||||
tokenizer=tokenizer,
|
||||
request=request,
|
||||
)
|
||||
max_tokens_param = sampling_params.max_tokens
|
||||
else:
|
||||
logprobs_processor = None
|
||||
detokenizer = None
|
||||
max_tokens_param = None
|
||||
assert request.pooling_params is not None
|
||||
output_kind = request.pooling_params.output_kind
|
||||
|
||||
return cls(
|
||||
request_id=request.request_id,
|
||||
parent_req=parent_req,
|
||||
request_index=request_index,
|
||||
lora_name=(request.lora_request.name
|
||||
if request.lora_request is not None else None),
|
||||
output_kind=request.sampling_params.output_kind,
|
||||
output_kind=output_kind,
|
||||
prompt=prompt,
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
logprobs_processor=LogprobsProcessor.from_new_request(
|
||||
tokenizer=tokenizer,
|
||||
request=request,
|
||||
),
|
||||
detokenizer=IncrementalDetokenizer.from_new_request(
|
||||
tokenizer=tokenizer,
|
||||
request=request,
|
||||
),
|
||||
max_tokens_param=(request.sampling_params.max_tokens if
|
||||
request.sampling_params is not None else None),
|
||||
logprobs_processor=logprobs_processor,
|
||||
detokenizer=detokenizer,
|
||||
max_tokens_param=max_tokens_param,
|
||||
arrival_time=request.arrival_time,
|
||||
queue=queue,
|
||||
log_stats=log_stats,
|
||||
@ -145,11 +163,12 @@ class RequestState:
|
||||
def make_request_output(
|
||||
self,
|
||||
new_token_ids: list[int],
|
||||
pooling_output: Optional[torch.Tensor],
|
||||
finish_reason: Optional[FinishReason],
|
||||
stop_reason: Union[int, str, None],
|
||||
kv_transfer_params: Optional[dict[str, Any]] = None,
|
||||
num_cached_tokens: int = 0,
|
||||
) -> Optional[RequestOutput]:
|
||||
) -> Optional[Union[RequestOutput, PoolingRequestOutput]]:
|
||||
|
||||
finished = finish_reason is not None
|
||||
final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
|
||||
@ -158,15 +177,20 @@ class RequestState:
|
||||
# Only the final output is required in FINAL_ONLY mode.
|
||||
return None
|
||||
|
||||
completion_output = self._new_completion_output(
|
||||
new_token_ids, finish_reason, stop_reason)
|
||||
|
||||
request_id = self.request_id
|
||||
if pooling_output is not None:
|
||||
return self._new_request_output(
|
||||
request_id, [self._new_pooling_output(pooling_output)],
|
||||
finished)
|
||||
|
||||
output = self._new_completion_output(new_token_ids, finish_reason,
|
||||
stop_reason)
|
||||
|
||||
if self.parent_req is None:
|
||||
outputs = [completion_output]
|
||||
outputs = [output]
|
||||
else:
|
||||
request_id, outputs, finished = self.parent_req.get_outputs(
|
||||
request_id, completion_output)
|
||||
request_id, output)
|
||||
if not outputs:
|
||||
return None
|
||||
|
||||
@ -176,12 +200,21 @@ class RequestState:
|
||||
def _new_request_output(
|
||||
self,
|
||||
request_id: str,
|
||||
outputs: list[CompletionOutput],
|
||||
outputs: Union[list[CompletionOutput], list[PoolingOutput]],
|
||||
finished: bool,
|
||||
kv_transfer_params: Optional[dict[str, Any]] = None,
|
||||
num_cached_tokens: int = 0,
|
||||
) -> RequestOutput:
|
||||
) -> Union[RequestOutput, PoolingRequestOutput]:
|
||||
|
||||
if isinstance(outputs[0], PoolingOutput):
|
||||
assert len(outputs) == 1
|
||||
return PoolingRequestOutput(
|
||||
request_id=request_id,
|
||||
outputs=outputs[0],
|
||||
prompt_token_ids=self.prompt_token_ids,
|
||||
finished=finished,
|
||||
)
|
||||
assert self.logprobs_processor is not None
|
||||
if self.output_kind == RequestOutputKind.DELTA:
|
||||
# Side effect: logprobs processor forgets prompt logprobs
|
||||
prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
|
||||
@ -193,7 +226,7 @@ class RequestState:
|
||||
prompt=self.prompt,
|
||||
prompt_token_ids=self.prompt_token_ids,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
outputs=outputs,
|
||||
outputs=cast(list[CompletionOutput], outputs),
|
||||
finished=finished,
|
||||
kv_transfer_params=kv_transfer_params,
|
||||
num_cached_tokens=num_cached_tokens,
|
||||
@ -206,6 +239,8 @@ class RequestState:
|
||||
stop_reason: Union[int, str, None],
|
||||
) -> CompletionOutput:
|
||||
|
||||
assert self.detokenizer is not None
|
||||
assert self.logprobs_processor is not None
|
||||
finished = finish_reason is not None
|
||||
delta = self.output_kind == RequestOutputKind.DELTA
|
||||
|
||||
@ -228,6 +263,13 @@ class RequestState:
|
||||
finish_reason=str(finish_reason) if finished else None,
|
||||
stop_reason=stop_reason if finished else None)
|
||||
|
||||
def _new_pooling_output(
|
||||
self,
|
||||
pooling_output: torch.Tensor,
|
||||
) -> PoolingOutput:
|
||||
|
||||
return PoolingOutput(data=pooling_output)
|
||||
|
||||
|
||||
class OutputProcessor:
|
||||
"""Process EngineCoreOutputs into RequestOutputs."""
|
||||
@ -326,7 +368,8 @@ class OutputProcessor:
|
||||
within the loop below.
|
||||
"""
|
||||
|
||||
request_outputs: list[RequestOutput] = []
|
||||
request_outputs: Union[list[RequestOutput],
|
||||
list[PoolingRequestOutput]] = []
|
||||
reqs_to_abort: list[str] = []
|
||||
for engine_core_output in engine_core_outputs:
|
||||
req_id = engine_core_output.request_id
|
||||
@ -341,25 +384,31 @@ class OutputProcessor:
|
||||
iteration_stats)
|
||||
|
||||
new_token_ids = engine_core_output.new_token_ids
|
||||
pooling_output = engine_core_output.pooling_output
|
||||
finish_reason = engine_core_output.finish_reason
|
||||
stop_reason = engine_core_output.stop_reason
|
||||
kv_transfer_params = engine_core_output.kv_transfer_params
|
||||
num_cached_tokens = engine_core_output.num_cached_tokens
|
||||
req_state.is_prefilling = False
|
||||
|
||||
# 2) Detokenize the token ids into text and perform stop checks.
|
||||
stop_string = req_state.detokenizer.update(
|
||||
new_token_ids, finish_reason == FinishReason.STOP)
|
||||
if stop_string:
|
||||
finish_reason = FinishReason.STOP
|
||||
stop_reason = stop_string
|
||||
if pooling_output is None:
|
||||
assert req_state.detokenizer is not None
|
||||
assert req_state.logprobs_processor is not None
|
||||
# 2) Detokenize the token ids into text and perform stop checks.
|
||||
stop_string = req_state.detokenizer.update(
|
||||
new_token_ids, finish_reason == FinishReason.STOP)
|
||||
if stop_string:
|
||||
finish_reason = FinishReason.STOP
|
||||
stop_reason = stop_string
|
||||
|
||||
# 3) Compute sample and prompt logprobs for request, if required.
|
||||
req_state.logprobs_processor.update_from_output(engine_core_output)
|
||||
# 3) Compute sample and prompt logprobs for request,
|
||||
# if required.
|
||||
req_state.logprobs_processor.update_from_output(
|
||||
engine_core_output)
|
||||
|
||||
# 4) Create and handle RequestOutput objects.
|
||||
if request_output := req_state.make_request_output(
|
||||
new_token_ids, finish_reason, stop_reason,
|
||||
new_token_ids, pooling_output, finish_reason, stop_reason,
|
||||
kv_transfer_params, num_cached_tokens):
|
||||
if req_state.queue is not None:
|
||||
# AsyncLLM: put into queue for handling by generate().
|
||||
|
||||
@ -136,8 +136,8 @@ class Processor:
|
||||
Should raise ValueError if unsupported for API Server.
|
||||
"""
|
||||
|
||||
if not isinstance(params, SamplingParams):
|
||||
raise ValueError("V1 does not yet support Pooling models.")
|
||||
if isinstance(params, PoolingParams):
|
||||
return
|
||||
|
||||
self._validate_logprobs(params)
|
||||
self._validate_sampling_params(params, lora_request)
|
||||
@ -263,18 +263,22 @@ class Processor:
|
||||
if encoder_inputs is not None:
|
||||
raise NotImplementedError
|
||||
|
||||
assert isinstance(params, SamplingParams)
|
||||
# TODO: can we avoid cloning here in multiproc case?
|
||||
sampling_params = params.clone()
|
||||
# If unset max tokens, then generate up to the max_model_len.
|
||||
if sampling_params.max_tokens is None:
|
||||
sampling_params.max_tokens = (
|
||||
self.model_config.max_model_len -
|
||||
len(decoder_inputs["prompt_token_ids"]))
|
||||
sampling_params.update_from_generation_config(
|
||||
self.generation_config_fields, eos_token_id)
|
||||
sampling_params.update_from_tokenizer(
|
||||
self.tokenizer.get_lora_tokenizer(lora_request))
|
||||
sampling_params = None
|
||||
pooling_params = None
|
||||
if isinstance(params, SamplingParams):
|
||||
# TODO: can we avoid cloning here in multiproc case?
|
||||
sampling_params = params.clone()
|
||||
# If unset max tokens, then generate up to the max_model_len.
|
||||
if sampling_params.max_tokens is None:
|
||||
sampling_params.max_tokens = (
|
||||
self.model_config.max_model_len -
|
||||
len(decoder_inputs["prompt_token_ids"]))
|
||||
sampling_params.update_from_generation_config(
|
||||
self.generation_config_fields, eos_token_id)
|
||||
sampling_params.update_from_tokenizer(
|
||||
self.tokenizer.get_lora_tokenizer(lora_request))
|
||||
else:
|
||||
pooling_params = params.clone()
|
||||
|
||||
# Multimodal related.
|
||||
sorted_mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] = None
|
||||
@ -331,6 +335,7 @@ class Processor:
|
||||
mm_hashes=sorted_mm_hashes,
|
||||
mm_placeholders=sorted_mm_positions,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=pooling_params,
|
||||
eos_token_id=eos_token_id,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
|
||||
@ -481,8 +481,9 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
finished_request.num_prompt_tokens)
|
||||
self.histogram_num_generation_tokens_request.observe(
|
||||
finished_request.num_generation_tokens)
|
||||
self.histogram_max_tokens_request.observe(
|
||||
finished_request.max_tokens_param)
|
||||
if finished_request.max_tokens_param:
|
||||
self.histogram_max_tokens_request.observe(
|
||||
finished_request.max_tokens_param)
|
||||
|
||||
if self.gauge_lora_info is not None:
|
||||
running_lora_adapters = \
|
||||
|
||||
@ -106,7 +106,6 @@ class IterationStats:
|
||||
|
||||
self.num_generation_tokens += num_new_generation_tokens
|
||||
if is_prefilling:
|
||||
assert num_new_generation_tokens > 0
|
||||
self.num_prompt_tokens += prompt_len
|
||||
|
||||
first_token_latency = self._time_since(req_stats.arrival_time)
|
||||
|
||||
@ -101,6 +101,9 @@ class ModelRunnerOutput:
|
||||
# [prompt_len]
|
||||
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]]
|
||||
|
||||
# [num_reqs, hidden_size]
|
||||
pooler_output: list[Optional[torch.Tensor]]
|
||||
|
||||
# [req_ids]
|
||||
finished_sending: Optional[set[str]] = None
|
||||
finished_recving: Optional[set[str]] = None
|
||||
@ -112,5 +115,6 @@ EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
finished_sending=None,
|
||||
finished_recving=None)
|
||||
|
||||
0
vllm/v1/pool/__init__.py
Normal file
0
vllm/v1/pool/__init__.py
Normal file
16
vllm/v1/pool/metadata.py
Normal file
16
vllm/v1/pool/metadata.py
Normal file
@ -0,0 +1,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.pooling_params import PoolingParams
|
||||
|
||||
|
||||
@dataclass
|
||||
class PoolingMetadata:
|
||||
"""Tensors for pooling."""
|
||||
|
||||
prompt_lens: torch.Tensor
|
||||
prompt_token_ids: Optional[torch.Tensor]
|
||||
pooling_params: list[PoolingParams]
|
||||
@ -5,6 +5,7 @@ import enum
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
|
||||
@ -25,7 +26,8 @@ class Request:
|
||||
multi_modal_inputs: Optional[list[MultiModalKwargs]],
|
||||
multi_modal_hashes: Optional[list[str]],
|
||||
multi_modal_placeholders: Optional[list[PlaceholderRange]],
|
||||
sampling_params: SamplingParams,
|
||||
sampling_params: Optional[SamplingParams],
|
||||
pooling_params: Optional[PoolingParams],
|
||||
eos_token_id: Optional[int],
|
||||
client_index: int = 0,
|
||||
lora_request: Optional["LoRARequest"] = None,
|
||||
@ -35,18 +37,35 @@ class Request:
|
||||
self.request_id = request_id
|
||||
self.client_index = client_index
|
||||
self.sampling_params = sampling_params
|
||||
self.pooling_params = pooling_params
|
||||
# Because of LoRA, the eos token id can be different for each request.
|
||||
self.eos_token_id = eos_token_id
|
||||
self.lora_request = lora_request
|
||||
self.structured_output_request = structured_output_request
|
||||
|
||||
self.status = (RequestStatus.WAITING_FOR_FSM
|
||||
if sampling_params.guided_decoding is not None else
|
||||
RequestStatus.WAITING)
|
||||
self.status = RequestStatus.WAITING
|
||||
if sampling_params and sampling_params.guided_decoding is not None:
|
||||
self.status = RequestStatus.WAITING_FOR_FSM
|
||||
self.events: list[EngineCoreEvent] = []
|
||||
self.stop_reason: Union[int, str, None] = None
|
||||
assert sampling_params.max_tokens is not None
|
||||
self.max_tokens = sampling_params.max_tokens
|
||||
|
||||
# P/D: Connector-specific KV transfer parameters.
|
||||
self.kv_transfer_params: Optional[dict[str, Any]] = None
|
||||
|
||||
if pooling_params is not None:
|
||||
self.max_tokens = 1
|
||||
elif sampling_params is not None:
|
||||
assert sampling_params.max_tokens is not None
|
||||
self.max_tokens = sampling_params.max_tokens
|
||||
if sampling_params.guided_decoding is not None:
|
||||
self.status = RequestStatus.WAITING_FOR_FSM
|
||||
|
||||
if sampling_params.extra_args is not None:
|
||||
self.kv_transfer_params = \
|
||||
sampling_params.extra_args.get("kv_transfer_params")
|
||||
else:
|
||||
raise ValueError(
|
||||
"sampling_params and pooling_params can't both be unset")
|
||||
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
self.num_prompt_tokens = len(self.prompt_token_ids)
|
||||
@ -63,11 +82,6 @@ class Request:
|
||||
self.num_encoder_inputs = len(self.mm_inputs)
|
||||
self.has_encoder_inputs = self.num_encoder_inputs > 0
|
||||
|
||||
# P/D: Connector-specific KV transfer parameters.
|
||||
kv_params = (None if sampling_params.extra_args is None else
|
||||
sampling_params.extra_args.get("kv_transfer_params"))
|
||||
self.kv_transfer_params: Optional[dict[str, Any]] = kv_params
|
||||
|
||||
# Sanity check
|
||||
assert len(self.mm_inputs) == len(self.mm_positions)
|
||||
if self.mm_hashes:
|
||||
@ -98,10 +112,12 @@ class Request:
|
||||
multi_modal_hashes=request.mm_hashes,
|
||||
multi_modal_placeholders=request.mm_placeholders,
|
||||
sampling_params=request.sampling_params,
|
||||
pooling_params=request.pooling_params,
|
||||
eos_token_id=request.eos_token_id,
|
||||
lora_request=request.lora_request,
|
||||
structured_output_request=StructuredOutputRequest(
|
||||
sampling_params=request.sampling_params),
|
||||
sampling_params=request.sampling_params) \
|
||||
if request.sampling_params else None,
|
||||
cache_salt=request.cache_salt,
|
||||
)
|
||||
|
||||
@ -141,7 +157,8 @@ class Request:
|
||||
|
||||
@property
|
||||
def use_structured_output(self) -> bool:
|
||||
return self.sampling_params.guided_decoding is not None
|
||||
return self.sampling_params is not None and \
|
||||
self.sampling_params.guided_decoding is not None
|
||||
|
||||
def record_event(
|
||||
self,
|
||||
|
||||
@ -62,13 +62,15 @@ class StructuredOutputManager:
|
||||
return
|
||||
|
||||
if TYPE_CHECKING:
|
||||
assert request.sampling_params.guided_decoding is not None
|
||||
assert request.sampling_params is not None and \
|
||||
request.sampling_params.guided_decoding is not None
|
||||
|
||||
# Initialize the backend the first time it is needed.
|
||||
#
|
||||
# NOTE: We only support a single backend. We do NOT support different
|
||||
# backends on a per-request basis in V1 (for now, anyway...).
|
||||
if self.backend is None:
|
||||
assert request.sampling_params is not None
|
||||
backend = request.sampling_params.guided_decoding.backend
|
||||
vocab_size = self.vllm_config.model_config.get_vocab_size()
|
||||
if backend == "xgrammar":
|
||||
|
||||
@ -10,9 +10,11 @@ import torch
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.utils import swap_dict_values
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.utils import copy_slice
|
||||
from vllm.v1.worker.block_table import MultiGroupBlockTable
|
||||
@ -27,7 +29,8 @@ class CachedRequestState:
|
||||
prompt_token_ids: list[int]
|
||||
mm_inputs: list[MultiModalKwargs]
|
||||
mm_positions: list[PlaceholderRange]
|
||||
sampling_params: SamplingParams
|
||||
sampling_params: Optional[SamplingParams]
|
||||
pooling_params: Optional[PoolingParams]
|
||||
generator: Optional[torch.Generator]
|
||||
|
||||
block_ids: tuple[list[int], ...]
|
||||
@ -226,6 +229,8 @@ class InputBatch:
|
||||
# This is updated each time the batch constituents change.
|
||||
self.sampling_metadata = self._make_sampling_metadata()
|
||||
|
||||
self.pooling_params: dict[str, PoolingParams] = {}
|
||||
|
||||
@property
|
||||
def req_ids(self) -> list[str]:
|
||||
# None elements should only be present transiently
|
||||
@ -269,77 +274,83 @@ class InputBatch:
|
||||
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
|
||||
self.block_table.add_row(request.block_ids, req_index)
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
if sampling_params.sampling_type == SamplingType.GREEDY:
|
||||
# Avoid later division by zero.
|
||||
self.temperature_cpu[req_index] = -1.0
|
||||
self.greedy_reqs.add(req_id)
|
||||
else:
|
||||
self.temperature_cpu[req_index] = sampling_params.temperature
|
||||
self.random_reqs.add(req_id)
|
||||
if sampling_params := request.sampling_params:
|
||||
if sampling_params.sampling_type == SamplingType.GREEDY:
|
||||
# Avoid later division by zero.
|
||||
self.temperature_cpu[req_index] = -1.0
|
||||
self.greedy_reqs.add(req_id)
|
||||
else:
|
||||
self.temperature_cpu[req_index] = sampling_params.temperature
|
||||
self.random_reqs.add(req_id)
|
||||
|
||||
self.top_p_cpu[req_index] = sampling_params.top_p
|
||||
if sampling_params.top_p < 1:
|
||||
self.top_p_reqs.add(req_id)
|
||||
top_k = sampling_params.top_k
|
||||
if 0 < top_k < self.vocab_size:
|
||||
self.top_k_reqs.add(req_id)
|
||||
else:
|
||||
top_k = self.vocab_size
|
||||
self.top_k_cpu[req_index] = top_k
|
||||
self.min_p_cpu[req_index] = sampling_params.min_p
|
||||
self.frequency_penalties_cpu[
|
||||
req_index] = sampling_params.frequency_penalty
|
||||
if sampling_params.min_p > _SAMPLING_EPS:
|
||||
self.min_p_reqs.add(req_id)
|
||||
if sampling_params.frequency_penalty != 0.0:
|
||||
self.frequency_penalties_reqs.add(req_id)
|
||||
self.presence_penalties_cpu[
|
||||
req_index] = sampling_params.presence_penalty
|
||||
if sampling_params.presence_penalty != 0.0:
|
||||
self.presence_penalties_reqs.add(req_id)
|
||||
self.repetition_penalties_cpu[
|
||||
req_index] = sampling_params.repetition_penalty
|
||||
if sampling_params.repetition_penalty != 1.0:
|
||||
self.repetition_penalties_reqs.add(req_id)
|
||||
if sampling_params.min_tokens:
|
||||
self.min_tokens[req_index] = (sampling_params.min_tokens,
|
||||
sampling_params.all_stop_token_ids)
|
||||
self.top_p_cpu[req_index] = sampling_params.top_p
|
||||
if sampling_params.top_p < 1:
|
||||
self.top_p_reqs.add(req_id)
|
||||
top_k = sampling_params.top_k
|
||||
if 0 < top_k < self.vocab_size:
|
||||
self.top_k_reqs.add(req_id)
|
||||
else:
|
||||
top_k = self.vocab_size
|
||||
self.top_k_cpu[req_index] = top_k
|
||||
self.min_p_cpu[req_index] = sampling_params.min_p
|
||||
self.frequency_penalties_cpu[
|
||||
req_index] = sampling_params.frequency_penalty
|
||||
if sampling_params.min_p > _SAMPLING_EPS:
|
||||
self.min_p_reqs.add(req_id)
|
||||
if sampling_params.frequency_penalty != 0.0:
|
||||
self.frequency_penalties_reqs.add(req_id)
|
||||
self.presence_penalties_cpu[
|
||||
req_index] = sampling_params.presence_penalty
|
||||
if sampling_params.presence_penalty != 0.0:
|
||||
self.presence_penalties_reqs.add(req_id)
|
||||
self.repetition_penalties_cpu[
|
||||
req_index] = sampling_params.repetition_penalty
|
||||
if sampling_params.repetition_penalty != 1.0:
|
||||
self.repetition_penalties_reqs.add(req_id)
|
||||
if sampling_params.min_tokens:
|
||||
self.min_tokens[req_index] = (
|
||||
sampling_params.min_tokens,
|
||||
sampling_params.all_stop_token_ids)
|
||||
|
||||
# NOTE(woosuk): self.generators should not include the requests that
|
||||
# do not have their own generator.
|
||||
if request.generator is not None:
|
||||
self.generators[req_index] = request.generator
|
||||
# NOTE(woosuk): self.generators should not include the requests that
|
||||
# do not have their own generator.
|
||||
if request.generator is not None:
|
||||
self.generators[req_index] = request.generator
|
||||
|
||||
if sampling_params.logprobs is not None:
|
||||
self.num_logprobs[req_id] = sampling_params.logprobs
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs
|
||||
if sampling_params.logit_bias is not None:
|
||||
self.logit_bias[req_index] = sampling_params.logit_bias
|
||||
if sampling_params.logprobs is not None:
|
||||
self.num_logprobs[req_id] = sampling_params.logprobs
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
self.num_prompt_logprobs[
|
||||
req_id] = sampling_params.prompt_logprobs
|
||||
if sampling_params.logit_bias is not None:
|
||||
self.logit_bias[req_index] = sampling_params.logit_bias
|
||||
|
||||
if sampling_params.allowed_token_ids:
|
||||
self.has_allowed_token_ids.add(req_id)
|
||||
if self.allowed_token_ids_mask_cpu_tensor is None:
|
||||
# Lazy allocation for this tensor, which can be large.
|
||||
if sampling_params.allowed_token_ids:
|
||||
self.has_allowed_token_ids.add(req_id)
|
||||
if self.allowed_token_ids_mask_cpu_tensor is None:
|
||||
# Lazy allocation for this tensor, which can be large.
|
||||
# False means we don't fill with -inf.
|
||||
self.allowed_token_ids_mask = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
self.vocab_size,
|
||||
dtype=torch.bool,
|
||||
device=self.device)
|
||||
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
self.vocab_size,
|
||||
dtype=torch.bool,
|
||||
device="cpu")
|
||||
self.allowed_token_ids_mask_cpu_tensor[req_index] = True
|
||||
# False means we don't fill with -inf.
|
||||
self.allowed_token_ids_mask = torch.zeros(self.max_num_reqs,
|
||||
self.vocab_size,
|
||||
dtype=torch.bool,
|
||||
device=self.device)
|
||||
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
self.vocab_size,
|
||||
dtype=torch.bool,
|
||||
device="cpu")
|
||||
self.allowed_token_ids_mask_cpu_tensor[req_index] = True
|
||||
# False means we don't fill with -inf.
|
||||
self.allowed_token_ids_mask_cpu_tensor[req_index][
|
||||
sampling_params.allowed_token_ids] = False
|
||||
self.allowed_token_ids_mask_cpu_tensor[req_index][
|
||||
sampling_params.allowed_token_ids] = False
|
||||
|
||||
if sampling_params.bad_words_token_ids:
|
||||
self.bad_words_token_ids[
|
||||
req_index] = sampling_params.bad_words_token_ids
|
||||
if sampling_params.bad_words_token_ids:
|
||||
self.bad_words_token_ids[
|
||||
req_index] = sampling_params.bad_words_token_ids
|
||||
else:
|
||||
assert request.pooling_params is not None
|
||||
self.pooling_params[req_id] = request.pooling_params
|
||||
|
||||
# Add request lora ID
|
||||
if request.lora_request:
|
||||
@ -392,6 +403,7 @@ class InputBatch:
|
||||
# False means we don't fill with -inf.
|
||||
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
|
||||
self.bad_words_token_ids.pop(req_index, None)
|
||||
self.pooling_params.pop(req_id, None)
|
||||
return req_index
|
||||
|
||||
def swap_states(self, i1: int, i2: int) -> None:
|
||||
@ -602,6 +614,25 @@ class InputBatch:
|
||||
bad_words_token_ids=self.bad_words_token_ids,
|
||||
)
|
||||
|
||||
@property
|
||||
def pooling_metadata(self) -> PoolingMetadata:
|
||||
if len(self.pooling_params) == 0:
|
||||
pooling_params = []
|
||||
else:
|
||||
# Note, for now this assumes that all request in the batch
|
||||
# are either sampling or pooling requests
|
||||
assert len(self.req_ids) == len(self.pooling_params)
|
||||
pooling_params = [
|
||||
self.pooling_params[req_id] for req_id in self.req_ids
|
||||
]
|
||||
|
||||
return PoolingMetadata(
|
||||
prompt_lens=torch.from_numpy(
|
||||
self.num_prompt_tokens[:self.num_reqs]).to(self.device),
|
||||
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
|
||||
pooling_params=pooling_params,
|
||||
)
|
||||
|
||||
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
||||
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
|
||||
prompt_token_ids_cpu_tensor = torch.empty(
|
||||
|
||||
@ -36,6 +36,7 @@ from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
@ -51,6 +52,7 @@ from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
|
||||
SlidingWindowSpec)
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
|
||||
ModelRunnerOutput)
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.rejection_sampler import RejectionSampler
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
@ -119,6 +121,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
cache_config.cache_dtype]
|
||||
|
||||
self.is_multimodal_model = model_config.is_multimodal_model
|
||||
self.is_pooling_model = model_config.pooler_config is not None
|
||||
self.max_model_len = model_config.max_model_len
|
||||
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
||||
self.max_num_reqs = scheduler_config.max_num_seqs
|
||||
@ -394,7 +397,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
for new_req_data in scheduler_output.scheduled_new_reqs:
|
||||
req_id = new_req_data.req_id
|
||||
sampling_params = new_req_data.sampling_params
|
||||
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
||||
pooling_params = new_req_data.pooling_params
|
||||
if sampling_params and \
|
||||
sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
||||
generator = torch.Generator(device=self.device)
|
||||
generator.manual_seed(sampling_params.seed)
|
||||
else:
|
||||
@ -406,6 +411,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
mm_inputs=new_req_data.mm_inputs,
|
||||
mm_positions=new_req_data.mm_positions,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=pooling_params,
|
||||
generator=generator,
|
||||
block_ids=new_req_data.block_ids,
|
||||
num_computed_tokens=new_req_data.num_computed_tokens,
|
||||
@ -563,7 +569,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> tuple[dict[str, Any], bool, torch.Tensor,
|
||||
Optional[SpecDecodeMetadata]]:
|
||||
Optional[SpecDecodeMetadata], np.ndarray]:
|
||||
"""
|
||||
:return: tuple[
|
||||
attn_metadata: layer-to-attention_metadata mapping,
|
||||
@ -750,7 +756,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.set_active_loras(self.input_batch, num_scheduled_tokens)
|
||||
|
||||
return (attn_metadata, attention_cuda_graphs, logits_indices,
|
||||
spec_decode_metadata)
|
||||
spec_decode_metadata, num_scheduled_tokens)
|
||||
|
||||
def _compute_cascade_attn_prefix_len(
|
||||
self,
|
||||
@ -1197,6 +1203,51 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
dtype=torch.int32)
|
||||
return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
|
||||
|
||||
def _pool(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
num_scheduled_tokens: int,
|
||||
num_scheduled_tokens_np: np.ndarray,
|
||||
finished_sending: Optional[set[str]],
|
||||
finished_recving: Optional[set[str]],
|
||||
) -> ModelRunnerOutput:
|
||||
assert self.input_batch.num_reqs ==\
|
||||
len(self.input_batch.pooling_params), \
|
||||
"Either all or none of the requests in" \
|
||||
" a batch must be pooling request"
|
||||
|
||||
extracted_hidden_states = list(
|
||||
torch.split(hidden_states[:num_scheduled_tokens],
|
||||
num_scheduled_tokens_np.tolist()))
|
||||
|
||||
pooling_metadata = self.input_batch.pooling_metadata
|
||||
|
||||
raw_pooler_output = self.model.pooler(
|
||||
hidden_states=extracted_hidden_states,
|
||||
pooling_metadata=pooling_metadata)
|
||||
|
||||
pooler_output: list[Optional[torch.Tensor]] = []
|
||||
seq_lens = self.seq_lens[:self.input_batch.num_reqs]
|
||||
for raw_output, seq_len, prompt_len in zip(
|
||||
raw_pooler_output, seq_lens, pooling_metadata.prompt_lens):
|
||||
|
||||
if seq_len == prompt_len:
|
||||
pooler_output.append(raw_output.data.cpu())
|
||||
else:
|
||||
pooler_output.append(None)
|
||||
|
||||
return ModelRunnerOutput(
|
||||
req_ids=self.input_batch.req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=[],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=pooler_output,
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
@ -1214,7 +1265,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# Prepare the decoder inputs.
|
||||
(attn_metadata, attention_cuda_graphs, logits_indices,
|
||||
spec_decode_metadata) = (self._prepare_inputs(scheduler_output))
|
||||
spec_decode_metadata,
|
||||
num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output))
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
if (self.use_cuda_graph
|
||||
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
|
||||
@ -1284,7 +1336,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# compiled with full CUDA graphs, we have to skip them entirely.
|
||||
skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
|
||||
|
||||
# Run the decoder.
|
||||
# Run the model.
|
||||
# Use persistent buffers for CUDA graphs.
|
||||
with set_forward_context(
|
||||
attn_metadata,
|
||||
@ -1326,6 +1378,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
all_gather_group=get_tp_group())
|
||||
logits = None
|
||||
else:
|
||||
if self.input_batch.pooling_params:
|
||||
return self._pool(hidden_states, num_scheduled_tokens,
|
||||
num_scheduled_tokens_np, finished_sending,
|
||||
finished_recving)
|
||||
|
||||
sample_hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
if broadcast_pp_output:
|
||||
@ -1541,6 +1598,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
spec_token_ids=spec_token_ids,
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
pooler_output=[],
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
)
|
||||
@ -1802,7 +1860,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self,
|
||||
num_tokens: int,
|
||||
capture_attn_cudagraph: bool = False,
|
||||
) -> torch.Tensor:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
# Padding for DP
|
||||
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
|
||||
@ -1899,7 +1957,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.drafter.dummy_run(num_tokens)
|
||||
|
||||
logit_indices = np.cumsum(num_scheduled_tokens) - 1
|
||||
return hidden_states[logit_indices]
|
||||
return hidden_states, hidden_states[logit_indices]
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_sampler_run(
|
||||
@ -1978,6 +2036,48 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
return sampler_output
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_pooler_run(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
|
||||
num_tokens = hidden_states.shape[0]
|
||||
max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
num_reqs = min(num_tokens, max_num_reqs)
|
||||
min_tokens_per_req = num_tokens // num_reqs
|
||||
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
|
||||
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
|
||||
assert sum(num_scheduled_tokens_list) == num_tokens
|
||||
assert len(num_scheduled_tokens_list) == num_reqs
|
||||
|
||||
hidden_states_list = list(
|
||||
torch.split(hidden_states, num_scheduled_tokens_list))
|
||||
|
||||
req_num_tokens = num_tokens // num_reqs
|
||||
|
||||
dummy_metadata = PoolingMetadata(
|
||||
prompt_lens=torch.tensor([h.shape[0] for h in hidden_states_list],
|
||||
device=self.device),
|
||||
prompt_token_ids=torch.zeros((num_reqs, req_num_tokens),
|
||||
dtype=torch.int32,
|
||||
device=self.device),
|
||||
pooling_params=[PoolingParams()] * num_reqs)
|
||||
|
||||
try:
|
||||
pooler_output = self.model.pooler(hidden_states=hidden_states_list,
|
||||
pooling_metadata=dummy_metadata)
|
||||
except RuntimeError as e:
|
||||
if 'out of memory' in str(e):
|
||||
raise RuntimeError(
|
||||
"CUDA out of memory occurred when warming up pooler with "
|
||||
f"{num_reqs} dummy requests. Please try lowering "
|
||||
"`max_num_seqs` or `gpu_memory_utilization` when "
|
||||
"initializing the engine.") from e
|
||||
else:
|
||||
raise e
|
||||
return pooler_output
|
||||
|
||||
def profile_run(self) -> None:
|
||||
# Profile with multimodal encoder & encoder cache.
|
||||
# TODO: handle encoder-decoder models once we support them.
|
||||
@ -2048,13 +2148,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Cache the dummy encoder outputs.
|
||||
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
|
||||
|
||||
hidden_states = self._dummy_run(self.max_num_tokens)
|
||||
hidden_states, last_hidden_states \
|
||||
= self._dummy_run(self.max_num_tokens)
|
||||
if get_pp_group().is_last_rank:
|
||||
sampler_output = self._dummy_sampler_run(hidden_states)
|
||||
if self.is_pooling_model:
|
||||
output = self._dummy_pooler_run(hidden_states)
|
||||
else:
|
||||
output = self._dummy_sampler_run(last_hidden_states)
|
||||
else:
|
||||
sampler_output = None
|
||||
output = None
|
||||
self._sync_device()
|
||||
del hidden_states, sampler_output
|
||||
del hidden_states, output
|
||||
self.encoder_cache.clear()
|
||||
gc.collect()
|
||||
|
||||
|
||||
@ -273,9 +273,14 @@ class Worker(WorkerBase):
|
||||
if get_pp_group().is_last_rank:
|
||||
max_num_reqs = min(self.scheduler_config.max_num_seqs,
|
||||
self.scheduler_config.max_num_batched_tokens)
|
||||
self.model_runner._dummy_sampler_run(
|
||||
hidden_states=self.model_runner._dummy_run(
|
||||
num_tokens=max_num_reqs))
|
||||
|
||||
hidden_states, last_hidden_states = \
|
||||
self.model_runner._dummy_run(num_tokens=max_num_reqs)
|
||||
if self.model_runner.is_pooling_model:
|
||||
self.model_runner._dummy_pooler_run(hidden_states)
|
||||
else:
|
||||
self.model_runner._dummy_sampler_run(
|
||||
hidden_states=last_hidden_states)
|
||||
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
|
||||
@ -231,6 +231,7 @@ class InputBatch:
|
||||
self.block_table.add_row(request.block_ids, req_index)
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
assert sampling_params is not None, "pooling requests not supported yet"
|
||||
if sampling_params.sampling_type == SamplingType.GREEDY:
|
||||
# Avoid later division by zero.
|
||||
self.temperature_cpu[req_index] = -1.0
|
||||
|
||||
@ -386,6 +386,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
req_ids_to_add: list[str] = []
|
||||
# Add new requests to the cached states.
|
||||
for new_req_data in scheduler_output.scheduled_new_reqs:
|
||||
assert new_req_data.sampling_params is not None,\
|
||||
"Pooling is not supported in TPU yet"
|
||||
req_id = new_req_data.req_id
|
||||
sampling_params = new_req_data.sampling_params
|
||||
|
||||
@ -395,6 +397,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
mm_inputs=new_req_data.mm_inputs,
|
||||
mm_positions=new_req_data.mm_positions,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=None,
|
||||
generator=None,
|
||||
block_ids=new_req_data.block_ids,
|
||||
num_computed_tokens=new_req_data.num_computed_tokens,
|
||||
@ -956,6 +959,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
spec_token_ids=None,
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
pooler_output=[],
|
||||
)
|
||||
|
||||
# Check there are no new graphs compiled - all the graphs should be
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user