vllm/tests/v1/core/utils.py
Cyrus Leung 4dff91c93d
[Refactor] Allow optional MultiModalKwargsItem in IPC (#23022)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-08-16 11:30:49 +00:00

166 lines
5.9 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
import torch
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
init_none_hash)
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
from vllm.v1.request import Request
from vllm.v1.structured_output import StructuredOutputManager
EOS_TOKEN_ID = 50256
def create_scheduler(
model: str = "facebook/opt-125m",
max_num_seqs: int = 16,
max_num_batched_tokens: int = 8192,
enable_prefix_caching: Optional[bool] = None,
long_prefill_token_threshold: int = 0,
disable_chunked_mm_input: bool = False,
use_kv_connector: bool = False,
num_blocks: int = 10000,
block_size: int = 16,
max_model_len: Optional[int] = None,
num_speculative_tokens: Optional[int] = None,
skip_tokenizer_init: bool = False,
async_scheduling: bool = False,
) -> Union[Scheduler, AsyncScheduler]:
'''Create scheduler under test.
Args:
model: model under test
max_num_seqs: max sequences to schedule
max_num_batch_tokens: max num tokens to batch
enable_prefix_caching: optionally force APC config
(True/False) or use default
(None)
Returns:
{class}`Scheduler` instance
'''
if max_model_len is None:
max_model_len = max_num_batched_tokens
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_model_len,
long_prefill_token_threshold=long_prefill_token_threshold,
disable_chunked_mm_input=disable_chunked_mm_input,
enable_chunked_prefill=True,
async_scheduling=async_scheduling,
)
model_config = ModelConfig(
model=model,
trust_remote_code=True,
dtype="float16",
seed=42,
skip_tokenizer_init=skip_tokenizer_init,
)
# Cache config, optionally force APC
kwargs_cache = ({} if enable_prefix_caching is None else {
'enable_prefix_caching': enable_prefix_caching
})
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
**kwargs_cache,
)
kv_transfer_config = KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"},
) if use_kv_connector else None
speculative_config: Optional[SpeculativeConfig] = None
if num_speculative_tokens is not None:
speculative_config = SpeculativeConfig(
model="ngram", num_speculative_tokens=num_speculative_tokens)
vllm_config = VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
speculative_config=speculative_config,
)
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, # A large number of blocks to hold all requests
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(['layer'],
FullAttentionSpec(block_size, 1, 1, torch.float32,
False))
],
)
cache_config.num_gpu_blocks = num_blocks
scheduler_cls = AsyncScheduler if async_scheduling else Scheduler
return scheduler_cls(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
)
_none_hash_initialized = False
def create_requests(
num_requests: int,
num_tokens: int = 10,
mm_positions: Optional[list[list[PlaceholderRange]]] = None,
max_tokens: int = 16,
stop_token_ids: Optional[list[int]] = None,
prompt_logprobs: Optional[int] = None,
same_prompt: bool = False,
block_size: int = 16,
) -> list[Request]:
global _none_hash_initialized
if not _none_hash_initialized:
init_none_hash(hash)
_none_hash_initialized = True
block_hasher = get_request_block_hasher(block_size, hash)
sampling_params = SamplingParams(ignore_eos=False,
max_tokens=max_tokens,
stop_token_ids=stop_token_ids,
prompt_logprobs=prompt_logprobs)
requests = []
for i in range(num_requests):
if mm_positions is not None:
mm_position = mm_positions[i]
mm_item = MultiModalKwargsItem.dummy("dummy_m")
mm_kwargs = [mm_item] * len(mm_position)
mm_hashes = ["hash"] * len(mm_position)
else:
mm_position = None
mm_kwargs = None
mm_hashes = None
prompt_token_ids = ([0] * num_tokens if same_prompt else [i] *
num_tokens)
request = Request(
request_id=f"{i}",
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
pooling_params=None,
multi_modal_kwargs=mm_kwargs,
multi_modal_placeholders=mm_position,
multi_modal_hashes=mm_hashes,
eos_token_id=EOS_TOKEN_ID,
block_hasher=block_hasher,
)
requests.append(request)
return requests