# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Optional, Union import torch from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) from vllm.multimodal.inputs import (MultiModalBatchedField, MultiModalFieldElem, MultiModalKwargsItem, PlaceholderRange) from vllm.sampling_params import SamplingParams from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, init_none_hash) from vllm.v1.core.sched.async_scheduler import AsyncScheduler from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec) from vllm.v1.request import Request from vllm.v1.structured_output import StructuredOutputManager EOS_TOKEN_ID = 50256 def create_scheduler( model: str = "facebook/opt-125m", max_num_seqs: int = 16, max_num_batched_tokens: int = 8192, enable_prefix_caching: Optional[bool] = None, long_prefill_token_threshold: int = 0, disable_chunked_mm_input: bool = False, use_kv_connector: bool = False, num_blocks: int = 10000, block_size: int = 16, max_model_len: Optional[int] = None, num_speculative_tokens: Optional[int] = None, skip_tokenizer_init: bool = False, async_scheduling: bool = False, ) -> Union[Scheduler, AsyncScheduler]: '''Create scheduler under test. Args: model: model under test max_num_seqs: max sequences to schedule max_num_batch_tokens: max num tokens to batch enable_prefix_caching: optionally force APC config (True/False) or use default (None) Returns: {class}`Scheduler` instance ''' if max_model_len is None: max_model_len = max_num_batched_tokens scheduler_config = SchedulerConfig( max_num_seqs=max_num_seqs, max_num_batched_tokens=max_num_batched_tokens, max_model_len=max_model_len, long_prefill_token_threshold=long_prefill_token_threshold, disable_chunked_mm_input=disable_chunked_mm_input, enable_chunked_prefill=True, async_scheduling=async_scheduling, ) model_config = ModelConfig( model=model, trust_remote_code=True, dtype="float16", seed=42, skip_tokenizer_init=skip_tokenizer_init, ) # Cache config, optionally force APC kwargs_cache = ({} if enable_prefix_caching is None else { 'enable_prefix_caching': enable_prefix_caching }) cache_config = CacheConfig( block_size=block_size, gpu_memory_utilization=0.9, swap_space=0, cache_dtype="auto", **kwargs_cache, ) kv_transfer_config = KVTransferConfig( kv_connector="SharedStorageConnector", kv_role="kv_both", kv_connector_extra_config={"shared_storage_path": "local_storage"}, ) if use_kv_connector else None speculative_config: Optional[SpeculativeConfig] = None if num_speculative_tokens is not None: speculative_config = SpeculativeConfig( model="ngram", num_speculative_tokens=num_speculative_tokens) vllm_config = VllmConfig( scheduler_config=scheduler_config, model_config=model_config, cache_config=cache_config, kv_transfer_config=kv_transfer_config, speculative_config=speculative_config, ) kv_cache_config = KVCacheConfig( num_blocks=num_blocks, # A large number of blocks to hold all requests kv_cache_tensors=[], kv_cache_groups=[ KVCacheGroupSpec(['layer'], FullAttentionSpec(block_size, 1, 1, torch.float32, False)) ], ) cache_config.num_gpu_blocks = num_blocks scheduler_cls = AsyncScheduler if async_scheduling else Scheduler return scheduler_cls( vllm_config=vllm_config, kv_cache_config=kv_cache_config, log_stats=True, structured_output_manager=StructuredOutputManager(vllm_config), ) _none_hash_initialized = False def create_requests( num_requests: int, num_tokens: int = 10, mm_positions: Optional[list[list[PlaceholderRange]]] = None, max_tokens: int = 16, stop_token_ids: Optional[list[int]] = None, prompt_logprobs: Optional[int] = None, same_prompt: bool = False, block_size: int = 16, ) -> list[Request]: global _none_hash_initialized if not _none_hash_initialized: init_none_hash(hash) _none_hash_initialized = True block_hasher = get_request_block_hasher(block_size, hash) sampling_params = SamplingParams(ignore_eos=False, max_tokens=max_tokens, stop_token_ids=stop_token_ids, prompt_logprobs=prompt_logprobs) requests = [] for i in range(num_requests): if mm_positions is not None: mm_position = mm_positions[i] mm_elem = MultiModalFieldElem( modality="dummy_m", key="dummy_k", data=None, field=MultiModalBatchedField(), ) mm_item = MultiModalKwargsItem.from_elems([mm_elem]) mm_kwargs = [mm_item] * len(mm_position) mm_hashes = ["hash"] * len(mm_position) else: mm_position = None mm_kwargs = None mm_hashes = None prompt_token_ids = ([0] * num_tokens if same_prompt else [i] * num_tokens) request = Request( request_id=f"{i}", prompt_token_ids=prompt_token_ids, sampling_params=sampling_params, pooling_params=None, multi_modal_kwargs=mm_kwargs, multi_modal_placeholders=mm_position, multi_modal_hashes=mm_hashes, eos_token_id=EOS_TOKEN_ID, block_hasher=block_hasher, ) requests.append(request) return requests