[Bugfix] Fix SHM cache initialization (#26427)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-10-09 17:48:04 +08:00 committed by GitHub
parent dc7976dd9f
commit 4bdf7ac593
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 357 additions and 417 deletions

View File

@ -113,15 +113,17 @@ def mock_serving_setup():
mock_engine.generate.reset_mock() mock_engine.generate.reset_mock()
mock_engine.add_lora.reset_mock() mock_engine.add_lora.reset_mock()
mock_model_config = MockModelConfig() mock_engine.model_config = MockModelConfig()
mock_engine.processor = MagicMock()
mock_engine.io_processor = MagicMock()
models = OpenAIServingModels( models = OpenAIServingModels(
engine_client=mock_engine, engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS, base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config,
) )
serving_completion = OpenAIServingCompletion( serving_completion = OpenAIServingCompletion(
mock_engine, mock_model_config, models, request_logger=None mock_engine, models, request_logger=None
) )
serving_completion._process_inputs = AsyncMock( serving_completion._process_inputs = AsyncMock(

View File

@ -245,17 +245,13 @@ class MockModelConfig:
return self.diff_sampling_param or {} return self.diff_sampling_param or {}
def _build_serving_chat( def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
engine: AsyncLLM, model_config: MockModelConfig
) -> OpenAIServingChat:
models = OpenAIServingModels( models = OpenAIServingModels(
engine_client=engine, engine_client=engine,
base_model_paths=BASE_MODEL_PATHS, base_model_paths=BASE_MODEL_PATHS,
model_config=model_config,
) )
serving_chat = OpenAIServingChat( serving_chat = OpenAIServingChat(
engine, engine,
model_config,
models, models,
response_role="assistant", response_role="assistant",
chat_template=CHAT_TEMPLATE, chat_template=CHAT_TEMPLATE,
@ -280,18 +276,17 @@ def _build_serving_chat(
@dataclass @dataclass
class MockEngine: class MockEngine:
async def get_model_config(self): model_config: MockModelConfig = field(default_factory=MockModelConfig)
return MockModelConfig() processor: MagicMock = field(default_factory=MagicMock)
io_processor: MagicMock = field(default_factory=MagicMock)
async def _async_serving_chat_init(): async def _async_serving_chat_init():
engine = MockEngine() engine = MockEngine()
model_config = await engine.get_model_config()
models = OpenAIServingModels(engine, model_config, BASE_MODEL_PATHS) models = OpenAIServingModels(engine, BASE_MODEL_PATHS)
serving_completion = OpenAIServingChat( serving_completion = OpenAIServingChat(
engine, engine,
model_config,
models, models,
response_role="assistant", response_role="assistant",
chat_template=CHAT_TEMPLATE, chat_template=CHAT_TEMPLATE,
@ -311,8 +306,11 @@ async def test_serving_chat_returns_correct_model_name():
mock_engine = MagicMock(spec=AsyncLLM) mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.processor = MagicMock()
mock_engine.io_processor = MagicMock()
serving_chat = _build_serving_chat(mock_engine, MockModelConfig()) serving_chat = _build_serving_chat(mock_engine)
messages = [{"role": "user", "content": "what is 1+1?"}] messages = [{"role": "user", "content": "what is 1+1?"}]
async def return_model_name(*args): async def return_model_name(*args):
@ -338,8 +336,11 @@ async def test_serving_chat_should_set_correct_max_tokens():
mock_engine = MagicMock(spec=AsyncLLM) mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.processor = MagicMock()
mock_engine.io_processor = MagicMock()
serving_chat = _build_serving_chat(mock_engine, MockModelConfig()) serving_chat = _build_serving_chat(mock_engine)
req = ChatCompletionRequest( req = ChatCompletionRequest(
model=MODEL_NAME, model=MODEL_NAME,
@ -368,9 +369,12 @@ async def test_serving_chat_should_set_correct_max_tokens():
mock_engine = MagicMock(spec=AsyncLLM) mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = mock_model_config
mock_engine.processor = MagicMock()
mock_engine.io_processor = MagicMock()
# Initialize the serving chat # Initialize the serving chat
serving_chat = _build_serving_chat(mock_engine, mock_model_config) serving_chat = _build_serving_chat(mock_engine)
# Test Case 1: No max_tokens specified in request # Test Case 1: No max_tokens specified in request
req = ChatCompletionRequest( req = ChatCompletionRequest(
@ -410,9 +414,12 @@ async def test_serving_chat_should_set_correct_max_tokens():
mock_engine = MagicMock(spec=AsyncLLM) mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = mock_model_config
mock_engine.processor = MagicMock()
mock_engine.io_processor = MagicMock()
# Initialize the serving chat # Initialize the serving chat
serving_chat = _build_serving_chat(mock_engine, mock_model_config) serving_chat = _build_serving_chat(mock_engine)
# Test case 1: No max_tokens specified, defaults to context_window # Test case 1: No max_tokens specified, defaults to context_window
req = ChatCompletionRequest( req = ChatCompletionRequest(
@ -453,9 +460,12 @@ async def test_serving_chat_could_load_correct_generation_config():
mock_engine = MagicMock(spec=AsyncLLM) mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = mock_model_config
mock_engine.processor = MagicMock()
mock_engine.io_processor = MagicMock()
# Initialize the serving chat # Initialize the serving chat
serving_chat = _build_serving_chat(mock_engine, mock_model_config) serving_chat = _build_serving_chat(mock_engine)
req = ChatCompletionRequest( req = ChatCompletionRequest(
model=MODEL_NAME, model=MODEL_NAME,
@ -496,8 +506,11 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
mock_engine = MagicMock(spec=AsyncLLM) mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = mock_model_config
mock_engine.processor = MagicMock()
mock_engine.io_processor = MagicMock()
serving_chat = _build_serving_chat(mock_engine, mock_model_config) serving_chat = _build_serving_chat(mock_engine)
# Test cache_salt # Test cache_salt
req = ChatCompletionRequest( req = ChatCompletionRequest(

View File

@ -22,10 +22,12 @@ def serving() -> OpenAIServing:
model_config = Mock(spec=ModelConfig) model_config = Mock(spec=ModelConfig)
model_config.max_model_len = 32768 model_config.max_model_len = 32768
models = Mock(spec=OpenAIServingModels) models = Mock(spec=OpenAIServingModels)
models.model_config = model_config
models.processor = Mock()
models.io_processor = Mock()
serving = OpenAIServing( serving = OpenAIServing(
engine_client=engine_client, engine_client=engine_client,
model_config=model_config,
models=models, models=models,
request_logger=None, request_logger=None,
) )

View File

@ -25,15 +25,17 @@ LORA_UNLOADING_SUCCESS_MESSAGE = (
async def _async_serving_models_init() -> OpenAIServingModels: async def _async_serving_models_init() -> OpenAIServingModels:
mock_model_config = MagicMock(spec=ModelConfig)
mock_engine_client = MagicMock(spec=EngineClient) mock_engine_client = MagicMock(spec=EngineClient)
# Set the max_model_len attribute to avoid missing attribute # Set the max_model_len attribute to avoid missing attribute
mock_model_config = MagicMock(spec=ModelConfig)
mock_model_config.max_model_len = 2048 mock_model_config.max_model_len = 2048
mock_engine_client.model_config = mock_model_config
mock_engine_client.processor = MagicMock()
mock_engine_client.io_processor = MagicMock()
serving_models = OpenAIServingModels( serving_models = OpenAIServingModels(
engine_client=mock_engine_client, engine_client=mock_engine_client,
base_model_paths=BASE_MODEL_PATHS, base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config,
lora_modules=None, lora_modules=None,
) )
await serving_models.init_static_loras() await serving_models.init_static_loras()

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from unittest.mock import AsyncMock, MagicMock from unittest.mock import MagicMock
import pytest import pytest
import pytest_asyncio import pytest_asyncio
@ -70,11 +70,14 @@ class TestInitializeToolSessions:
"""Create a real OpenAIServingResponses instance for testing""" """Create a real OpenAIServingResponses instance for testing"""
# Create minimal mocks for required dependencies # Create minimal mocks for required dependencies
engine_client = MagicMock() engine_client = MagicMock()
engine_client.get_model_config = AsyncMock()
model_config = MagicMock() model_config = MagicMock()
model_config.hf_config.model_type = "test" model_config.hf_config.model_type = "test"
model_config.get_diff_sampling_param.return_value = {} model_config.get_diff_sampling_param.return_value = {}
engine_client.model_config = model_config
engine_client.processor = MagicMock()
engine_client.io_processor = MagicMock()
models = MagicMock() models = MagicMock()
@ -83,7 +86,6 @@ class TestInitializeToolSessions:
# Create the actual instance # Create the actual instance
instance = OpenAIServingResponses( instance = OpenAIServingResponses(
engine_client=engine_client, engine_client=engine_client,
model_config=model_config,
models=models, models=models,
request_logger=None, request_logger=None,
chat_template=None, chat_template=None,
@ -132,18 +134,20 @@ class TestValidateGeneratorInput:
"""Create a real OpenAIServingResponses instance for testing""" """Create a real OpenAIServingResponses instance for testing"""
# Create minimal mocks for required dependencies # Create minimal mocks for required dependencies
engine_client = MagicMock() engine_client = MagicMock()
engine_client.get_model_config = AsyncMock()
model_config = MagicMock() model_config = MagicMock()
model_config.hf_config.model_type = "test" model_config.hf_config.model_type = "test"
model_config.get_diff_sampling_param.return_value = {} model_config.get_diff_sampling_param.return_value = {}
engine_client.model_config = model_config
engine_client.processor = MagicMock()
engine_client.io_processor = MagicMock()
models = MagicMock() models = MagicMock()
# Create the actual instance # Create the actual instance
instance = OpenAIServingResponses( instance = OpenAIServingResponses(
engine_client=engine_client, engine_client=engine_client,
model_config=model_config,
models=models, models=models,
request_logger=None, request_logger=None,
chat_template=None, chat_template=None,

View File

@ -7,6 +7,7 @@ from vllm.config import ModelConfig
from vllm.inputs import zip_enc_dec_prompts from vllm.inputs import zip_enc_dec_prompts
from vllm.inputs.parse import parse_raw_prompts from vllm.inputs.parse import parse_raw_prompts
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
pytestmark = pytest.mark.cpu_test pytestmark = pytest.mark.cpu_test
@ -106,7 +107,8 @@ def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs):
) )
def test_preprocessor_text_no_mm_inputs(model_id, prompt): def test_preprocessor_text_no_mm_inputs(model_id, prompt):
model_config = ModelConfig(model=model_id) model_config = ModelConfig(model=model_id)
input_preprocessor = InputPreprocessor(model_config) tokenizer = init_tokenizer_from_configs(model_config)
input_preprocessor = InputPreprocessor(model_config, tokenizer)
with pytest.raises(ValueError, match="does not support multimodal inputs"): with pytest.raises(ValueError, match="does not support multimodal inputs"):
input_preprocessor.preprocess(prompt) input_preprocessor.preprocess(prompt)
@ -127,8 +129,8 @@ def test_preprocessor_text_no_mm_inputs(model_id, prompt):
) )
def test_preprocessor_always_mm_code_path(model_id, prompt): def test_preprocessor_always_mm_code_path(model_id, prompt):
model_config = ModelConfig(model=model_id) model_config = ModelConfig(model=model_id)
input_preprocessor = InputPreprocessor(model_config) tokenizer = init_tokenizer_from_configs(model_config)
tokenizer = input_preprocessor.tokenizer input_preprocessor = InputPreprocessor(model_config, tokenizer)
# HF processor adds sep token # HF processor adds sep token
sep_token_id = tokenizer.vocab[tokenizer.sep_token] sep_token_id = tokenizer.vocab[tokenizer.sep_token]

View File

@ -65,7 +65,7 @@ def _mk_processor(
device_config=DeviceConfig(device="cpu"), device_config=DeviceConfig(device="cpu"),
) )
return Processor(vllm_config) return Processor(vllm_config, tokenizer=None)
def test_multi_modal_uuids_length_mismatch_raises(monkeypatch): def test_multi_modal_uuids_length_mismatch_raises(monkeypatch):

View File

@ -459,7 +459,7 @@ def test_all_logprobs(example_prompts):
results_logprobs_all = runner.llm.generate( results_logprobs_all = runner.llm.generate(
example_prompts, sampling_params=sampling_params_logprobs_all example_prompts, sampling_params=sampling_params_logprobs_all
) )
vocab_size = runner.llm.llm_engine.get_model_config().get_vocab_size() vocab_size = runner.llm.llm_engine.model_config.get_vocab_size()
for i in range(len(results_logprobs_all)): for i in range(len(results_logprobs_all)):
logprobs = results_logprobs_all[i].outputs[0].logprobs logprobs = results_logprobs_all[i].outputs[0].logprobs

View File

@ -186,7 +186,7 @@ async def run_vllm_async(
engine_args, engine_args,
disable_frontend_multiprocessing=disable_frontend_multiprocessing, disable_frontend_multiprocessing=disable_frontend_multiprocessing,
) as llm: ) as llm:
model_config = await llm.get_model_config() model_config = llm.model_config
assert all( assert all(
model_config.max_model_len model_config.max_model_len
>= (request.prompt_len + request.expected_output_len) >= (request.prompt_len + request.expected_output_len)

View File

@ -1,26 +1,23 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Iterable, Mapping from collections.abc import AsyncGenerator, Iterable, Mapping
from typing import Any, Optional, Union from typing import Any, Optional, Union
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
from vllm.inputs.data import PromptType, TokensPrompt from vllm.inputs.data import PromptType
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors.interface import IOProcessor from vllm.plugins.io_processors import IOProcessor
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import Device, collect_from_async_generator, random_uuid from vllm.utils import Device
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.processor import Processor
logger = init_logger(__name__) logger = init_logger(__name__)
@ -28,6 +25,11 @@ logger = init_logger(__name__)
class EngineClient(ABC): class EngineClient(ABC):
"""Protocol class for Clients to Engine""" """Protocol class for Clients to Engine"""
vllm_config: VllmConfig
model_config: ModelConfig
processor: Processor
io_processor: Optional[IOProcessor]
@property @property
@abstractmethod @abstractmethod
def is_running(self) -> bool: ... def is_running(self) -> bool: ...
@ -61,180 +63,6 @@ class EngineClient(ABC):
"""Generate outputs for a request.""" """Generate outputs for a request."""
... ...
async def beam_search(
self,
prompt: PromptType,
request_id: str,
params: BeamSearchParams,
lora_request: Optional[LoRARequest] = None,
) -> AsyncGenerator[RequestOutput, None]:
beam_width = params.beam_width
max_tokens = params.max_tokens
ignore_eos = params.ignore_eos
temperature = params.temperature
length_penalty = params.length_penalty
include_stop_str_in_output = params.include_stop_str_in_output
preprocessor = await self.get_input_preprocessor()
tokenizer = preprocessor.get_tokenizer()
eos_token_id = tokenizer.eos_token_id
if is_explicit_encoder_decoder_prompt(prompt):
raise NotImplementedError
else:
processed_inputs = preprocessor._prompt_to_llm_inputs(prompt)
if processed_inputs["type"] == "embeds":
raise NotImplementedError
# This is a workaround to fix multimodal beam search; this is a
# bandaid fix for 2 small problems:
# 1. Multi_modal_data on the processed_inputs currently resolves to
# `None`.
# 2. preprocessing above expands the multimodal placeholders. However,
# this happens again in generation, so the double expansion causes
# a mismatch.
# TODO - would be ideal to handle this more gracefully.
if isinstance(prompt, str):
prompt_text = prompt
prompt_token_ids = []
multi_modal_data = None
else:
prompt_text = prompt.get("prompt")
prompt_token_ids = prompt.get("prompt_token_ids", [])
multi_modal_data = prompt.get("multi_modal_data")
mm_processor_kwargs = processed_inputs.get("mm_processor_kwargs")
tokenized_length = len(prompt_token_ids)
sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
beam_search_params = SamplingParams(
logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature,
)
all_beams = [
BeamSearchSequence(
tokens=prompt_token_ids,
cum_logprob=0,
logprobs=[],
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs,
lora_request=lora_request,
)
]
completed = []
for _ in range(max_tokens):
prompts_batch, lora_req_batch = zip(
*[
(
TokensPrompt(
prompt_token_ids=beam.tokens,
multi_modal_data=beam.multi_modal_data,
mm_processor_kwargs=beam.mm_processor_kwargs,
),
beam.lora_request,
)
for beam in all_beams
]
)
tasks = []
request_id = f"beam_search-{random_uuid()}"
for i, (individual_prompt, lora_req) in enumerate(
zip(prompts_batch, lora_req_batch)
):
request_id_item = f"{request_id}-{i}"
task = asyncio.create_task(
collect_from_async_generator(
self.generate(
individual_prompt,
beam_search_params,
request_id_item,
lora_request=lora_req,
)
)
)
tasks.append(task)
output = await asyncio.gather(*tasks)
output = [x[0] for x in output]
new_beams = []
for i, current_beam in enumerate(all_beams):
result = output[i]
if result.outputs[0].logprobs is not None:
logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items():
if token_id == eos_token_id and not ignore_eos:
completed.append(
BeamSearchSequence(
tokens=current_beam.tokens + [token_id]
if include_stop_str_in_output
else current_beam.tokens,
logprobs=current_beam.logprobs + [logprobs],
cum_logprob=current_beam.cum_logprob
+ logprob_obj.logprob,
finish_reason="stop",
stop_reason=eos_token_id,
)
)
else:
new_beams.append(
BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs],
lora_request=current_beam.lora_request,
cum_logprob=current_beam.cum_logprob
+ logprob_obj.logprob,
multi_modal_data=current_beam.multi_modal_data,
mm_processor_kwargs=current_beam.mm_processor_kwargs,
)
)
sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
all_beams = sorted_beams[:beam_width]
completed.extend(all_beams)
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
best_beams = sorted_completed[:beam_width]
for beam in best_beams:
if beam.tokens[-1] == eos_token_id and not ignore_eos:
# Skip the eos token in the text.
tokens = beam.tokens[tokenized_length:-1]
else:
tokens = beam.tokens[tokenized_length:]
beam.text = tokenizer.decode(tokens)
yield RequestOutput(
request_id=request_id,
prompt=prompt_text,
outputs=[
CompletionOutput(
text=beam.text,
cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens[tokenized_length:],
index=i,
logprobs=beam.logprobs,
finish_reason=beam.finish_reason
if beam.finish_reason is not None
else "length",
stop_reason=beam.stop_reason,
)
for (i, beam) in enumerate(best_beams)
],
finished=True,
prompt_token_ids=prompt_token_ids,
prompt_logprobs=None,
)
@abstractmethod @abstractmethod
def encode( def encode(
self, self,
@ -259,29 +87,11 @@ class EngineClient(ABC):
""" """
... ...
@abstractmethod
async def get_vllm_config(self) -> VllmConfig:
"""Get the vllm configuration of the vLLM engine."""
...
@abstractmethod
async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine."""
...
@abstractmethod
async def get_input_preprocessor(self) -> InputPreprocessor:
"""Get the input processor of the vLLM engine."""
...
@abstractmethod @abstractmethod
async def get_tokenizer(self) -> AnyTokenizer: async def get_tokenizer(self) -> AnyTokenizer:
"""Get the tokenizer""" """Get the tokenizer"""
... ...
async def get_io_processor(self) -> IOProcessor:
raise NotImplementedError
@abstractmethod @abstractmethod
async def is_tracing_enabled(self) -> bool: ... async def is_tracing_enabled(self) -> bool: ...

View File

@ -66,7 +66,6 @@ from vllm.outputs import (
RequestOutput, RequestOutput,
ScoringRequestOutput, ScoringRequestOutput,
) )
from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
@ -79,7 +78,6 @@ from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, Device, as_iter, is_list_of from vllm.utils import Counter, Device, as_iter, is_list_of
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.llm_engine import LLMEngine from vllm.v1.engine.llm_engine import LLMEngine
from vllm.v1.engine.processor import Processor
from vllm.v1.sample.logits_processor import LogitsProcessor from vllm.v1.sample.logits_processor import LogitsProcessor
if TYPE_CHECKING: if TYPE_CHECKING:
@ -335,21 +333,13 @@ class LLM:
self.request_counter = Counter() self.request_counter = Counter()
self.default_sampling_params: Union[dict[str, Any], None] = None self.default_sampling_params: Union[dict[str, Any], None] = None
supported_tasks = self.llm_engine.get_supported_tasks() # type: ignore supported_tasks = self.llm_engine.get_supported_tasks()
logger.info("Supported tasks: %s", supported_tasks)
logger.info("Supported_tasks: %s", supported_tasks)
self.supported_tasks = supported_tasks self.supported_tasks = supported_tasks
# Load the Input/Output processor plugin if any self.model_config = self.llm_engine.model_config
io_processor_plugin = self.llm_engine.model_config.io_processor_plugin self.processor = self.llm_engine.processor
self.io_processor = get_io_processor( self.io_processor = self.llm_engine.io_processor
self.llm_engine.vllm_config, io_processor_plugin
)
@property
def model_config(self):
return self.llm_engine.model_config
def get_tokenizer(self) -> AnyTokenizer: def get_tokenizer(self) -> AnyTokenizer:
return self.llm_engine.get_tokenizer() return self.llm_engine.get_tokenizer()
@ -364,18 +354,9 @@ class LLM:
else: else:
self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer) self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer)
def _get_processor(self) -> Processor:
if not hasattr(self, "_processor"):
vllm_config = self.llm_engine.vllm_config
self._processor = Processor(vllm_config)
return self._processor
def get_default_sampling_params(self) -> SamplingParams: def get_default_sampling_params(self) -> SamplingParams:
if self.default_sampling_params is None: if self.default_sampling_params is None:
self.default_sampling_params = ( self.default_sampling_params = self.model_config.get_diff_sampling_param()
self.llm_engine.model_config.get_diff_sampling_param()
)
if self.default_sampling_params: if self.default_sampling_params:
return SamplingParams.from_optional(**self.default_sampling_params) return SamplingParams.from_optional(**self.default_sampling_params)
return SamplingParams() return SamplingParams()
@ -423,7 +404,7 @@ class LLM:
considered legacy and may be deprecated in the future. You should considered legacy and may be deprecated in the future. You should
instead pass them via the `inputs` parameter. instead pass them via the `inputs` parameter.
""" """
model_config = self.llm_engine.model_config model_config = self.model_config
runner_type = model_config.runner_type runner_type = model_config.runner_type
if runner_type != "generate": if runner_type != "generate":
raise ValueError( raise ValueError(
@ -463,7 +444,7 @@ class LLM:
# isn't multimodal, leave the lora as is. # isn't multimodal, leave the lora as is.
if ( if (
lora_config is None lora_config is None
or not self.llm_engine.model_config.is_multimodal_model or not self.model_config.is_multimodal_model
or (lora_config and lora_config.default_mm_loras is None) or (lora_config and lora_config.default_mm_loras is None)
): ):
return lora_request return lora_request
@ -495,15 +476,13 @@ class LLM:
if ( if (
not default_mm_loras not default_mm_loras
or not isinstance(prompt, dict) or not isinstance(prompt, dict)
or "multi_modal_data" not in prompt or not (mm_data := prompt.get("multi_modal_data") or {})
): ):
return lora_request return lora_request
prompt = cast(Union[TextPrompt, TokensPrompt], prompt) intersection = set(
mm_data.keys() # type: ignore
intersection = set(prompt["multi_modal_data"].keys()).intersection( ).intersection(default_mm_loras.keys())
default_mm_loras.keys()
)
if not intersection: if not intersection:
return lora_request return lora_request
if len(intersection) > 1: if len(intersection) > 1:
@ -819,7 +798,7 @@ class LLM:
list_of_messages = [cast(list[ChatCompletionMessageParam], messages)] list_of_messages = [cast(list[ChatCompletionMessageParam], messages)]
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
model_config = self.llm_engine.get_model_config() model_config = self.model_config
resolved_content_format = resolve_chat_template_content_format( resolved_content_format = resolve_chat_template_content_format(
chat_template, chat_template,
tools, tools,
@ -1031,7 +1010,7 @@ class LLM:
pooling_task, pooling_task,
) )
model_config = self.llm_engine.model_config model_config = self.model_config
runner_type = model_config.runner_type runner_type = model_config.runner_type
if runner_type != "pooling": if runner_type != "pooling":
raise ValueError( raise ValueError(
@ -1276,7 +1255,7 @@ class LLM:
pooling_params: Optional[PoolingParams] = None, pooling_params: Optional[PoolingParams] = None,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
) -> list[ScoringRequestOutput]: ) -> list[ScoringRequestOutput]:
model_config = self.llm_engine.model_config model_config = self.model_config
if isinstance(tokenizer, MistralTokenizer): if isinstance(tokenizer, MistralTokenizer):
raise ValueError("Score API is not supported for Mistral tokenizer") raise ValueError("Score API is not supported for Mistral tokenizer")
@ -1287,7 +1266,6 @@ class LLM:
if pooling_params is None: if pooling_params is None:
pooling_params = PoolingParams(task="score") pooling_params = PoolingParams(task="score")
model_config = self.llm_engine.model_config
pooling_params.verify("score", model_config) pooling_params.verify("score", model_config)
pooling_params_list = list[PoolingParams]() pooling_params_list = list[PoolingParams]()
@ -1301,8 +1279,6 @@ class LLM:
input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)] input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]
model_config = self.llm_engine.model_config
for q, d in input_pairs: for q, d in input_pairs:
_, engine_prompt = get_score_prompt( _, engine_prompt = get_score_prompt(
model_config=model_config, model_config=model_config,
@ -1380,7 +1356,7 @@ class LLM:
A list of `ScoringRequestOutput` objects containing the A list of `ScoringRequestOutput` objects containing the
generated scores in the same order as the input prompts. generated scores in the same order as the input prompts.
""" """
model_config = self.llm_engine.model_config model_config = self.model_config
runner_type = model_config.runner_type runner_type = model_config.runner_type
if runner_type != "pooling": if runner_type != "pooling":
raise ValueError( raise ValueError(
@ -1658,8 +1634,7 @@ class LLM:
tokenization_kwargs, tokenization_kwargs,
) )
processor = self._get_processor() engine_request = self.processor.process_inputs(
engine_request = processor.process_inputs(
request_id, request_id,
engine_prompt, engine_prompt,
params, params,

View File

@ -1601,10 +1601,11 @@ def build_app(args: Namespace) -> FastAPI:
async def init_app_state( async def init_app_state(
engine_client: EngineClient, engine_client: EngineClient,
vllm_config: VllmConfig,
state: State, state: State,
args: Namespace, args: Namespace,
) -> None: ) -> None:
vllm_config = engine_client.vllm_config
if args.served_model_name is not None: if args.served_model_name is not None:
served_model_names = args.served_model_name served_model_names = args.served_model_name
else: else:
@ -1622,11 +1623,9 @@ async def init_app_state(
state.engine_client = engine_client state.engine_client = engine_client
state.log_stats = not args.disable_log_stats state.log_stats = not args.disable_log_stats
state.vllm_config = vllm_config state.vllm_config = vllm_config
model_config = vllm_config.model_config
supported_tasks = await engine_client.get_supported_tasks() supported_tasks = await engine_client.get_supported_tasks()
logger.info("Supported tasks: %s", supported_tasks)
logger.info("Supported_tasks: %s", supported_tasks)
resolved_chat_template = load_chat_template(args.chat_template) resolved_chat_template = load_chat_template(args.chat_template)
if resolved_chat_template is not None: if resolved_chat_template is not None:
@ -1688,7 +1687,6 @@ async def init_app_state(
state.openai_serving_models = OpenAIServingModels( state.openai_serving_models = OpenAIServingModels(
engine_client=engine_client, engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths, base_model_paths=base_model_paths,
lora_modules=lora_modules, lora_modules=lora_modules,
) )
@ -1696,7 +1694,6 @@ async def init_app_state(
state.openai_serving_responses = ( state.openai_serving_responses = (
OpenAIServingResponses( OpenAIServingResponses(
engine_client, engine_client,
model_config,
state.openai_serving_models, state.openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
chat_template=resolved_chat_template, chat_template=resolved_chat_template,
@ -1717,7 +1714,6 @@ async def init_app_state(
state.openai_serving_chat = ( state.openai_serving_chat = (
OpenAIServingChat( OpenAIServingChat(
engine_client, engine_client,
model_config,
state.openai_serving_models, state.openai_serving_models,
args.response_role, args.response_role,
request_logger=request_logger, request_logger=request_logger,
@ -1740,7 +1736,6 @@ async def init_app_state(
state.openai_serving_completion = ( state.openai_serving_completion = (
OpenAIServingCompletion( OpenAIServingCompletion(
engine_client, engine_client,
model_config,
state.openai_serving_models, state.openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids, return_tokens_as_token_ids=args.return_tokens_as_token_ids,
@ -1754,7 +1749,6 @@ async def init_app_state(
state.openai_serving_pooling = ( state.openai_serving_pooling = (
OpenAIServingPooling( OpenAIServingPooling(
engine_client, engine_client,
vllm_config,
state.openai_serving_models, state.openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
chat_template=resolved_chat_template, chat_template=resolved_chat_template,
@ -1768,7 +1762,6 @@ async def init_app_state(
state.openai_serving_embedding = ( state.openai_serving_embedding = (
OpenAIServingEmbedding( OpenAIServingEmbedding(
engine_client, engine_client,
model_config,
state.openai_serving_models, state.openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
chat_template=resolved_chat_template, chat_template=resolved_chat_template,
@ -1782,7 +1775,6 @@ async def init_app_state(
state.openai_serving_classification = ( state.openai_serving_classification = (
ServingClassification( ServingClassification(
engine_client, engine_client,
model_config,
state.openai_serving_models, state.openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
log_error_stack=args.log_error_stack, log_error_stack=args.log_error_stack,
@ -1793,7 +1785,6 @@ async def init_app_state(
state.openai_serving_scores = ( state.openai_serving_scores = (
ServingScores( ServingScores(
engine_client, engine_client,
model_config,
state.openai_serving_models, state.openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
log_error_stack=args.log_error_stack, log_error_stack=args.log_error_stack,
@ -1803,7 +1794,6 @@ async def init_app_state(
) )
state.openai_serving_tokenization = OpenAIServingTokenization( state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client, engine_client,
model_config,
state.openai_serving_models, state.openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
chat_template=resolved_chat_template, chat_template=resolved_chat_template,
@ -1814,7 +1804,6 @@ async def init_app_state(
state.openai_serving_transcription = ( state.openai_serving_transcription = (
OpenAIServingTranscription( OpenAIServingTranscription(
engine_client, engine_client,
model_config,
state.openai_serving_models, state.openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
log_error_stack=args.log_error_stack, log_error_stack=args.log_error_stack,
@ -1825,7 +1814,6 @@ async def init_app_state(
state.openai_serving_translation = ( state.openai_serving_translation = (
OpenAIServingTranslation( OpenAIServingTranslation(
engine_client, engine_client,
model_config,
state.openai_serving_models, state.openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
log_error_stack=args.log_error_stack, log_error_stack=args.log_error_stack,
@ -1946,12 +1934,11 @@ async def run_server_worker(
maybe_register_tokenizer_info_endpoint(args) maybe_register_tokenizer_info_endpoint(args)
app = build_app(args) app = build_app(args)
vllm_config = await engine_client.get_vllm_config() await init_app_state(engine_client, app.state, args)
await init_app_state(engine_client, vllm_config, app.state, args)
logger.info( logger.info(
"Starting vLLM API server %d on %s", "Starting vLLM API server %d on %s",
vllm_config.parallel_config._api_process_rank, engine_client.vllm_config.parallel_config._api_process_rank,
listen_address, listen_address,
) )
shutdown_task = await serve_http( shutdown_task = await serve_http(

View File

@ -14,7 +14,6 @@ import torch
from prometheus_client import start_http_server from prometheus_client import start_http_server
from tqdm import tqdm from tqdm import tqdm
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
@ -328,7 +327,6 @@ async def run_request(
async def run_batch( async def run_batch(
engine_client: EngineClient, engine_client: EngineClient,
vllm_config: VllmConfig,
args: Namespace, args: Namespace,
) -> None: ) -> None:
if args.served_model_name is not None: if args.served_model_name is not None:
@ -345,22 +343,19 @@ async def run_batch(
BaseModelPath(name=name, model_path=args.model) for name in served_model_names BaseModelPath(name=name, model_path=args.model) for name in served_model_names
] ]
model_config = vllm_config.model_config model_config = engine_client.model_config
supported_tasks = await engine_client.get_supported_tasks() supported_tasks = await engine_client.get_supported_tasks()
logger.info("Supported_tasks: %s", supported_tasks) logger.info("Supported tasks: %s", supported_tasks)
# Create the openai serving objects. # Create the openai serving objects.
openai_serving_models = OpenAIServingModels( openai_serving_models = OpenAIServingModels(
engine_client=engine_client, engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths, base_model_paths=base_model_paths,
lora_modules=None, lora_modules=None,
) )
openai_serving_chat = ( openai_serving_chat = (
OpenAIServingChat( OpenAIServingChat(
engine_client, engine_client,
model_config,
openai_serving_models, openai_serving_models,
args.response_role, args.response_role,
request_logger=request_logger, request_logger=request_logger,
@ -374,7 +369,6 @@ async def run_batch(
openai_serving_embedding = ( openai_serving_embedding = (
OpenAIServingEmbedding( OpenAIServingEmbedding(
engine_client, engine_client,
model_config,
openai_serving_models, openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
chat_template=None, chat_template=None,
@ -392,7 +386,6 @@ async def run_batch(
openai_serving_scores = ( openai_serving_scores = (
ServingScores( ServingScores(
engine_client, engine_client,
model_config,
openai_serving_models, openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
) )
@ -509,9 +502,7 @@ async def main(args: Namespace):
usage_context=UsageContext.OPENAI_BATCH_RUNNER, usage_context=UsageContext.OPENAI_BATCH_RUNNER,
disable_frontend_multiprocessing=False, disable_frontend_multiprocessing=False,
) as engine_client: ) as engine_client:
vllm_config = await engine_client.get_vllm_config() await run_batch(engine_client, args)
await run_batch(engine_client, vllm_config, args)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -15,7 +15,6 @@ from fastapi import Request
from openai_harmony import Message as OpenAIMessage from openai_harmony import Message as OpenAIMessage
from pydantic import TypeAdapter from pydantic import TypeAdapter
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ( from vllm.entrypoints.chat_utils import (
ChatTemplateContentFormatOption, ChatTemplateContentFormatOption,
@ -81,7 +80,6 @@ class OpenAIServingChat(OpenAIServing):
def __init__( def __init__(
self, self,
engine_client: EngineClient, engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels, models: OpenAIServingModels,
response_role: str, response_role: str,
*, *,
@ -101,7 +99,6 @@ class OpenAIServingChat(OpenAIServing):
) -> None: ) -> None:
super().__init__( super().__init__(
engine_client=engine_client, engine_client=engine_client,
model_config=model_config,
models=models, models=models,
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids, return_tokens_as_token_ids=return_tokens_as_token_ids,
@ -138,7 +135,7 @@ class OpenAIServingChat(OpenAIServing):
self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
if self.enable_auto_tools: if self.enable_auto_tools:
try: try:
if tool_parser == "pythonic" and model_config.model.startswith( if tool_parser == "pythonic" and self.model_config.model.startswith(
"meta-llama/Llama-3.2" "meta-llama/Llama-3.2"
): ):
logger.warning( logger.warning(
@ -169,7 +166,7 @@ class OpenAIServingChat(OpenAIServing):
else: else:
self.tool_call_id_type = "random" self.tool_call_id_type = "random"
self.use_harmony = model_config.hf_config.model_type == "gpt_oss" self.use_harmony = self.model_config.hf_config.model_type == "gpt_oss"
if self.use_harmony: if self.use_harmony:
if "stop_token_ids" not in self.default_sampling_params: if "stop_token_ids" not in self.default_sampling_params:
self.default_sampling_params["stop_token_ids"] = [] self.default_sampling_params["stop_token_ids"] = []
@ -338,7 +335,7 @@ class OpenAIServingChat(OpenAIServing):
) )
if isinstance(sampling_params, BeamSearchParams): if isinstance(sampling_params, BeamSearchParams):
generator = self.engine_client.beam_search( generator = self.beam_search(
prompt=engine_prompt, prompt=engine_prompt,
request_id=request_id, request_id=request_id,
params=sampling_params, params=sampling_params,

View File

@ -8,7 +8,6 @@ import numpy as np
from fastapi import Request from fastapi import Request
from typing_extensions import override from typing_extensions import override
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
@ -128,7 +127,6 @@ class ServingClassification(ClassificationMixin):
def __init__( def __init__(
self, self,
engine_client: EngineClient, engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels, models: OpenAIServingModels,
*, *,
request_logger: Optional[RequestLogger], request_logger: Optional[RequestLogger],
@ -136,7 +134,6 @@ class ServingClassification(ClassificationMixin):
) -> None: ) -> None:
super().__init__( super().__init__(
engine_client=engine_client, engine_client=engine_client,
model_config=model_config,
models=models, models=models,
request_logger=request_logger, request_logger=request_logger,
log_error_stack=log_error_stack, log_error_stack=log_error_stack,

View File

@ -10,7 +10,6 @@ from typing import Optional, Union, cast
import jinja2 import jinja2
from fastapi import Request from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
@ -44,7 +43,6 @@ class OpenAIServingCompletion(OpenAIServing):
def __init__( def __init__(
self, self,
engine_client: EngineClient, engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels, models: OpenAIServingModels,
*, *,
request_logger: Optional[RequestLogger], request_logger: Optional[RequestLogger],
@ -55,7 +53,6 @@ class OpenAIServingCompletion(OpenAIServing):
): ):
super().__init__( super().__init__(
engine_client=engine_client, engine_client=engine_client,
model_config=model_config,
models=models, models=models,
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids, return_tokens_as_token_ids=return_tokens_as_token_ids,
@ -201,7 +198,7 @@ class OpenAIServingCompletion(OpenAIServing):
# but pre-commit in CI fails without it. # but pre-commit in CI fails without it.
engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt], engine_prompt) engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt], engine_prompt)
if isinstance(sampling_params, BeamSearchParams): if isinstance(sampling_params, BeamSearchParams):
generator = self.engine_client.beam_search( generator = self.beam_search(
prompt=engine_prompt, prompt=engine_prompt,
request_id=request_id, request_id=request_id,
params=sampling_params, params=sampling_params,

View File

@ -10,7 +10,6 @@ import torch
from fastapi import Request from fastapi import Request
from typing_extensions import assert_never, override from typing_extensions import assert_never, override
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
@ -597,7 +596,6 @@ class OpenAIServingEmbedding(EmbeddingMixin):
def __init__( def __init__(
self, self,
engine_client: EngineClient, engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels, models: OpenAIServingModels,
*, *,
request_logger: Optional[RequestLogger], request_logger: Optional[RequestLogger],
@ -608,7 +606,6 @@ class OpenAIServingEmbedding(EmbeddingMixin):
) -> None: ) -> None:
super().__init__( super().__init__(
engine_client=engine_client, engine_client=engine_client,
model_config=model_config,
models=models, models=models,
request_logger=request_logger, request_logger=request_logger,
log_error_stack=log_error_stack, log_error_stack=log_error_stack,

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import json import json
import sys import sys
import time import time
@ -15,17 +16,13 @@ from pydantic import BaseModel, ConfigDict, Field
from starlette.datastructures import Headers from starlette.datastructures import Headers
from typing_extensions import TypeIs from typing_extensions import TypeIs
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.processor import Processor
if sys.version_info >= (3, 12): if sys.version_info >= (3, 12):
from typing import TypedDict from typing import TypedDict
else: else:
from typing_extensions import TypedDict from typing_extensions import TypedDict
import vllm.envs as envs import vllm.envs as envs
from vllm.config import ModelConfig from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ( from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam, ChatCompletionMessageParam,
@ -68,9 +65,14 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser from vllm.entrypoints.openai.tool_parsers import ToolParser
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs.data import PromptType from vllm.inputs.data import PromptType
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.inputs.parse import PromptComponents, get_prompt_components from vllm.inputs.parse import (
PromptComponents,
get_prompt_components,
is_explicit_encoder_decoder_prompt,
)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob, PromptLogprobs from vllm.logprobs import Logprob, PromptLogprobs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
@ -78,7 +80,7 @@ from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error
MultiModalDataDict, MultiModalDataDict,
MultiModalUUIDDict, MultiModalUUIDDict,
) )
from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tracing import ( from vllm.tracing import (
@ -89,11 +91,13 @@ from vllm.tracing import (
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import ( from vllm.utils import (
AsyncMicrobatchTokenizer, AsyncMicrobatchTokenizer,
collect_from_async_generator,
is_list_of, is_list_of,
make_async, make_async,
merge_async_iterators, merge_async_iterators,
random_uuid, random_uuid,
) )
from vllm.v1.engine import EngineCoreRequest
logger = init_logger(__name__) logger = init_logger(__name__)
@ -240,7 +244,6 @@ class OpenAIServing:
def __init__( def __init__(
self, self,
engine_client: EngineClient, engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels, models: OpenAIServingModels,
*, *,
request_logger: Optional[RequestLogger], request_logger: Optional[RequestLogger],
@ -251,8 +254,6 @@ class OpenAIServing:
super().__init__() super().__init__()
self.engine_client = engine_client self.engine_client = engine_client
self.model_config = model_config
self.max_model_len = model_config.max_model_len
self.models = models self.models = models
@ -268,12 +269,194 @@ class OpenAIServing:
self._async_tokenizer_pool: dict[AnyTokenizer, AsyncMicrobatchTokenizer] = {} self._async_tokenizer_pool: dict[AnyTokenizer, AsyncMicrobatchTokenizer] = {}
self.log_error_stack = log_error_stack self.log_error_stack = log_error_stack
async def _get_processor(self) -> Processor: self.processor = self.models.processor
if not hasattr(self, "_processor"): self.io_processor = self.models.io_processor
vllm_config = await self.engine_client.get_vllm_config() self.model_config = self.models.model_config
self._processor = Processor(vllm_config) self.max_model_len = self.model_config.max_model_len
return self._processor async def beam_search(
self,
prompt: PromptType,
request_id: str,
params: BeamSearchParams,
lora_request: Optional[LoRARequest] = None,
) -> AsyncGenerator[RequestOutput, None]:
beam_width = params.beam_width
max_tokens = params.max_tokens
ignore_eos = params.ignore_eos
temperature = params.temperature
length_penalty = params.length_penalty
include_stop_str_in_output = params.include_stop_str_in_output
processor = self.processor
tokenizer = processor.tokenizer
if tokenizer is None:
raise ValueError(
"You cannot use beam search when `skip_tokenizer_init` is True"
)
eos_token_id: int = tokenizer.eos_token_id # type: ignore
if is_explicit_encoder_decoder_prompt(prompt):
raise NotImplementedError
else:
processed_inputs = processor.input_preprocessor._prompt_to_llm_inputs(
prompt
)
if processed_inputs["type"] == "embeds":
raise NotImplementedError
# This is a workaround to fix multimodal beam search; this is a
# bandaid fix for 2 small problems:
# 1. Multi_modal_data on the processed_inputs currently resolves to
# `None`.
# 2. preprocessing above expands the multimodal placeholders. However,
# this happens again in generation, so the double expansion causes
# a mismatch.
# TODO - would be ideal to handle this more gracefully.
prompt_text: Optional[str]
prompt_token_ids: list[int]
multi_modal_data: Optional[MultiModalDataDict]
if isinstance(prompt, str):
prompt_text = prompt
prompt_token_ids = []
multi_modal_data = None
else:
prompt_text = prompt.get("prompt") # type: ignore
prompt_token_ids = prompt.get("prompt_token_ids", []) # type: ignore
multi_modal_data = prompt.get("multi_modal_data") # type: ignore
mm_processor_kwargs: Optional[dict[str, Any]] = processed_inputs.get(
"mm_processor_kwargs"
) # type: ignore
tokenized_length = len(prompt_token_ids)
sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
beam_search_params = SamplingParams(
logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature,
)
all_beams = [
BeamSearchSequence(
tokens=prompt_token_ids,
cum_logprob=0,
logprobs=[],
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs,
lora_request=lora_request,
)
]
completed = []
for _ in range(max_tokens):
prompts_batch, lora_req_batch = zip(
*[
(
EngineTokensPrompt(
prompt_token_ids=beam.tokens,
multi_modal_data=beam.multi_modal_data,
mm_processor_kwargs=beam.mm_processor_kwargs,
),
beam.lora_request,
)
for beam in all_beams
]
)
tasks = []
request_id_batch = f"{request_id}-{random_uuid()}"
for i, (individual_prompt, lora_req) in enumerate(
zip(prompts_batch, lora_req_batch)
):
request_id_item = f"{request_id_batch}-beam-{i}"
task = asyncio.create_task(
collect_from_async_generator(
self.engine_client.generate(
individual_prompt,
beam_search_params,
request_id_item,
lora_request=lora_req,
)
)
)
tasks.append(task)
output = [x[0] for x in await asyncio.gather(*tasks)]
new_beams = []
for i, current_beam in enumerate(all_beams):
result = output[i]
if result.outputs[0].logprobs is not None:
logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items():
if token_id == eos_token_id and not ignore_eos:
completed.append(
BeamSearchSequence(
tokens=current_beam.tokens + [token_id]
if include_stop_str_in_output
else current_beam.tokens,
logprobs=current_beam.logprobs + [logprobs],
cum_logprob=current_beam.cum_logprob
+ logprob_obj.logprob,
finish_reason="stop",
stop_reason=eos_token_id,
)
)
else:
new_beams.append(
BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs],
lora_request=current_beam.lora_request,
cum_logprob=current_beam.cum_logprob
+ logprob_obj.logprob,
multi_modal_data=current_beam.multi_modal_data,
mm_processor_kwargs=current_beam.mm_processor_kwargs,
)
)
sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
all_beams = sorted_beams[:beam_width]
completed.extend(all_beams)
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
best_beams = sorted_completed[:beam_width]
for beam in best_beams:
if beam.tokens[-1] == eos_token_id and not ignore_eos:
# Skip the eos token in the text.
tokens = beam.tokens[tokenized_length:-1]
else:
tokens = beam.tokens[tokenized_length:]
beam.text = tokenizer.decode(tokens)
yield RequestOutput(
request_id=request_id,
prompt=prompt_text,
outputs=[
CompletionOutput(
text=beam.text, # type: ignore
cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens[tokenized_length:],
index=i,
logprobs=beam.logprobs,
finish_reason=beam.finish_reason
if beam.finish_reason is not None
else "length",
stop_reason=beam.stop_reason,
)
for (i, beam) in enumerate(best_beams)
],
finished=True,
prompt_token_ids=prompt_token_ids,
prompt_logprobs=None,
)
def _get_renderer(self, tokenizer: Optional[AnyTokenizer]) -> BaseRenderer: def _get_renderer(self, tokenizer: Optional[AnyTokenizer]) -> BaseRenderer:
""" """
@ -938,8 +1121,7 @@ class OpenAIServing:
self.max_model_len, params.truncate_prompt_tokens, tokenization_kwargs self.max_model_len, params.truncate_prompt_tokens, tokenization_kwargs
) )
processor = await self._get_processor() engine_request = self.processor.process_inputs(
engine_request = processor.process_inputs(
request_id, request_id,
engine_prompt, engine_prompt,
params, params,

View File

@ -7,7 +7,6 @@ from dataclasses import dataclass
from http import HTTPStatus from http import HTTPStatus
from typing import Optional, Union from typing import Optional, Union
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
ErrorInfo, ErrorInfo,
@ -51,18 +50,14 @@ class OpenAIServingModels:
def __init__( def __init__(
self, self,
engine_client: EngineClient, engine_client: EngineClient,
model_config: ModelConfig,
base_model_paths: list[BaseModelPath], base_model_paths: list[BaseModelPath],
*, *,
lora_modules: Optional[list[LoRAModulePath]] = None, lora_modules: Optional[list[LoRAModulePath]] = None,
): ):
super().__init__() super().__init__()
self.base_model_paths = base_model_paths
self.max_model_len = model_config.max_model_len
self.engine_client = engine_client self.engine_client = engine_client
self.model_config = model_config self.base_model_paths = base_model_paths
self.static_lora_modules = lora_modules self.static_lora_modules = lora_modules
self.lora_requests: dict[str, LoRARequest] = {} self.lora_requests: dict[str, LoRARequest] = {}
@ -75,6 +70,11 @@ class OpenAIServingModels:
) )
self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock) self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock)
self.processor = self.engine_client.processor
self.io_processor = self.engine_client.io_processor
self.model_config = self.engine_client.model_config
self.max_model_len = self.model_config.max_model_len
async def init_static_loras(self): async def init_static_loras(self):
"""Loads all static LoRA modules. """Loads all static LoRA modules.
Raises if any fail to load""" Raises if any fail to load"""

View File

@ -13,7 +13,6 @@ import torch
from fastapi import Request from fastapi import Request
from typing_extensions import assert_never from typing_extensions import assert_never
from vllm.config import VllmConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
@ -34,7 +33,6 @@ from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import _validate_truncation_size from vllm.entrypoints.utils import _validate_truncation_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.plugins.io_processors import get_io_processor
from vllm.utils import merge_async_iterators from vllm.utils import merge_async_iterators
logger = init_logger(__name__) logger = init_logger(__name__)
@ -60,7 +58,6 @@ class OpenAIServingPooling(OpenAIServing):
def __init__( def __init__(
self, self,
engine_client: EngineClient, engine_client: EngineClient,
vllm_config: VllmConfig,
models: OpenAIServingModels, models: OpenAIServingModels,
*, *,
request_logger: Optional[RequestLogger], request_logger: Optional[RequestLogger],
@ -71,7 +68,6 @@ class OpenAIServingPooling(OpenAIServing):
) -> None: ) -> None:
super().__init__( super().__init__(
engine_client=engine_client, engine_client=engine_client,
model_config=vllm_config.model_config,
models=models, models=models,
request_logger=request_logger, request_logger=request_logger,
log_error_stack=log_error_stack, log_error_stack=log_error_stack,
@ -80,8 +76,6 @@ class OpenAIServingPooling(OpenAIServing):
self.chat_template = chat_template self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template self.trust_request_chat_template = trust_request_chat_template
io_processor_plugin = self.model_config.io_processor_plugin
self.io_processor = get_io_processor(vllm_config, io_processor_plugin)
async def create_pooling( async def create_pooling(
self, self,

View File

@ -49,7 +49,6 @@ from openai.types.responses.response_reasoning_item import (
from openai_harmony import Message as OpenAIHarmonyMessage from openai_harmony import Message as OpenAIHarmonyMessage
from vllm import envs from vllm import envs
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ( from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam, ChatCompletionMessageParam,
@ -109,7 +108,6 @@ class OpenAIServingResponses(OpenAIServing):
def __init__( def __init__(
self, self,
engine_client: EngineClient, engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels, models: OpenAIServingModels,
*, *,
request_logger: Optional[RequestLogger], request_logger: Optional[RequestLogger],
@ -127,7 +125,6 @@ class OpenAIServingResponses(OpenAIServing):
) -> None: ) -> None:
super().__init__( super().__init__(
engine_client=engine_client, engine_client=engine_client,
model_config=model_config,
models=models, models=models,
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids, return_tokens_as_token_ids=return_tokens_as_token_ids,
@ -176,7 +173,7 @@ class OpenAIServingResponses(OpenAIServing):
"the store." "the store."
) )
self.use_harmony = model_config.hf_config.model_type == "gpt_oss" self.use_harmony = self.model_config.hf_config.model_type == "gpt_oss"
if self.use_harmony: if self.use_harmony:
logger.warning( logger.warning(
"For gpt-oss, we ignore --enable-auto-tool-choice " "For gpt-oss, we ignore --enable-auto-tool-choice "

View File

@ -7,7 +7,6 @@ from typing import Any, Optional, Union
from fastapi import Request from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
@ -47,7 +46,6 @@ class ServingScores(OpenAIServing):
def __init__( def __init__(
self, self,
engine_client: EngineClient, engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels, models: OpenAIServingModels,
*, *,
request_logger: Optional[RequestLogger], request_logger: Optional[RequestLogger],
@ -55,7 +53,6 @@ class ServingScores(OpenAIServing):
) -> None: ) -> None:
super().__init__( super().__init__(
engine_client=engine_client, engine_client=engine_client,
model_config=model_config,
models=models, models=models,
request_logger=request_logger, request_logger=request_logger,
log_error_stack=log_error_stack, log_error_stack=log_error_stack,

View File

@ -6,7 +6,6 @@ from typing import Any, Final, Optional, Union
import jinja2 import jinja2
from fastapi import Request from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
@ -32,7 +31,6 @@ class OpenAIServingTokenization(OpenAIServing):
def __init__( def __init__(
self, self,
engine_client: EngineClient, engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels, models: OpenAIServingModels,
*, *,
request_logger: Optional[RequestLogger], request_logger: Optional[RequestLogger],
@ -43,7 +41,6 @@ class OpenAIServingTokenization(OpenAIServing):
) -> None: ) -> None:
super().__init__( super().__init__(
engine_client=engine_client, engine_client=engine_client,
model_config=model_config,
models=models, models=models,
request_logger=request_logger, request_logger=request_logger,
log_error_stack=log_error_stack, log_error_stack=log_error_stack,

View File

@ -5,7 +5,6 @@ from typing import Optional, Union
from fastapi import Request from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
@ -34,7 +33,6 @@ class OpenAIServingTranscription(OpenAISpeechToText):
def __init__( def __init__(
self, self,
engine_client: EngineClient, engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels, models: OpenAIServingModels,
*, *,
request_logger: Optional[RequestLogger], request_logger: Optional[RequestLogger],
@ -43,7 +41,6 @@ class OpenAIServingTranscription(OpenAISpeechToText):
): ):
super().__init__( super().__init__(
engine_client=engine_client, engine_client=engine_client,
model_config=model_config,
models=models, models=models,
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids, return_tokens_as_token_ids=return_tokens_as_token_ids,
@ -95,7 +92,6 @@ class OpenAIServingTranslation(OpenAISpeechToText):
def __init__( def __init__(
self, self,
engine_client: EngineClient, engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels, models: OpenAIServingModels,
*, *,
request_logger: Optional[RequestLogger], request_logger: Optional[RequestLogger],
@ -104,7 +100,6 @@ class OpenAIServingTranslation(OpenAISpeechToText):
): ):
super().__init__( super().__init__(
engine_client=engine_client, engine_client=engine_client,
model_config=model_config,
models=models, models=models,
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids, return_tokens_as_token_ids=return_tokens_as_token_ids,

View File

@ -12,7 +12,6 @@ import numpy as np
from fastapi import Request from fastapi import Request
import vllm.envs as envs import vllm.envs as envs
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
@ -53,7 +52,6 @@ class OpenAISpeechToText(OpenAIServing):
def __init__( def __init__(
self, self,
engine_client: EngineClient, engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels, models: OpenAIServingModels,
*, *,
request_logger: Optional[RequestLogger], request_logger: Optional[RequestLogger],
@ -63,7 +61,6 @@ class OpenAISpeechToText(OpenAIServing):
): ):
super().__init__( super().__init__(
engine_client=engine_client, engine_client=engine_client,
model_config=model_config,
models=models, models=models,
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids, return_tokens_as_token_ids=return_tokens_as_token_ids,
@ -74,7 +71,7 @@ class OpenAISpeechToText(OpenAIServing):
self.task_type = task_type self.task_type = task_type
self.asr_config = self.model_cls.get_speech_to_text_config( self.asr_config = self.model_cls.get_speech_to_text_config(
model_config, task_type self.model_config, task_type
) )
self.max_audio_filesize_mb = envs.VLLM_MAX_AUDIO_CLIP_FILESIZE_MB self.max_audio_filesize_mb = envs.VLLM_MAX_AUDIO_CLIP_FILESIZE_MB

View File

@ -20,13 +20,13 @@ class TextPrompt(TypedDict):
prompt: str prompt: str
"""The input text to be tokenized before passing to the model.""" """The input text to be tokenized before passing to the model."""
multi_modal_data: NotRequired["MultiModalDataDict"] multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
""" """
Optional multi-modal data to pass to the model, Optional multi-modal data to pass to the model,
if the model supports it. if the model supports it.
""" """
mm_processor_kwargs: NotRequired[dict[str, Any]] mm_processor_kwargs: NotRequired[Optional[dict[str, Any]]]
""" """
Optional multi-modal processor kwargs to be forwarded to the Optional multi-modal processor kwargs to be forwarded to the
multimodal input mapper & processor. Note that if multiple modalities multimodal input mapper & processor. Note that if multiple modalities
@ -61,13 +61,13 @@ class TokensPrompt(TypedDict):
token_type_ids: NotRequired[list[int]] token_type_ids: NotRequired[list[int]]
"""A list of token type IDs to pass to the cross encoder model.""" """A list of token type IDs to pass to the cross encoder model."""
multi_modal_data: NotRequired["MultiModalDataDict"] multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
""" """
Optional multi-modal data to pass to the model, Optional multi-modal data to pass to the model,
if the model supports it. if the model supports it.
""" """
mm_processor_kwargs: NotRequired[dict[str, Any]] mm_processor_kwargs: NotRequired[Optional[dict[str, Any]]]
""" """
Optional multi-modal processor kwargs to be forwarded to the Optional multi-modal processor kwargs to be forwarded to the
multimodal input mapper & processor. Note that if multiple modalities multimodal input mapper & processor. Note that if multiple modalities

View File

@ -17,7 +17,7 @@ from vllm.multimodal.inputs import (
MultiModalUUIDDict, MultiModalUUIDDict,
) )
from vllm.multimodal.processing import BaseMultiModalProcessor from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils.jsontree import json_iter_leaves from vllm.utils.jsontree import json_iter_leaves
from .data import ( from .data import (
@ -45,20 +45,17 @@ class InputPreprocessor:
def __init__( def __init__(
self, self,
model_config: ModelConfig, model_config: ModelConfig,
tokenizer: Optional[AnyTokenizer],
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
mm_processor_cache: Optional[BaseMultiModalProcessorCache] = None, mm_processor_cache: Optional[BaseMultiModalProcessorCache] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.model_config = model_config self.model_config = model_config
self.tokenizer = tokenizer
self.mm_registry = mm_registry self.mm_registry = mm_registry
self.mm_processor_cache = mm_processor_cache self.mm_processor_cache = mm_processor_cache
if model_config.skip_tokenizer_init:
self.tokenizer = None
else:
self.tokenizer = init_tokenizer_from_configs(model_config)
def get_tokenizer(self) -> AnyTokenizer: def get_tokenizer(self) -> AnyTokenizer:
if self.tokenizer is None: if self.tokenizer is None:
raise ValueError( raise ValueError(
@ -351,8 +348,8 @@ class InputPreprocessor:
if self.model_config.is_multimodal_model: if self.model_config.is_multimodal_model:
inputs = self._process_multimodal( inputs = self._process_multimodal(
prompt_token_ids, prompt_token_ids,
parsed_content.get("multi_modal_data", {}), parsed_content.get("multi_modal_data") or {},
parsed_content.get("mm_processor_kwargs"), parsed_content.get("mm_processor_kwargs") or {},
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
@ -380,8 +377,8 @@ class InputPreprocessor:
if self.model_config.is_multimodal_model: if self.model_config.is_multimodal_model:
inputs = self._process_multimodal( inputs = self._process_multimodal(
prompt_text, prompt_text,
parsed_content.get("multi_modal_data", {}), parsed_content.get("multi_modal_data") or {},
parsed_content.get("mm_processor_kwargs"), parsed_content.get("mm_processor_kwargs") or {},
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )

View File

@ -12,23 +12,23 @@ import numpy as np
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.config import ModelConfig, VllmConfig from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.utils import _validate_truncation_size from vllm.entrypoints.utils import _validate_truncation_size
from vllm.envs import VLLM_V1_OUTPUT_PROC_CHUNK_SIZE from vllm.envs import VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
from vllm.inputs import PromptType from vllm.inputs import PromptType
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
from vllm.tracing import init_tracer from vllm.tracing import init_tracer
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device, as_list, cancel_task_threadsafe, cdiv, deprecate_kwargs from vllm.utils import Device, as_list, cancel_task_threadsafe, cdiv, deprecate_kwargs
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
@ -104,8 +104,16 @@ class AsyncLLM(EngineClient):
"logger list; enabling logging without default stat loggers" "logger list; enabling logging without default stat loggers"
) )
# Processor (converts Inputs --> EngineCoreRequests). if self.model_config.skip_tokenizer_init:
self.processor = Processor(vllm_config, mm_registry=mm_registry) tokenizer = None
else:
tokenizer = init_tokenizer_from_configs(self.model_config)
self.processor = Processor(self.vllm_config, tokenizer)
self.io_processor = get_io_processor(
self.vllm_config,
self.model_config.io_processor_plugin,
)
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput). # OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
self.output_processor = OutputProcessor( self.output_processor = OutputProcessor(
@ -245,10 +253,6 @@ class AsyncLLM(EngineClient):
cancel_task_threadsafe(getattr(self, "output_handler", None)) cancel_task_threadsafe(getattr(self, "output_handler", None))
@property
def tokenizer(self) -> Optional[AnyTokenizer]:
return self.processor.tokenizer
async def get_supported_tasks(self) -> tuple[SupportedTask, ...]: async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return await self.engine_core.get_supported_tasks_async() return await self.engine_core.get_supported_tasks_async()
@ -615,14 +619,13 @@ class AsyncLLM(EngineClient):
logger.info("Request %s failed.", request_id) logger.info("Request %s failed.", request_id)
raise EngineGenerateError() from e raise EngineGenerateError() from e
async def get_vllm_config(self) -> VllmConfig: @property
return self.vllm_config def tokenizer(self) -> Optional[AnyTokenizer]:
return self.processor.tokenizer
async def get_model_config(self) -> ModelConfig: @tokenizer.setter
return self.model_config def tokenizer(self, tokenizer: Optional[AnyTokenizer]) -> None:
self.processor.tokenizer = tokenizer
async def get_input_preprocessor(self) -> InputPreprocessor:
return self.processor.input_preprocessor
async def get_tokenizer(self) -> AnyTokenizer: async def get_tokenizer(self) -> AnyTokenizer:
if self.tokenizer is None: if self.tokenizer is None:

View File

@ -19,11 +19,12 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
from vllm.tracing import init_tracer from vllm.tracing import init_tracer
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device from vllm.utils import Device
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
@ -95,8 +96,16 @@ class LLMEngine:
self.dp_group = None self.dp_group = None
self.should_execute_dummy_batch = False self.should_execute_dummy_batch = False
# Processor (convert Inputs --> EngineCoreRequests) if self.model_config.skip_tokenizer_init:
self.processor = Processor(vllm_config, mm_registry=mm_registry) tokenizer = None
else:
tokenizer = init_tokenizer_from_configs(self.model_config)
self.processor = Processor(self.vllm_config, tokenizer)
self.io_processor = get_io_processor(
self.vllm_config,
self.model_config.io_processor_plugin,
)
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput). # OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
self.output_processor = OutputProcessor( self.output_processor = OutputProcessor(
@ -204,14 +213,6 @@ class LLMEngine:
def validate_outputs(cls, outputs, output_type): def validate_outputs(cls, outputs, output_type):
return outputs return outputs
@property
def tokenizer(self) -> Optional[AnyTokenizer]:
return self.processor.tokenizer
@tokenizer.setter
def tokenizer(self, tokenizer: Optional[AnyTokenizer]) -> None:
self.processor.tokenizer = tokenizer
def get_supported_tasks(self) -> tuple[SupportedTask, ...]: def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.engine_core.get_supported_tasks() return self.engine_core.get_supported_tasks()
@ -313,12 +314,6 @@ class LLMEngine:
return processed_outputs.request_outputs return processed_outputs.request_outputs
def get_vllm_config(self):
return self.vllm_config
def get_model_config(self):
return self.model_config
def start_profile(self): def start_profile(self):
self.engine_core.profile(True) self.engine_core.profile(True)
@ -345,6 +340,14 @@ class LLMEngine:
assert self.log_stats, "Stat logging disabled" assert self.log_stats, "Stat logging disabled"
return get_metrics_snapshot() return get_metrics_snapshot()
@property
def tokenizer(self) -> Optional[AnyTokenizer]:
return self.processor.tokenizer
@tokenizer.setter
def tokenizer(self, tokenizer: Optional[AnyTokenizer]) -> None:
self.processor.tokenizer = tokenizer
def get_tokenizer(self) -> AnyTokenizer: def get_tokenizer(self) -> AnyTokenizer:
if self.tokenizer is None: if self.tokenizer is None:
raise ValueError( raise ValueError(

View File

@ -37,6 +37,7 @@ class Processor:
def __init__( def __init__(
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
tokenizer: Optional[AnyTokenizer],
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
) -> None: ) -> None:
self.vllm_config = vllm_config self.vllm_config = vllm_config
@ -52,6 +53,7 @@ class Processor:
self.input_preprocessor = InputPreprocessor( self.input_preprocessor = InputPreprocessor(
self.model_config, self.model_config,
tokenizer,
mm_registry, mm_registry,
mm_processor_cache=self.mm_processor_cache, mm_processor_cache=self.mm_processor_cache,
) )