# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Optional, Union import torch from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) from vllm.multimodal.inputs import (MultiModalBatchedField, MultiModalFieldElem, MultiModalKwargsItem, PlaceholderRange) from vllm.sampling_params import SamplingParams from vllm.v1.core.sched.async_scheduler import AsyncScheduler from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec) from vllm.v1.request import Request from vllm.v1.structured_output import StructuredOutputManager EOS_TOKEN_ID = 50256 def create_scheduler( model: str = "facebook/opt-125m", max_num_seqs: int = 16, max_num_batched_tokens: int = 8192, enable_prefix_caching: Optional[bool] = None, long_prefill_token_threshold: int = 0, disable_chunked_mm_input: bool = False, use_kv_connector: bool = False, num_blocks: int = 10000, block_size: int = 16, max_model_len: Optional[int] = None, num_speculative_tokens: Optional[int] = None, skip_tokenizer_init: bool = False, async_scheduling: bool = False, ) -> Union[Scheduler, AsyncScheduler]: '''Create scheduler under test. Args: model: model under test max_num_seqs: max sequences to schedule max_num_batch_tokens: max num tokens to batch enable_prefix_caching: optionally force APC config (True/False) or use default (None) Returns: {class}`Scheduler` instance ''' if max_model_len is None: max_model_len = max_num_batched_tokens scheduler_config = SchedulerConfig( max_num_seqs=max_num_seqs, max_num_batched_tokens=max_num_batched_tokens, max_model_len=max_model_len, long_prefill_token_threshold=long_prefill_token_threshold, disable_chunked_mm_input=disable_chunked_mm_input, enable_chunked_prefill=True, async_scheduling=async_scheduling, ) model_config = ModelConfig( model=model, trust_remote_code=True, dtype="float16", seed=42, skip_tokenizer_init=skip_tokenizer_init, ) # Cache config, optionally force APC kwargs_cache = ({} if enable_prefix_caching is None else { 'enable_prefix_caching': enable_prefix_caching }) cache_config = CacheConfig( block_size=block_size, gpu_memory_utilization=0.9, swap_space=0, cache_dtype="auto", **kwargs_cache, ) kv_transfer_config = KVTransferConfig( kv_connector="SharedStorageConnector", kv_role="kv_both", kv_connector_extra_config={"shared_storage_path": "local_storage"}, ) if use_kv_connector else None speculative_config: Optional[SpeculativeConfig] = None if num_speculative_tokens is not None: speculative_config = SpeculativeConfig( model="ngram", num_speculative_tokens=num_speculative_tokens) vllm_config = VllmConfig( scheduler_config=scheduler_config, model_config=model_config, cache_config=cache_config, kv_transfer_config=kv_transfer_config, speculative_config=speculative_config, ) kv_cache_config = KVCacheConfig( num_blocks=num_blocks, # A large number of blocks to hold all requests kv_cache_tensors=[], kv_cache_groups=[ KVCacheGroupSpec(['layer'], FullAttentionSpec(block_size, 1, 1, torch.float32, False)) ], ) cache_config.num_gpu_blocks = num_blocks scheduler_cls = AsyncScheduler if async_scheduling else Scheduler return scheduler_cls( vllm_config=vllm_config, kv_cache_config=kv_cache_config, log_stats=True, structured_output_manager=StructuredOutputManager(vllm_config), ) def create_requests( num_requests: int, num_tokens: int = 10, mm_positions: Optional[list[list[PlaceholderRange]]] = None, max_tokens: int = 16, stop_token_ids: Optional[list[int]] = None, prompt_logprobs: Optional[int] = None, same_prompt: bool = False, ) -> list[Request]: sampling_params = SamplingParams(ignore_eos=False, max_tokens=max_tokens, stop_token_ids=stop_token_ids, prompt_logprobs=prompt_logprobs) requests = [] for i in range(num_requests): if mm_positions is not None: mm_position = mm_positions[i] mm_elem = MultiModalFieldElem( modality="dummy_m", key="dummy_k", data=None, field=MultiModalBatchedField(), ) mm_item = MultiModalKwargsItem.from_elems([mm_elem]) mm_kwargs = [mm_item] * len(mm_position) else: mm_position = None mm_kwargs = None prompt_token_ids = ([0] * num_tokens if same_prompt else [i] * num_tokens) request = Request( request_id=f"{i}", prompt_token_ids=prompt_token_ids, sampling_params=sampling_params, pooling_params=None, multi_modal_kwargs=mm_kwargs, multi_modal_placeholders=mm_position, multi_modal_hashes=None, eos_token_id=EOS_TOKEN_ID, ) requests.append(request) return requests