# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch from vllm.config import ( CacheConfig, KVTransferConfig, ModelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig, ) from vllm.multimodal.inputs import ( MultiModalFeatureSpec, MultiModalKwargsItem, PlaceholderRange, ) from vllm.sampling_params import SamplingParams from vllm.utils import sha256 from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash from vllm.v1.core.sched.async_scheduler import AsyncScheduler from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.kv_cache_interface import ( FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, ) from vllm.v1.request import Request from vllm.v1.structured_output import StructuredOutputManager EOS_TOKEN_ID = 50256 def create_scheduler( model: str = "facebook/opt-125m", max_num_seqs: int = 16, max_num_batched_tokens: int = 8192, enable_prefix_caching: bool | None = None, long_prefill_token_threshold: int = 0, disable_chunked_mm_input: bool = False, use_kv_connector: bool = False, num_blocks: int = 10000, block_size: int = 16, max_model_len: int | None = None, num_speculative_tokens: int | None = None, skip_tokenizer_init: bool = False, async_scheduling: bool = False, ) -> Scheduler | AsyncScheduler: """Create scheduler under test. Args: model: model under test max_num_seqs: max sequences to schedule max_num_batch_tokens: max num tokens to batch enable_prefix_caching: optionally force APC config (True/False) or use default (None) Returns: {class}`Scheduler` instance """ if max_model_len is None: max_model_len = max_num_batched_tokens scheduler_config = SchedulerConfig( max_num_seqs=max_num_seqs, max_num_batched_tokens=max_num_batched_tokens, max_model_len=max_model_len, long_prefill_token_threshold=long_prefill_token_threshold, disable_chunked_mm_input=disable_chunked_mm_input, enable_chunked_prefill=True, async_scheduling=async_scheduling, ) model_config = ModelConfig( model=model, trust_remote_code=True, dtype="float16", seed=42, skip_tokenizer_init=skip_tokenizer_init, ) # Cache config, optionally force APC kwargs_cache = ( {} if enable_prefix_caching is None else {"enable_prefix_caching": enable_prefix_caching} ) cache_config = CacheConfig( block_size=block_size, gpu_memory_utilization=0.9, swap_space=0, cache_dtype="auto", **kwargs_cache, ) kv_transfer_config = ( KVTransferConfig( kv_connector="SharedStorageConnector", kv_role="kv_both", kv_connector_extra_config={"shared_storage_path": "local_storage"}, ) if use_kv_connector else None ) speculative_config: SpeculativeConfig | None = None if num_speculative_tokens is not None: speculative_config = SpeculativeConfig( model="ngram", num_speculative_tokens=num_speculative_tokens ) vllm_config = VllmConfig( scheduler_config=scheduler_config, model_config=model_config, cache_config=cache_config, kv_transfer_config=kv_transfer_config, speculative_config=speculative_config, ) kv_cache_config = KVCacheConfig( num_blocks=num_blocks, # A large number of blocks to hold all requests kv_cache_tensors=[], kv_cache_groups=[ KVCacheGroupSpec( ["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False) ) ], ) cache_config.num_gpu_blocks = num_blocks scheduler_cls = AsyncScheduler if async_scheduling else Scheduler return scheduler_cls( vllm_config=vllm_config, kv_cache_config=kv_cache_config, block_size=block_size, log_stats=True, structured_output_manager=StructuredOutputManager(vllm_config), ) _none_hash_initialized = False def create_requests( num_requests: int, num_tokens: int = 10, mm_positions: list[list[PlaceholderRange]] | None = None, max_tokens: int = 16, stop_token_ids: list[int] | None = None, prompt_logprobs: int | None = None, same_prompt: bool = False, block_size: int = 16, ) -> list[Request]: global _none_hash_initialized if not _none_hash_initialized: init_none_hash(sha256) _none_hash_initialized = True block_hasher = get_request_block_hasher(block_size, sha256) sampling_params = SamplingParams( ignore_eos=False, max_tokens=max_tokens, stop_token_ids=stop_token_ids, prompt_logprobs=prompt_logprobs, ) requests = [] for i in range(num_requests): mm_features = [] if mm_positions is not None: mm_position = mm_positions[i] for j, position in enumerate(mm_position): # Dummy hash for each mm item should be unique # since encoder cache tracks entries by hash identifier = f"hash{i}_{j}" mm_feature = MultiModalFeatureSpec( data=MultiModalKwargsItem.dummy("dummy_m"), mm_position=position, identifier=identifier, modality="image", ) mm_features.append(mm_feature) prompt_token_ids = [0] * num_tokens if same_prompt else [i] * num_tokens request = Request( request_id=f"{i}", prompt_token_ids=prompt_token_ids, sampling_params=sampling_params, pooling_params=None, mm_features=mm_features if mm_features else None, eos_token_id=EOS_TOKEN_ID, block_hasher=block_hasher, ) requests.append(request) return requests