vllm/tests/v1/core/utils.py
Harry Mellor 8fcaaf6a16
Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-10-12 09:51:31 -07:00

191 lines
6.0 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.config import (
CacheConfig,
KVTransferConfig,
ModelConfig,
SchedulerConfig,
SpeculativeConfig,
VllmConfig,
)
from vllm.multimodal.inputs import (
MultiModalFeatureSpec,
MultiModalKwargsItem,
PlaceholderRange,
)
from vllm.sampling_params import SamplingParams
from vllm.utils import sha256
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
)
from vllm.v1.request import Request
from vllm.v1.structured_output import StructuredOutputManager
EOS_TOKEN_ID = 50256
def create_scheduler(
model: str = "facebook/opt-125m",
max_num_seqs: int = 16,
max_num_batched_tokens: int = 8192,
enable_prefix_caching: bool | None = None,
long_prefill_token_threshold: int = 0,
disable_chunked_mm_input: bool = False,
use_kv_connector: bool = False,
num_blocks: int = 10000,
block_size: int = 16,
max_model_len: int | None = None,
num_speculative_tokens: int | None = None,
skip_tokenizer_init: bool = False,
async_scheduling: bool = False,
) -> Scheduler | AsyncScheduler:
"""Create scheduler under test.
Args:
model: model under test
max_num_seqs: max sequences to schedule
max_num_batch_tokens: max num tokens to batch
enable_prefix_caching: optionally force APC config
(True/False) or use default
(None)
Returns:
{class}`Scheduler` instance
"""
if max_model_len is None:
max_model_len = max_num_batched_tokens
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_model_len,
long_prefill_token_threshold=long_prefill_token_threshold,
disable_chunked_mm_input=disable_chunked_mm_input,
enable_chunked_prefill=True,
async_scheduling=async_scheduling,
)
model_config = ModelConfig(
model=model,
trust_remote_code=True,
dtype="float16",
seed=42,
skip_tokenizer_init=skip_tokenizer_init,
)
# Cache config, optionally force APC
kwargs_cache = (
{}
if enable_prefix_caching is None
else {"enable_prefix_caching": enable_prefix_caching}
)
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
**kwargs_cache,
)
kv_transfer_config = (
KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"},
)
if use_kv_connector
else None
)
speculative_config: SpeculativeConfig | None = None
if num_speculative_tokens is not None:
speculative_config = SpeculativeConfig(
model="ngram", num_speculative_tokens=num_speculative_tokens
)
vllm_config = VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
speculative_config=speculative_config,
)
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, # A large number of blocks to hold all requests
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False)
)
],
)
cache_config.num_gpu_blocks = num_blocks
scheduler_cls = AsyncScheduler if async_scheduling else Scheduler
return scheduler_cls(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
block_size=block_size,
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
)
_none_hash_initialized = False
def create_requests(
num_requests: int,
num_tokens: int = 10,
mm_positions: list[list[PlaceholderRange]] | None = None,
max_tokens: int = 16,
stop_token_ids: list[int] | None = None,
prompt_logprobs: int | None = None,
same_prompt: bool = False,
block_size: int = 16,
) -> list[Request]:
global _none_hash_initialized
if not _none_hash_initialized:
init_none_hash(sha256)
_none_hash_initialized = True
block_hasher = get_request_block_hasher(block_size, sha256)
sampling_params = SamplingParams(
ignore_eos=False,
max_tokens=max_tokens,
stop_token_ids=stop_token_ids,
prompt_logprobs=prompt_logprobs,
)
requests = []
for i in range(num_requests):
mm_features = []
if mm_positions is not None:
mm_position = mm_positions[i]
for j, position in enumerate(mm_position):
# Dummy hash for each mm item should be unique
# since encoder cache tracks entries by hash
identifier = f"hash{i}_{j}"
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"),
mm_position=position,
identifier=identifier,
modality="image",
)
mm_features.append(mm_feature)
prompt_token_ids = [0] * num_tokens if same_prompt else [i] * num_tokens
request = Request(
request_id=f"{i}",
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
pooling_params=None,
mm_features=mm_features if mm_features else None,
eos_token_id=EOS_TOKEN_ID,
block_hasher=block_hasher,
)
requests.append(request)
return requests