mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:05:01 +08:00
191 lines
6.0 KiB
Python
191 lines
6.0 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from typing import Optional, Union
|
|
|
|
import torch
|
|
|
|
from vllm.config import (
|
|
CacheConfig,
|
|
KVTransferConfig,
|
|
ModelConfig,
|
|
SchedulerConfig,
|
|
SpeculativeConfig,
|
|
VllmConfig,
|
|
)
|
|
from vllm.multimodal.inputs import (
|
|
MultiModalFeatureSpec,
|
|
MultiModalKwargsItem,
|
|
PlaceholderRange,
|
|
)
|
|
from vllm.sampling_params import SamplingParams
|
|
from vllm.utils import sha256
|
|
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
|
|
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
|
|
from vllm.v1.core.sched.scheduler import Scheduler
|
|
from vllm.v1.kv_cache_interface import (
|
|
FullAttentionSpec,
|
|
KVCacheConfig,
|
|
KVCacheGroupSpec,
|
|
)
|
|
from vllm.v1.request import Request
|
|
from vllm.v1.structured_output import StructuredOutputManager
|
|
|
|
EOS_TOKEN_ID = 50256
|
|
|
|
|
|
def create_scheduler(
|
|
model: str = "facebook/opt-125m",
|
|
max_num_seqs: int = 16,
|
|
max_num_batched_tokens: int = 8192,
|
|
enable_prefix_caching: Optional[bool] = None,
|
|
long_prefill_token_threshold: int = 0,
|
|
disable_chunked_mm_input: bool = False,
|
|
use_kv_connector: bool = False,
|
|
num_blocks: int = 10000,
|
|
block_size: int = 16,
|
|
max_model_len: Optional[int] = None,
|
|
num_speculative_tokens: Optional[int] = None,
|
|
skip_tokenizer_init: bool = False,
|
|
async_scheduling: bool = False,
|
|
) -> Union[Scheduler, AsyncScheduler]:
|
|
"""Create scheduler under test.
|
|
|
|
Args:
|
|
model: model under test
|
|
max_num_seqs: max sequences to schedule
|
|
max_num_batch_tokens: max num tokens to batch
|
|
enable_prefix_caching: optionally force APC config
|
|
(True/False) or use default
|
|
(None)
|
|
|
|
Returns:
|
|
{class}`Scheduler` instance
|
|
"""
|
|
if max_model_len is None:
|
|
max_model_len = max_num_batched_tokens
|
|
scheduler_config = SchedulerConfig(
|
|
max_num_seqs=max_num_seqs,
|
|
max_num_batched_tokens=max_num_batched_tokens,
|
|
max_model_len=max_model_len,
|
|
long_prefill_token_threshold=long_prefill_token_threshold,
|
|
disable_chunked_mm_input=disable_chunked_mm_input,
|
|
enable_chunked_prefill=True,
|
|
async_scheduling=async_scheduling,
|
|
)
|
|
model_config = ModelConfig(
|
|
model=model,
|
|
trust_remote_code=True,
|
|
dtype="float16",
|
|
seed=42,
|
|
skip_tokenizer_init=skip_tokenizer_init,
|
|
)
|
|
# Cache config, optionally force APC
|
|
kwargs_cache = (
|
|
{}
|
|
if enable_prefix_caching is None
|
|
else {"enable_prefix_caching": enable_prefix_caching}
|
|
)
|
|
cache_config = CacheConfig(
|
|
block_size=block_size,
|
|
gpu_memory_utilization=0.9,
|
|
swap_space=0,
|
|
cache_dtype="auto",
|
|
**kwargs_cache,
|
|
)
|
|
kv_transfer_config = (
|
|
KVTransferConfig(
|
|
kv_connector="SharedStorageConnector",
|
|
kv_role="kv_both",
|
|
kv_connector_extra_config={"shared_storage_path": "local_storage"},
|
|
)
|
|
if use_kv_connector
|
|
else None
|
|
)
|
|
|
|
speculative_config: Optional[SpeculativeConfig] = None
|
|
if num_speculative_tokens is not None:
|
|
speculative_config = SpeculativeConfig(
|
|
model="ngram", num_speculative_tokens=num_speculative_tokens
|
|
)
|
|
|
|
vllm_config = VllmConfig(
|
|
scheduler_config=scheduler_config,
|
|
model_config=model_config,
|
|
cache_config=cache_config,
|
|
kv_transfer_config=kv_transfer_config,
|
|
speculative_config=speculative_config,
|
|
)
|
|
kv_cache_config = KVCacheConfig(
|
|
num_blocks=num_blocks, # A large number of blocks to hold all requests
|
|
kv_cache_tensors=[],
|
|
kv_cache_groups=[
|
|
KVCacheGroupSpec(
|
|
["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False)
|
|
)
|
|
],
|
|
)
|
|
cache_config.num_gpu_blocks = num_blocks
|
|
scheduler_cls = AsyncScheduler if async_scheduling else Scheduler
|
|
return scheduler_cls(
|
|
vllm_config=vllm_config,
|
|
kv_cache_config=kv_cache_config,
|
|
log_stats=True,
|
|
structured_output_manager=StructuredOutputManager(vllm_config),
|
|
)
|
|
|
|
|
|
_none_hash_initialized = False
|
|
|
|
|
|
def create_requests(
|
|
num_requests: int,
|
|
num_tokens: int = 10,
|
|
mm_positions: Optional[list[list[PlaceholderRange]]] = None,
|
|
max_tokens: int = 16,
|
|
stop_token_ids: Optional[list[int]] = None,
|
|
prompt_logprobs: Optional[int] = None,
|
|
same_prompt: bool = False,
|
|
block_size: int = 16,
|
|
) -> list[Request]:
|
|
global _none_hash_initialized
|
|
if not _none_hash_initialized:
|
|
init_none_hash(sha256)
|
|
_none_hash_initialized = True
|
|
|
|
block_hasher = get_request_block_hasher(block_size, sha256)
|
|
sampling_params = SamplingParams(
|
|
ignore_eos=False,
|
|
max_tokens=max_tokens,
|
|
stop_token_ids=stop_token_ids,
|
|
prompt_logprobs=prompt_logprobs,
|
|
)
|
|
requests = []
|
|
for i in range(num_requests):
|
|
mm_features = []
|
|
if mm_positions is not None:
|
|
mm_position = mm_positions[i]
|
|
for j, position in enumerate(mm_position):
|
|
# Dummy hash for each mm item should be unique
|
|
# since encoder cache tracks entries by hash
|
|
identifier = f"hash{i}_{j}"
|
|
mm_feature = MultiModalFeatureSpec(
|
|
data=MultiModalKwargsItem.dummy("dummy_m"),
|
|
mm_position=position,
|
|
identifier=identifier,
|
|
modality="image",
|
|
)
|
|
mm_features.append(mm_feature)
|
|
|
|
prompt_token_ids = [0] * num_tokens if same_prompt else [i] * num_tokens
|
|
request = Request(
|
|
request_id=f"{i}",
|
|
prompt_token_ids=prompt_token_ids,
|
|
sampling_params=sampling_params,
|
|
pooling_params=None,
|
|
mm_features=mm_features if mm_features else None,
|
|
eos_token_id=EOS_TOKEN_ID,
|
|
block_hasher=block_hasher,
|
|
)
|
|
requests.append(request)
|
|
return requests
|