mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 04:04:57 +08:00
[Bugfix] Fix SHM cache initialization (#26427)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
dc7976dd9f
commit
4bdf7ac593
@ -113,15 +113,17 @@ def mock_serving_setup():
|
||||
mock_engine.generate.reset_mock()
|
||||
mock_engine.add_lora.reset_mock()
|
||||
|
||||
mock_model_config = MockModelConfig()
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
|
||||
models = OpenAIServingModels(
|
||||
engine_client=mock_engine,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
model_config=mock_model_config,
|
||||
)
|
||||
|
||||
serving_completion = OpenAIServingCompletion(
|
||||
mock_engine, mock_model_config, models, request_logger=None
|
||||
mock_engine, models, request_logger=None
|
||||
)
|
||||
|
||||
serving_completion._process_inputs = AsyncMock(
|
||||
|
||||
@ -245,17 +245,13 @@ class MockModelConfig:
|
||||
return self.diff_sampling_param or {}
|
||||
|
||||
|
||||
def _build_serving_chat(
|
||||
engine: AsyncLLM, model_config: MockModelConfig
|
||||
) -> OpenAIServingChat:
|
||||
def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
|
||||
models = OpenAIServingModels(
|
||||
engine_client=engine,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
model_config=model_config,
|
||||
)
|
||||
serving_chat = OpenAIServingChat(
|
||||
engine,
|
||||
model_config,
|
||||
models,
|
||||
response_role="assistant",
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
@ -280,18 +276,17 @@ def _build_serving_chat(
|
||||
|
||||
@dataclass
|
||||
class MockEngine:
|
||||
async def get_model_config(self):
|
||||
return MockModelConfig()
|
||||
model_config: MockModelConfig = field(default_factory=MockModelConfig)
|
||||
processor: MagicMock = field(default_factory=MagicMock)
|
||||
io_processor: MagicMock = field(default_factory=MagicMock)
|
||||
|
||||
|
||||
async def _async_serving_chat_init():
|
||||
engine = MockEngine()
|
||||
model_config = await engine.get_model_config()
|
||||
|
||||
models = OpenAIServingModels(engine, model_config, BASE_MODEL_PATHS)
|
||||
models = OpenAIServingModels(engine, BASE_MODEL_PATHS)
|
||||
serving_completion = OpenAIServingChat(
|
||||
engine,
|
||||
model_config,
|
||||
models,
|
||||
response_role="assistant",
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
@ -311,8 +306,11 @@ async def test_serving_chat_returns_correct_model_name():
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine, MockModelConfig())
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
messages = [{"role": "user", "content": "what is 1+1?"}]
|
||||
|
||||
async def return_model_name(*args):
|
||||
@ -338,8 +336,11 @@ async def test_serving_chat_should_set_correct_max_tokens():
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine, MockModelConfig())
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
req = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
@ -368,9 +369,12 @@ async def test_serving_chat_should_set_correct_max_tokens():
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = mock_model_config
|
||||
mock_engine.processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
|
||||
# Initialize the serving chat
|
||||
serving_chat = _build_serving_chat(mock_engine, mock_model_config)
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
# Test Case 1: No max_tokens specified in request
|
||||
req = ChatCompletionRequest(
|
||||
@ -410,9 +414,12 @@ async def test_serving_chat_should_set_correct_max_tokens():
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = mock_model_config
|
||||
mock_engine.processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
|
||||
# Initialize the serving chat
|
||||
serving_chat = _build_serving_chat(mock_engine, mock_model_config)
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
# Test case 1: No max_tokens specified, defaults to context_window
|
||||
req = ChatCompletionRequest(
|
||||
@ -453,9 +460,12 @@ async def test_serving_chat_could_load_correct_generation_config():
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = mock_model_config
|
||||
mock_engine.processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
|
||||
# Initialize the serving chat
|
||||
serving_chat = _build_serving_chat(mock_engine, mock_model_config)
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
req = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
@ -496,8 +506,11 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = mock_model_config
|
||||
mock_engine.processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine, mock_model_config)
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
# Test cache_salt
|
||||
req = ChatCompletionRequest(
|
||||
|
||||
@ -22,10 +22,12 @@ def serving() -> OpenAIServing:
|
||||
model_config = Mock(spec=ModelConfig)
|
||||
model_config.max_model_len = 32768
|
||||
models = Mock(spec=OpenAIServingModels)
|
||||
models.model_config = model_config
|
||||
models.processor = Mock()
|
||||
models.io_processor = Mock()
|
||||
|
||||
serving = OpenAIServing(
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=None,
|
||||
)
|
||||
|
||||
@ -25,15 +25,17 @@ LORA_UNLOADING_SUCCESS_MESSAGE = (
|
||||
|
||||
|
||||
async def _async_serving_models_init() -> OpenAIServingModels:
|
||||
mock_model_config = MagicMock(spec=ModelConfig)
|
||||
mock_engine_client = MagicMock(spec=EngineClient)
|
||||
# Set the max_model_len attribute to avoid missing attribute
|
||||
mock_model_config = MagicMock(spec=ModelConfig)
|
||||
mock_model_config.max_model_len = 2048
|
||||
mock_engine_client.model_config = mock_model_config
|
||||
mock_engine_client.processor = MagicMock()
|
||||
mock_engine_client.io_processor = MagicMock()
|
||||
|
||||
serving_models = OpenAIServingModels(
|
||||
engine_client=mock_engine_client,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
model_config=mock_model_config,
|
||||
lora_modules=None,
|
||||
)
|
||||
await serving_models.init_static_loras()
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from contextlib import AsyncExitStack
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
@ -70,11 +70,14 @@ class TestInitializeToolSessions:
|
||||
"""Create a real OpenAIServingResponses instance for testing"""
|
||||
# Create minimal mocks for required dependencies
|
||||
engine_client = MagicMock()
|
||||
engine_client.get_model_config = AsyncMock()
|
||||
|
||||
model_config = MagicMock()
|
||||
model_config.hf_config.model_type = "test"
|
||||
model_config.get_diff_sampling_param.return_value = {}
|
||||
engine_client.model_config = model_config
|
||||
|
||||
engine_client.processor = MagicMock()
|
||||
engine_client.io_processor = MagicMock()
|
||||
|
||||
models = MagicMock()
|
||||
|
||||
@ -83,7 +86,6 @@ class TestInitializeToolSessions:
|
||||
# Create the actual instance
|
||||
instance = OpenAIServingResponses(
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=None,
|
||||
chat_template=None,
|
||||
@ -132,18 +134,20 @@ class TestValidateGeneratorInput:
|
||||
"""Create a real OpenAIServingResponses instance for testing"""
|
||||
# Create minimal mocks for required dependencies
|
||||
engine_client = MagicMock()
|
||||
engine_client.get_model_config = AsyncMock()
|
||||
|
||||
model_config = MagicMock()
|
||||
model_config.hf_config.model_type = "test"
|
||||
model_config.get_diff_sampling_param.return_value = {}
|
||||
engine_client.model_config = model_config
|
||||
|
||||
engine_client.processor = MagicMock()
|
||||
engine_client.io_processor = MagicMock()
|
||||
|
||||
models = MagicMock()
|
||||
|
||||
# Create the actual instance
|
||||
instance = OpenAIServingResponses(
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=None,
|
||||
chat_template=None,
|
||||
|
||||
@ -7,6 +7,7 @@ from vllm.config import ModelConfig
|
||||
from vllm.inputs import zip_enc_dec_prompts
|
||||
from vllm.inputs.parse import parse_raw_prompts
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
@ -106,7 +107,8 @@ def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs):
|
||||
)
|
||||
def test_preprocessor_text_no_mm_inputs(model_id, prompt):
|
||||
model_config = ModelConfig(model=model_id)
|
||||
input_preprocessor = InputPreprocessor(model_config)
|
||||
tokenizer = init_tokenizer_from_configs(model_config)
|
||||
input_preprocessor = InputPreprocessor(model_config, tokenizer)
|
||||
|
||||
with pytest.raises(ValueError, match="does not support multimodal inputs"):
|
||||
input_preprocessor.preprocess(prompt)
|
||||
@ -127,8 +129,8 @@ def test_preprocessor_text_no_mm_inputs(model_id, prompt):
|
||||
)
|
||||
def test_preprocessor_always_mm_code_path(model_id, prompt):
|
||||
model_config = ModelConfig(model=model_id)
|
||||
input_preprocessor = InputPreprocessor(model_config)
|
||||
tokenizer = input_preprocessor.tokenizer
|
||||
tokenizer = init_tokenizer_from_configs(model_config)
|
||||
input_preprocessor = InputPreprocessor(model_config, tokenizer)
|
||||
|
||||
# HF processor adds sep token
|
||||
sep_token_id = tokenizer.vocab[tokenizer.sep_token]
|
||||
|
||||
@ -65,7 +65,7 @@ def _mk_processor(
|
||||
device_config=DeviceConfig(device="cpu"),
|
||||
)
|
||||
|
||||
return Processor(vllm_config)
|
||||
return Processor(vllm_config, tokenizer=None)
|
||||
|
||||
|
||||
def test_multi_modal_uuids_length_mismatch_raises(monkeypatch):
|
||||
|
||||
@ -459,7 +459,7 @@ def test_all_logprobs(example_prompts):
|
||||
results_logprobs_all = runner.llm.generate(
|
||||
example_prompts, sampling_params=sampling_params_logprobs_all
|
||||
)
|
||||
vocab_size = runner.llm.llm_engine.get_model_config().get_vocab_size()
|
||||
vocab_size = runner.llm.llm_engine.model_config.get_vocab_size()
|
||||
|
||||
for i in range(len(results_logprobs_all)):
|
||||
logprobs = results_logprobs_all[i].outputs[0].logprobs
|
||||
|
||||
@ -186,7 +186,7 @@ async def run_vllm_async(
|
||||
engine_args,
|
||||
disable_frontend_multiprocessing=disable_frontend_multiprocessing,
|
||||
) as llm:
|
||||
model_config = await llm.get_model_config()
|
||||
model_config = llm.model_config
|
||||
assert all(
|
||||
model_config.max_model_len
|
||||
>= (request.prompt_len + request.expected_output_len)
|
||||
|
||||
@ -1,26 +1,23 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncGenerator, Iterable, Mapping
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.inputs.data import PromptType, TokensPrompt
|
||||
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
|
||||
from vllm.plugins.io_processors.interface import IOProcessor
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.plugins.io_processors import IOProcessor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import Device, collect_from_async_generator, random_uuid
|
||||
from vllm.utils import Device
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.processor import Processor
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -28,6 +25,11 @@ logger = init_logger(__name__)
|
||||
class EngineClient(ABC):
|
||||
"""Protocol class for Clients to Engine"""
|
||||
|
||||
vllm_config: VllmConfig
|
||||
model_config: ModelConfig
|
||||
processor: Processor
|
||||
io_processor: Optional[IOProcessor]
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def is_running(self) -> bool: ...
|
||||
@ -61,180 +63,6 @@ class EngineClient(ABC):
|
||||
"""Generate outputs for a request."""
|
||||
...
|
||||
|
||||
async def beam_search(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
request_id: str,
|
||||
params: BeamSearchParams,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
beam_width = params.beam_width
|
||||
max_tokens = params.max_tokens
|
||||
ignore_eos = params.ignore_eos
|
||||
temperature = params.temperature
|
||||
length_penalty = params.length_penalty
|
||||
include_stop_str_in_output = params.include_stop_str_in_output
|
||||
|
||||
preprocessor = await self.get_input_preprocessor()
|
||||
tokenizer = preprocessor.get_tokenizer()
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
if is_explicit_encoder_decoder_prompt(prompt):
|
||||
raise NotImplementedError
|
||||
else:
|
||||
processed_inputs = preprocessor._prompt_to_llm_inputs(prompt)
|
||||
|
||||
if processed_inputs["type"] == "embeds":
|
||||
raise NotImplementedError
|
||||
|
||||
# This is a workaround to fix multimodal beam search; this is a
|
||||
# bandaid fix for 2 small problems:
|
||||
# 1. Multi_modal_data on the processed_inputs currently resolves to
|
||||
# `None`.
|
||||
# 2. preprocessing above expands the multimodal placeholders. However,
|
||||
# this happens again in generation, so the double expansion causes
|
||||
# a mismatch.
|
||||
# TODO - would be ideal to handle this more gracefully.
|
||||
if isinstance(prompt, str):
|
||||
prompt_text = prompt
|
||||
prompt_token_ids = []
|
||||
multi_modal_data = None
|
||||
else:
|
||||
prompt_text = prompt.get("prompt")
|
||||
prompt_token_ids = prompt.get("prompt_token_ids", [])
|
||||
multi_modal_data = prompt.get("multi_modal_data")
|
||||
|
||||
mm_processor_kwargs = processed_inputs.get("mm_processor_kwargs")
|
||||
|
||||
tokenized_length = len(prompt_token_ids)
|
||||
|
||||
sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
|
||||
|
||||
beam_search_params = SamplingParams(
|
||||
logprobs=2 * beam_width,
|
||||
max_tokens=1,
|
||||
temperature=temperature,
|
||||
)
|
||||
all_beams = [
|
||||
BeamSearchSequence(
|
||||
tokens=prompt_token_ids,
|
||||
cum_logprob=0,
|
||||
logprobs=[],
|
||||
multi_modal_data=multi_modal_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
]
|
||||
completed = []
|
||||
|
||||
for _ in range(max_tokens):
|
||||
prompts_batch, lora_req_batch = zip(
|
||||
*[
|
||||
(
|
||||
TokensPrompt(
|
||||
prompt_token_ids=beam.tokens,
|
||||
multi_modal_data=beam.multi_modal_data,
|
||||
mm_processor_kwargs=beam.mm_processor_kwargs,
|
||||
),
|
||||
beam.lora_request,
|
||||
)
|
||||
for beam in all_beams
|
||||
]
|
||||
)
|
||||
|
||||
tasks = []
|
||||
|
||||
request_id = f"beam_search-{random_uuid()}"
|
||||
for i, (individual_prompt, lora_req) in enumerate(
|
||||
zip(prompts_batch, lora_req_batch)
|
||||
):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
task = asyncio.create_task(
|
||||
collect_from_async_generator(
|
||||
self.generate(
|
||||
individual_prompt,
|
||||
beam_search_params,
|
||||
request_id_item,
|
||||
lora_request=lora_req,
|
||||
)
|
||||
)
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
output = await asyncio.gather(*tasks)
|
||||
|
||||
output = [x[0] for x in output]
|
||||
|
||||
new_beams = []
|
||||
for i, current_beam in enumerate(all_beams):
|
||||
result = output[i]
|
||||
|
||||
if result.outputs[0].logprobs is not None:
|
||||
logprobs = result.outputs[0].logprobs[0]
|
||||
for token_id, logprob_obj in logprobs.items():
|
||||
if token_id == eos_token_id and not ignore_eos:
|
||||
completed.append(
|
||||
BeamSearchSequence(
|
||||
tokens=current_beam.tokens + [token_id]
|
||||
if include_stop_str_in_output
|
||||
else current_beam.tokens,
|
||||
logprobs=current_beam.logprobs + [logprobs],
|
||||
cum_logprob=current_beam.cum_logprob
|
||||
+ logprob_obj.logprob,
|
||||
finish_reason="stop",
|
||||
stop_reason=eos_token_id,
|
||||
)
|
||||
)
|
||||
else:
|
||||
new_beams.append(
|
||||
BeamSearchSequence(
|
||||
tokens=current_beam.tokens + [token_id],
|
||||
logprobs=current_beam.logprobs + [logprobs],
|
||||
lora_request=current_beam.lora_request,
|
||||
cum_logprob=current_beam.cum_logprob
|
||||
+ logprob_obj.logprob,
|
||||
multi_modal_data=current_beam.multi_modal_data,
|
||||
mm_processor_kwargs=current_beam.mm_processor_kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
|
||||
all_beams = sorted_beams[:beam_width]
|
||||
|
||||
completed.extend(all_beams)
|
||||
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
|
||||
best_beams = sorted_completed[:beam_width]
|
||||
|
||||
for beam in best_beams:
|
||||
if beam.tokens[-1] == eos_token_id and not ignore_eos:
|
||||
# Skip the eos token in the text.
|
||||
tokens = beam.tokens[tokenized_length:-1]
|
||||
else:
|
||||
tokens = beam.tokens[tokenized_length:]
|
||||
beam.text = tokenizer.decode(tokens)
|
||||
|
||||
yield RequestOutput(
|
||||
request_id=request_id,
|
||||
prompt=prompt_text,
|
||||
outputs=[
|
||||
CompletionOutput(
|
||||
text=beam.text,
|
||||
cumulative_logprob=beam.cum_logprob,
|
||||
token_ids=beam.tokens[tokenized_length:],
|
||||
index=i,
|
||||
logprobs=beam.logprobs,
|
||||
finish_reason=beam.finish_reason
|
||||
if beam.finish_reason is not None
|
||||
else "length",
|
||||
stop_reason=beam.stop_reason,
|
||||
)
|
||||
for (i, beam) in enumerate(best_beams)
|
||||
],
|
||||
finished=True,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt_logprobs=None,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def encode(
|
||||
self,
|
||||
@ -259,29 +87,11 @@ class EngineClient(ABC):
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_vllm_config(self) -> VllmConfig:
|
||||
"""Get the vllm configuration of the vLLM engine."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_model_config(self) -> ModelConfig:
|
||||
"""Get the model configuration of the vLLM engine."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_input_preprocessor(self) -> InputPreprocessor:
|
||||
"""Get the input processor of the vLLM engine."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_tokenizer(self) -> AnyTokenizer:
|
||||
"""Get the tokenizer"""
|
||||
...
|
||||
|
||||
async def get_io_processor(self) -> IOProcessor:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def is_tracing_enabled(self) -> bool: ...
|
||||
|
||||
|
||||
@ -66,7 +66,6 @@ from vllm.outputs import (
|
||||
RequestOutput,
|
||||
ScoringRequestOutput,
|
||||
)
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
|
||||
from vllm.tasks import PoolingTask
|
||||
@ -79,7 +78,6 @@ from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Counter, Device, as_iter, is_list_of
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.llm_engine import LLMEngine
|
||||
from vllm.v1.engine.processor import Processor
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -335,21 +333,13 @@ class LLM:
|
||||
self.request_counter = Counter()
|
||||
self.default_sampling_params: Union[dict[str, Any], None] = None
|
||||
|
||||
supported_tasks = self.llm_engine.get_supported_tasks() # type: ignore
|
||||
|
||||
logger.info("Supported_tasks: %s", supported_tasks)
|
||||
|
||||
supported_tasks = self.llm_engine.get_supported_tasks()
|
||||
logger.info("Supported tasks: %s", supported_tasks)
|
||||
self.supported_tasks = supported_tasks
|
||||
|
||||
# Load the Input/Output processor plugin if any
|
||||
io_processor_plugin = self.llm_engine.model_config.io_processor_plugin
|
||||
self.io_processor = get_io_processor(
|
||||
self.llm_engine.vllm_config, io_processor_plugin
|
||||
)
|
||||
|
||||
@property
|
||||
def model_config(self):
|
||||
return self.llm_engine.model_config
|
||||
self.model_config = self.llm_engine.model_config
|
||||
self.processor = self.llm_engine.processor
|
||||
self.io_processor = self.llm_engine.io_processor
|
||||
|
||||
def get_tokenizer(self) -> AnyTokenizer:
|
||||
return self.llm_engine.get_tokenizer()
|
||||
@ -364,18 +354,9 @@ class LLM:
|
||||
else:
|
||||
self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer)
|
||||
|
||||
def _get_processor(self) -> Processor:
|
||||
if not hasattr(self, "_processor"):
|
||||
vllm_config = self.llm_engine.vllm_config
|
||||
self._processor = Processor(vllm_config)
|
||||
|
||||
return self._processor
|
||||
|
||||
def get_default_sampling_params(self) -> SamplingParams:
|
||||
if self.default_sampling_params is None:
|
||||
self.default_sampling_params = (
|
||||
self.llm_engine.model_config.get_diff_sampling_param()
|
||||
)
|
||||
self.default_sampling_params = self.model_config.get_diff_sampling_param()
|
||||
if self.default_sampling_params:
|
||||
return SamplingParams.from_optional(**self.default_sampling_params)
|
||||
return SamplingParams()
|
||||
@ -423,7 +404,7 @@ class LLM:
|
||||
considered legacy and may be deprecated in the future. You should
|
||||
instead pass them via the `inputs` parameter.
|
||||
"""
|
||||
model_config = self.llm_engine.model_config
|
||||
model_config = self.model_config
|
||||
runner_type = model_config.runner_type
|
||||
if runner_type != "generate":
|
||||
raise ValueError(
|
||||
@ -463,7 +444,7 @@ class LLM:
|
||||
# isn't multimodal, leave the lora as is.
|
||||
if (
|
||||
lora_config is None
|
||||
or not self.llm_engine.model_config.is_multimodal_model
|
||||
or not self.model_config.is_multimodal_model
|
||||
or (lora_config and lora_config.default_mm_loras is None)
|
||||
):
|
||||
return lora_request
|
||||
@ -495,15 +476,13 @@ class LLM:
|
||||
if (
|
||||
not default_mm_loras
|
||||
or not isinstance(prompt, dict)
|
||||
or "multi_modal_data" not in prompt
|
||||
or not (mm_data := prompt.get("multi_modal_data") or {})
|
||||
):
|
||||
return lora_request
|
||||
|
||||
prompt = cast(Union[TextPrompt, TokensPrompt], prompt)
|
||||
|
||||
intersection = set(prompt["multi_modal_data"].keys()).intersection(
|
||||
default_mm_loras.keys()
|
||||
)
|
||||
intersection = set(
|
||||
mm_data.keys() # type: ignore
|
||||
).intersection(default_mm_loras.keys())
|
||||
if not intersection:
|
||||
return lora_request
|
||||
if len(intersection) > 1:
|
||||
@ -819,7 +798,7 @@ class LLM:
|
||||
list_of_messages = [cast(list[ChatCompletionMessageParam], messages)]
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
model_config = self.llm_engine.get_model_config()
|
||||
model_config = self.model_config
|
||||
resolved_content_format = resolve_chat_template_content_format(
|
||||
chat_template,
|
||||
tools,
|
||||
@ -1031,7 +1010,7 @@ class LLM:
|
||||
pooling_task,
|
||||
)
|
||||
|
||||
model_config = self.llm_engine.model_config
|
||||
model_config = self.model_config
|
||||
runner_type = model_config.runner_type
|
||||
if runner_type != "pooling":
|
||||
raise ValueError(
|
||||
@ -1276,7 +1255,7 @@ class LLM:
|
||||
pooling_params: Optional[PoolingParams] = None,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
) -> list[ScoringRequestOutput]:
|
||||
model_config = self.llm_engine.model_config
|
||||
model_config = self.model_config
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
raise ValueError("Score API is not supported for Mistral tokenizer")
|
||||
@ -1287,7 +1266,6 @@ class LLM:
|
||||
if pooling_params is None:
|
||||
pooling_params = PoolingParams(task="score")
|
||||
|
||||
model_config = self.llm_engine.model_config
|
||||
pooling_params.verify("score", model_config)
|
||||
pooling_params_list = list[PoolingParams]()
|
||||
|
||||
@ -1301,8 +1279,6 @@ class LLM:
|
||||
|
||||
input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]
|
||||
|
||||
model_config = self.llm_engine.model_config
|
||||
|
||||
for q, d in input_pairs:
|
||||
_, engine_prompt = get_score_prompt(
|
||||
model_config=model_config,
|
||||
@ -1380,7 +1356,7 @@ class LLM:
|
||||
A list of `ScoringRequestOutput` objects containing the
|
||||
generated scores in the same order as the input prompts.
|
||||
"""
|
||||
model_config = self.llm_engine.model_config
|
||||
model_config = self.model_config
|
||||
runner_type = model_config.runner_type
|
||||
if runner_type != "pooling":
|
||||
raise ValueError(
|
||||
@ -1658,8 +1634,7 @@ class LLM:
|
||||
tokenization_kwargs,
|
||||
)
|
||||
|
||||
processor = self._get_processor()
|
||||
engine_request = processor.process_inputs(
|
||||
engine_request = self.processor.process_inputs(
|
||||
request_id,
|
||||
engine_prompt,
|
||||
params,
|
||||
|
||||
@ -1601,10 +1601,11 @@ def build_app(args: Namespace) -> FastAPI:
|
||||
|
||||
async def init_app_state(
|
||||
engine_client: EngineClient,
|
||||
vllm_config: VllmConfig,
|
||||
state: State,
|
||||
args: Namespace,
|
||||
) -> None:
|
||||
vllm_config = engine_client.vllm_config
|
||||
|
||||
if args.served_model_name is not None:
|
||||
served_model_names = args.served_model_name
|
||||
else:
|
||||
@ -1622,11 +1623,9 @@ async def init_app_state(
|
||||
state.engine_client = engine_client
|
||||
state.log_stats = not args.disable_log_stats
|
||||
state.vllm_config = vllm_config
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
supported_tasks = await engine_client.get_supported_tasks()
|
||||
|
||||
logger.info("Supported_tasks: %s", supported_tasks)
|
||||
logger.info("Supported tasks: %s", supported_tasks)
|
||||
|
||||
resolved_chat_template = load_chat_template(args.chat_template)
|
||||
if resolved_chat_template is not None:
|
||||
@ -1688,7 +1687,6 @@ async def init_app_state(
|
||||
|
||||
state.openai_serving_models = OpenAIServingModels(
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=lora_modules,
|
||||
)
|
||||
@ -1696,7 +1694,6 @@ async def init_app_state(
|
||||
state.openai_serving_responses = (
|
||||
OpenAIServingResponses(
|
||||
engine_client,
|
||||
model_config,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
@ -1717,7 +1714,6 @@ async def init_app_state(
|
||||
state.openai_serving_chat = (
|
||||
OpenAIServingChat(
|
||||
engine_client,
|
||||
model_config,
|
||||
state.openai_serving_models,
|
||||
args.response_role,
|
||||
request_logger=request_logger,
|
||||
@ -1740,7 +1736,6 @@ async def init_app_state(
|
||||
state.openai_serving_completion = (
|
||||
OpenAIServingCompletion(
|
||||
engine_client,
|
||||
model_config,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
@ -1754,7 +1749,6 @@ async def init_app_state(
|
||||
state.openai_serving_pooling = (
|
||||
OpenAIServingPooling(
|
||||
engine_client,
|
||||
vllm_config,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
@ -1768,7 +1762,6 @@ async def init_app_state(
|
||||
state.openai_serving_embedding = (
|
||||
OpenAIServingEmbedding(
|
||||
engine_client,
|
||||
model_config,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
@ -1782,7 +1775,6 @@ async def init_app_state(
|
||||
state.openai_serving_classification = (
|
||||
ServingClassification(
|
||||
engine_client,
|
||||
model_config,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=args.log_error_stack,
|
||||
@ -1793,7 +1785,6 @@ async def init_app_state(
|
||||
state.openai_serving_scores = (
|
||||
ServingScores(
|
||||
engine_client,
|
||||
model_config,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=args.log_error_stack,
|
||||
@ -1803,7 +1794,6 @@ async def init_app_state(
|
||||
)
|
||||
state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||
engine_client,
|
||||
model_config,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
@ -1814,7 +1804,6 @@ async def init_app_state(
|
||||
state.openai_serving_transcription = (
|
||||
OpenAIServingTranscription(
|
||||
engine_client,
|
||||
model_config,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=args.log_error_stack,
|
||||
@ -1825,7 +1814,6 @@ async def init_app_state(
|
||||
state.openai_serving_translation = (
|
||||
OpenAIServingTranslation(
|
||||
engine_client,
|
||||
model_config,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=args.log_error_stack,
|
||||
@ -1946,12 +1934,11 @@ async def run_server_worker(
|
||||
maybe_register_tokenizer_info_endpoint(args)
|
||||
app = build_app(args)
|
||||
|
||||
vllm_config = await engine_client.get_vllm_config()
|
||||
await init_app_state(engine_client, vllm_config, app.state, args)
|
||||
await init_app_state(engine_client, app.state, args)
|
||||
|
||||
logger.info(
|
||||
"Starting vLLM API server %d on %s",
|
||||
vllm_config.parallel_config._api_process_rank,
|
||||
engine_client.vllm_config.parallel_config._api_process_rank,
|
||||
listen_address,
|
||||
)
|
||||
shutdown_task = await serve_http(
|
||||
|
||||
@ -14,7 +14,6 @@ import torch
|
||||
from prometheus_client import start_http_server
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
@ -328,7 +327,6 @@ async def run_request(
|
||||
|
||||
async def run_batch(
|
||||
engine_client: EngineClient,
|
||||
vllm_config: VllmConfig,
|
||||
args: Namespace,
|
||||
) -> None:
|
||||
if args.served_model_name is not None:
|
||||
@ -345,22 +343,19 @@ async def run_batch(
|
||||
BaseModelPath(name=name, model_path=args.model) for name in served_model_names
|
||||
]
|
||||
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
model_config = engine_client.model_config
|
||||
supported_tasks = await engine_client.get_supported_tasks()
|
||||
logger.info("Supported_tasks: %s", supported_tasks)
|
||||
logger.info("Supported tasks: %s", supported_tasks)
|
||||
|
||||
# Create the openai serving objects.
|
||||
openai_serving_models = OpenAIServingModels(
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=None,
|
||||
)
|
||||
openai_serving_chat = (
|
||||
OpenAIServingChat(
|
||||
engine_client,
|
||||
model_config,
|
||||
openai_serving_models,
|
||||
args.response_role,
|
||||
request_logger=request_logger,
|
||||
@ -374,7 +369,6 @@ async def run_batch(
|
||||
openai_serving_embedding = (
|
||||
OpenAIServingEmbedding(
|
||||
engine_client,
|
||||
model_config,
|
||||
openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
chat_template=None,
|
||||
@ -392,7 +386,6 @@ async def run_batch(
|
||||
openai_serving_scores = (
|
||||
ServingScores(
|
||||
engine_client,
|
||||
model_config,
|
||||
openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
)
|
||||
@ -509,9 +502,7 @@ async def main(args: Namespace):
|
||||
usage_context=UsageContext.OPENAI_BATCH_RUNNER,
|
||||
disable_frontend_multiprocessing=False,
|
||||
) as engine_client:
|
||||
vllm_config = await engine_client.get_vllm_config()
|
||||
|
||||
await run_batch(engine_client, vllm_config, args)
|
||||
await run_batch(engine_client, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -15,7 +15,6 @@ from fastapi import Request
|
||||
from openai_harmony import Message as OpenAIMessage
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatTemplateContentFormatOption,
|
||||
@ -81,7 +80,6 @@ class OpenAIServingChat(OpenAIServing):
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
response_role: str,
|
||||
*,
|
||||
@ -101,7 +99,6 @@ class OpenAIServingChat(OpenAIServing):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
@ -138,7 +135,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
|
||||
if self.enable_auto_tools:
|
||||
try:
|
||||
if tool_parser == "pythonic" and model_config.model.startswith(
|
||||
if tool_parser == "pythonic" and self.model_config.model.startswith(
|
||||
"meta-llama/Llama-3.2"
|
||||
):
|
||||
logger.warning(
|
||||
@ -169,7 +166,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
else:
|
||||
self.tool_call_id_type = "random"
|
||||
|
||||
self.use_harmony = model_config.hf_config.model_type == "gpt_oss"
|
||||
self.use_harmony = self.model_config.hf_config.model_type == "gpt_oss"
|
||||
if self.use_harmony:
|
||||
if "stop_token_ids" not in self.default_sampling_params:
|
||||
self.default_sampling_params["stop_token_ids"] = []
|
||||
@ -338,7 +335,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
)
|
||||
|
||||
if isinstance(sampling_params, BeamSearchParams):
|
||||
generator = self.engine_client.beam_search(
|
||||
generator = self.beam_search(
|
||||
prompt=engine_prompt,
|
||||
request_id=request_id,
|
||||
params=sampling_params,
|
||||
|
||||
@ -8,7 +8,6 @@ import numpy as np
|
||||
from fastapi import Request
|
||||
from typing_extensions import override
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
@ -128,7 +127,6 @@ class ServingClassification(ClassificationMixin):
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
@ -136,7 +134,6 @@ class ServingClassification(ClassificationMixin):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack,
|
||||
|
||||
@ -10,7 +10,6 @@ from typing import Optional, Union, cast
|
||||
import jinja2
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
@ -44,7 +43,6 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
@ -55,7 +53,6 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
):
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
@ -201,7 +198,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
# but pre-commit in CI fails without it.
|
||||
engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt], engine_prompt)
|
||||
if isinstance(sampling_params, BeamSearchParams):
|
||||
generator = self.engine_client.beam_search(
|
||||
generator = self.beam_search(
|
||||
prompt=engine_prompt,
|
||||
request_id=request_id,
|
||||
params=sampling_params,
|
||||
|
||||
@ -10,7 +10,6 @@ import torch
|
||||
from fastapi import Request
|
||||
from typing_extensions import assert_never, override
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
@ -597,7 +596,6 @@ class OpenAIServingEmbedding(EmbeddingMixin):
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
@ -608,7 +606,6 @@ class OpenAIServingEmbedding(EmbeddingMixin):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack,
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
@ -15,17 +16,13 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
from starlette.datastructures import Headers
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.processor import Processor
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import TypedDict
|
||||
else:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
@ -68,9 +65,14 @@ from vllm.entrypoints.openai.protocol import (
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
||||
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||
from vllm.inputs.parse import PromptComponents, get_prompt_components
|
||||
from vllm.inputs.parse import (
|
||||
PromptComponents,
|
||||
get_prompt_components,
|
||||
is_explicit_encoder_decoder_prompt,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob, PromptLogprobs
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -78,7 +80,7 @@ from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error
|
||||
MultiModalDataDict,
|
||||
MultiModalUUIDDict,
|
||||
)
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.tracing import (
|
||||
@ -89,11 +91,13 @@ from vllm.tracing import (
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import (
|
||||
AsyncMicrobatchTokenizer,
|
||||
collect_from_async_generator,
|
||||
is_list_of,
|
||||
make_async,
|
||||
merge_async_iterators,
|
||||
random_uuid,
|
||||
)
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -240,7 +244,6 @@ class OpenAIServing:
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
@ -251,8 +254,6 @@ class OpenAIServing:
|
||||
super().__init__()
|
||||
|
||||
self.engine_client = engine_client
|
||||
self.model_config = model_config
|
||||
self.max_model_len = model_config.max_model_len
|
||||
|
||||
self.models = models
|
||||
|
||||
@ -268,12 +269,194 @@ class OpenAIServing:
|
||||
self._async_tokenizer_pool: dict[AnyTokenizer, AsyncMicrobatchTokenizer] = {}
|
||||
self.log_error_stack = log_error_stack
|
||||
|
||||
async def _get_processor(self) -> Processor:
|
||||
if not hasattr(self, "_processor"):
|
||||
vllm_config = await self.engine_client.get_vllm_config()
|
||||
self._processor = Processor(vllm_config)
|
||||
self.processor = self.models.processor
|
||||
self.io_processor = self.models.io_processor
|
||||
self.model_config = self.models.model_config
|
||||
self.max_model_len = self.model_config.max_model_len
|
||||
|
||||
return self._processor
|
||||
async def beam_search(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
request_id: str,
|
||||
params: BeamSearchParams,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
beam_width = params.beam_width
|
||||
max_tokens = params.max_tokens
|
||||
ignore_eos = params.ignore_eos
|
||||
temperature = params.temperature
|
||||
length_penalty = params.length_penalty
|
||||
include_stop_str_in_output = params.include_stop_str_in_output
|
||||
|
||||
processor = self.processor
|
||||
tokenizer = processor.tokenizer
|
||||
if tokenizer is None:
|
||||
raise ValueError(
|
||||
"You cannot use beam search when `skip_tokenizer_init` is True"
|
||||
)
|
||||
|
||||
eos_token_id: int = tokenizer.eos_token_id # type: ignore
|
||||
|
||||
if is_explicit_encoder_decoder_prompt(prompt):
|
||||
raise NotImplementedError
|
||||
else:
|
||||
processed_inputs = processor.input_preprocessor._prompt_to_llm_inputs(
|
||||
prompt
|
||||
)
|
||||
|
||||
if processed_inputs["type"] == "embeds":
|
||||
raise NotImplementedError
|
||||
|
||||
# This is a workaround to fix multimodal beam search; this is a
|
||||
# bandaid fix for 2 small problems:
|
||||
# 1. Multi_modal_data on the processed_inputs currently resolves to
|
||||
# `None`.
|
||||
# 2. preprocessing above expands the multimodal placeholders. However,
|
||||
# this happens again in generation, so the double expansion causes
|
||||
# a mismatch.
|
||||
# TODO - would be ideal to handle this more gracefully.
|
||||
prompt_text: Optional[str]
|
||||
prompt_token_ids: list[int]
|
||||
multi_modal_data: Optional[MultiModalDataDict]
|
||||
if isinstance(prompt, str):
|
||||
prompt_text = prompt
|
||||
prompt_token_ids = []
|
||||
multi_modal_data = None
|
||||
else:
|
||||
prompt_text = prompt.get("prompt") # type: ignore
|
||||
prompt_token_ids = prompt.get("prompt_token_ids", []) # type: ignore
|
||||
multi_modal_data = prompt.get("multi_modal_data") # type: ignore
|
||||
|
||||
mm_processor_kwargs: Optional[dict[str, Any]] = processed_inputs.get(
|
||||
"mm_processor_kwargs"
|
||||
) # type: ignore
|
||||
|
||||
tokenized_length = len(prompt_token_ids)
|
||||
|
||||
sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
|
||||
|
||||
beam_search_params = SamplingParams(
|
||||
logprobs=2 * beam_width,
|
||||
max_tokens=1,
|
||||
temperature=temperature,
|
||||
)
|
||||
all_beams = [
|
||||
BeamSearchSequence(
|
||||
tokens=prompt_token_ids,
|
||||
cum_logprob=0,
|
||||
logprobs=[],
|
||||
multi_modal_data=multi_modal_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
]
|
||||
completed = []
|
||||
|
||||
for _ in range(max_tokens):
|
||||
prompts_batch, lora_req_batch = zip(
|
||||
*[
|
||||
(
|
||||
EngineTokensPrompt(
|
||||
prompt_token_ids=beam.tokens,
|
||||
multi_modal_data=beam.multi_modal_data,
|
||||
mm_processor_kwargs=beam.mm_processor_kwargs,
|
||||
),
|
||||
beam.lora_request,
|
||||
)
|
||||
for beam in all_beams
|
||||
]
|
||||
)
|
||||
|
||||
tasks = []
|
||||
request_id_batch = f"{request_id}-{random_uuid()}"
|
||||
|
||||
for i, (individual_prompt, lora_req) in enumerate(
|
||||
zip(prompts_batch, lora_req_batch)
|
||||
):
|
||||
request_id_item = f"{request_id_batch}-beam-{i}"
|
||||
task = asyncio.create_task(
|
||||
collect_from_async_generator(
|
||||
self.engine_client.generate(
|
||||
individual_prompt,
|
||||
beam_search_params,
|
||||
request_id_item,
|
||||
lora_request=lora_req,
|
||||
)
|
||||
)
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
output = [x[0] for x in await asyncio.gather(*tasks)]
|
||||
|
||||
new_beams = []
|
||||
for i, current_beam in enumerate(all_beams):
|
||||
result = output[i]
|
||||
|
||||
if result.outputs[0].logprobs is not None:
|
||||
logprobs = result.outputs[0].logprobs[0]
|
||||
for token_id, logprob_obj in logprobs.items():
|
||||
if token_id == eos_token_id and not ignore_eos:
|
||||
completed.append(
|
||||
BeamSearchSequence(
|
||||
tokens=current_beam.tokens + [token_id]
|
||||
if include_stop_str_in_output
|
||||
else current_beam.tokens,
|
||||
logprobs=current_beam.logprobs + [logprobs],
|
||||
cum_logprob=current_beam.cum_logprob
|
||||
+ logprob_obj.logprob,
|
||||
finish_reason="stop",
|
||||
stop_reason=eos_token_id,
|
||||
)
|
||||
)
|
||||
else:
|
||||
new_beams.append(
|
||||
BeamSearchSequence(
|
||||
tokens=current_beam.tokens + [token_id],
|
||||
logprobs=current_beam.logprobs + [logprobs],
|
||||
lora_request=current_beam.lora_request,
|
||||
cum_logprob=current_beam.cum_logprob
|
||||
+ logprob_obj.logprob,
|
||||
multi_modal_data=current_beam.multi_modal_data,
|
||||
mm_processor_kwargs=current_beam.mm_processor_kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
|
||||
all_beams = sorted_beams[:beam_width]
|
||||
|
||||
completed.extend(all_beams)
|
||||
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
|
||||
best_beams = sorted_completed[:beam_width]
|
||||
|
||||
for beam in best_beams:
|
||||
if beam.tokens[-1] == eos_token_id and not ignore_eos:
|
||||
# Skip the eos token in the text.
|
||||
tokens = beam.tokens[tokenized_length:-1]
|
||||
else:
|
||||
tokens = beam.tokens[tokenized_length:]
|
||||
beam.text = tokenizer.decode(tokens)
|
||||
|
||||
yield RequestOutput(
|
||||
request_id=request_id,
|
||||
prompt=prompt_text,
|
||||
outputs=[
|
||||
CompletionOutput(
|
||||
text=beam.text, # type: ignore
|
||||
cumulative_logprob=beam.cum_logprob,
|
||||
token_ids=beam.tokens[tokenized_length:],
|
||||
index=i,
|
||||
logprobs=beam.logprobs,
|
||||
finish_reason=beam.finish_reason
|
||||
if beam.finish_reason is not None
|
||||
else "length",
|
||||
stop_reason=beam.stop_reason,
|
||||
)
|
||||
for (i, beam) in enumerate(best_beams)
|
||||
],
|
||||
finished=True,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt_logprobs=None,
|
||||
)
|
||||
|
||||
def _get_renderer(self, tokenizer: Optional[AnyTokenizer]) -> BaseRenderer:
|
||||
"""
|
||||
@ -938,8 +1121,7 @@ class OpenAIServing:
|
||||
self.max_model_len, params.truncate_prompt_tokens, tokenization_kwargs
|
||||
)
|
||||
|
||||
processor = await self._get_processor()
|
||||
engine_request = processor.process_inputs(
|
||||
engine_request = self.processor.process_inputs(
|
||||
request_id,
|
||||
engine_prompt,
|
||||
params,
|
||||
|
||||
@ -7,7 +7,6 @@ from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from typing import Optional, Union
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ErrorInfo,
|
||||
@ -51,18 +50,14 @@ class OpenAIServingModels:
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
base_model_paths: list[BaseModelPath],
|
||||
*,
|
||||
lora_modules: Optional[list[LoRAModulePath]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.base_model_paths = base_model_paths
|
||||
|
||||
self.max_model_len = model_config.max_model_len
|
||||
self.engine_client = engine_client
|
||||
self.model_config = model_config
|
||||
self.base_model_paths = base_model_paths
|
||||
|
||||
self.static_lora_modules = lora_modules
|
||||
self.lora_requests: dict[str, LoRARequest] = {}
|
||||
@ -75,6 +70,11 @@ class OpenAIServingModels:
|
||||
)
|
||||
self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock)
|
||||
|
||||
self.processor = self.engine_client.processor
|
||||
self.io_processor = self.engine_client.io_processor
|
||||
self.model_config = self.engine_client.model_config
|
||||
self.max_model_len = self.model_config.max_model_len
|
||||
|
||||
async def init_static_loras(self):
|
||||
"""Loads all static LoRA modules.
|
||||
Raises if any fail to load"""
|
||||
|
||||
@ -13,7 +13,6 @@ import torch
|
||||
from fastapi import Request
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
@ -34,7 +33,6 @@ from vllm.entrypoints.renderer import RenderConfig
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingOutput, PoolingRequestOutput
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.utils import merge_async_iterators
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -60,7 +58,6 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
vllm_config: VllmConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
@ -71,7 +68,6 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
model_config=vllm_config.model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack,
|
||||
@ -80,8 +76,6 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
self.trust_request_chat_template = trust_request_chat_template
|
||||
io_processor_plugin = self.model_config.io_processor_plugin
|
||||
self.io_processor = get_io_processor(vllm_config, io_processor_plugin)
|
||||
|
||||
async def create_pooling(
|
||||
self,
|
||||
|
||||
@ -49,7 +49,6 @@ from openai.types.responses.response_reasoning_item import (
|
||||
from openai_harmony import Message as OpenAIHarmonyMessage
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
@ -109,7 +108,6 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
@ -127,7 +125,6 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
@ -176,7 +173,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
"the store."
|
||||
)
|
||||
|
||||
self.use_harmony = model_config.hf_config.model_type == "gpt_oss"
|
||||
self.use_harmony = self.model_config.hf_config.model_type == "gpt_oss"
|
||||
if self.use_harmony:
|
||||
logger.warning(
|
||||
"For gpt-oss, we ignore --enable-auto-tool-choice "
|
||||
|
||||
@ -7,7 +7,6 @@ from typing import Any, Optional, Union
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
@ -47,7 +46,6 @@ class ServingScores(OpenAIServing):
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
@ -55,7 +53,6 @@ class ServingScores(OpenAIServing):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack,
|
||||
|
||||
@ -6,7 +6,6 @@ from typing import Any, Final, Optional, Union
|
||||
import jinja2
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
@ -32,7 +31,6 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
@ -43,7 +41,6 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack,
|
||||
|
||||
@ -5,7 +5,6 @@ from typing import Optional, Union
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
@ -34,7 +33,6 @@ class OpenAIServingTranscription(OpenAISpeechToText):
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
@ -43,7 +41,6 @@ class OpenAIServingTranscription(OpenAISpeechToText):
|
||||
):
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
@ -95,7 +92,6 @@ class OpenAIServingTranslation(OpenAISpeechToText):
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
@ -104,7 +100,6 @@ class OpenAIServingTranslation(OpenAISpeechToText):
|
||||
):
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
|
||||
@ -12,7 +12,6 @@ import numpy as np
|
||||
from fastapi import Request
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
@ -53,7 +52,6 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
@ -63,7 +61,6 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
):
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
@ -74,7 +71,7 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
self.task_type = task_type
|
||||
|
||||
self.asr_config = self.model_cls.get_speech_to_text_config(
|
||||
model_config, task_type
|
||||
self.model_config, task_type
|
||||
)
|
||||
|
||||
self.max_audio_filesize_mb = envs.VLLM_MAX_AUDIO_CLIP_FILESIZE_MB
|
||||
|
||||
@ -20,13 +20,13 @@ class TextPrompt(TypedDict):
|
||||
prompt: str
|
||||
"""The input text to be tokenized before passing to the model."""
|
||||
|
||||
multi_modal_data: NotRequired["MultiModalDataDict"]
|
||||
multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
|
||||
"""
|
||||
Optional multi-modal data to pass to the model,
|
||||
if the model supports it.
|
||||
"""
|
||||
|
||||
mm_processor_kwargs: NotRequired[dict[str, Any]]
|
||||
mm_processor_kwargs: NotRequired[Optional[dict[str, Any]]]
|
||||
"""
|
||||
Optional multi-modal processor kwargs to be forwarded to the
|
||||
multimodal input mapper & processor. Note that if multiple modalities
|
||||
@ -61,13 +61,13 @@ class TokensPrompt(TypedDict):
|
||||
token_type_ids: NotRequired[list[int]]
|
||||
"""A list of token type IDs to pass to the cross encoder model."""
|
||||
|
||||
multi_modal_data: NotRequired["MultiModalDataDict"]
|
||||
multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
|
||||
"""
|
||||
Optional multi-modal data to pass to the model,
|
||||
if the model supports it.
|
||||
"""
|
||||
|
||||
mm_processor_kwargs: NotRequired[dict[str, Any]]
|
||||
mm_processor_kwargs: NotRequired[Optional[dict[str, Any]]]
|
||||
"""
|
||||
Optional multi-modal processor kwargs to be forwarded to the
|
||||
multimodal input mapper & processor. Note that if multiple modalities
|
||||
|
||||
@ -17,7 +17,7 @@ from vllm.multimodal.inputs import (
|
||||
MultiModalUUIDDict,
|
||||
)
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils.jsontree import json_iter_leaves
|
||||
|
||||
from .data import (
|
||||
@ -45,20 +45,17 @@ class InputPreprocessor:
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
tokenizer: Optional[AnyTokenizer],
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
mm_processor_cache: Optional[BaseMultiModalProcessorCache] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.model_config = model_config
|
||||
self.tokenizer = tokenizer
|
||||
self.mm_registry = mm_registry
|
||||
self.mm_processor_cache = mm_processor_cache
|
||||
|
||||
if model_config.skip_tokenizer_init:
|
||||
self.tokenizer = None
|
||||
else:
|
||||
self.tokenizer = init_tokenizer_from_configs(model_config)
|
||||
|
||||
def get_tokenizer(self) -> AnyTokenizer:
|
||||
if self.tokenizer is None:
|
||||
raise ValueError(
|
||||
@ -351,8 +348,8 @@ class InputPreprocessor:
|
||||
if self.model_config.is_multimodal_model:
|
||||
inputs = self._process_multimodal(
|
||||
prompt_token_ids,
|
||||
parsed_content.get("multi_modal_data", {}),
|
||||
parsed_content.get("mm_processor_kwargs"),
|
||||
parsed_content.get("multi_modal_data") or {},
|
||||
parsed_content.get("mm_processor_kwargs") or {},
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
@ -380,8 +377,8 @@ class InputPreprocessor:
|
||||
if self.model_config.is_multimodal_model:
|
||||
inputs = self._process_multimodal(
|
||||
prompt_text,
|
||||
parsed_content.get("multi_modal_data", {}),
|
||||
parsed_content.get("mm_processor_kwargs"),
|
||||
parsed_content.get("multi_modal_data") or {},
|
||||
parsed_content.get("mm_processor_kwargs") or {},
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
|
||||
@ -12,23 +12,23 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.envs import VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.tracing import init_tracer
|
||||
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Device, as_list, cancel_task_threadsafe, cdiv, deprecate_kwargs
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
@ -104,8 +104,16 @@ class AsyncLLM(EngineClient):
|
||||
"logger list; enabling logging without default stat loggers"
|
||||
)
|
||||
|
||||
# Processor (converts Inputs --> EngineCoreRequests).
|
||||
self.processor = Processor(vllm_config, mm_registry=mm_registry)
|
||||
if self.model_config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = init_tokenizer_from_configs(self.model_config)
|
||||
|
||||
self.processor = Processor(self.vllm_config, tokenizer)
|
||||
self.io_processor = get_io_processor(
|
||||
self.vllm_config,
|
||||
self.model_config.io_processor_plugin,
|
||||
)
|
||||
|
||||
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
|
||||
self.output_processor = OutputProcessor(
|
||||
@ -245,10 +253,6 @@ class AsyncLLM(EngineClient):
|
||||
|
||||
cancel_task_threadsafe(getattr(self, "output_handler", None))
|
||||
|
||||
@property
|
||||
def tokenizer(self) -> Optional[AnyTokenizer]:
|
||||
return self.processor.tokenizer
|
||||
|
||||
async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
return await self.engine_core.get_supported_tasks_async()
|
||||
|
||||
@ -615,14 +619,13 @@ class AsyncLLM(EngineClient):
|
||||
logger.info("Request %s failed.", request_id)
|
||||
raise EngineGenerateError() from e
|
||||
|
||||
async def get_vllm_config(self) -> VllmConfig:
|
||||
return self.vllm_config
|
||||
@property
|
||||
def tokenizer(self) -> Optional[AnyTokenizer]:
|
||||
return self.processor.tokenizer
|
||||
|
||||
async def get_model_config(self) -> ModelConfig:
|
||||
return self.model_config
|
||||
|
||||
async def get_input_preprocessor(self) -> InputPreprocessor:
|
||||
return self.processor.input_preprocessor
|
||||
@tokenizer.setter
|
||||
def tokenizer(self, tokenizer: Optional[AnyTokenizer]) -> None:
|
||||
self.processor.tokenizer = tokenizer
|
||||
|
||||
async def get_tokenizer(self) -> AnyTokenizer:
|
||||
if self.tokenizer is None:
|
||||
|
||||
@ -19,11 +19,12 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.tracing import init_tracer
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Device
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
@ -95,8 +96,16 @@ class LLMEngine:
|
||||
self.dp_group = None
|
||||
self.should_execute_dummy_batch = False
|
||||
|
||||
# Processor (convert Inputs --> EngineCoreRequests)
|
||||
self.processor = Processor(vllm_config, mm_registry=mm_registry)
|
||||
if self.model_config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = init_tokenizer_from_configs(self.model_config)
|
||||
|
||||
self.processor = Processor(self.vllm_config, tokenizer)
|
||||
self.io_processor = get_io_processor(
|
||||
self.vllm_config,
|
||||
self.model_config.io_processor_plugin,
|
||||
)
|
||||
|
||||
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
|
||||
self.output_processor = OutputProcessor(
|
||||
@ -204,14 +213,6 @@ class LLMEngine:
|
||||
def validate_outputs(cls, outputs, output_type):
|
||||
return outputs
|
||||
|
||||
@property
|
||||
def tokenizer(self) -> Optional[AnyTokenizer]:
|
||||
return self.processor.tokenizer
|
||||
|
||||
@tokenizer.setter
|
||||
def tokenizer(self, tokenizer: Optional[AnyTokenizer]) -> None:
|
||||
self.processor.tokenizer = tokenizer
|
||||
|
||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
return self.engine_core.get_supported_tasks()
|
||||
|
||||
@ -313,12 +314,6 @@ class LLMEngine:
|
||||
|
||||
return processed_outputs.request_outputs
|
||||
|
||||
def get_vllm_config(self):
|
||||
return self.vllm_config
|
||||
|
||||
def get_model_config(self):
|
||||
return self.model_config
|
||||
|
||||
def start_profile(self):
|
||||
self.engine_core.profile(True)
|
||||
|
||||
@ -345,6 +340,14 @@ class LLMEngine:
|
||||
assert self.log_stats, "Stat logging disabled"
|
||||
return get_metrics_snapshot()
|
||||
|
||||
@property
|
||||
def tokenizer(self) -> Optional[AnyTokenizer]:
|
||||
return self.processor.tokenizer
|
||||
|
||||
@tokenizer.setter
|
||||
def tokenizer(self, tokenizer: Optional[AnyTokenizer]) -> None:
|
||||
self.processor.tokenizer = tokenizer
|
||||
|
||||
def get_tokenizer(self) -> AnyTokenizer:
|
||||
if self.tokenizer is None:
|
||||
raise ValueError(
|
||||
|
||||
@ -37,6 +37,7 @@ class Processor:
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
tokenizer: Optional[AnyTokenizer],
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
) -> None:
|
||||
self.vllm_config = vllm_config
|
||||
@ -52,6 +53,7 @@ class Processor:
|
||||
|
||||
self.input_preprocessor = InputPreprocessor(
|
||||
self.model_config,
|
||||
tokenizer,
|
||||
mm_registry,
|
||||
mm_processor_cache=self.mm_processor_cache,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user