From 787e59629cb374561ef1add9f939dab9b0686dc9 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 8 Sep 2025 16:42:26 -0700 Subject: [PATCH] wip Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu_model_runner copy.py | 3205 +++++++++++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 602 +++-- vllm/v1/worker/gpu_worker_states.py | 296 +-- 3 files changed, 3581 insertions(+), 522 deletions(-) create mode 100644 vllm/v1/worker/gpu_model_runner copy.py diff --git a/vllm/v1/worker/gpu_model_runner copy.py b/vllm/v1/worker/gpu_model_runner copy.py new file mode 100644 index 0000000000000..d55af6185ec90 --- /dev/null +++ b/vllm/v1/worker/gpu_model_runner copy.py @@ -0,0 +1,3205 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import gc +import itertools +import time +from collections import defaultdict +from collections.abc import Iterator +from contextlib import contextmanager +from copy import deepcopy +from typing import TYPE_CHECKING, Any, Optional, Union, cast + +import numpy as np +import torch +import torch.distributed +import torch.nn as nn +from tqdm import tqdm + +import vllm.envs as envs +from vllm.attention import Attention, AttentionType +from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention +from vllm.compilation.counter import compilation_counter +from vllm.compilation.cuda_graph import CUDAGraphWrapper +from vllm.compilation.monitor import set_cudagraph_capturing_enabled +from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig, + get_layers_from_vllm_config, update_config) +from vllm.distributed.eplb.eplb_state import EplbState +from vllm.distributed.kv_transfer import (get_kv_transfer_group, + has_kv_transfer_group) +from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks +from vllm.distributed.parallel_state import ( + get_pp_group, get_tp_group, graph_capture, is_global_first_rank, + prepare_communication_buffer_for_model) +from vllm.forward_context import (BatchDescriptor, DPMetadata, + set_forward_context) +from vllm.logger import init_logger +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.layers.mamba.abstract import MambaBase +from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding +from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader +from vllm.model_executor.models.interfaces import (is_mixture_of_experts, + supports_eagle3, + supports_transcription) +from vllm.model_executor.models.interfaces_base import ( + VllmModelForPooling, is_pooling_model, is_text_generation_model) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem, + PlaceholderRange) +from vllm.multimodal.utils import group_mm_kwargs_by_modality +from vllm.pooling_params import PoolingParams +from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.tasks import GenerationTask, PoolingTask, SupportedTask +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, + GiB_bytes, LazyLoader, cdiv, check_use_alibi, + get_dtype_size, is_pin_memory_available, round_up, + supports_dynamo) +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, + create_fast_prefill_custom_backend) +from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher +from vllm.v1.kv_cache_interface import (AttentionSpec, + ChunkedLocalAttentionSpec, + EncoderOnlyAttentionSpec, + FullAttentionSpec, KVCacheConfig, + KVCacheGroupSpec, KVCacheSpec, + MambaSpec, SlidingWindowSpec) +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, + DraftTokenIds, LogprobsLists, LogprobsTensors, + ModelRunnerOutput, SamplerOutput) +from vllm.v1.pool.metadata import PoolingMetadata +from vllm.v1.sample.logits_processor import LogitsProcessors +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.rejection_sampler import RejectionSampler +from vllm.v1.sample.sampler import Sampler +from vllm.v1.spec_decode.eagle import EagleProposer +from vllm.v1.spec_decode.medusa import MedusaProposer +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext +from vllm.v1.worker.gpu_block_table import BlockTables +from vllm.v1.worker.gpu_input_batch import InputBatch, prepare_inputs +from vllm.v1.worker.gpu_worker_states import RequestState +from vllm.v1.worker.kv_connector_model_runner_mixin import ( + KVConnectorModelRunnerMixin, KVConnectorOutput) +from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin + +from .utils import (AttentionGroup, MultiModalBudget, + add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache, + gather_mm_placeholders, sanity_check_mm_encoder_outputs, + scatter_mm_placeholders) + +if TYPE_CHECKING: + import xgrammar as xgr + + from vllm.model_executor.model_loader.tensorizer import TensorizerConfig + from vllm.v1.core.sched.output import SchedulerOutput +else: + xgr = LazyLoader("xgr", globals(), "xgrammar") + +logger = init_logger(__name__) + + +# Wrapper for ModelRunnerOutput to support overlapped execution. +class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput): + + def __init__( + self, + model_runner_output: ModelRunnerOutput, + sampled_token_ids: torch.Tensor, + invalid_req_indices: list[int], + async_output_copy_stream: torch.cuda.Stream, + ): + self._model_runner_output = model_runner_output + self._invalid_req_indices = invalid_req_indices + + # Event on the copy stream so we can synchronize the non-blocking copy. + self._async_copy_ready_event = torch.cuda.Event() + + # Keep a reference to the device tensor to avoid it being + # deallocated until we finish copying it to the host. + self._sampled_token_ids = sampled_token_ids + + # Initiate the copy on a separate stream, but do not synchronize it. + default_stream = torch.cuda.current_stream() + with torch.cuda.stream(async_output_copy_stream): + async_output_copy_stream.wait_stream(default_stream) + self._sampled_token_ids_cpu = self._sampled_token_ids.to( + 'cpu', non_blocking=True) + self._async_copy_ready_event.record() + + def get_output(self) -> ModelRunnerOutput: + """Copy the device tensors to the host and return a ModelRunnerOutput. + + This function blocks until the copy is finished. + """ + self._async_copy_ready_event.synchronize() + + # Release the device tensor once the copy has completed + del self._sampled_token_ids + + valid_sampled_token_ids = self._sampled_token_ids_cpu.tolist() + for i in self._invalid_req_indices: + valid_sampled_token_ids[i].clear() + + output = self._model_runner_output + output.sampled_token_ids = valid_sampled_token_ids + return output + + +class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): + + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + ): + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.compilation_config = vllm_config.compilation_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.speculative_config = vllm_config.speculative_config + self.observability_config = vllm_config.observability_config + + from vllm.model_executor.models.utils import set_cpu_offload_max_bytes + set_cpu_offload_max_bytes( + int(self.cache_config.cpu_offload_gb * 1024**3)) + + model_config = self.model_config + cache_config = self.cache_config + scheduler_config = self.scheduler_config + parallel_config = self.parallel_config + self.device = device + self.pin_memory = is_pin_memory_available() + self.dtype = self.model_config.dtype + if cache_config.cache_dtype == "auto": + self.kv_cache_dtype = self.dtype + else: + self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ + cache_config.cache_dtype] + + self.is_pooling_model = (model_config.runner_type == 'pooling') + self.is_multimodal_raw_input_only_model = ( + model_config.is_multimodal_raw_input_only_model) + + self.max_model_len = model_config.max_model_len + self.dcp_world_size = self.parallel_config.decode_context_parallel_size + self.max_num_tokens = scheduler_config.max_num_batched_tokens + self.max_num_reqs = scheduler_config.max_num_seqs + + # Model-related. + self.num_query_heads = model_config.get_num_attention_heads( + parallel_config) + self.hidden_size = model_config.get_hidden_size() + self.attention_chunk_size = model_config.attention_chunk_size + # Only relevant for models using ALiBi (e.g, MPT) + self.use_alibi = check_use_alibi(model_config) + + self.cascade_attn_enabled = not self.model_config.disable_cascade_attn + + # Multi-modal data support + self.mm_registry = MULTIMODAL_REGISTRY + self.uses_mrope = model_config.uses_mrope + self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( + model_config) + + # Sampler + self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) + + self.eplb_state: Optional[EplbState] = None + """ + State of the expert parallelism load balancer. + + Will be lazily initialized when the model is loaded. + """ + + # Lazy initializations + # self.model: nn.Module # Set after load_model + # Initialize in initialize_kv_cache + self.kv_caches: list[torch.Tensor] = [] + # indexes: [kv_cache_group_id][attn_group] + self.attn_groups: list[list[AttentionGroup]] = [] + # self.kv_cache_config: KVCacheConfig + + # mm_hash -> encoder_output + self.encoder_cache: dict[str, torch.Tensor] = {} + + self.use_aux_hidden_state_outputs = False + # Set up speculative decoding. + # NOTE(Jiayi): currently we put the entire draft model on + # the last PP rank. This is not ideal if there are many + # layers in the draft model. + if self.speculative_config and get_pp_group().is_last_rank: + if self.speculative_config.method == "ngram": + self.drafter = NgramProposer(self.vllm_config) + elif self.speculative_config.use_eagle(): + self.drafter = EagleProposer(self.vllm_config, self.device, + self) # type: ignore + if self.speculative_config.method == "eagle3": + self.use_aux_hidden_state_outputs = True + elif self.speculative_config.method == "medusa": + self.drafter = MedusaProposer( + vllm_config=self.vllm_config, + device=self.device) # type: ignore + else: + raise ValueError("Unknown speculative decoding method: " + f"{self.speculative_config.method}") + self.rejection_sampler = RejectionSampler() + + # Request states. + self.req_states = RequestState( + max_num_reqs=self.max_num_reqs, + max_model_len=self.max_model_len, + max_num_batched_tokens=self.max_num_tokens, + max_num_cached_reqs=2 * self.max_num_reqs, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=self.model_config.get_vocab_size(), + block_sizes=[self.cache_config.block_size], + ) + + self.use_async_scheduling = self.scheduler_config.async_scheduling + self.async_output_copy_stream = torch.cuda.Stream() if \ + self.use_async_scheduling else None + + # TODO(woosuk): Provide an option to tune the max cudagraph batch size. + # The convention is different. + # self.cudagraph_batch_sizes sorts in ascending order. + # The batch sizes in the config are in descending order. + if self.compilation_config.cudagraph_capture_sizes and \ + self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: + self.cudagraph_batch_sizes = list( + reversed(self.compilation_config.cudagraph_capture_sizes)) + + # Cache the device properties. + self._init_device_properties() + + # Persistent buffers for CUDA graphs. + self.input_ids = self._make_buffer(self.max_num_tokens, + dtype=torch.int32) + # Because inputs_embeds may be bfloat16 and we don't need a numpy + # version of this tensor, avoid a RuntimeError by not creating a + # numpy buffer. + self.inputs_embeds = self._make_buffer(self.max_num_tokens, + self.hidden_size, + dtype=self.dtype, + numpy=False) + self.positions = self._make_buffer(self.max_num_tokens, + dtype=torch.int64) + self.query_start_loc = self._make_buffer(self.max_num_reqs + 1, + dtype=torch.int32) + self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) + self.cu_num_draft_tokens = self._make_buffer(self.max_num_reqs, + dtype=torch.int32) + + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.uses_mrope: + # NOTE: `mrope_positions` is implemented with one additional dummy + # position on purpose to make it non-contiguous so that it can work + # with torch compile. + # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923 + + # NOTE: When M-RoPE is enabled, position ids are 3D regardless of + # the modality of inputs. For text-only inputs, each dimension has + # identical position IDs, making M-RoPE functionally equivalent to + # 1D-RoPE. + # See page 5 of https://arxiv.org/abs/2409.12191 + self.mrope_positions = self._make_buffer( + (3, self.max_num_tokens + 1), dtype=torch.int64) + + # None in the first PP rank. The rest are set after load_model. + self.intermediate_tensors: Optional[IntermediateTensors] = None + + self.idx_mapping = self._make_buffer(self.max_num_reqs, + dtype=torch.int32) + + # Layer pairings for cross-layer KV sharing. + # If an Attention layer `layer_name` is in the keys of this dict, it + # means this layer will perform attention using the keys and values + # from the KV cache of `shared_kv_cache_layers[layer_name]`. + self.shared_kv_cache_layers: dict[str, str] = {} + self.kv_sharing_fast_prefill_eligible_layers: set[str] = set() + + self.kv_sharing_fast_prefill_logits_indices = None + if self.cache_config.kv_sharing_fast_prefill: + self.kv_sharing_fast_prefill_logits_indices = torch.zeros( + self.max_num_tokens, dtype=torch.int32, device=self.device) + + self.uniform_decode_query_len = 1 if not self.speculative_config else \ + 1 + self.speculative_config.num_speculative_tokens + + # Cudagraph dispatcher for runtime cudagraph dispatching. + self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) + + self.mm_budget = (MultiModalBudget( + self.model_config, + self.scheduler_config, + self.mm_registry, + ) if self.supports_mm_inputs else None) + + # Attention layers that are only in the KVCacheConfig of the runner + # (e.g., KV sharing, encoder-only attention), but not in the + # KVCacheConfig of the scheduler. + self.runner_only_attn_layers: set[str] = set() + + # Cached outputs. + self._draft_token_ids: Optional[Union[list[list[int]], + torch.Tensor]] = None + self._draft_req_ids: Optional[list[str]] = None + self.transfer_event = torch.cuda.Event() + self.sampled_token_ids_pinned_cpu = torch.empty( + (self.max_model_len, 1), + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory) + + def _make_buffer(self, + *size: Union[int, torch.SymInt], + dtype: torch.dtype, + numpy: bool = True) -> CpuGpuBuffer: + # Bfloat16 torch tensors cannot be directly cast to a numpy array, so + # if a bfloat16 buffer is needed without a corresponding numpy array, + # don't bother instantiating the numpy array. + return CpuGpuBuffer(*size, + dtype=dtype, + device=self.device, + pin_memory=self.pin_memory, + with_numpy=numpy) + + def _init_model_kwargs(self, num_tokens: int): + return {} + model_kwargs = dict[str, Any]() + + if not self.is_pooling_model: + return model_kwargs + + num_reqs = self.input_batch.num_reqs + pooling_params = self.input_batch.get_pooling_params() + + token_type_id_requests = dict[int, Any]() + for i, param in enumerate(pooling_params): + if param.extra_kwargs is not None and \ + (token_types := param.extra_kwargs.get( + "compressed_token_type_ids")) is not None: + token_type_id_requests[i] = token_types + + if len(token_type_id_requests) == 0: + return model_kwargs + + seq_lens = self.seq_lens.gpu[:num_reqs] + token_type_ids = [] + + for i in range(num_reqs): + pos = token_type_id_requests.get(i, seq_lens[i]) + ids = (torch.arange(seq_lens[i]) >= pos).int() + token_type_ids.append(ids) + + model_kwargs["token_type_ids"] = torch.concat(token_type_ids).to( + device=self.device) + return model_kwargs + + # Note: used for model runner override. + def _init_device_properties(self) -> None: + """Initialize attributes from torch.cuda.get_device_properties + """ + self.device_properties = torch.cuda.get_device_properties(self.device) + self.num_sms = self.device_properties.multi_processor_count + + # Note: used for model runner override. + def _sync_device(self) -> None: + torch.cuda.synchronize() + + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: + """Update the cached states and the persistent batch with the scheduler + output. + + The updated states are used by the `_prepare_inputs` function to create + the input GPU tensors for the model. + + The SamplingMetadata is updated and copied to the GPU if there is a + new/resumed/paused/finished request in the batch. + """ + # Remove the finished requests from the persistent batch. + # NOTE(woosuk): There could be an edge case where finished_req_ids and + # scheduled_req_ids overlap. This happens when a request is aborted and + # then resubmitted with the same ID. In this case, we treat them as two + # distinct requests - clearing the cached states for the first request + # and handling the second as a new request. + for req_id in scheduler_output.finished_req_ids: + self.req_states.remove_request(req_id) + self.encoder_cache.pop(req_id, None) + + # Free the cached encoder outputs. + for mm_hash in scheduler_output.free_encoder_mm_hashes: + self.encoder_cache.pop(mm_hash, None) + + req_indices: list[int] = [] + cu_num_new_blocks: list[list[int]] = [ + [0] for _ in range(self.block_tables.num_kv_cache_groups) + ] + new_block_ids: list[list[int]] = [ + [] for _ in range(self.block_tables.num_kv_cache_groups) + ] + overwrite: list[bool] = [] + + # Add new requests to the cached states. + for new_req_data in scheduler_output.scheduled_new_reqs: + req_id = new_req_data.req_id + self.req_states.add_request( + req_id=req_id, + prompt_token_ids=new_req_data.prompt_token_ids, + num_computed_tokens=new_req_data.num_computed_tokens, + sampling_params=new_req_data.sampling_params, + ) + + req_index = self.req_states.req_id_to_index[req_id] + req_indices.append(req_index) + for i, block_ids in enumerate(new_req_data.block_ids): + x = cu_num_new_blocks[i][-1] + cu_num_new_blocks[i].append(x + len(block_ids)) + new_block_ids[i].extend(block_ids) + overwrite.append(True) + + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.uses_mrope: + self._init_mrope_positions(req_id) + + # Update the states of the running/resumed requests. + is_last_rank = get_pp_group().is_last_rank + cached_reqs = scheduler_output.scheduled_cached_reqs + for i, req_id in enumerate(cached_reqs.req_ids): + req_index = self.req_states.req_id_to_index[req_id] + + # Update input batch. + if not is_last_rank: + # When using PP, the scheduler sends the sampled tokens back, + # because there's no direct communication between the first- + # stage worker and the last-stage worker. + new_token_ids = cached_reqs.new_token_ids[i] + self.req_states.append_token_ids(req_index, new_token_ids) + + req_new_block_ids = cached_reqs.new_block_ids[i] + if req_new_block_ids is not None: + req_indices.append(req_index) + for group_id, block_ids in enumerate(req_new_block_ids): + x = cu_num_new_blocks[group_id][-1] + cu_num_new_blocks[group_id].append(x + len(block_ids)) + new_block_ids[group_id].extend(block_ids) + # If the request is resumed from preemption, we need to + # overwrite the existing block IDs. + overwrite.append(cached_reqs.resumed_from_preemption[i]) + + self.req_states.num_computed_tokens.np[req_index] = ( + cached_reqs.num_computed_tokens[i]) + + if req_indices: + self.block_tables.append_block_ids( + req_indices=req_indices, + cu_num_new_blocks=cu_num_new_blocks, + new_block_ids=new_block_ids, + overwrite=overwrite, + ) + + def _init_mrope_positions(self, req_id: str) -> None: + req_idx = self.req_states.req_id_to_index[req_id] + req_data = self.req_states.req_data[req_idx] + prompt_len = self.req_states.num_prompt_tokens.np[req_idx] + prompt_token_ids = self.req_states.token_ids.np[req_idx, :prompt_len] + + image_grid_thw = [] + video_grid_thw = [] + second_per_grid_ts = [] + audio_feature_lengths = [] + use_audio_in_video = False + for mm_item in req_data.mm_kwargs: + mm_input = mm_item.get_data() + if (t := mm_input.get("image_grid_thw")) is not None: + image_grid_thw.append(t.tolist()) + if (t := mm_input.get("video_grid_thw")) is not None: + video_grid_thw.append(t.tolist()) + if (t := mm_input.get("second_per_grid_ts")) is not None: + second_per_grid_ts.append(t) + if (t := mm_input.get("audio_feature_lengths")) is not None: + audio_feature_lengths.append(t) + if mm_input.get("use_audio_in_video") is True: + use_audio_in_video = True + + req_data.mrope_positions, req_data.mrope_position_delta = \ + MRotaryEmbedding.get_input_positions_tensor( + prompt_token_ids, + hf_config=self.model_config.hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + + def _extract_mm_kwargs( + self, + scheduler_output: "SchedulerOutput", + ) -> BatchedTensorInputs: + if not scheduler_output or not self.is_multimodal_raw_input_only_model: + return {} + + mm_kwargs = list[MultiModalKwargsItem]() + for req in scheduler_output.scheduled_new_reqs: + mm_kwargs.extend(req.mm_kwargs) + + # Input all modalities at once + mm_kwargs_combined: BatchedTensorInputs = {} + for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + ): + mm_kwargs_combined.update(mm_kwargs_group) + + return mm_kwargs_combined + + def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs: + if not self.is_multimodal_raw_input_only_model: + return {} + + mm_budget = self.mm_budget + assert mm_budget is not None + + dummy_modality = mm_budget.get_modality_with_max_tokens() + return self._get_mm_dummy_batch(dummy_modality, num_seqs) + + def _prepare_inputs( + self, + scheduler_output: "SchedulerOutput", + ) -> InputBatch: + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 + num_reqs = len(scheduler_output.num_scheduled_tokens) + + # batch_idx -> req_id + req_ids = sorted(scheduler_output.num_scheduled_tokens, + key=scheduler_output.num_scheduled_tokens.get) + # batch_idx -> req_idx + idx_mapping_list = [ + self.req_states.req_id_to_index[req_id] for req_id in req_ids + ] + self.idx_mapping.np[:num_reqs] = idx_mapping_list + idx_mapping_np = self.idx_mapping.np[:num_reqs] + idx_mapping = self.idx_mapping.copy_to_gpu(num_reqs) + # req_id -> batch_idx + req_id_to_batch_idx = {req_id: i for i, req_id in enumerate(req_ids)} + + # OPTIMIZATION: Start copying the block table first. + # This way, we can overlap the copy with the following CPU operations. + block_tables = self.block_tables.compute_block_tables(idx_mapping) + + # Get the number of scheduled tokens for each request. + num_scheduled_tokens = np.array( + [scheduler_output.num_scheduled_tokens[i] for i in req_ids], + dtype=np.int32) + + prepare_inputs( + idx_mapping_np, + self.req_states.token_ids.np, + self.req_states.num_computed_tokens.np, + num_scheduled_tokens, + self.input_ids.np, + self.query_start_loc.np, + self.seq_lens.np, + self.positions.np, + ) + self.input_ids.copy_to_gpu(total_num_scheduled_tokens) + self.positions.copy_to_gpu(total_num_scheduled_tokens) + + # NOTE(woosuk): We should copy the whole query_start_loc and seq_lens + # tensors from CPU to GPU, because they may include paddings needed + # for full CUDA graph mode. + self.query_start_loc.copy_to_gpu() + self.seq_lens.copy_to_gpu() + query_start_loc = self.query_start_loc.gpu[:num_reqs + 1] + max_query_len = int(num_scheduled_tokens.max()) + seq_lens = self.seq_lens.gpu[:num_reqs] + max_seq_len = int(self.seq_lens.np[:num_reqs].max()) + + # Compute the slot mappings on GPUs. + slot_mappings = self.block_tables.compute_slot_mappings( + query_start_loc, self.positions.gpu[:total_num_scheduled_tokens]) + + if self.uses_mrope: + self._calc_mrope_positions(req_ids, num_scheduled_tokens) + # Optimization: To avoid gather and scatter, copy the whole M-RoPE + # tensor from CPU to GPU although only a part of it is used. + self.mrope_positions.copy_to_gpu() + + use_spec_decode = len( + scheduler_output.scheduled_spec_decode_tokens) > 0 + if not use_spec_decode: + # NOTE(woosuk): Due to chunked prefills, the batch may contain + # partial requests. While we should not sample any token + # from these partial requests, we do so for simplicity. + # We will ignore the sampled tokens from the partial requests. + logits_indices = query_start_loc[1:] - 1 + spec_decode_metadata = None + else: + # Get the number of draft tokens for each request. + spec_decode_metadata = self._prepare_spec_decode_metadata( + req_ids, + scheduler_output.scheduled_spec_decode_tokens, + query_start_loc, + ) + logits_indices = spec_decode_metadata.logits_indices + + logits_indices_padded = None + if self.cache_config.kv_sharing_fast_prefill: + logits_indices_padded = self._prepare_kv_sharing_fast_prefill( + logits_indices) + + # Used in the below loop. + query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1] + seq_lens_cpu = self.seq_lens.cpu[:num_reqs] + num_computed_tokens_np = self.req_states.num_computed_tokens.np[ + idx_mapping_np] + num_computed_tokens_cpu = torch.from_numpy(num_computed_tokens_np) + spec_decode_common_attn_metadata = None + + attn_metadata: dict[str, Any] = {} + # Prepare the attention metadata for each KV cache group and make layers + # in the same group share the same metadata. + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups): + + if isinstance(kv_cache_group_spec.kv_cache_spec, + EncoderOnlyAttentionSpec): + # Encoder-only layers do not have KV cache, so we need to + # create a dummy block table and slot mapping for them. + blk_table_tensor = torch.zeros( + (num_reqs, 1), + dtype=torch.int32, + device=self.device, + ) + slot_mapping = torch.zeros( + (total_num_scheduled_tokens, ), + dtype=torch.int64, + device=self.device, + ) + num_common_prefix_blocks = 0 + else: + blk_table_tensor = block_tables[kv_cache_group_id] + slot_mapping = slot_mappings[kv_cache_group_id] + num_common_prefix_blocks = ( + scheduler_output. + num_common_prefix_blocks[kv_cache_group_id]) + + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=query_start_loc, + query_start_loc_cpu=query_start_loc_cpu, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, + num_computed_tokens_cpu=num_computed_tokens_cpu, + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_query_len, + max_seq_len=max_seq_len, + block_table_tensor=blk_table_tensor, + slot_mapping=slot_mapping, + logits_indices_padded=logits_indices_padded, + num_logits_indices=logits_indices.size(0), + causal=True, + ) + + if self.speculative_config and \ + spec_decode_common_attn_metadata is None: + spec_decode_common_attn_metadata = common_attn_metadata + + for attn_group in self.attn_groups[kv_cache_group_id]: + # Prepare for cascade attention if enabled & beneficial. + common_prefix_len = 0 + builder = attn_group.metadata_builder + if self.cascade_attn_enabled: + common_prefix_len = self._compute_cascade_attn_prefix_len( + num_scheduled_tokens, + num_computed_tokens_np, + num_common_prefix_blocks, + kv_cache_group_spec.kv_cache_spec, + builder, + ) + + attn_metadata_i = (builder.build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + )) + + for layer_name in attn_group.layer_names: + attn_metadata[layer_name] = attn_metadata_i + + # # Hot-Swap lora model + # if self.lora_config: + # self.set_active_loras(input_batch, num_scheduled_tokens) + + return InputBatch( + req_ids=req_ids, + num_scheduled_tokens=num_scheduled_tokens, + req_id_to_batch_idx=req_id_to_batch_idx, + idx_mapping=idx_mapping, + idx_mapping_np=idx_mapping_np, + num_reqs=num_reqs, + total_num_tokens=total_num_scheduled_tokens, + max_query_len=max_query_len, + attn_metadata=attn_metadata, + spec_decode_metadata=spec_decode_metadata, + spec_decode_common_attn_metadata=spec_decode_common_attn_metadata, + logits_indices=logits_indices, + ) + + def _compute_cascade_attn_prefix_len( + self, + num_scheduled_tokens: np.ndarray, + num_computed_tokens: np.ndarray, + num_common_prefix_blocks: int, + kv_cache_spec: KVCacheSpec, + attn_metadata_builder: AttentionMetadataBuilder, + ) -> int: + """Compute the length of the common prefix for cascade attention. + + NOTE(woosuk): The common prefix length returned by this function + represents the length used specifically for cascade attention, not the + actual number of tokens shared between requests. When cascade attention + is disabled (use_cascade=False), this function returns 0 even if + requests share common tokens. Additionally, the common prefix length is + truncated to a multiple of the block size and may be further truncated + due to implementation details explained below. + + Args: + num_scheduled_tokens: Number of tokens scheduled per request. + num_common_prefix_blocks: Number of shared KV cache blocks. + + Returns: + int: Length of common prefix in tokens. + """ + common_prefix_len = num_common_prefix_blocks * kv_cache_spec.block_size + if common_prefix_len == 0: + # Common case. + return 0 + + # NOTE(woosuk): Cascade attention uses two attention kernels: one + # for the common prefix and the other for the rest. For the first + # kernel, we concatenate all the query tokens (possibly from + # different requests) and treat them as if they are from the same + # request. Then, we use bi-directional attention to process the + # common prefix in the KV cache. Importantly, this means that the + # first kernel does not do any masking. + + # Consider the following example: + # Request 1's input query: [D, E, X] + # Request 1's kv cache: [A, B, C, D, E, X] + # Request 1's num_computed_tokens: 3 (i.e., [A, B, C]) + # Request 2's input query: [E, Y] + # Request 2's kv cache: [A, B, C, D, E, Y] + # Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D]) + + # If we use [A, B, C, D, E] as the common prefix, then the + # first kernel will compute the bi-directional attention between + # input query [D, E, X, E, Y] and common prefix [A, B, C, D, E]. + # However, this is wrong because D in Request 1 should not attend to + # E in the common prefix (i.e., we need masking). + # To avoid this, [A, B, C, D] should be the common prefix. + # That is, the common prefix should be capped by the minimum + # num_computed_tokens among the requests, and plus one to include + # the first token of the query. + + # In practice, we use [A, B, C] as the common prefix, instead of + # [A, B, C, D] (i.e., the common prefix is capped by the minimum + # num_computed_tokens, without plus one). + # This is because of an implementation detail: We want to always + # use two kernels for cascade attention. Let's imagine: + # Request 3's input query: [D] + # Request 3's kv cache: [A, B, C, D] + # Request 3's num_computed_tokens: 3 (i.e., [A, B, C]) + # If we use [A, B, C, D] as the common prefix for Request 1-3, + # then Request 3 will be processed only by the first kernel, + # and the second kernel will get an empty input. While this is not + # a fundamental problem, our current implementation does not support + # this case. + common_prefix_len = min(common_prefix_len, num_computed_tokens.min()) + # common_prefix_len should be a multiple of the block size. + common_prefix_len = (common_prefix_len // kv_cache_spec.block_size * + kv_cache_spec.block_size) + use_sliding_window = (isinstance(kv_cache_spec, SlidingWindowSpec) or + (isinstance(kv_cache_spec, FullAttentionSpec) + and kv_cache_spec.sliding_window is not None)) + use_local_attention = ( + isinstance(kv_cache_spec, ChunkedLocalAttentionSpec) + or (isinstance(kv_cache_spec, FullAttentionSpec) + and kv_cache_spec.attention_chunk_size is not None)) + assert isinstance(kv_cache_spec, AttentionSpec) + use_cascade = attn_metadata_builder.use_cascade_attention( + common_prefix_len=common_prefix_len, + query_lens=num_scheduled_tokens, + num_query_heads=self.num_query_heads, + num_kv_heads=kv_cache_spec.num_kv_heads, + use_alibi=self.use_alibi, + use_sliding_window=use_sliding_window, + use_local_attention=use_local_attention, + num_sms=self.num_sms, + ) + return common_prefix_len if use_cascade else 0 + + def _prepare_spec_decode_metadata( + self, + req_ids: list[str], + req_id_to_draft_token_ids: dict[str, list[int]], + query_start_loc: torch.Tensor, + ) -> SpecDecodeMetadata: + # Get the number of draft tokens for each request. + num_reqs = len(req_ids) + num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) + for i, req_id in enumerate(req_ids): + draft_token_ids = req_id_to_draft_token_ids.get(req_id) + if draft_token_ids: + num_draft_tokens[i] = len(draft_token_ids) + np.cumsum(num_draft_tokens, + dtype=np.int32, + out=self.cu_num_draft_tokens.np[:num_reqs]) + cu_num_draft_tokens = self.cu_num_draft_tokens.copy_to_gpu(num_reqs) + return self.req_states.make_spec_decode_metadata( + query_start_loc, + cu_num_draft_tokens, + cu_num_draft_tokens.np[:num_reqs], + self.input_ids.gpu, + ) + + def _calc_mrope_positions( + self, + req_ids: list[str], + query_lens: np.ndarray, + ): + mrope_pos_ptr = 0 + for i, req_id in enumerate(req_ids): + req_idx = self.req_states.req_id_to_index[req_id] + req_data = self.req_states.req_data[req_idx] + assert req_data.mrope_positions is not None + + num_computed_tokens = self.req_states.num_computed_tokens.np[req_idx] + num_scheduled_tokens = query_lens[i] + num_prompt_tokens = self.req_states.num_prompt_tokens.np[req_idx] + + if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens: + prompt_part_len = max(0, + num_prompt_tokens - num_computed_tokens) + completion_part_len = max( + 0, num_scheduled_tokens - prompt_part_len) + else: + prompt_part_len = num_scheduled_tokens + completion_part_len = 0 + + assert num_scheduled_tokens == prompt_part_len + completion_part_len + + if prompt_part_len > 0: + # prompt's mrope_positions are pre-computed + dst_start = mrope_pos_ptr + dst_end = mrope_pos_ptr + prompt_part_len + src_start = num_computed_tokens + src_end = num_computed_tokens + prompt_part_len + + self.mrope_positions.cpu[:, dst_start:dst_end] = ( + req_data.mrope_positions[:, src_start:src_end]) + mrope_pos_ptr += prompt_part_len + + if completion_part_len > 0: + # compute completion's mrope_positions on-the-fly + dst_start = mrope_pos_ptr + dst_end = mrope_pos_ptr + completion_part_len + + MRotaryEmbedding.get_next_input_positions_tensor( + out=self.mrope_positions.np, + out_offset=dst_start, + mrope_position_delta=req_data.mrope_position_delta, + context_len=num_computed_tokens + prompt_part_len, + num_new_tokens=completion_part_len, + ) + + mrope_pos_ptr += completion_part_len + + def _prepare_kv_sharing_fast_prefill( + self, + logits_indices: torch.Tensor, + ) -> torch.Tensor: + assert self.kv_sharing_fast_prefill_logits_indices is not None + num_logits = logits_indices.shape[0] + assert num_logits > 0 + self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_( + logits_indices) + # There might have leftover indices in logits_indices[num_logits:] + # from previous iterations, whose values may be greater than the + # batch size in the current iteration. To ensure indices are always + # valid, we fill the padded indices with the last index. + self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_( + logits_indices[-1].item()) + if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and num_logits <= self.cudagraph_batch_sizes[-1]): + # Use piecewise CUDA graphs. + # Add padding to the batch size. + num_logits_padded = self.vllm_config.pad_for_cudagraph(num_logits) + else: + num_logits_padded = num_logits + logits_indices_padded = ( + self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded]) + return logits_indices_padded + + def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): + scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs + if not scheduled_encoder_inputs: + return + # Batch the multi-modal inputs. + mm_kwargs = list[MultiModalKwargsItem]() + # list of tuple (mm_hash, position_info) + mm_hashes_pos = list[tuple[str, PlaceholderRange]]() + for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): + req_idx = self.req_states.req_id_to_index[req_id] + req_data = self.req_states.req_data[req_idx] + + for mm_input_id in encoder_input_ids: + mm_hash = req_data.mm_hashes[mm_input_id] + mm_kwargs.append(req_data.mm_kwargs[mm_input_id]) + mm_hashes_pos.append( + (mm_hash, req_data.mm_positions[mm_input_id])) + + # Batch mm inputs as much as we can: if a request in the batch has + # multiple modalities or a different modality than the previous one, + # we process it separately to preserve item order. + # FIXME(ywang96): This is a hacky way to deal with multiple modalities + # in the same batch while still being able to benefit from batching + # multimodal inputs. The proper solution should be reordering the + # encoder outputs. + encoder_outputs = [] + for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + ): + # Run the encoder. + # `curr_group_outputs` is either of the following: + # 1. A tensor of shape (num_items, feature_size, hidden_size) + # in case feature_size is fixed across all multimodal items. + # 2. A list or tuple (length: num_items) of tensors, each of shape + # (feature_size, hidden_size) in case the feature size is dynamic + # depending on the input multimodal items. + curr_group_outputs = self.model.get_multimodal_embeddings( + **mm_kwargs_group) + + sanity_check_mm_encoder_outputs( + curr_group_outputs, + expected_num_items=num_items, + ) + + for output in curr_group_outputs: + encoder_outputs.append(output) + + # Cache the encoder outputs by mm_hash + for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs): + self.encoder_cache[mm_hash] = scatter_mm_placeholders( + output, + is_embed=pos_info.is_embed, + ) + + def _gather_mm_embeddings( + self, + input_batch: InputBatch, + shift_computed_tokens: int = 0, + ) -> list[torch.Tensor]: + mm_embeds: list[torch.Tensor] = [] + for i, req_id in enumerate(input_batch.req_ids): + num_scheduled_tokens = input_batch.num_scheduled_tokens[i] + req_idx = self.req_states.req_id_to_index[req_id] + num_computed_tokens = ( + self.req_states.num_computed_tokens.np[req_idx] + + shift_computed_tokens) + req_data = self.req_states.req_data[req_idx] + mm_positions = req_data.mm_positions + mm_hashes = req_data.mm_hashes + for i, pos_info in enumerate(mm_positions): + start_pos = pos_info.offset + num_encoder_tokens = pos_info.length + + # The encoder output is needed if the two ranges overlap: + # [num_computed_tokens, + # num_computed_tokens + num_scheduled_tokens) and + # [start_pos, start_pos + num_encoder_tokens) + if start_pos >= num_computed_tokens + num_scheduled_tokens: + # The encoder output is not needed in this step. + break + if start_pos + num_encoder_tokens <= num_computed_tokens: + # The encoder output is already processed and stored + # in the decoder's KV cache. + continue + + start_idx = max(num_computed_tokens - start_pos, 0) + end_idx = min( + num_computed_tokens - start_pos + num_scheduled_tokens, + num_encoder_tokens, + ) + assert start_idx < end_idx + + mm_hash = mm_hashes[i] + encoder_output = self.encoder_cache.get(mm_hash, None) + assert encoder_output is not None,\ + f"Encoder cache miss for {mm_hash}." + + if (is_embed := pos_info.is_embed) is not None: + is_embed = is_embed[start_idx:end_idx] + + mm_embeds_item = gather_mm_placeholders( + encoder_output[start_idx:end_idx], + is_embed=is_embed, + ) + mm_embeds.append(mm_embeds_item) + return mm_embeds + + def get_model(self) -> nn.Module: + # get raw model out of the cudagraph wrapper. + if isinstance(self.model, CUDAGraphWrapper): + return self.model.unwrap() + return self.model + + def get_supported_generation_tasks(self) -> list[GenerationTask]: + model = self.get_model() + supported_tasks = list[GenerationTask]() + + if is_text_generation_model(model): + supported_tasks.append("generate") + + if supports_transcription(model): + if model.supports_transcription_only: + return ["transcription"] + + supported_tasks.append("transcription") + + return supported_tasks + + def get_supported_pooling_tasks(self) -> list[PoolingTask]: + model = self.get_model() + if not is_pooling_model(model): + return [] + + supported_tasks = list(model.pooler.get_supported_tasks()) + + if (self.scheduler_config.chunked_prefill_enabled + and "encode" in supported_tasks): + supported_tasks.remove("encode") + + logger.debug_once("Chunked prefill is not supported with " + "encode task which using ALL pooling. " + "Please turn off chunked prefill by " + "`--no-enable-chunked-prefill` before using it.") + + if "score" in supported_tasks: + num_labels = getattr(self.model_config.hf_config, "num_labels", 0) + if num_labels != 1: + supported_tasks.remove("score") + logger.debug_once( + "Score API is only enabled for num_labels == 1.") + + return supported_tasks + + def get_supported_tasks(self) -> tuple[SupportedTask, ...]: + tasks = list[SupportedTask]() + + if self.model_config.runner_type == "generate": + tasks.extend(self.get_supported_generation_tasks()) + if self.model_config.runner_type == "pooling": + tasks.extend(self.get_supported_pooling_tasks()) + + return tuple(tasks) + + def apply_grammar_bitmask( + self, + scheduler_output: "SchedulerOutput", + logits: torch.Tensor, + ): + grammar_bitmask = scheduler_output.grammar_bitmask + if grammar_bitmask is None: + return + + # We receive the structured output bitmask from the scheduler, + # compacted to contain bitmasks only for structured output requests. + # The order of the requests in the bitmask is not guaranteed to be the + # same as the order of the requests in the gpu runner's batch. We need + # to sort the bitmask to match the order of the requests used here. + + # Get the batch indices of the structured output requests. + # Keep track of the number of speculative tokens scheduled for every + # request in the batch, as the logit indices are offset by this amount. + struct_out_req_batch_indices: dict[str, int] = {} + cumulative_offset = 0 + seq = sorted(self.req_states.req_id_to_index.items(), key=lambda x: x[1]) + for req_id, batch_index in seq: + logit_index = batch_index + cumulative_offset + cumulative_offset += len( + scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) + if req_id in scheduler_output.structured_output_request_ids: + struct_out_req_batch_indices[req_id] = logit_index + + out_indices = [] + + # Reorder the bitmask to match the order of the requests in the batch. + sorted_bitmask = np.full(shape=(logits.shape[0], + grammar_bitmask.shape[1]), + fill_value=-1, + dtype=grammar_bitmask.dtype) + cumulative_index = 0 + seq = sorted(scheduler_output.structured_output_request_ids.items(), + key=lambda x: x[1]) + for req_id, _ in seq: + logit_index = struct_out_req_batch_indices[req_id] + num_spec_tokens = len( + scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) + for i in range(1 + num_spec_tokens): + sorted_bitmask[logit_index + i] = \ + grammar_bitmask[cumulative_index + i] + out_indices.append(logit_index + i) + cumulative_index += 1 + num_spec_tokens + grammar_bitmask = sorted_bitmask + + # If the length of out indices and the logits have the same shape + # we don't need to pass indices to the kernel, + # since the bitmask is already aligned with the logits. + skip_out_indices = len(out_indices) == logits.shape[0] + + # Serialization of np.ndarray is much more efficient than a tensor, + # so we receive it in that format. + grammar_bitmask = torch.from_numpy(grammar_bitmask).contiguous() + + xgr.apply_token_bitmask_inplace( + logits, + grammar_bitmask.to(self.device, non_blocking=True), + indices=out_indices if not skip_out_indices else None, + ) + + def sync_and_slice_intermediate_tensors( + self, num_tokens: int, intermediate_tensors: IntermediateTensors, + sync_self: bool) -> IntermediateTensors: + + assert self.intermediate_tensors is not None + + tp = self.vllm_config.parallel_config.tensor_parallel_size + enabled_sp = self.compilation_config.pass_config. \ + enable_sequence_parallelism + if enabled_sp: + # When sequence parallelism is enabled, we always pad num_tokens + # to be a multiple of tensor_parallel_size (tp) earlier + assert num_tokens % tp == 0 + is_residual_scattered = tp > 1 and enabled_sp \ + and num_tokens % tp == 0 + + # When sequence parallelism is enabled, the "residual" tensor is sharded + # across tensor parallel ranks, so each rank only needs its own slice. + if sync_self: + assert intermediate_tensors is not None + for k, v in intermediate_tensors.items(): + is_scattered = k == "residual" and is_residual_scattered + copy_len = num_tokens // tp if is_scattered else \ + num_tokens + self.intermediate_tensors[k][:copy_len].copy_( + v[:copy_len], non_blocking=True) + + return IntermediateTensors({ + k: + v[:num_tokens // tp] + if k == "residual" and is_residual_scattered else v[:num_tokens] + for k, v in self.intermediate_tensors.items() + }) + + def eplb_step(self, + is_dummy: bool = False, + is_profile: bool = False) -> None: + """ + Step for the EPLB (Expert Parallelism Load Balancing) state. + """ + if not self.parallel_config.enable_eplb: + return + + assert self.eplb_state is not None + model = self.get_model() + assert is_mixture_of_experts(model) + self.eplb_state.step( + model, + is_dummy, + is_profile, + log_stats=self.parallel_config.eplb_config.log_balancedness, + ) + + def get_dp_padding(self, + num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: + dp_size = self.vllm_config.parallel_config.data_parallel_size + dp_rank = self.vllm_config.parallel_config.data_parallel_rank + + # For DP: Don't pad when setting enforce_eager. + # This lets us set enforce_eager on the prefiller in a P/D setup and + # still use CUDA graphs (enabled by this padding) on the decoder. + # + # TODO(tms) : There are many cases where padding is enabled for + # prefills, causing unnecessary and excessive padding of activations. + + if dp_size == 1 or self.vllm_config.model_config.enforce_eager: + # Early exit. + return 0, None + + num_tokens_across_dp = DPMetadata.num_tokens_across_dp( + num_tokens, dp_size, dp_rank) + max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item() + num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] * + dp_size, + device="cpu", + dtype=torch.int32) + return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding + + def _pool( + self, + hidden_states: torch.Tensor, + num_scheduled_tokens: int, + num_scheduled_tokens_np: np.ndarray, + kv_connector_output: Optional[KVConnectorOutput], + ) -> ModelRunnerOutput: + assert self.input_batch.num_reqs ==\ + len(self.input_batch.pooling_params), \ + "Either all or none of the requests in" \ + " a batch must be pooling request" + + hidden_states = hidden_states[:num_scheduled_tokens] + pooling_metadata = self.input_batch.get_pooling_metadata() + pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(), + device=hidden_states.device) + seq_lens_cpu = self.seq_lens.cpu[:self.input_batch.num_reqs] + + # Pooling models D2H & synchronize occurs in pooler.py:build_output + raw_pooler_output = self.model.pooler( + hidden_states=hidden_states, pooling_metadata=pooling_metadata) + + pooler_output: list[Optional[torch.Tensor]] = [] + for raw_output, seq_len, prompt_len in zip( + raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens): + + output = raw_output.data if seq_len == prompt_len else None + pooler_output.append(output) + + return ModelRunnerOutput( + req_ids=self.input_batch.req_ids, + req_id_to_index=self.input_batch.req_id_to_batch_idx, + sampled_token_ids=[], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=pooler_output, + kv_connector_output=kv_connector_output, + ) + + def _preprocess( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]: + self._update_states(scheduler_output) + if not scheduler_output.total_num_scheduled_tokens: + if not has_kv_transfer_group(): + # Return empty ModelRunnerOutput if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + + return self.kv_connector_no_forward(scheduler_output, + self.vllm_config) + + # if self.cache_config.kv_sharing_fast_prefill: + # assert not self.input_batch.num_prompt_logprobs, ( + # "--kv-sharing-fast-prefill produces incorrect logprobs for " + # "prompt tokens, tokens, please disable it when the requests " + # "need prompt logprobs") + + # Prepare the decoder inputs. + input_batch = self._prepare_inputs(scheduler_output) + + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and not envs.VLLM_DISABLE_PAD_FOR_CUDAGRAPH + and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): + # Use CUDA graphs. + # Add padding to the batch size. + num_input_tokens = self.vllm_config.pad_for_cudagraph( + num_scheduled_tokens) + else: + # Eager mode. + # Pad tokens to multiple of tensor_parallel_size when + # enabled collective fusion for SP + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + if self.compilation_config.pass_config. \ + enable_sequence_parallelism and tp_size > 1: + num_input_tokens = round_up(num_scheduled_tokens, tp_size) + else: + num_input_tokens = num_scheduled_tokens + + # Padding for DP + num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens) + num_input_tokens += num_pad + + # _prepare_inputs may reorder the batch, so we must gather multi + # modal outputs after that to ensure the correct order + if self.supports_mm_inputs and get_pp_group().is_first_rank: + # Run the multimodal encoder if any. + self._execute_mm_encoder(scheduler_output) + mm_embeds = self._gather_mm_embeddings(scheduler_output) + + # NOTE(woosuk): To unify token ids and soft tokens (vision + # embeddings), we always use embeddings (rather than token ids) + # as input to the multimodal model, even when the input is text. + inputs_embeds_scheduled = self.model.get_input_embeddings( + input_ids=self.input_ids.gpu[:num_scheduled_tokens], + multimodal_embeddings=mm_embeds or None, + ) + + # TODO(woosuk): Avoid the copy. Optimize. + self.inputs_embeds.gpu[:num_scheduled_tokens].copy_( + inputs_embeds_scheduled) + + input_ids = None + inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] + model_kwargs = { + **self._init_model_kwargs(num_scheduled_tokens), + **self._extract_mm_kwargs(scheduler_output), + } + else: + # For text-only models, we use token ids as input. + # While it is possible to use embeddings as input just like the + # multimodal models, it is not desirable for performance since + # then the embedding layer is not included in the CUDA graph. + input_ids = self.input_ids.gpu[:num_input_tokens] + inputs_embeds = None + model_kwargs = self._init_model_kwargs(num_input_tokens) + if self.uses_mrope: + positions = self.mrope_positions.gpu[:, :num_input_tokens] + else: + positions = self.positions.gpu[:num_input_tokens] + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + intermediate_tensors = self.sync_and_slice_intermediate_tensors( + num_input_tokens, intermediate_tensors, True) + + uniform_decode = (input_batch.max_query_len + == self.uniform_decode_query_len + and num_scheduled_tokens + == input_batch.num_reqs * input_batch.max_query_len) + batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, + uniform_decode=uniform_decode) + cudagraph_runtime_mode, batch_descriptor = \ + self.cudagraph_dispatcher.dispatch(batch_descriptor) + + # Run the model. + # Use persistent buffers for CUDA graphs. + with set_forward_context( + input_batch.attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, + ), self.maybe_get_kv_connector_output( + scheduler_output) as kv_connector_output: + + model_output = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, + ) + + if self.use_aux_hidden_state_outputs: + hidden_states, aux_hidden_states = model_output + else: + hidden_states = model_output + aux_hidden_states = None + + # Broadcast PP output for external_launcher (torchrun) + # to make sure we are synced across pp ranks + # TODO: Support overlapping mirco-batches + # https://github.com/vllm-project/vllm/issues/18019 + broadcast_pp_output = \ + self.parallel_config.distributed_executor_backend \ + == "external_launcher" and len(get_pp_group().ranks) > 0 + if not get_pp_group().is_last_rank: + # For mid-pipeline stages, return the hidden states. + assert isinstance(hidden_states, IntermediateTensors) + if not broadcast_pp_output: + hidden_states.kv_connector_output = kv_connector_output + return hidden_states + get_pp_group().send_tensor_dict(hidden_states.tensors, + all_gather_group=get_tp_group()) + logits = None + else: + sample_hidden_states = hidden_states[input_batch.logits_indices] + logits = self.model.compute_logits(sample_hidden_states, None) + if broadcast_pp_output: + model_output_broadcast_data = { + "logits": logits.contiguous(), + } if logits is not None else {} + model_output_broadcast_data = get_pp_group().broadcast_tensor_dict( + model_output_broadcast_data, src=len(get_pp_group().ranks) - 1) + assert model_output_broadcast_data is not None + logits = model_output_broadcast_data["logits"] + + # Apply structured output bitmasks if present + if scheduler_output.grammar_bitmask is not None: + self.apply_grammar_bitmask(scheduler_output, logits) + + def _sample( + self, logits: Optional[torch.Tensor], + spec_decode_metadata: Optional[SpecDecodeMetadata] + ) -> SamplerOutput: + # Sample the next token and get logprobs if needed. + sampling_metadata = self.req_states.make_sampling_metadata( + input_batch.idx_mapping) + if input_batch.spec_decode_metadata is None: + sampler_output = self.sampler( + logits=logits, + sampling_metadata=sampling_metadata, + ) + else: + # When indexing with a tensor (bonus_logits_indices), PyTorch + # creates a new tensor with separate storage from the original + # logits tensor. This means any in-place operations on bonus_logits + # won't affect the original logits tensor. + assert logits is not None + bonus_logits = logits[ + input_batch.spec_decode_metadata.bonus_logits_indices] + sampler_output = self.sampler( + logits=bonus_logits, + sampling_metadata=sampling_metadata, + ) + bonus_token_ids = sampler_output.sampled_token_ids + + # Just like `bonus_logits`, `target_logits` is a new tensor with + # separate storage from the original `logits` tensor. Therefore, + # it is safe to update `target_logits` in place. + target_logits = logits[ + input_batch.spec_decode_metadata.target_logits_indices] + output_token_ids = self.rejection_sampler( + input_batch.spec_decode_metadata, + None, # draft_probs + target_logits, + bonus_token_ids, + sampling_metadata, + ) + sampler_output.sampled_token_ids = output_token_ids + + return sampler_output + + def _bookkeeping_sync( + self, scheduler_output: "SchedulerOutput", + sampler_output: SamplerOutput, logits: Optional[torch.Tensor], + hidden_states: torch.Tensor, num_scheduled_tokens: int + ) -> tuple[ + dict[str, int], + Optional[LogprobsLists], + list[list[int]], + dict[str, Optional[LogprobsTensors]], + list[str], + dict[str, int], + list[int], + ]: + num_nans_in_logits = {} + if envs.VLLM_COMPUTE_NANS_IN_LOGITS: + num_nans_in_logits = self._get_nans_in_logits( + input_batch.req_ids, logits) + + # TODO(woosuk): The following loop can be slow since it iterates over + # the requests one by one. Optimize. + discard_sampled_tokens_req_indices: list[int] = [] + for i, req_id in enumerate(input_batch.req_ids): + req_idx = self.req_states.req_id_to_index[req_id] + seq_len = (self.req_states.num_computed_tokens.np[req_idx] + + input_batch.num_scheduled_tokens[i]) + if seq_len < self.req_states.num_tokens.np[req_idx]: + # Ignore the sampled token for partial prefills. + # Rewind the generator state as if the token was not sampled. + # This relies on cuda-specific torch-internal impl details + generator = self.req_states.generators.get(req_idx) + if generator is not None: + generator.set_offset(generator.get_offset() - 4) + # Record the index of the request that should not be sampled, + # so that we could clear the sampled tokens before returning. + discard_sampled_tokens_req_indices.append(i) + + # Copy some objects so they don't get modified after returning. + # This is important when using async scheduling. + req_ids_output_copy = self.input_batch.req_ids.copy() + req_id_to_index_output_copy = \ + self.input_batch.req_id_to_index.copy() + + # NOTE: GPU -> CPU Sync happens here. + # Move as many CPU operations as possible before this sync point. + logprobs_tensors = sampler_output.logprobs_tensors + logprobs_lists = logprobs_tensors.tolists() \ + if logprobs_tensors is not None else None + + # Compute prompt logprobs if needed. + prompt_logprobs_dict = self._get_prompt_logprobs_dict( + hidden_states[:num_scheduled_tokens], + scheduler_output.num_scheduled_tokens, + ) + + num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] + sampled_token_ids = sampler_output.sampled_token_ids + invalid_req_indices = [] + if not self.use_async_scheduling: + # Get the valid generated tokens. + max_gen_len = sampled_token_ids.shape[-1] + if max_gen_len == 1: + # No spec decode tokens. + valid_sampled_token_ids = self._to_list(sampled_token_ids) + else: + # Includes spec decode tokens. + valid_sampled_token_ids = self.rejection_sampler.parse_output( + sampled_token_ids, + self.input_batch.vocab_size, + ) + # Mask out the sampled tokens that should not be sampled. + for i in discard_sampled_tokens_req_indices: + valid_sampled_token_ids[i].clear() + else: + # Includes spec decode tokens. + valid_sampled_token_ids = self.rejection_sampler.parse_output( + sampled_token_ids, self.vocab_size) + # Mask out the sampled tokens that should not be sampled. + for i in discard_sampled_tokens_req_indices: + valid_sampled_token_ids[i].clear() + + # Cache the sampled tokens in the model runner, so that the scheduler + # doesn't need to send them back. + # NOTE(woosuk): As an exception, when using PP, the scheduler sends + # the sampled tokens back, because there's no direct communication + # between the first-stage worker and the last-stage worker. + for i, req_id in enumerate(input_batch.req_ids): + sampled_ids = valid_sampled_token_ids[i] + if not sampled_ids: + continue + req_idx = self.req_states.req_id_to_index[req_id] + + start_idx = self.req_states.num_tokens.np[req_idx] + end_idx = start_idx + len(sampled_ids) + assert end_idx <= self.max_model_len, ( + "Sampled token IDs exceed the max model length. " + f"Total number of tokens: {end_idx} > max_model_len: " + f"{self.max_model_len}") + + self.req_states.token_ids.np[req_idx, + start_idx:end_idx] = sampled_ids + self.req_states.num_tokens.np[req_idx] = end_idx + + if self.speculative_config: + assert input_batch.spec_decode_common_attn_metadata is not None + self._draft_token_ids = self.propose_draft_token_ids( + input_batch, + valid_sampled_token_ids, + sampling_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + ) + self._draft_req_ids = input_batch.req_ids + + with record_function_or_nullcontext("Postprocess"): + if self.use_aux_hidden_state_outputs: + hidden_states, aux_hidden_states = model_output + else: + hidden_states = model_output + aux_hidden_states = None + + # Broadcast PP output for external_launcher (torchrun) + # to make sure we are synced across pp ranks + # TODO: Support overlapping mirco-batches + # https://github.com/vllm-project/vllm/issues/18019 + broadcast_pp_output = \ + self.parallel_config.distributed_executor_backend \ + == "external_launcher" and len(get_pp_group().ranks) > 0 + if not get_pp_group().is_last_rank: + # For mid-pipeline stages, return the hidden states. + assert isinstance(hidden_states, IntermediateTensors) + if not broadcast_pp_output: + hidden_states.kv_connector_output = kv_connector_output + return hidden_states + get_pp_group().send_tensor_dict( + hidden_states.tensors, all_gather_group=get_tp_group()) + logits = None + else: + if self.is_pooling_model: + return self._pool(hidden_states, num_scheduled_tokens, + num_scheduled_tokens_np, + kv_connector_output) + + sample_hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(sample_hidden_states, None) + if broadcast_pp_output: + model_output_broadcast_data = { + "logits": logits.contiguous(), + } if logits is not None else {} + model_output_broadcast_data = get_pp_group( + ).broadcast_tensor_dict(model_output_broadcast_data, + src=len(get_pp_group().ranks) - 1) + assert model_output_broadcast_data is not None + logits = model_output_broadcast_data["logits"] + + # Apply structured output bitmasks if present + if scheduler_output.grammar_bitmask is not None: + self.apply_grammar_bitmask(scheduler_output, logits) + + with record_function_or_nullcontext("Sample"): + sampler_output = self._sample(logits, spec_decode_metadata) + + with record_function_or_nullcontext("Bookkeep"): + assert isinstance(hidden_states, torch.Tensor) + ( + num_nans_in_logits, + logprobs_lists, + valid_sampled_token_ids, + prompt_logprobs_dict, + req_ids_output_copy, + req_id_to_index_output_copy, + invalid_req_indices, + ) = self._bookkeeping_sync(scheduler_output, sampler_output, + logits, hidden_states, + num_scheduled_tokens) + + if self.speculative_config: + assert spec_decode_common_attn_metadata is not None + with record_function_or_nullcontext("Draft"): + self._draft_token_ids = self.propose_draft_token_ids( + scheduler_output, + valid_sampled_token_ids, + self.input_batch.sampling_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + spec_decode_metadata, + spec_decode_common_attn_metadata, + ) + + with record_function_or_nullcontext("EPLB"): + self.eplb_step() + + output = ModelRunnerOutput( + req_ids=req_ids_output_copy, + req_id_to_index=req_id_to_index_output_copy, + sampled_token_ids=valid_sampled_token_ids, + logprobs=logprobs_lists, + prompt_logprobs_dict=prompt_logprobs_dict, + pooler_output=[], + kv_connector_output=kv_connector_output, + num_nans_in_logits=num_nans_in_logits, + ) + + if not self.use_async_scheduling: + return output + + return AsyncGPUModelRunnerOutput( + model_runner_output=output, + sampled_token_ids=sampler_output.sampled_token_ids, + invalid_req_indices=invalid_req_indices, + async_output_copy_stream=self.async_output_copy_stream, + ) + + def take_draft_token_ids(self) -> Optional[DraftTokenIds]: + if self._draft_token_ids is None: + return None + if isinstance(self._draft_token_ids, torch.Tensor): + draft_token_ids = self._draft_token_ids.tolist() + else: + draft_token_ids = self._draft_token_ids + self._draft_token_ids = None + + assert self._draft_req_ids + req_ids = self._draft_req_ids + self._draft_req_ids = None + return DraftTokenIds(req_ids, draft_token_ids) + + def propose_draft_token_ids( + self, + input_batch: InputBatch, + sampled_token_ids: list[list[int]], + sampling_metadata: SamplingMetadata, + hidden_states: torch.Tensor, + sample_hidden_states: torch.Tensor, + aux_hidden_states: Optional[torch.Tensor], + ) -> Union[list[list[int]], torch.Tensor]: + num_scheduled_tokens = input_batch.total_num_tokens + if self.speculative_config.method == "ngram": + assert isinstance(self.drafter, NgramProposer) + draft_token_ids = self.propose_ngram_draft_token_ids( + input_batch, sampled_token_ids) + elif self.speculative_config.method == "medusa": + assert isinstance(self.drafter, MedusaProposer) + if sample_hidden_states.shape[0] == len(sampled_token_ids): + # The input to the target model does not include draft tokens. + hidden_states = sample_hidden_states + else: + indices = [] + offset = 0 + for num_draft, tokens in zip( + input_batch.spec_decode_metadata.num_draft_tokens, + sampled_token_ids): + indices.append(offset + len(tokens) - 1) + offset += num_draft + 1 + indices = torch.tensor(indices, device=self.device) + hidden_states = sample_hidden_states[indices] + + draft_token_ids = self.drafter.propose( + target_hidden_states=hidden_states, + sampling_metadata=sampling_metadata, + ) + elif self.speculative_config.use_eagle(): + assert isinstance(self.drafter, EagleProposer) + # TODO(woosuk): Refactor the loop. + req_ids = input_batch.req_ids + next_token_ids: list[int] = [] + for i, token_ids in enumerate(sampled_token_ids): + if token_ids: + # Common case. + next_token_id = token_ids[-1] + else: + # Partial prefill (rare case). + # Get the next token id from the request state. + req_id = req_ids[i] + req_state = self.req_states[req_id] + seq_len = (req_state.num_computed_tokens + + input_batch.num_scheduled_tokens[i]) + next_token_id = req_state.get_token_id(seq_len) + next_token_ids.append(next_token_id) + next_token_ids = torch.tensor(next_token_ids, + dtype=torch.int32, + device=self.device) + + if input_batch.spec_decode_metadata is None: + # input_ids can be None for multimodal models. + target_token_ids = self.input_ids.gpu[:num_scheduled_tokens] + # TODO(woosuk): Support M-RoPE. + target_positions = self.positions.gpu[:num_scheduled_tokens] + if self.use_aux_hidden_state_outputs: + target_hidden_states = torch.cat( + [h[:num_scheduled_tokens] for h in aux_hidden_states], + dim=-1) + else: + target_hidden_states = hidden_states[:num_scheduled_tokens] + else: + # TODO(woosuk): Refactor this. + num_draft_tokens = ( + input_batch.spec_decode_metadata.num_draft_tokens) + num_rejected_tokens = [ + n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 + for i, n in enumerate(num_draft_tokens) + ] + num_rejected_tokens_cpu = torch.tensor(num_rejected_tokens, + dtype=torch.int32) + common_attn_metadata, token_indices =\ + self.drafter.prepare_inputs( + input_batch.spec_decode_common_attn_metadata, + num_rejected_tokens_cpu) + + target_token_ids = self.input_ids.gpu[token_indices] + # TODO(woosuk): Support M-RoPE. + target_positions = self.positions.gpu[token_indices] + if self.use_aux_hidden_state_outputs: + target_hidden_states = torch.cat( + [h[token_indices] for h in aux_hidden_states], dim=-1) + else: + target_hidden_states = hidden_states[token_indices] + mm_embeds = None + if self.supports_mm_inputs: + mm_embeds = self._gather_mm_embeddings(input_batch, + shift_computed_tokens=1) + + draft_token_ids = self.drafter.propose( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + next_token_ids=next_token_ids, + sampling_metadata=sampling_metadata, + common_attn_metadata=common_attn_metadata, + mm_embeds=mm_embeds, + ) + return draft_token_ids + + def propose_ngram_draft_token_ids( + self, + input_batch: InputBatch, + sampled_token_ids: list[list[int]], + ) -> list[list[int]]: + # TODO(woosuk): Optimize. + draft_token_ids: list[list[int]] = [] + for i, sampled_ids in enumerate(sampled_token_ids): + num_sampled_ids = len(sampled_ids) + if not num_sampled_ids: + # Skip speculative decoding. + draft_token_ids.append([]) + continue + + # # Skip requests that require sampling parameters that are not + # # supported with speculative decoding. + # req_id = input_batch.req_ids[i] + # if req_id in self.requests.spec_decode_unsupported_reqs: + # draft_token_ids.append([]) + # continue + + num_tokens = self.req_states.num_tokens.np[i] + if num_tokens >= self.max_model_len: + # Skip requests that have already reached the max model length. + draft_token_ids.append([]) + continue + + drafter_output = self.drafter.propose( + self.req_states.token_ids.np[i, :num_tokens]) + if drafter_output is None or len(drafter_output) == 0: + draft_token_ids.append([]) + else: + draft_token_ids.append(drafter_output.tolist()) + return draft_token_ids + + def update_config(self, overrides: dict[str, Any]) -> None: + allowed_config_names = {"load_config", "model_config"} + for config_name, config_overrides in overrides.items(): + assert config_name in allowed_config_names, \ + f"Config `{config_name}` not supported. " \ + f"Allowed configs: {allowed_config_names}" + config = getattr(self, config_name) + new_config = update_config(config, config_overrides) + setattr(self, config_name, new_config) + + def load_model(self, eep_scale_up: bool = False) -> None: + """ + Args: + eep_scale_up: the model loading is for elastic EP scale up. + """ + logger.info("Starting to load model %s...", self.model_config.model) + if eep_scale_up: + from vllm.distributed.parallel_state import get_ep_group + num_local_physical_experts = torch.empty(1, + dtype=torch.int32, + device="cpu") + torch.distributed.broadcast(num_local_physical_experts, + group=get_ep_group().cpu_group, + group_src=0) + num_local_physical_experts = int(num_local_physical_experts.item()) + new_ep_size = get_ep_group().world_size + global_expert_load, old_global_expert_indices = ( + EplbState.recv_state()) + num_logical_experts = global_expert_load.shape[1] + self.parallel_config.eplb_config.num_redundant_experts = ( + num_local_physical_experts * new_ep_size - num_logical_experts) + assert old_global_expert_indices.shape[ + 1] % num_local_physical_experts == 0 + old_ep_size = old_global_expert_indices.shape[ + 1] // num_local_physical_experts + rank_mapping = { + old_ep_rank: old_ep_rank + for old_ep_rank in range(old_ep_size) + } + else: + global_expert_load = None + old_global_expert_indices = None + rank_mapping = None + + with DeviceMemoryProfiler() as m: + time_before_load = time.perf_counter() + model_loader = get_model_loader(self.load_config) + logger.info("Loading model from scratch...") + self.model = model_loader.load_model( + vllm_config=self.vllm_config, model_config=self.model_config) + if self.lora_config: + self.model = self.load_lora_model(self.model, + self.model_config, + self.scheduler_config, + self.lora_config, + self.device) + if hasattr(self, "drafter"): + logger.info("Loading drafter model...") + self.drafter.load_model(self.model) + if self.use_aux_hidden_state_outputs: + if supports_eagle3(self.model): + self.model.set_aux_hidden_state_layers( + self.model.get_eagle3_aux_hidden_state_layers()) + else: + raise RuntimeError( + "Model does not support EAGLE3 interface but " + "aux_hidden_state_outputs was requested") + time_after_load = time.perf_counter() + self.model_memory_usage = m.consumed_memory + logger.info("Model loading took %.4f GiB and %.6f seconds", + self.model_memory_usage / GiB_bytes, + time_after_load - time_before_load) + prepare_communication_buffer_for_model(self.model) + + if is_mixture_of_experts( + self.model) and self.parallel_config.enable_eplb: + logger.info("EPLB is enabled for model %s.", + self.model_config.model) + self.eplb_state = EplbState.build( + self.model, + self.device, + self.parallel_config, + global_expert_load, + old_global_expert_indices, + rank_mapping, + ) + + if ( + self.vllm_config.compilation_config.level == \ + CompilationLevel.DYNAMO_AS_IS and supports_dynamo() + ): + backend = self.vllm_config.compilation_config.init_backend( + self.vllm_config) + compilation_counter.dynamo_as_is_count += 1 + self.model.compile( + fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + backend=backend) + return + # for other compilation levels, cudagraph behavior is controlled by + # CudagraphWraper and CudagraphDispatcher of vllm. + + # wrap the model with full cudagraph wrapper if needed. + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): + self.model = CUDAGraphWrapper(self.model, + self.vllm_config, + runtime_mode=CUDAGraphMode.FULL) + + def reload_weights(self) -> None: + assert getattr(self, "model", None) is not None, \ + "Cannot reload weights before model is loaded." + model_loader = get_model_loader(self.load_config) + logger.info("Reloading weights inplace...") + model = self.get_model() + model_loader.load_weights(model, model_config=self.model_config) + + def save_tensorized_model( + self, + tensorizer_config: "TensorizerConfig", + ) -> None: + model = self.get_model() + TensorizerLoader.save_model( + model, + tensorizer_config=tensorizer_config, + model_config=self.model_config, + ) + + def _get_prompt_logprobs_dict( + self, + hidden_states: torch.Tensor, + num_scheduled_tokens: dict[str, int], + ) -> dict[str, Optional[LogprobsTensors]]: + num_prompt_logprobs_dict = self.req_states.num_prompt_logprobs + if not num_prompt_logprobs_dict: + return {} + + in_progress_dict = self.req_states.in_progress_prompt_logprobs_cpu + prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {} + + # Since prompt logprobs are a rare feature, prioritize simple, + # maintainable loop over optimal performance. + completed_prefill_reqs = [] + for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items(): + num_tokens = num_scheduled_tokens[req_id] + + # Get metadata for this request. + request = self.req_states[req_id] + num_prompt_tokens = len(request.prompt_token_ids) + prompt_token_ids = torch.tensor(request.prompt_token_ids).to( + self.device, non_blocking=True) + + # Set up target LogprobsTensors object. + logprobs_tensors = in_progress_dict.get(req_id) + if not logprobs_tensors: + # Create empty logprobs CPU tensors for the entire prompt. + # If chunked, we'll copy in slice by slice. + logprobs_tensors = LogprobsTensors.empty_cpu( + num_prompt_tokens - 1, num_prompt_logprobs + 1) + in_progress_dict[req_id] = logprobs_tensors + + # Determine number of logits to retrieve. + start_idx = request.num_computed_tokens + start_tok = start_idx + 1 + num_remaining_tokens = num_prompt_tokens - start_tok + if num_tokens <= num_remaining_tokens: + # This is a chunk, more tokens remain. + # In the == case, there are no more prompt logprobs to produce + # but we want to defer returning them to the next step where we + # have new generated tokens to return. + num_logits = num_tokens + else: + # This is the last chunk of prompt tokens to return. + num_logits = num_remaining_tokens + completed_prefill_reqs.append(req_id) + prompt_logprobs_dict[req_id] = logprobs_tensors + + if num_logits <= 0: + # This can happen for the final chunk if we prefilled exactly + # (num_prompt_tokens - 1) tokens for this request in the prior + # step. There are no more prompt logprobs to produce. + continue + + # Get the logits corresponding to this req's prompt tokens. + # If this is a partial request (i.e. chunked prefill), + # then there is prompt logprob generated for each index. + req_idx = self.input_batch.req_id_to_index[req_id] + offset = self.query_start_loc.np[req_idx].item() + prompt_hidden_states = hidden_states[offset:offset + num_logits] + logits = self.model.compute_logits(prompt_hidden_states, None) + + # Get the "target" tokens for each index. For prompt at index i, + # the token at prompt index i+1 is the "sampled" token we want + # to gather the logprob for. + tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits] + + # Compute prompt logprobs. + logprobs = self.sampler.compute_logprobs(logits) + token_ids, logprobs, ranks = self.sampler.gather_logprobs( + logprobs, num_prompt_logprobs, tgt_token_ids) + + # Transfer GPU->CPU async. + chunk_slice = slice(start_idx, start_idx + num_logits) + logprobs_tensors.logprob_token_ids[chunk_slice].copy_( + token_ids, non_blocking=True) + logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, + non_blocking=True) + logprobs_tensors.selected_token_ranks[chunk_slice].copy_( + ranks, non_blocking=True) + + # Remove requests that have completed prefill from the batch + # num_prompt_logprobs_dict. + for req_id in completed_prefill_reqs: + del num_prompt_logprobs_dict[req_id] + del in_progress_dict[req_id] + + # Must synchronize the non-blocking GPU->CPU transfers. + if prompt_logprobs_dict: + self._sync_device() + + return prompt_logprobs_dict + + def _get_nans_in_logits( + self, + req_ids: list[str], + logits: Optional[torch.Tensor], + ) -> dict[str, int]: + try: + if logits is None: + return {req_id: 0 for req_id in req_ids} + + num_nans_in_logits = {} + num_nans_for_index = logits.isnan().sum(dim=-1).cpu().numpy() + for i, req_id in enumerate(req_ids): + num_nans_in_logits[req_id] = (int(num_nans_for_index[i]) + if num_nans_for_index is not None + and i < logits.shape[0] else 0) + return num_nans_in_logits + except IndexError: + return {} + + @contextmanager + def maybe_randomize_inputs(self, input_ids: torch.Tensor): + """ + Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set. + This is to help balance expert-selection + - during profile_run + - during DP rank dummy run + """ + dp_size = self.vllm_config.parallel_config.data_parallel_size + randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1 + if not randomize_inputs: + yield + else: + import functools + + @functools.cache + def rand_input_ids() -> torch.Tensor: + return torch.randint_like( + self.input_ids.gpu, + low=0, + high=self.model_config.get_vocab_size(), + dtype=input_ids.dtype) + + logger.debug_once("Randomizing dummy data for DP Rank") + input_ids.copy_(rand_input_ids()[:input_ids.size(0)], + non_blocking=True) + yield + input_ids.fill_(0) + + def _get_mm_dummy_batch( + self, + modality: str, + max_items_per_batch: int, + ) -> BatchedTensorInputs: + """Dummy data for profiling and precompiling multimodal models.""" + assert self.mm_budget is not None + + dummy_decoder_data = self.mm_registry.get_decoder_dummy_data( + model_config=self.model_config, + seq_len=self.max_num_tokens, + mm_counts={modality: 1}, + cache=self.mm_budget.cache, + ) + dummy_mm_data = dummy_decoder_data.multi_modal_data + + # Result in the maximum GPU consumption of the model + dummy_mm_item = dummy_mm_data[modality][0] + dummy_mm_items = [dummy_mm_item] * max_items_per_batch + + return next(mm_kwargs_group + for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( + dummy_mm_items, + device=self.device, + pin_memory=self.pin_memory, + )) + + @torch.inference_mode() + def _dummy_run( + self, + num_tokens: int, + cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + force_attention: bool = False, + uniform_decode: bool = False, + skip_eplb: bool = False, + is_profile: bool = False, + remove_lora: bool = True, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Run a dummy forward pass to warm up/profile run or capture the + CUDA graph for the model. + + Args: + num_tokens: Number of tokens to run the dummy forward pass. + cudagraph_runtime_mode: used to control the behavior. + - CUDAGraphMode.NONE: No cudagraph, for warm up and profile run + - CUDAGraphMode.PIECEWISE: Piecewise cudagraph. + - CUDAGraphMode.FULL: Full cudagraph, attention metadata is + needed. + force_attention: If True, always create attention metadata. Used to + warm up attention backend when mode is NONE. + uniform_decode: If True, the batch is a uniform decode batch. + skip_eplb: If True, skip EPLB state update. + is_profile: If True, this is a profile run. + remove_lora: If False, dummy LoRAs are not destroyed after the run + """ + assert cudagraph_runtime_mode in { + CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL + } + + # Padding for DP + num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) + num_tokens += num_pad + + # If cudagraph_mode.decode_mode() == FULL and + # cudagraph_mode.seperate_routine(). This means that we are using + # different graphs and/or modes for mixed prefill-decode batches vs. + # uniform decode batches. A uniform decode batch means that all + # requests have identical query length, except a potential virtual + # request (shorter) in the batch account for padding. + # Uniform decode batch could either be common pure decode, where + # max_query_len == 1, or speculative decode, where + # max_query_len == 1 + num_spec_decode_tokens. + + # When setting max_query_len = 1, we switch to and capture the optimized + # routine of FA2 for pure decode, i.e., Flashdecode + an optimization + # for GQA/MQA. + max_query_len = self.uniform_decode_query_len if uniform_decode else \ + num_tokens + + # Set num_scheduled_tokens based on num_tokens and max_num_seqs + # for dummy run with LoRA so that the num_reqs collectively + # has num_tokens in total. + assert num_tokens <= self.scheduler_config.max_num_batched_tokens + max_num_reqs = self.scheduler_config.max_num_seqs + if uniform_decode: + num_reqs = cdiv(num_tokens, max_query_len) + assert num_reqs <= max_num_reqs, \ + "Do not capture num_reqs > max_num_reqs for uniform batch" + num_scheduled_tokens_list = [max_query_len] * num_reqs + if num_tokens % max_query_len != 0: + num_scheduled_tokens_list[-1] = num_tokens % max_query_len + else: + num_reqs = min(num_tokens, max_num_reqs) + min_tokens_per_req = num_tokens // num_reqs + num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + + assert sum(num_scheduled_tokens_list) == num_tokens + assert len(num_scheduled_tokens_list) == num_reqs + num_scheduled_tokens = np.array(num_scheduled_tokens_list, + dtype=np.int32) + + attn_metadata: Optional[dict[str, Any]] = None + + # If force_attention is True, we always capture attention. Otherwise, + # it only happens for cudagraph_runtime_mode=FULL. + if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: + attn_metadata = {} + + # Make sure max_model_len is used at the graph capture time. + self.seq_lens.np[:num_reqs] = self.max_model_len + self.seq_lens.np[num_reqs:] = 0 + self.seq_lens.copy_to_gpu() + + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups): + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=self.query_start_loc.gpu[:num_reqs + 1], + query_start_loc_cpu=self.query_start_loc.cpu[:num_reqs + + 1], + seq_lens=self.seq_lens.gpu[:num_reqs], + seq_lens_cpu=self.seq_lens.cpu[:num_reqs], + num_computed_tokens_cpu=self.req_states.num_computed_tokens. + cpu[:num_reqs], + num_reqs=num_reqs, + num_actual_tokens=num_tokens, + max_query_len=max_query_len, + max_seq_len=self.max_model_len, + block_table_tensor=self.block_tables. + block_tables[kv_cache_group_id][:num_reqs], + slot_mapping=self.block_tables. + slot_mappings[kv_cache_group_id][:num_tokens], + causal=True, + ) + + for attn_group in self.attn_groups[kv_cache_group_id]: + attn_metadata_i = attn_group.metadata_builder\ + .build_for_cudagraph_capture(common_attn_metadata) + for layer_name in kv_cache_group_spec.layer_names: + attn_metadata[layer_name] = attn_metadata_i + + with self.maybe_dummy_run_with_lora(self.lora_config, + num_scheduled_tokens, remove_lora): + if self.supports_mm_inputs: + input_ids = None + inputs_embeds = self.inputs_embeds.gpu[:num_tokens] + model_kwargs = { + **self._init_model_kwargs(num_tokens), + **self._dummy_mm_kwargs(num_reqs), + } + else: + input_ids = self.input_ids.gpu[:num_tokens] + inputs_embeds = None + model_kwargs = self._init_model_kwargs(num_tokens) + + if self.uses_mrope: + positions = self.mrope_positions.gpu[:, :num_tokens] + else: + positions = self.positions.gpu[:num_tokens] + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + if self.intermediate_tensors is None: + self.intermediate_tensors = ( + self.model.make_empty_intermediate_tensors( + batch_size=self.max_num_tokens, + dtype=self.model_config.dtype, + device=self.device)) + + intermediate_tensors = self.sync_and_slice_intermediate_tensors( + num_tokens, None, False) + if cudagraph_runtime_mode == CUDAGraphMode.NONE: + batch_descriptor = None + else: + # filter out the valid batch descriptor + _cg_mode, batch_descriptor = \ + self.cudagraph_dispatcher.dispatch( + BatchDescriptor(num_tokens=num_tokens, + uniform_decode=uniform_decode)) + # sanity check + assert cudagraph_runtime_mode == _cg_mode, ( + f"Cudagraph runtime mode mismatch at dummy_run. " + f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}.") + + with self.maybe_randomize_inputs(input_ids), set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor): + outputs = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, + ) + + if self.use_aux_hidden_state_outputs: + hidden_states, _ = outputs + else: + hidden_states = outputs + + if self.speculative_config and self.speculative_config.use_eagle(): + assert isinstance(self.drafter, EagleProposer) + self.drafter.dummy_run(num_tokens) + + # This is necessary to avoid blocking DP. + # For dummy runs, we typically skip EPLB since we don't have any real + # requests to process. + # However, in DP settings, there may be cases when some DP ranks do + # not have any requests to process, so they're executing dummy batches. + # In such cases, we still have to trigger EPLB to make sure + # ranks execute the rearrangement in synchronization. + if not skip_eplb: + self.eplb_step(is_dummy=True, is_profile=is_profile) + + logit_indices = np.cumsum(num_scheduled_tokens) - 1 + return hidden_states, hidden_states[logit_indices] + + @torch.inference_mode() + def _dummy_sampler_run( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + # The dummy hidden states may contain special values, + # like `inf` or `nan`. + # To avoid breaking the sampler, we use a random tensor here instead. + hidden_states = torch.rand_like(hidden_states) + + logits = self.model.compute_logits(hidden_states, None) + num_reqs = logits.size(0) + + dummy_tensors = lambda v: torch.full( + (num_reqs, ), v, device=self.device) + + dummy_metadata = SamplingMetadata( + temperature=dummy_tensors(0.5), + all_greedy=False, + all_random=False, + top_p=dummy_tensors(0.9), + top_k=dummy_tensors(logits.size(1) - 1), + generators={}, + max_num_logprobs=None, + no_penalties=True, + frequency_penalties=dummy_tensors(0.1), + presence_penalties=dummy_tensors(0.1), + repetition_penalties=dummy_tensors(0.1), + allowed_token_ids_mask=None, + bad_words_token_ids={}, + logitsprocs=LogitsProcessors(), + token_ids=None, + num_tokens=None, + num_prompt_tokens=None, + ) + try: + sampler_output = self.sampler(logits=logits, + sampling_metadata=dummy_metadata) + except RuntimeError as e: + if 'out of memory' in str(e): + raise RuntimeError( + "CUDA out of memory occurred when warming up sampler with " + f"{num_reqs} dummy requests. Please try lowering " + "`max_num_seqs` or `gpu_memory_utilization` when " + "initializing the engine.") from e + else: + raise e + if self.speculative_config: + draft_token_ids = [[0] for _ in range(num_reqs)] + dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy( + draft_token_ids, self.device) + + num_tokens = sum(len(ids) for ids in draft_token_ids) + # draft_probs = torch.randn( + # num_tokens, logits.shape[-1], device=self.device, + # dtype=logits.dtype) + draft_probs = None + target_logits = torch.randn(num_tokens, + logits.shape[-1], + device=self.device, + dtype=logits.dtype) + # NOTE(woosuk): Here, we should use int32 because the sampler uses + # int32 for bonus_token_ids. If the dtype mismatches, re-compilation + # will occur at runtime. + bonus_token_ids = torch.zeros(num_reqs, + device=self.device, + dtype=torch.int32) + self.rejection_sampler( + dummy_spec_decode_metadata, + draft_probs, + target_logits, + bonus_token_ids, + dummy_metadata, + ) + return sampler_output + + def _dummy_pooler_run_task( + self, + hidden_states: torch.Tensor, + task: PoolingTask, + ) -> PoolerOutput: + num_tokens = hidden_states.shape[0] + max_num_reqs = self.scheduler_config.max_num_seqs + num_reqs = min(num_tokens, max_num_reqs) + min_tokens_per_req = num_tokens // num_reqs + num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + assert sum(num_scheduled_tokens_list) == num_tokens + assert len(num_scheduled_tokens_list) == num_reqs + + req_num_tokens = num_tokens // num_reqs + + dummy_prompt_lens = torch.tensor( + num_scheduled_tokens_list, + device="cpu", + ) + dummy_token_ids = torch.zeros((num_reqs, req_num_tokens), + dtype=torch.int32, + device=self.device) + + model = cast(VllmModelForPooling, self.get_model()) + dummy_pooling_params = PoolingParams(task=task) + to_update = model.pooler.get_pooling_updates(task) + to_update.apply(dummy_pooling_params) + + dummy_metadata = PoolingMetadata( + prompt_lens=dummy_prompt_lens, + prompt_token_ids=dummy_token_ids, + pooling_params=[dummy_pooling_params] * num_reqs, + ) + + dummy_metadata.build_pooling_cursor(num_scheduled_tokens_list, + device=hidden_states.device) + + try: + return model.pooler(hidden_states=hidden_states, + pooling_metadata=dummy_metadata) + except RuntimeError as e: + if 'out of memory' in str(e): + raise RuntimeError( + "CUDA out of memory occurred when warming up pooler " + f"({task=}) with {num_reqs} dummy requests. Please try " + "lowering `max_num_seqs` or `gpu_memory_utilization` when " + "initializing the engine.") from e + else: + raise e + + @torch.inference_mode() + def _dummy_pooler_run( + self, + hidden_states: torch.Tensor, + ) -> PoolerOutput: + # Find the task that has the largest output for subsequent steps + output_size = dict[PoolingTask, float]() + for task in self.get_supported_pooling_tasks(): + # Run a full batch with each task to ensure none of them OOMs + output = self._dummy_pooler_run_task(hidden_states, task) + output_size[task] = output.get_data_nbytes() + del output # Allow GC + + max_task = max(output_size.items(), key=lambda x: x[1])[0] + return self._dummy_pooler_run_task(hidden_states, max_task) + + def profile_run(self) -> None: + # Profile with multimodal encoder & encoder cache. + if self.supports_mm_inputs: + if self.model_config.multimodal_config.skip_mm_profiling: + logger.info( + "Skipping memory profiling for multimodal encoder and " + "encoder cache.") + else: + mm_budget = self.mm_budget + assert mm_budget is not None + + # TODO: handle encoder-decoder models once we support them. + if (encoder_budget := mm_budget.get_encoder_budget()) > 0: + # NOTE: Currently model is profiled with a single non-text + # modality with the max possible input tokens even when + # it supports multiple. + dummy_modality = mm_budget.get_modality_with_max_tokens() + max_mm_items_per_batch = mm_budget \ + .max_items_per_batch_by_modality[dummy_modality] + + logger.info( + "Encoder cache will be initialized with a budget of " + "%s tokens, and profiled with %s %s items of the " + "maximum feature size.", + encoder_budget, + max_mm_items_per_batch, + dummy_modality, + ) + + # Create dummy batch of multimodal inputs. + batched_dummy_mm_inputs = self._get_mm_dummy_batch( + dummy_modality, + max_mm_items_per_batch, + ) + + # Run multimodal encoder. + dummy_encoder_outputs = \ + self.model.get_multimodal_embeddings( + **batched_dummy_mm_inputs) + + sanity_check_mm_encoder_outputs( + dummy_encoder_outputs, + expected_num_items=max_mm_items_per_batch, + ) + + # Cache the dummy encoder outputs. + self.encoder_cache["tmp"] = dict( + enumerate(dummy_encoder_outputs)) + + # Add `is_profile` here to pre-allocate communication buffers + hidden_states, last_hidden_states \ + = self._dummy_run(self.max_num_tokens, is_profile=True) + if get_pp_group().is_last_rank: + if self.is_pooling_model: + output = self._dummy_pooler_run(hidden_states) + else: + output = self._dummy_sampler_run(last_hidden_states) + else: + output = None + self._sync_device() + del hidden_states, output + self.encoder_cache.clear() + gc.collect() + + def capture_model(self) -> None: + if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE: + logger.warning( + "Skipping CUDA graph capture. To turn on CUDA graph capture, " + "ensure `cudagraph_mode` was not manually set to `NONE`") + return + else: + self.initialize_cudagraph_capture() + + compilation_counter.num_gpu_runner_capture_triggers += 1 + + start_time = time.perf_counter() + start_free_gpu_memory = torch.cuda.mem_get_info()[0] + + @contextmanager + def freeze_gc(): + # Optimize garbage collection during CUDA graph capture. + # Clean up, then freeze all remaining objects from being included + # in future collections. + gc.collect() + should_freeze = not envs.VLLM_ENABLE_CUDAGRAPH_GC + if should_freeze: + gc.freeze() + try: + yield + finally: + if should_freeze: + gc.unfreeze() + + # Trigger CUDA graph capture for specific shapes. + # Capture the large shapes first so that the smaller shapes + # can reuse the memory pool allocated for the large shapes. + set_cudagraph_capturing_enabled(True) + with freeze_gc(), graph_capture(device=self.device): + cudagraph_mode = self.compilation_config.cudagraph_mode + if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: + cudagraph_runtime_mode = cudagraph_mode.mixed_mode() + + compilation_cases = list(reversed(self.cudagraph_batch_sizes)) + self._capture_cudagraphs( + compilation_cases, + cudagraph_runtime_mode=cudagraph_runtime_mode, + uniform_decode=False) + + # Capture full cudagraph for uniform decode batches if we have + # dont already have full mixed prefill-decode cudagraphs + if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and \ + cudagraph_mode.separate_routine(): + max_num_tokens = self.scheduler_config.max_num_seqs * \ + self.uniform_decode_query_len + decode_cudagraph_batch_sizes = [ + x for x in self.cudagraph_batch_sizes if + x <= max_num_tokens and x >= self.uniform_decode_query_len + ] + compilation_cases_decode = list( + reversed(decode_cudagraph_batch_sizes)) + self._capture_cudagraphs( + compilation_cases=compilation_cases_decode, + cudagraph_runtime_mode=CUDAGraphMode.FULL, + uniform_decode=True) + + # Disable cudagraph capturing globally, so any unexpected cudagraph + # capturing will be detected and raise an error after here. + # Note: We don't put it into graph_capture context manager because + # we may do lazy capturing in future that still allows capturing + # after here. + set_cudagraph_capturing_enabled(False) + + end_time = time.perf_counter() + end_free_gpu_memory = torch.cuda.mem_get_info()[0] + elapsed_time = end_time - start_time + cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory + # This usually takes 5~20 seconds. + logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", + elapsed_time, cuda_graph_size / (1 << 30)) + + def _capture_cudagraphs(self, compilation_cases: list[int], + cudagraph_runtime_mode: CUDAGraphMode, + uniform_decode: bool): + assert cudagraph_runtime_mode != CUDAGraphMode.NONE and \ + cudagraph_runtime_mode in [CUDAGraphMode.FULL, + CUDAGraphMode.PIECEWISE] + + # Only rank 0 should print progress bar during capture + if is_global_first_rank(): + compilation_cases = tqdm( + compilation_cases, + disable=not self.load_config.use_tqdm_on_load, + desc="Capturing CUDA graphs ({}, {})".format( + "decode" if uniform_decode else "mixed prefill-decode", + cudagraph_runtime_mode.name)) + # We skip EPLB here since we don't want to record dummy metrics + for num_tokens in compilation_cases: + for _ in range(self.compilation_config.cudagraph_num_of_warmups): + # Use CUDAGraphRuntimeStyle.NONE (default) for warmup. + # But be careful, warm up with `NONE`is orthogonal to + # if we want to warm up attention or not. This is + # different from the case where `FULL` implies capture + # attention while `PIECEWISE` implies no attention. + force_attention = ( + cudagraph_runtime_mode == CUDAGraphMode.FULL) + self._dummy_run(num_tokens, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + force_attention=force_attention, + uniform_decode=uniform_decode, + skip_eplb=True, + remove_lora=False) + self._dummy_run(num_tokens, + cudagraph_runtime_mode=cudagraph_runtime_mode, + uniform_decode=uniform_decode, + skip_eplb=True, + remove_lora=False) + self.maybe_remove_all_loras(self.lora_config) + + def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: + """ + Initialize the attention backends and attention metadata builders. + """ + assert len(self.attn_groups) == 0, \ + "Attention backends are already initialized" + + def get_attn_backends_for_layers( + layer_names: list[str] + ) -> dict[type[AttentionBackend], list[str]]: + layers = get_layers_from_vllm_config(self.vllm_config, + AttentionLayerBase, + layer_names) + attn_backends = {} + attn_backend_layers = defaultdict(list) + # Dedupe based on full class name; this is a bit safer than + # using the class itself as the key because when we create dynamic + # attention backend subclasses (e.g. ChunkedLocalAttention) unless + # they are cached correctly, there will be different objects per + # layer. + for layer_name in layer_names: + attn_backend = layers[layer_name].get_attn_backend() + + if layer_name in self.kv_sharing_fast_prefill_eligible_layers: + attn_backend = create_fast_prefill_custom_backend( + "FastPrefill", + attn_backend, + ) + + key = attn_backend.full_cls_name() + attn_backends[key] = attn_backend + attn_backend_layers[key].append(layer_name) + return { + attn_backends[k]: v + for k, v in attn_backend_layers.items() + } + + def create_attn_groups( + attn_backends_map: dict[AttentionBackend, list[str]], + kv_cache_spec: KVCacheSpec, + ) -> list[AttentionGroup]: + attn_groups: list[AttentionGroup] = [] + for attn_backend, layer_names in attn_backends_map.items(): + attn_metadata_builder_i = attn_backend.get_builder_cls()( + kv_cache_spec, + layer_names, + self.vllm_config, + self.device, + ) + attn_group = AttentionGroup(attn_backend, + attn_metadata_builder_i, + layer_names) + attn_groups.append(attn_group) + return attn_groups + + for kv_cache_group_spec in kv_cache_config.kv_cache_groups: + kv_cache_spec = kv_cache_group_spec.kv_cache_spec + attn_backends = get_attn_backends_for_layers( + kv_cache_group_spec.layer_names) + self.attn_groups.append( + create_attn_groups(attn_backends, kv_cache_spec)) + + def initialize_cudagraph_capture(self) -> None: + min_cg_support = AttentionCGSupport.ALWAYS + min_cg_builder_name = None + + for attn_group in self._attn_group_iterator(): + builder = attn_group.metadata_builder + if builder.cudagraph_support.value < min_cg_support.value: + min_cg_support = builder.cudagraph_support + min_cg_builder_name = builder.__class__.__name__ + + # Flexible resolve the cudagraph mode + cudagraph_mode = self.compilation_config.cudagraph_mode + # check cudagraph for mixed batch is supported + if cudagraph_mode.mixed_mode() == CUDAGraphMode.FULL \ + and min_cg_support != AttentionCGSupport.ALWAYS: + msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported " + f"with {min_cg_builder_name} backend (support: " + f"{min_cg_support})") + if min_cg_support == AttentionCGSupport.NEVER: + # if not supported any full cudagraphs, just raise it. + msg += "; please try cudagraph_mode=PIECEWISE, and "\ + "make sure compilation level is piecewise" + raise ValueError(msg) + + # attempt to resolve the full cudagraph related mode + if self.compilation_config.splitting_ops_contain_attention(): + msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE" + cudagraph_mode = self.compilation_config.cudagraph_mode = \ + CUDAGraphMode.FULL_AND_PIECEWISE + else: + msg += "; setting cudagraph_mode=FULL_DECODE_ONLY" + cudagraph_mode = self.compilation_config.cudagraph_mode = \ + CUDAGraphMode.FULL_DECODE_ONLY + logger.warning(msg) + + # check that if we are doing spec-decode + decode full-cudagraphs it is + # supported + if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + and self.uniform_decode_query_len > 1 and min_cg_support.value + < AttentionCGSupport.UNIFORM_BATCH.value): + msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported" + f" with spec-decode for attention backend " + f"{min_cg_builder_name} (support: {min_cg_support})") + if self.compilation_config.splitting_ops_contain_attention(): + msg += "; setting cudagraph_mode=PIECEWISE" + cudagraph_mode = self.compilation_config.cudagraph_mode = \ + CUDAGraphMode.PIECEWISE + else: + msg += "; setting cudagraph_mode=NONE" + cudagraph_mode = self.compilation_config.cudagraph_mode = \ + CUDAGraphMode.NONE + logger.warning(msg) + + # double check that we can support full cudagraph if they are requested + # even after automatic downgrades + if cudagraph_mode.has_full_cudagraphs() \ + and min_cg_support == AttentionCGSupport.NEVER: + raise ValueError(f"CUDAGraphMode.{cudagraph_mode.name} is not " + f"supported with {min_cg_builder_name} backend (" + f"support:{min_cg_support}) " + "; please try cudagraph_mode=PIECEWISE, " + "and make sure compilation level is piecewise") + + # Trigger cudagraph dispatching keys initialization here (after + # initializing attn backends). + self.cudagraph_dispatcher.initialize_cudagraph_keys( + self.compilation_config.cudagraph_mode, + self.uniform_decode_query_len) + + def _allocate_kv_cache_tensors( + self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: + """ + Initializes the KV cache buffer with the correct size. The buffer needs + to be reshaped to the desired shape before being used by the models. + + Args: + kv_cache_config: The KV cache config + Returns: + dict[str, torch.Tensor]: A map between layer names to their + corresponding memory buffer for KV cache. + """ + kv_cache_raw_tensors: dict[str, torch.Tensor] = {} + for kv_cache_tensor in kv_cache_config.kv_cache_tensors: + tensor = torch.zeros(kv_cache_tensor.size, + dtype=torch.int8, + device=self.device) + for layer_name in kv_cache_tensor.shared_by: + kv_cache_raw_tensors[layer_name] = tensor + + layer_names = set() + for group in kv_cache_config.kv_cache_groups: + for layer_name in group.layer_names: + if layer_name in self.runner_only_attn_layers: + continue + layer_names.add(layer_name) + assert layer_names == set(kv_cache_raw_tensors.keys( + )), "Some layers are not correctly initialized" + return kv_cache_raw_tensors + + def _attn_group_iterator(self) -> Iterator[AttentionGroup]: + return itertools.chain.from_iterable(self.attn_groups) + + def _kv_cache_spec_attn_group_iterator( + self) -> Iterator[tuple[KVCacheSpec, AttentionGroup]]: + if not self.kv_cache_config.kv_cache_groups: + return + for kv_cache_spec_id, attn_groups in enumerate(self.attn_groups): + for attn_group in attn_groups: + yield self.kv_cache_config.kv_cache_groups[ + kv_cache_spec_id].kv_cache_spec, attn_group + + def _reshape_kv_cache_tensors( + self, + kv_cache_config: KVCacheConfig, + kv_cache_raw_tensors: dict[str, torch.Tensor], + ) -> dict[str, torch.Tensor]: + """ + Reshape the KV cache tensors to the desired shape and dtype. + + Args: + kv_cache_config: The KV cache config + kv_cache_raw_tensors: The KV cache buffer of each layer, with + correct size but uninitialized shape. + Returns: + Dict[str, torch.Tensor]: A map between layer names to their + corresponding memory buffer for KV cache. + """ + kv_caches: dict[str, torch.Tensor] = {} + has_attn, has_mamba = False, False + for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator(): + attn_backend = group.backend + for layer_name in group.layer_names: + if layer_name in self.runner_only_attn_layers: + continue + raw_tensor = kv_cache_raw_tensors[layer_name] + assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 + num_blocks = (raw_tensor.numel() // + kv_cache_spec.page_size_bytes) + if isinstance(kv_cache_spec, AttentionSpec): + has_attn = True + kv_cache_shape = attn_backend.get_kv_cache_shape( + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + dtype = kv_cache_spec.dtype + try: + kv_cache_stride_order = \ + attn_backend.get_kv_cache_stride_order() + assert len(kv_cache_stride_order) == len( + kv_cache_shape) + except (AttributeError, NotImplementedError): + kv_cache_stride_order = tuple( + range(len(kv_cache_shape))) + # The allocation respects the backend-defined stride order + # to ensure the semantic remains consistent for each + # backend. We first obtain the generic kv cache shape and + # then permute it according to the stride order which could + # result in a non-contiguous tensor. + kv_cache_shape = tuple(kv_cache_shape[i] + for i in kv_cache_stride_order) + # Maintain original KV shape view. + inv_order = [ + kv_cache_stride_order.index(i) + for i in range(len(kv_cache_stride_order)) + ] + kv_caches[layer_name] = kv_cache_raw_tensors[ + layer_name].view(dtype).view(kv_cache_shape).permute( + *inv_order) + elif isinstance(kv_cache_spec, MambaSpec): + has_mamba = True + raw_tensor = kv_cache_raw_tensors[layer_name] + state_tensors = [] + storage_offset_bytes = 0 + for (shape, dtype) in zip(kv_cache_spec.shapes, + kv_cache_spec.dtypes): + dtype_size = get_dtype_size(dtype) + num_element_per_page = ( + kv_cache_spec.page_size_bytes // dtype_size) + target_shape = (num_blocks, *shape) + stride = torch.empty(target_shape).stride() + target_stride = (num_element_per_page, *stride[1:]) + assert storage_offset_bytes % dtype_size == 0 + tensor = torch.as_strided( + raw_tensor.view(dtype), + size=target_shape, + stride=target_stride, + storage_offset=storage_offset_bytes // dtype_size, + ) + state_tensors.append(tensor) + storage_offset_bytes += stride[0] * dtype_size + + kv_caches[layer_name] = state_tensors + else: + raise NotImplementedError + + if has_attn and has_mamba: + self._update_hybrid_attention_mamba_layout(kv_caches) + + return kv_caches + + def _update_hybrid_attention_mamba_layout( + self, kv_caches: dict[str, torch.Tensor]) -> None: + """ + Update the layout of attention layers from (2, num_blocks, ...) to + (num_blocks, 2, ...). + + Args: + kv_caches: The KV cache buffer of each layer. + """ + + for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator(): + for layer_name in group.layer_names: + kv_cache = kv_caches[layer_name] + if (isinstance(kv_cache_spec, AttentionSpec) + and kv_cache.shape[0] == 2): + assert kv_cache.shape[1] != 2, \ + "Fail to determine whether the layout is " \ + "(2, num_blocks, ...) or (num_blocks, 2, ...) for " \ + f"a tensor of shape {kv_cache.shape}" + hidden_size = kv_cache.shape[2:].numel() + kv_cache.as_strided_(size=kv_cache.shape, + stride=(hidden_size, 2 * hidden_size, + *kv_cache.stride()[2:])) + + def initialize_kv_cache_tensors( + self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: + """ + Initialize the memory buffer for KV cache. + + Args: + kv_cache_config: The KV cache config + Returns: + Dict[str, torch.Tensor]: A map between layer names to their + corresponding memory buffer for KV cache. + """ + # Initialize the memory buffer for KV cache + kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config) + # Change the memory buffer to the desired shape + kv_caches = self._reshape_kv_cache_tensors(kv_cache_config, + kv_cache_raw_tensors) + + # Set up cross-layer KV cache sharing + for layer_name, target_layer_name in self.shared_kv_cache_layers.items( + ): + logger.debug("%s reuses KV cache of %s", layer_name, + target_layer_name) + kv_caches[layer_name] = kv_caches[target_layer_name] + + bind_kv_cache(kv_caches, + self.compilation_config.static_forward_context, + self.kv_caches) + return kv_caches + + def maybe_add_kv_sharing_layers_to_kv_cache_groups( + self, kv_cache_config: KVCacheConfig) -> None: + """ + Add layers that re-use KV cache to KV cache group of its target layer. + Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()` + """ + if not self.shared_kv_cache_layers: + # No cross-layer KV sharing, return + return + + add_kv_sharing_layers_to_kv_cache_groups( + self.shared_kv_cache_layers, + kv_cache_config.kv_cache_groups, + self.runner_only_attn_layers, + ) + + if self.cache_config.kv_sharing_fast_prefill: + # In You Only Cache Once (https://arxiv.org/abs/2405.05254) or other + # similar KV sharing setups, only the layers that generate KV caches + # are involved in the prefill phase, enabling prefill to early exit. + attn_layers = get_layers_from_vllm_config(self.vllm_config, + Attention) + for layer_name in reversed(attn_layers): + if layer_name in self.shared_kv_cache_layers: + self.kv_sharing_fast_prefill_eligible_layers.add( + layer_name) + else: + break + + def init_block_tables(self, kv_cache_config: KVCacheConfig) -> None: + block_sizes = [ + kv_cache_group.kv_cache_spec.block_size + for kv_cache_group in kv_cache_config.kv_cache_groups + ] + self.block_tables = BlockTables( + block_sizes=block_sizes, + max_num_reqs=self.max_num_reqs, + max_num_cached_reqs=2 * self.max_num_reqs, + max_num_batched_tokens=self.max_num_tokens, + max_model_len=self.max_model_len, + device=self.device, + pin_memory=self.pin_memory, + ) + + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: + """ + Initialize KV cache based on `kv_cache_config`. + Args: + kv_cache_config: Configuration for the KV cache, including the KV + cache size of each layer + """ + kv_cache_config = deepcopy(kv_cache_config) + self.kv_cache_config = kv_cache_config + self.init_block_tables(kv_cache_config) + self.may_add_encoder_only_layers_to_kv_cache_config() + self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) + self.initialize_attn_backend(kv_cache_config) + kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) + + if self.speculative_config and self.speculative_config.use_eagle(): + assert isinstance(self.drafter, EagleProposer) + # validate all draft model layers belong to the same kv cache + # group + self.drafter.validate_same_kv_cache_group(kv_cache_config) + + if has_kv_transfer_group(): + get_kv_transfer_group().register_kv_caches(kv_caches) + if self.device.type == 'xpu': + get_kv_transfer_group().set_host_xfer_buffer_ops( + copy_kv_blocks) + + if self.dcp_world_size > 1: + layer_names = self.attn_groups[0][0].layer_names + layers = get_layers_from_vllm_config(self.vllm_config, + AttentionLayerBase, + layer_names) + for layer in layers.values(): + assert layer.impl.need_to_return_lse_for_decode, ( + "DCP requires attention impls to return" + " the softmax lse for decode, but the impl " + f"{layer.impl.__class__.__name__} " + "does not return the softmax lse for decode.") + + def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: + """ + Add encoder-only layers to the KV cache config. + """ + block_size = self.vllm_config.cache_config.block_size + use_mla = self.vllm_config.model_config.use_mla + encoder_only_attn_specs: dict[AttentionSpec, + list[str]] = defaultdict(list) + attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) + for layer_name, attn_module in attn_layers.items(): + if attn_module.attn_type == AttentionType.ENCODER_ONLY: + attn_spec = EncoderOnlyAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + use_mla=use_mla) + encoder_only_attn_specs[attn_spec].append(layer_name) + self.runner_only_attn_layers.add(layer_name) + if len(encoder_only_attn_specs) > 0: + assert len( + encoder_only_attn_specs + ) == 1, "Only support one encoder-only attention spec now" + spec, layer_names = encoder_only_attn_specs.popitem() + self.kv_cache_config.kv_cache_groups.append( + KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec)) + + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: + """ + Generates the KVCacheSpec by parsing the kv cache format from each + Attention module in the static forward context. + Returns: + KVCacheSpec: A dictionary mapping layer names to their KV cache + format. Layers that do not need KV cache are not included. + """ + + block_size = self.vllm_config.cache_config.block_size + use_mla = self.vllm_config.model_config.use_mla + kv_cache_spec: dict[str, KVCacheSpec] = {} + attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) + for layer_name, attn_module in attn_layers.items(): + if (kv_tgt_layer := + attn_module.kv_sharing_target_layer_name) is not None: + # The layer doesn't need its own KV cache and will use that of + # the target layer. We skip creating a KVCacheSpec for it, so + # that KV cache management logic will act as this layer does + # not exist, and doesn't allocate KV cache for the layer. This + # enables the memory saving of cross-layer kv sharing, allowing + # a given amount of memory to accommodate longer context lengths + # or enable more requests to be processed simultaneously. + self.shared_kv_cache_layers[layer_name] = kv_tgt_layer + continue + + # TODO: Support other attention modules, e.g., cross-attention + # TODO(lucas): move the attention specs into the model layers like + # the attention backends + if attn_module.attn_type == AttentionType.DECODER: + if attn_module.sliding_window is not None: + kv_cache_spec[layer_name] = SlidingWindowSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + sliding_window=attn_module.sliding_window, + use_mla=use_mla) + elif self.attention_chunk_size is not None \ + and isinstance(attn_module, ChunkedLocalAttention): + kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + attention_chunk_size=self.attention_chunk_size, + use_mla=use_mla) + else: + kv_cache_spec[layer_name] = FullAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + use_mla=use_mla) + elif attn_module.attn_type in (AttentionType.ENCODER, + AttentionType.ENCODER_ONLY): + # encoder-only attention does not need KV cache. + continue + elif attn_module.attn_type == AttentionType.ENCODER_DECODER: + raise NotImplementedError + else: + raise ValueError( + f"Unknown attention type: {attn_module.attn_type}") + + mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase) + if len(mamba_layers) > 0: + if self.vllm_config.speculative_config is not None: + raise NotImplementedError( + "Mamba with speculative decoding is not supported yet.") + if self.vllm_config.cache_config.enable_prefix_caching: + raise NotImplementedError( + "Prefix caching is not supported for Mamba yet.") + max_model_len = self.vllm_config.model_config.max_model_len + + page_size_padded = ( + self.vllm_config.cache_config.mamba_page_size_padded) + + # Set block_size to max_model_len, so that mamba model will always + # have only one block in the KV cache. + for layer_name, mamba_module in mamba_layers.items(): + kv_cache_spec[layer_name] = MambaSpec( + shapes=mamba_module.get_state_shape(), + dtypes=mamba_module.get_state_dtype(), + block_size=max_model_len, + page_size_padded=page_size_padded, + mamba_type=mamba_module.mamba_type) + + return kv_cache_spec + + def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: + # This is a short term mitigation for issue mentioned in + # https://github.com/vllm-project/vllm/issues/22754. + # `tolist` would trigger a cuda wise stream sync, which + # would block other copy ops from other cuda streams. + # A cuda event sync would avoid such a situation. Since + # this is in the critical path of every single model + # forward loop, this has caused perf issue for a disagg + # setup. + pinned = self.sampled_token_ids_pinned_cpu[:sampled_token_ids.shape[0]] + pinned.copy_(sampled_token_ids, non_blocking=True) + self.transfer_event.record() + self.transfer_event.synchronize() + return pinned.tolist() diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d55af6185ec90..d8a9c36870f7e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -49,6 +49,7 @@ from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem, PlaceholderRange) from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.pooling_params import PoolingParams +from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, @@ -69,7 +70,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, DraftTokenIds, LogprobsLists, LogprobsTensors, ModelRunnerOutput, SamplerOutput) from vllm.v1.pool.metadata import PoolingMetadata -from vllm.v1.sample.logits_processor import LogitsProcessors +from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.sampler import Sampler @@ -78,9 +79,9 @@ from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext +from vllm.v1.worker.gpu_worker_states import RequestState from vllm.v1.worker.gpu_block_table import BlockTables from vllm.v1.worker.gpu_input_batch import InputBatch, prepare_inputs -from vllm.v1.worker.gpu_worker_states import RequestState from vllm.v1.worker.kv_connector_model_runner_mixin import ( KVConnectorModelRunnerMixin, KVConnectorOutput) from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin @@ -252,15 +253,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.rejection_sampler = RejectionSampler() # Request states. + self.max_num_cached_reqs = 2 * self.max_num_reqs self.req_states = RequestState( max_num_reqs=self.max_num_reqs, max_model_len=self.max_model_len, max_num_batched_tokens=self.max_num_tokens, - max_num_cached_reqs=2 * self.max_num_reqs, + max_num_cached_reqs=self.max_num_cached_reqs, device=self.device, pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), - block_sizes=[self.cache_config.block_size], ) self.use_async_scheduling = self.scheduler_config.async_scheduling @@ -282,6 +283,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Persistent buffers for CUDA graphs. self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32) + self.positions = self._make_buffer(self.max_num_tokens, + dtype=torch.int64) + self.query_start_loc = self._make_buffer(self.max_num_reqs + 1, + dtype=torch.int32) + self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) # Because inputs_embeds may be bfloat16 and we don't need a numpy # version of this tensor, avoid a RuntimeError by not creating a # numpy buffer. @@ -289,13 +295,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.hidden_size, dtype=self.dtype, numpy=False) - self.positions = self._make_buffer(self.max_num_tokens, - dtype=torch.int64) - self.query_start_loc = self._make_buffer(self.max_num_reqs + 1, - dtype=torch.int32) - self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) - self.cu_num_draft_tokens = self._make_buffer(self.max_num_reqs, - dtype=torch.int32) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -315,9 +314,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # None in the first PP rank. The rest are set after load_model. self.intermediate_tensors: Optional[IntermediateTensors] = None - self.idx_mapping = self._make_buffer(self.max_num_reqs, - dtype=torch.int32) - # Layer pairings for cross-layer KV sharing. # If an Attention layer `layer_name` is in the keys of this dict, it # means this layer will perform attention using the keys and values @@ -350,7 +346,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Cached outputs. self._draft_token_ids: Optional[Union[list[list[int]], torch.Tensor]] = None - self._draft_req_ids: Optional[list[str]] = None self.transfer_event = torch.cuda.Event() self.sampled_token_ids_pinned_cpu = torch.empty( (self.max_model_len, 1), @@ -372,7 +367,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): with_numpy=numpy) def _init_model_kwargs(self, num_tokens: int): - return {} model_kwargs = dict[str, Any]() if not self.is_pooling_model: @@ -424,7 +418,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): The SamplingMetadata is updated and copied to the GPU if there is a new/resumed/paused/finished request in the batch. """ - # Remove the finished requests from the persistent batch. + # Remove finished requests from the cached states. # NOTE(woosuk): There could be an edge case where finished_req_ids and # scheduled_req_ids overlap. This happens when a request is aborted and # then resubmitted with the same ID. In this case, we treat them as two @@ -439,12 +433,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.encoder_cache.pop(mm_hash, None) req_indices: list[int] = [] - cu_num_new_blocks: list[list[int]] = [ + cu_num_new_blocks = tuple( [0] for _ in range(self.block_tables.num_kv_cache_groups) - ] - new_block_ids: list[list[int]] = [ + ) + new_block_ids = tuple( [] for _ in range(self.block_tables.num_kv_cache_groups) - ] + ) overwrite: list[bool] = [] # Add new requests to the cached states. @@ -510,6 +504,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_data = self.req_states.req_data[req_idx] prompt_len = self.req_states.num_prompt_tokens.np[req_idx] prompt_token_ids = self.req_states.token_ids.np[req_idx, :prompt_len] + prompt_token_ids = prompt_token_ids.tolist() image_grid_thw = [] video_grid_thw = [] @@ -583,6 +578,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # batch_idx -> req_id req_ids = sorted(scheduler_output.num_scheduled_tokens, key=scheduler_output.num_scheduled_tokens.get) + num_scheduled_tokens = np.array( + [scheduler_output.num_scheduled_tokens[i] for i in req_ids], + dtype=np.int32) + # batch_idx -> req_idx idx_mapping_list = [ self.req_states.req_id_to_index[req_id] for req_id in req_ids @@ -629,11 +628,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): slot_mappings = self.block_tables.compute_slot_mappings( query_start_loc, self.positions.gpu[:total_num_scheduled_tokens]) + # Calculate M-RoPE positions. + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: self._calc_mrope_positions(req_ids, num_scheduled_tokens) - # Optimization: To avoid gather and scatter, copy the whole M-RoPE - # tensor from CPU to GPU although only a part of it is used. - self.mrope_positions.copy_to_gpu() + self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_( + self.mrope_positions.cpu[:, :total_num_scheduled_tokens], + non_blocking=True) use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 @@ -722,7 +723,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if self.cascade_attn_enabled: common_prefix_len = self._compute_cascade_attn_prefix_len( num_scheduled_tokens, - num_computed_tokens_np, num_common_prefix_blocks, kv_cache_group_spec.kv_cache_spec, builder, @@ -736,9 +736,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): for layer_name in attn_group.layer_names: attn_metadata[layer_name] = attn_metadata_i - # # Hot-Swap lora model - # if self.lora_config: - # self.set_active_loras(input_batch, num_scheduled_tokens) + # Hot-Swap lora model + if self.lora_config: + self.set_active_loras(self.input_batch, num_scheduled_tokens) return InputBatch( req_ids=req_ids, @@ -758,7 +758,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def _compute_cascade_attn_prefix_len( self, num_scheduled_tokens: np.ndarray, - num_computed_tokens: np.ndarray, num_common_prefix_blocks: int, kv_cache_spec: KVCacheSpec, attn_metadata_builder: AttentionMetadataBuilder, @@ -824,7 +823,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # and the second kernel will get an empty input. While this is not # a fundamental problem, our current implementation does not support # this case. - common_prefix_len = min(common_prefix_len, num_computed_tokens.min()) + num_reqs = len(num_scheduled_tokens) + common_prefix_len = min( + common_prefix_len, + self.input_batch.num_computed_tokens_cpu[:num_reqs].min()) # common_prefix_len should be a multiple of the block size. common_prefix_len = (common_prefix_len // kv_cache_spec.block_size * kv_cache_spec.block_size) @@ -848,6 +850,55 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) return common_prefix_len if use_cascade else 0 + def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): + mrope_pos_ptr = 0 + for index, req_id in enumerate(self.input_batch.req_ids): + req = self.requests[req_id] + assert req.mrope_positions is not None + + num_computed_tokens = \ + self.input_batch.num_computed_tokens_cpu[index] + num_scheduled_tokens = \ + scheduler_output.num_scheduled_tokens[req_id] + num_prompt_tokens = len(req.prompt_token_ids) + + if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens: + prompt_part_len = max(0, + num_prompt_tokens - num_computed_tokens) + completion_part_len = max( + 0, num_scheduled_tokens - prompt_part_len) + else: + prompt_part_len = num_scheduled_tokens + completion_part_len = 0 + + assert num_scheduled_tokens == prompt_part_len + completion_part_len + + if prompt_part_len > 0: + # prompt's mrope_positions are pre-computed + dst_start = mrope_pos_ptr + dst_end = mrope_pos_ptr + prompt_part_len + src_start = num_computed_tokens + src_end = num_computed_tokens + prompt_part_len + + self.mrope_positions.cpu[:, dst_start:dst_end] = ( + req.mrope_positions[:, src_start:src_end]) + mrope_pos_ptr += prompt_part_len + + if completion_part_len > 0: + # compute completion's mrope_positions on-the-fly + dst_start = mrope_pos_ptr + dst_end = mrope_pos_ptr + completion_part_len + + MRotaryEmbedding.get_next_input_positions_tensor( + out=self.mrope_positions.np, + out_offset=dst_start, + mrope_position_delta=req.mrope_position_delta, + context_len=num_computed_tokens + prompt_part_len, + num_new_tokens=completion_part_len, + ) + + mrope_pos_ptr += completion_part_len + def _prepare_spec_decode_metadata( self, req_ids: list[str], @@ -872,58 +923,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.input_ids.gpu, ) - def _calc_mrope_positions( - self, - req_ids: list[str], - query_lens: np.ndarray, - ): - mrope_pos_ptr = 0 - for i, req_id in enumerate(req_ids): - req_idx = self.req_states.req_id_to_index[req_id] - req_data = self.req_states.req_data[req_idx] - assert req_data.mrope_positions is not None - - num_computed_tokens = self.req_states.num_computed_tokens.np[req_idx] - num_scheduled_tokens = query_lens[i] - num_prompt_tokens = self.req_states.num_prompt_tokens.np[req_idx] - - if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens: - prompt_part_len = max(0, - num_prompt_tokens - num_computed_tokens) - completion_part_len = max( - 0, num_scheduled_tokens - prompt_part_len) - else: - prompt_part_len = num_scheduled_tokens - completion_part_len = 0 - - assert num_scheduled_tokens == prompt_part_len + completion_part_len - - if prompt_part_len > 0: - # prompt's mrope_positions are pre-computed - dst_start = mrope_pos_ptr - dst_end = mrope_pos_ptr + prompt_part_len - src_start = num_computed_tokens - src_end = num_computed_tokens + prompt_part_len - - self.mrope_positions.cpu[:, dst_start:dst_end] = ( - req_data.mrope_positions[:, src_start:src_end]) - mrope_pos_ptr += prompt_part_len - - if completion_part_len > 0: - # compute completion's mrope_positions on-the-fly - dst_start = mrope_pos_ptr - dst_end = mrope_pos_ptr + completion_part_len - - MRotaryEmbedding.get_next_input_positions_tensor( - out=self.mrope_positions.np, - out_offset=dst_start, - mrope_position_delta=req_data.mrope_position_delta, - context_len=num_computed_tokens + prompt_part_len, - num_new_tokens=completion_part_len, - ) - - mrope_pos_ptr += completion_part_len - def _prepare_kv_sharing_fast_prefill( self, logits_indices: torch.Tensor, @@ -959,14 +958,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # list of tuple (mm_hash, position_info) mm_hashes_pos = list[tuple[str, PlaceholderRange]]() for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): - req_idx = self.req_states.req_id_to_index[req_id] - req_data = self.req_states.req_data[req_idx] + req_state = self.requests[req_id] for mm_input_id in encoder_input_ids: - mm_hash = req_data.mm_hashes[mm_input_id] - mm_kwargs.append(req_data.mm_kwargs[mm_input_id]) + mm_hash = req_state.mm_hashes[mm_input_id] + mm_kwargs.append(req_state.mm_kwargs[mm_input_id]) mm_hashes_pos.append( - (mm_hash, req_data.mm_positions[mm_input_id])) + (mm_hash, req_state.mm_positions[mm_input_id])) # Batch mm inputs as much as we can: if a request in the batch has # multiple modalities or a different modality than the previous one, @@ -1008,19 +1006,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def _gather_mm_embeddings( self, - input_batch: InputBatch, + scheduler_output: "SchedulerOutput", shift_computed_tokens: int = 0, ) -> list[torch.Tensor]: mm_embeds: list[torch.Tensor] = [] - for i, req_id in enumerate(input_batch.req_ids): - num_scheduled_tokens = input_batch.num_scheduled_tokens[i] - req_idx = self.req_states.req_id_to_index[req_id] - num_computed_tokens = ( - self.req_states.num_computed_tokens.np[req_idx] + - shift_computed_tokens) - req_data = self.req_states.req_data[req_idx] - mm_positions = req_data.mm_positions - mm_hashes = req_data.mm_hashes + for req_id in self.input_batch.req_ids: + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ + req_id] + req_state = self.requests[req_id] + num_computed_tokens = \ + req_state.num_computed_tokens + shift_computed_tokens + mm_positions = req_state.mm_positions + mm_hashes = req_state.mm_hashes for i, pos_info in enumerate(mm_positions): start_pos = pos_info.offset num_encoder_tokens = pos_info.length @@ -1135,7 +1132,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # request in the batch, as the logit indices are offset by this amount. struct_out_req_batch_indices: dict[str, int] = {} cumulative_offset = 0 - seq = sorted(self.req_states.req_id_to_index.items(), key=lambda x: x[1]) + seq = sorted(self.input_batch.req_id_to_index.items(), + key=lambda x: x[1]) for req_id, batch_index in seq: logit_index = batch_index + cumulative_offset cumulative_offset += len( @@ -1260,17 +1258,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def _pool( self, hidden_states: torch.Tensor, - num_scheduled_tokens: int, - num_scheduled_tokens_np: np.ndarray, + input_batch: InputBatch, kv_connector_output: Optional[KVConnectorOutput], ) -> ModelRunnerOutput: - assert self.input_batch.num_reqs ==\ - len(self.input_batch.pooling_params), \ - "Either all or none of the requests in" \ - " a batch must be pooling request" - hidden_states = hidden_states[:num_scheduled_tokens] - pooling_metadata = self.input_batch.get_pooling_metadata() + pooling_metadata = self.req_states.get_pooling_metadata() pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(), device=hidden_states.device) seq_lens_cpu = self.seq_lens.cpu[:self.input_batch.num_reqs] @@ -1288,7 +1280,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return ModelRunnerOutput( req_ids=self.input_batch.req_ids, - req_id_to_index=self.input_batch.req_id_to_batch_idx, + req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=[], logprobs=None, prompt_logprobs_dict={}, @@ -1299,25 +1291,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def _preprocess( self, scheduler_output: "SchedulerOutput", + input_batch: InputBatch, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]: - self._update_states(scheduler_output) - if not scheduler_output.total_num_scheduled_tokens: - if not has_kv_transfer_group(): - # Return empty ModelRunnerOutput if there's no work to do. - return EMPTY_MODEL_RUNNER_OUTPUT - - return self.kv_connector_no_forward(scheduler_output, - self.vllm_config) - - # if self.cache_config.kv_sharing_fast_prefill: - # assert not self.input_batch.num_prompt_logprobs, ( - # "--kv-sharing-fast-prefill produces incorrect logprobs for " - # "prompt tokens, tokens, please disable it when the requests " - # "need prompt logprobs") - - # Prepare the decoder inputs. - input_batch = self._prepare_inputs(scheduler_output) + ) -> tuple[int, int, Optional[torch.Tensor], Optional[torch.Tensor], + Optional[torch.Tensor], torch.Tensor, + Optional[IntermediateTensors], dict[str, Any]]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE @@ -1342,8 +1320,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens) num_input_tokens += num_pad - # _prepare_inputs may reorder the batch, so we must gather multi - # modal outputs after that to ensure the correct order if self.supports_mm_inputs and get_pp_group().is_first_rank: # Run the multimodal encoder if any. self._execute_mm_encoder(scheduler_output) @@ -1386,81 +1362,26 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): intermediate_tensors = self.sync_and_slice_intermediate_tensors( num_input_tokens, intermediate_tensors, True) - uniform_decode = (input_batch.max_query_len - == self.uniform_decode_query_len - and num_scheduled_tokens - == input_batch.num_reqs * input_batch.max_query_len) - batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, - uniform_decode=uniform_decode) - cudagraph_runtime_mode, batch_descriptor = \ - self.cudagraph_dispatcher.dispatch(batch_descriptor) - - # Run the model. - # Use persistent buffers for CUDA graphs. - with set_forward_context( - input_batch.attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens, - num_tokens_across_dp=num_tokens_across_dp, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=batch_descriptor, - ), self.maybe_get_kv_connector_output( - scheduler_output) as kv_connector_output: - - model_output = self.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - **model_kwargs, - ) - - if self.use_aux_hidden_state_outputs: - hidden_states, aux_hidden_states = model_output - else: - hidden_states = model_output - aux_hidden_states = None - - # Broadcast PP output for external_launcher (torchrun) - # to make sure we are synced across pp ranks - # TODO: Support overlapping mirco-batches - # https://github.com/vllm-project/vllm/issues/18019 - broadcast_pp_output = \ - self.parallel_config.distributed_executor_backend \ - == "external_launcher" and len(get_pp_group().ranks) > 0 - if not get_pp_group().is_last_rank: - # For mid-pipeline stages, return the hidden states. - assert isinstance(hidden_states, IntermediateTensors) - if not broadcast_pp_output: - hidden_states.kv_connector_output = kv_connector_output - return hidden_states - get_pp_group().send_tensor_dict(hidden_states.tensors, - all_gather_group=get_tp_group()) - logits = None - else: - sample_hidden_states = hidden_states[input_batch.logits_indices] - logits = self.model.compute_logits(sample_hidden_states, None) - if broadcast_pp_output: - model_output_broadcast_data = { - "logits": logits.contiguous(), - } if logits is not None else {} - model_output_broadcast_data = get_pp_group().broadcast_tensor_dict( - model_output_broadcast_data, src=len(get_pp_group().ranks) - 1) - assert model_output_broadcast_data is not None - logits = model_output_broadcast_data["logits"] - - # Apply structured output bitmasks if present - if scheduler_output.grammar_bitmask is not None: - self.apply_grammar_bitmask(scheduler_output, logits) + return ( + num_scheduled_tokens, + num_input_tokens, + num_tokens_across_dp, + input_ids, + inputs_embeds, + positions, + intermediate_tensors, + model_kwargs, + ) def _sample( - self, logits: Optional[torch.Tensor], - spec_decode_metadata: Optional[SpecDecodeMetadata] + self, + logits: Optional[torch.Tensor], + input_batch: InputBatch, ) -> SamplerOutput: # Sample the next token and get logprobs if needed. - sampling_metadata = self.req_states.make_sampling_metadata( - input_batch.idx_mapping) - if input_batch.spec_decode_metadata is None: + sampling_metadata = input_batch.sampling_metadata + spec_decode_metadata = input_batch.spec_decode_metadata + if spec_decode_metadata is None: sampler_output = self.sampler( logits=logits, sampling_metadata=sampling_metadata, @@ -1471,8 +1392,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # logits tensor. This means any in-place operations on bonus_logits # won't affect the original logits tensor. assert logits is not None - bonus_logits = logits[ - input_batch.spec_decode_metadata.bonus_logits_indices] + bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] sampler_output = self.sampler( logits=bonus_logits, sampling_metadata=sampling_metadata, @@ -1482,10 +1402,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Just like `bonus_logits`, `target_logits` is a new tensor with # separate storage from the original `logits` tensor. Therefore, # it is safe to update `target_logits` in place. - target_logits = logits[ - input_batch.spec_decode_metadata.target_logits_indices] + target_logits = logits[spec_decode_metadata.target_logits_indices] output_token_ids = self.rejection_sampler( - input_batch.spec_decode_metadata, + spec_decode_metadata, None, # draft_probs target_logits, bonus_token_ids, @@ -1496,9 +1415,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return sampler_output def _bookkeeping_sync( - self, scheduler_output: "SchedulerOutput", - sampler_output: SamplerOutput, logits: Optional[torch.Tensor], - hidden_states: torch.Tensor, num_scheduled_tokens: int + self, + scheduler_output: "SchedulerOutput", + sampler_output: SamplerOutput, + logits: Optional[torch.Tensor], + hidden_states: torch.Tensor, + num_scheduled_tokens: int, ) -> tuple[ dict[str, int], Optional[LogprobsLists], @@ -1510,21 +1432,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ]: num_nans_in_logits = {} if envs.VLLM_COMPUTE_NANS_IN_LOGITS: - num_nans_in_logits = self._get_nans_in_logits( - input_batch.req_ids, logits) + num_nans_in_logits = self._get_nans_in_logits(logits) # TODO(woosuk): The following loop can be slow since it iterates over # the requests one by one. Optimize. - discard_sampled_tokens_req_indices: list[int] = [] - for i, req_id in enumerate(input_batch.req_ids): - req_idx = self.req_states.req_id_to_index[req_id] - seq_len = (self.req_states.num_computed_tokens.np[req_idx] + - input_batch.num_scheduled_tokens[i]) - if seq_len < self.req_states.num_tokens.np[req_idx]: + discard_sampled_tokens_req_indices = [] + for i, req_id in enumerate(self.input_batch.req_ids): + req_state = self.requests[req_id] + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + if seq_len < req_state.num_tokens: # Ignore the sampled token for partial prefills. # Rewind the generator state as if the token was not sampled. # This relies on cuda-specific torch-internal impl details - generator = self.req_states.generators.get(req_idx) + generator = self.input_batch.generators.get(i) if generator is not None: generator.set_offset(generator.get_offset() - 4) # Record the index of the request that should not be sampled, @@ -1568,46 +1489,127 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): for i in discard_sampled_tokens_req_indices: valid_sampled_token_ids[i].clear() else: - # Includes spec decode tokens. - valid_sampled_token_ids = self.rejection_sampler.parse_output( - sampled_token_ids, self.vocab_size) - # Mask out the sampled tokens that should not be sampled. - for i in discard_sampled_tokens_req_indices: - valid_sampled_token_ids[i].clear() + valid_sampled_token_ids = [] + invalid_req_indices = list(discard_sampled_tokens_req_indices) + invalid_req_indices_set = set(invalid_req_indices) + assert sampled_token_ids.shape[-1] == 1 + + # Cache the sampled tokens on the GPU and avoid CPU sync. + # These will be copied into input_ids in the next step + # when preparing inputs. + self.input_batch.prev_sampled_token_ids = \ + sampled_token_ids + self.input_batch.prev_sampled_token_ids_invalid_indices = \ + invalid_req_indices_set + self.input_batch.prev_req_id_to_index = { + req_id: i + for i, req_id in enumerate(self.input_batch.req_ids) + if i not in invalid_req_indices_set + } # Cache the sampled tokens in the model runner, so that the scheduler # doesn't need to send them back. # NOTE(woosuk): As an exception, when using PP, the scheduler sends # the sampled tokens back, because there's no direct communication # between the first-stage worker and the last-stage worker. - for i, req_id in enumerate(input_batch.req_ids): - sampled_ids = valid_sampled_token_ids[i] + req_ids = self.input_batch.req_ids + for req_idx in range(num_sampled_tokens): + if self.use_async_scheduling: + sampled_ids = [-1] if \ + req_idx not in invalid_req_indices_set else None + else: + sampled_ids = valid_sampled_token_ids[req_idx] if not sampled_ids: continue - req_idx = self.req_states.req_id_to_index[req_id] - start_idx = self.req_states.num_tokens.np[req_idx] + start_idx = self.input_batch.num_tokens_no_spec[req_idx] end_idx = start_idx + len(sampled_ids) assert end_idx <= self.max_model_len, ( "Sampled token IDs exceed the max model length. " f"Total number of tokens: {end_idx} > max_model_len: " f"{self.max_model_len}") - self.req_states.token_ids.np[req_idx, - start_idx:end_idx] = sampled_ids - self.req_states.num_tokens.np[req_idx] = end_idx + self.input_batch.token_ids_cpu[req_idx, + start_idx:end_idx] = sampled_ids + self.input_batch.num_tokens_no_spec[req_idx] = end_idx + self.input_batch.num_tokens[req_idx] = end_idx - if self.speculative_config: - assert input_batch.spec_decode_common_attn_metadata is not None - self._draft_token_ids = self.propose_draft_token_ids( - input_batch, - valid_sampled_token_ids, - sampling_metadata, - hidden_states, - sample_hidden_states, - aux_hidden_states, + req_id = req_ids[req_idx] + req_state = self.requests[req_id] + req_state.output_token_ids.extend(sampled_ids) + + return ( + num_nans_in_logits, + logprobs_lists, + valid_sampled_token_ids, + prompt_logprobs_dict, + req_ids_output_copy, + req_id_to_index_output_copy, + invalid_req_indices, + ) + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]: + with record_function_or_nullcontext("Preprocess"): + self._update_states(scheduler_output) + if not scheduler_output.total_num_scheduled_tokens: + if not has_kv_transfer_group(): + # Return empty ModelRunnerOutput if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + return self.kv_connector_no_forward(scheduler_output, + self.vllm_config) + if self.cache_config.kv_sharing_fast_prefill: + assert not self.input_batch.num_prompt_logprobs, ( + "--kv-sharing-fast-prefill produces incorrect logprobs for " + "prompt tokens, tokens, please disable it when the requests" + " need prompt logprobs") + + # Prepare the decoder inputs. + input_batch = self._prepare_inputs(scheduler_output) + + ( + num_scheduled_tokens, + num_input_tokens, + num_tokens_across_dp, + input_ids, + inputs_embeds, + positions, + intermediate_tensors, + model_kwargs, + ) = self._preprocess(scheduler_output, input_batch, intermediate_tensors) + + uniform_decode = (max_query_len + == self.uniform_decode_query_len) and ( + num_scheduled_tokens + == self.input_batch.num_reqs * max_query_len) + batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, + uniform_decode=uniform_decode) + cudagraph_runtime_mode, batch_descriptor = \ + self.cudagraph_dispatcher.dispatch(batch_descriptor) + + # Run the model. + # Use persistent buffers for CUDA graphs. + with (set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, + ), record_function_or_nullcontext("Forward"), + self.maybe_get_kv_connector_output(scheduler_output) as + kv_connector_output): + model_output = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, ) - self._draft_req_ids = input_batch.req_ids with record_function_or_nullcontext("Postprocess"): if self.use_aux_hidden_state_outputs: @@ -1655,9 +1657,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.apply_grammar_bitmask(scheduler_output, logits) with record_function_or_nullcontext("Sample"): - sampler_output = self._sample(logits, spec_decode_metadata) + sampler_output = self._sample(logits, input_batch) - with record_function_or_nullcontext("Bookkeep"): + with record_function_or_nullcontext("Postprocess"): assert isinstance(hidden_states, torch.Tensor) ( num_nans_in_logits, @@ -1677,7 +1679,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self._draft_token_ids = self.propose_draft_token_ids( scheduler_output, valid_sampled_token_ids, - self.input_batch.sampling_metadata, + input_batch.sampling_metadata, hidden_states, sample_hidden_states, aux_hidden_states, @@ -1712,31 +1714,30 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def take_draft_token_ids(self) -> Optional[DraftTokenIds]: if self._draft_token_ids is None: return None + req_ids = self.input_batch.req_ids if isinstance(self._draft_token_ids, torch.Tensor): draft_token_ids = self._draft_token_ids.tolist() else: draft_token_ids = self._draft_token_ids self._draft_token_ids = None - - assert self._draft_req_ids - req_ids = self._draft_req_ids - self._draft_req_ids = None return DraftTokenIds(req_ids, draft_token_ids) def propose_draft_token_ids( self, - input_batch: InputBatch, + scheduler_output: "SchedulerOutput", sampled_token_ids: list[list[int]], sampling_metadata: SamplingMetadata, hidden_states: torch.Tensor, sample_hidden_states: torch.Tensor, aux_hidden_states: Optional[torch.Tensor], + spec_decode_metadata: Optional[SpecDecodeMetadata], + common_attn_metadata: CommonAttentionMetadata, ) -> Union[list[list[int]], torch.Tensor]: - num_scheduled_tokens = input_batch.total_num_tokens + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if self.speculative_config.method == "ngram": assert isinstance(self.drafter, NgramProposer) draft_token_ids = self.propose_ngram_draft_token_ids( - input_batch, sampled_token_ids) + sampled_token_ids) elif self.speculative_config.method == "medusa": assert isinstance(self.drafter, MedusaProposer) if sample_hidden_states.shape[0] == len(sampled_token_ids): @@ -1746,7 +1747,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): indices = [] offset = 0 for num_draft, tokens in zip( - input_batch.spec_decode_metadata.num_draft_tokens, + spec_decode_metadata.num_draft_tokens, sampled_token_ids): indices.append(offset + len(tokens) - 1) offset += num_draft + 1 @@ -1760,7 +1761,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): elif self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) # TODO(woosuk): Refactor the loop. - req_ids = input_batch.req_ids + req_ids = self.input_batch.req_ids next_token_ids: list[int] = [] for i, token_ids in enumerate(sampled_token_ids): if token_ids: @@ -1770,16 +1771,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Partial prefill (rare case). # Get the next token id from the request state. req_id = req_ids[i] - req_state = self.req_states[req_id] + req_state = self.requests[req_id] seq_len = (req_state.num_computed_tokens + - input_batch.num_scheduled_tokens[i]) + scheduler_output.num_scheduled_tokens[req_id]) next_token_id = req_state.get_token_id(seq_len) next_token_ids.append(next_token_id) next_token_ids = torch.tensor(next_token_ids, dtype=torch.int32, device=self.device) - if input_batch.spec_decode_metadata is None: + if spec_decode_metadata is None: # input_ids can be None for multimodal models. target_token_ids = self.input_ids.gpu[:num_scheduled_tokens] # TODO(woosuk): Support M-RoPE. @@ -1792,8 +1793,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): target_hidden_states = hidden_states[:num_scheduled_tokens] else: # TODO(woosuk): Refactor this. - num_draft_tokens = ( - input_batch.spec_decode_metadata.num_draft_tokens) + num_draft_tokens = spec_decode_metadata.num_draft_tokens num_rejected_tokens = [ n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 for i, n in enumerate(num_draft_tokens) @@ -1802,8 +1802,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dtype=torch.int32) common_attn_metadata, token_indices =\ self.drafter.prepare_inputs( - input_batch.spec_decode_common_attn_metadata, - num_rejected_tokens_cpu) + common_attn_metadata, num_rejected_tokens_cpu) target_token_ids = self.input_ids.gpu[token_indices] # TODO(woosuk): Support M-RoPE. @@ -1815,7 +1814,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): target_hidden_states = hidden_states[token_indices] mm_embeds = None if self.supports_mm_inputs: - mm_embeds = self._gather_mm_embeddings(input_batch, + mm_embeds = self._gather_mm_embeddings(scheduler_output, shift_computed_tokens=1) draft_token_ids = self.drafter.propose( @@ -1831,10 +1830,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def propose_ngram_draft_token_ids( self, - input_batch: InputBatch, sampled_token_ids: list[list[int]], ) -> list[list[int]]: # TODO(woosuk): Optimize. + req_ids = self.input_batch.req_ids draft_token_ids: list[list[int]] = [] for i, sampled_ids in enumerate(sampled_token_ids): num_sampled_ids = len(sampled_ids) @@ -1843,21 +1842,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): draft_token_ids.append([]) continue - # # Skip requests that require sampling parameters that are not - # # supported with speculative decoding. - # req_id = input_batch.req_ids[i] - # if req_id in self.requests.spec_decode_unsupported_reqs: - # draft_token_ids.append([]) - # continue + # Skip requests that require sampling parameters that are not + # supported with speculative decoding. + req_id = req_ids[i] + if req_id in self.input_batch.spec_decode_unsupported_reqs: + draft_token_ids.append([]) + continue - num_tokens = self.req_states.num_tokens.np[i] + num_tokens = self.input_batch.num_tokens_no_spec[i] if num_tokens >= self.max_model_len: # Skip requests that have already reached the max model length. draft_token_ids.append([]) continue drafter_output = self.drafter.propose( - self.req_states.token_ids.np[i, :num_tokens]) + self.input_batch.token_ids_cpu[i, :num_tokens]) if drafter_output is None or len(drafter_output) == 0: draft_token_ids.append([]) else: @@ -1995,11 +1994,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): hidden_states: torch.Tensor, num_scheduled_tokens: dict[str, int], ) -> dict[str, Optional[LogprobsTensors]]: - num_prompt_logprobs_dict = self.req_states.num_prompt_logprobs + num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs if not num_prompt_logprobs_dict: return {} - in_progress_dict = self.req_states.in_progress_prompt_logprobs_cpu + in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {} # Since prompt logprobs are a rare feature, prioritize simple, @@ -2009,7 +2008,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_tokens = num_scheduled_tokens[req_id] # Get metadata for this request. - request = self.req_states[req_id] + request = self.requests[req_id] num_prompt_tokens = len(request.prompt_token_ids) prompt_token_ids = torch.tensor(request.prompt_token_ids).to( self.device, non_blocking=True) @@ -2086,19 +2085,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def _get_nans_in_logits( self, - req_ids: list[str], logits: Optional[torch.Tensor], ) -> dict[str, int]: try: if logits is None: - return {req_id: 0 for req_id in req_ids} + return {req_id: 0 for req_id in self.input_batch.req_ids} num_nans_in_logits = {} num_nans_for_index = logits.isnan().sum(dim=-1).cpu().numpy() - for i, req_id in enumerate(req_ids): - num_nans_in_logits[req_id] = (int(num_nans_for_index[i]) - if num_nans_for_index is not None - and i < logits.shape[0] else 0) + for req_id in self.input_batch.req_ids: + req_index = self.input_batch.req_id_to_index[req_id] + num_nans_in_logits[req_id] = ( + int(num_nans_for_index[req_index]) + if num_nans_for_index is not None + and req_index < logits.shape[0] else 0) return num_nans_in_logits except IndexError: return {} @@ -2255,18 +2255,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): 1], seq_lens=self.seq_lens.gpu[:num_reqs], seq_lens_cpu=self.seq_lens.cpu[:num_reqs], - num_computed_tokens_cpu=self.req_states.num_computed_tokens. - cpu[:num_reqs], + num_computed_tokens_cpu=self.input_batch. + num_computed_tokens_cpu_tensor[:num_reqs], num_reqs=num_reqs, num_actual_tokens=num_tokens, max_query_len=max_query_len, max_seq_len=self.max_model_len, - block_table_tensor=self.block_tables. - block_tables[kv_cache_group_id][:num_reqs], - slot_mapping=self.block_tables. - slot_mappings[kv_cache_group_id][:num_tokens], - causal=True, - ) + block_table_tensor=self.input_batch.block_table[ + kv_cache_group_id].get_device_tensor()[:num_reqs], + slot_mapping=self.input_batch. + block_table[kv_cache_group_id].slot_mapping[:num_tokens], + causal=True) for attn_group in self.attn_groups[kv_cache_group_id]: attn_metadata_i = attn_group.metadata_builder\ @@ -2380,15 +2379,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): generators={}, max_num_logprobs=None, no_penalties=True, + prompt_token_ids=None, frequency_penalties=dummy_tensors(0.1), presence_penalties=dummy_tensors(0.1), repetition_penalties=dummy_tensors(0.1), + output_token_ids=[[] for _ in range(num_reqs)], allowed_token_ids_mask=None, bad_words_token_ids={}, logitsprocs=LogitsProcessors(), - token_ids=None, - num_tokens=None, - num_prompt_tokens=None, ) try: sampler_output = self.sampler(logits=logits, @@ -2990,6 +2988,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.kv_caches) return kv_caches + def init_block_tables(self, kv_cache_config: KVCacheConfig) -> None: + block_sizes = [ + kv_cache_group.kv_cache_spec.block_size + for kv_cache_group in kv_cache_config.kv_cache_groups + ] + self.block_tables = BlockTables( + block_sizes=block_sizes, + max_num_reqs=self.max_num_reqs, + max_num_cached_reqs=self.max_num_cached_reqs, + max_num_batched_tokens=self.max_num_tokens, + max_model_len=self.max_model_len, + device=self.device, + pin_memory=self.pin_memory, + ) + def maybe_add_kv_sharing_layers_to_kv_cache_groups( self, kv_cache_config: KVCacheConfig) -> None: """ @@ -3019,21 +3032,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): else: break - def init_block_tables(self, kv_cache_config: KVCacheConfig) -> None: - block_sizes = [ - kv_cache_group.kv_cache_spec.block_size - for kv_cache_group in kv_cache_config.kv_cache_groups - ] - self.block_tables = BlockTables( - block_sizes=block_sizes, - max_num_reqs=self.max_num_reqs, - max_num_cached_reqs=2 * self.max_num_reqs, - max_num_batched_tokens=self.max_num_tokens, - max_model_len=self.max_model_len, - device=self.device, - pin_memory=self.pin_memory, - ) - def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. diff --git a/vllm/v1/worker/gpu_worker_states.py b/vllm/v1/worker/gpu_worker_states.py index 1612755f669e4..dac084ac307bd 100644 --- a/vllm/v1/worker/gpu_worker_states.py +++ b/vllm/v1/worker/gpu_worker_states.py @@ -47,78 +47,6 @@ class RequestData: ] -class SamplingStates: - - def __init__( - self, - max_num_reqs: int, - max_model_len: int, - max_num_cached_reqs: int, - vocab_size: int, - device: torch.device, - ): - self.max_num_reqs = max_num_reqs - self.max_model_len = max_model_len - self.max_num_cached_reqs = max_num_cached_reqs - self.vocab_size = vocab_size - self.device = device - - self.temperature = self._make_param(torch.float32) - self.greedy_req_indices: set[int] = set() - self.top_p = self._make_param(torch.float32) - self.top_p_req_indices: set[int] = set() - self.top_k = self._make_param(torch.int32) - self.top_k_req_indices: set[int] = set() - - self.frequency_penalties = self._make_param(torch.float32) - self.presence_penalties = self._make_param(torch.float32) - self.repetition_penalties = self._make_param(torch.float32) - self.penalty_req_indices: set[int] = set() - - self.generators: dict[int, torch.Generator] = {} - - def _make_param(self, dtype: torch.dtype) -> torch.Tensor: - return torch.zeros(self.max_num_reqs, dtype=dtype, device=self.device) - - def add_requests( - self, - req_indices: list[int], - sampling_params: list[SamplingParams], - ) -> None: - num_reqs = len(req_indices) - for i in range(num_reqs): - req_idx = req_indices[i] - sampling_param = sampling_params[i] - - temp = sampling_param.temperature - if temp == 0.0: - self.greedy_req_indices.add(req_idx) - - top_p = sampling_param.top_p - if top_p < 1.0: - self.top_p_req_indices.add(req_idx) - top_k = sampling_param.top_k - if 0 < top_k < self.vocab_size: - self.top_k_req_indices.add(req_idx) - else: - top_k = self.vocab_size - - if sampling_param.frequency_penalty != 0.0 or sampling_param.presence_penalty != 0.0 or sampling_param.repetition_penalty != 1.0: - self.penalty_req_indices.add(req_idx) - - if sampling_param.sampling_type == SamplingType.RANDOM_SEED: - generator = torch.Generator(device=self.device) - generator.manual_seed(sampling_param.seed) - self.generators[req_idx] = generator - - def remove_request(self, req_idx: int) -> None: - self.greedy_req_indices.discard(req_idx) - self.top_p_req_indices.discard(req_idx) - self.top_k_req_indices.discard(req_idx) - self.penalty_req_indices.discard(req_idx) - self.generators.pop(req_idx, None) - - class RequestState: def __init__( @@ -130,7 +58,6 @@ class RequestState: device: torch.device, pin_memory: bool, vocab_size: int, - block_sizes: list[int], # The block_size of each kv cache group logitsprocs: Optional[LogitsProcessors] = None, is_spec_decode: bool = False, is_pooling_model: bool = False, @@ -144,7 +71,6 @@ class RequestState: self.vocab_size = vocab_size self.is_spec_decode = is_spec_decode self.pooling_params = None - self.block_sizes = block_sizes self.num_prompt_logprobs: dict[int, int] = {} self.req_id_to_index: dict[str, int] = {} @@ -160,36 +86,23 @@ class RequestState: dtype=torch.int32, cpu_only=True, ) - self.num_prompt_tokens = self._make_param(torch.int32) - self.num_tokens = self._make_param(torch.int32) - self.num_computed_tokens = self._make_param(torch.int32) + self.num_prompt_tokens = np.zeros(self.max_num_cached_reqs, dtype=np.int32) + self.num_tokens = np.zeros(self.max_num_cached_reqs, dtype=np.int32) + self.num_computed_tokens = np.zeros(self.max_num_cached_reqs, dtype=np.int32) - self.sampling_states = SamplingStates( - max_num_reqs=max_num_reqs, - max_model_len=max_model_len, - max_num_cached_reqs=max_num_cached_reqs, - device=device, - ) + self.temperature = np.zeros(self.max_num_cached_reqs, dtype=np.float32) + self.greedy_req_indices: set[int] = set() + self.top_p = np.zeros(self.max_num_cached_reqs, dtype=np.float32) + self.top_p_req_indices: set[int] = set() + self.top_k = np.zeros(self.max_num_cached_reqs, dtype=np.int32) + self.top_k_req_indices: set[int] = set() - def _make_param( - self, - dtype: torch.dtype, - num_cols: int = 1, - cpu_only: bool = False, - ) -> Param: - return Param( - self.max_num_cached_reqs, - num_cols, - self.max_num_reqs if not cpu_only else 0, - dtype, - self.device, - self.pin_memory, - is_scalar=num_cols == 1, - ) + self.frequency_penalties = np.zeros(self.max_num_cached_reqs, dtype=np.float32) + self.presence_penalties = np.zeros(self.max_num_cached_reqs, dtype=np.float32) + self.repetition_penalties = np.zeros(self.max_num_cached_reqs, dtype=np.float32) + self.penalty_req_indices: set[int] = set() - @property - def num_cached_reqs(self) -> int: - return len(self.req_id_to_index) + self.generators: dict[int, torch.Generator] = {} def add_request( self, @@ -204,46 +117,31 @@ class RequestState: self.index_to_req_id[req_idx] = req_id prompt_len = len(prompt_token_ids) - self.num_prompt_tokens.np[req_idx] = prompt_len - self.num_tokens.np[req_idx] = prompt_len - self.token_ids.np[req_idx, :prompt_len] = prompt_token_ids - self.num_computed_tokens.np[req_idx] = num_computed_tokens + self.num_prompt_tokens[req_idx] = prompt_len + self.num_tokens[req_idx] = prompt_len + self.token_ids[req_idx, :prompt_len] = prompt_token_ids + self.num_computed_tokens[req_idx] = num_computed_tokens - self.temperature.np[req_idx] = sampling_params.temperature - if sampling_params.sampling_type == SamplingType.GREEDY: - # NOTE: Be careful about division by zero. - self.greedy_reqs.add(req_id) - elif sampling_params.sampling_type == SamplingType.RANDOM: - self.random_reqs.add(req_id) - - self.top_p.np[req_idx] = sampling_params.top_p - if sampling_params.top_p < 1.0: - self.top_p_reqs.add(req_id) - - top_k = sampling_params.top_k - if 0 < top_k < self.vocab_size: - self.top_k_reqs.add(req_id) + self.temperature[req_idx] = sampling_params.temperature + self.top_p[req_idx] = sampling_params.top_p + if 0 < sampling_params.top_k < self.vocab_size: + top_k = sampling_params.top_k else: top_k = self.vocab_size - self.top_k.np[req_idx] = top_k - - self.frequency_penalties.np[ - req_idx] = sampling_params.frequency_penalty - if sampling_params.frequency_penalty != 0.0: - self.frequency_penalties_reqs.add(req_id) - self.presence_penalties.np[req_idx] = sampling_params.presence_penalty - if sampling_params.presence_penalty != 0.0: - self.presence_penalties_reqs.add(req_id) - self.repetition_penalties.np[ - req_idx] = sampling_params.repetition_penalty - if sampling_params.repetition_penalty != 1.0: - self.repetition_penalties_reqs.add(req_id) + self.top_k[req_idx] = top_k + self.frequency_penalties[req_idx] = sampling_params.frequency_penalty + self.presence_penalties[req_idx] = sampling_params.presence_penalty + self.repetition_penalties[req_idx] = sampling_params.repetition_penalty if sampling_params.sampling_type == SamplingType.RANDOM_SEED: generator = torch.Generator(device=self.device) generator.manual_seed(sampling_params.seed) self.generators[req_idx] = generator + @property + def num_cached_reqs(self) -> int: + return len(self.req_id_to_index) + def append_token_ids( self, req_idx: int, @@ -262,65 +160,57 @@ class RequestState: self.index_to_req_id.pop(req_idx, None) self.free_indices.append(req_idx) - self.greedy_reqs.discard(req_id) - self.random_reqs.discard(req_id) - self.top_p_reqs.discard(req_id) - self.top_k_reqs.discard(req_id) - self.frequency_penalties_reqs.discard(req_id) - self.presence_penalties_reqs.discard(req_id) - self.repetition_penalties_reqs.discard(req_id) - self.generators.pop(req_idx, None) - def make_sampling_metadata( self, - batch_idx_to_req_idx: torch.Tensor, + idx_mapping: np.ndarray, ) -> SamplingMetadata: - batch_size = batch_idx_to_req_idx.shape[0] - if self.top_p_reqs: - top_p_buffer = self.top_p.mirror_to_gpu() - top_p = self.top_p.gpu + temperature = self.temperature[idx_mapping] + all_greedy = np.all(temperature == 0.0) + all_random = np.all(temperature != 0.0) + temperature = self._copy_np_to_gpu(temperature) + + top_p = self.top_p[idx_mapping] + no_top_p = np.all(top_p == 1.0) + top_p = self._copy_np_to_gpu(top_p) if not no_top_p else None + top_k = self.top_k[idx_mapping] + no_top_k = np.all(top_k == self.vocab_size) + top_k = self._copy_np_to_gpu(top_k) if not no_top_k else None + + frequency_penalties = self.frequency_penalties[idx_mapping] + presence_penalties = self.presence_penalties[idx_mapping] + repetition_penalties = self.repetition_penalties[idx_mapping] + no_penalties = (np.all(frequency_penalties == 0.0) and + np.all(presence_penalties == 0.0) and + np.all(repetition_penalties == 1.0)) + if no_penalties: + frequency_penalties = None + presence_penalties = None + repetition_penalties = None else: - top_p_buffer = self.top_p.gpu_buffer - top_p = None - if self.top_k_reqs: - top_k_buffer = self.top_k.mirror_to_gpu() - top_k = self.top_k.gpu + frequency_penalties = self._copy_np_to_gpu(frequency_penalties) + presence_penalties = self._copy_np_to_gpu(presence_penalties) + repetition_penalties = self._copy_np_to_gpu(repetition_penalties) + + if self.generators: + generators = { + req_idx: self.generators[req_idx] + for req_idx in idx_mapping + if req_idx in self.generators + } else: - top_k_buffer = self.top_k.gpu_buffer - top_k = None - # TODO(woosuk): Use UVA to optimize CPU -> GPU copy. - _make_sampling_metadata_kernel[(batch_size, )]( - batch_idx_to_req_idx, - self.temperature.mirror_to_gpu(), - self.temperature.gpu, - top_p_buffer, - self.top_p.gpu, - top_k_buffer, - self.top_k.gpu, - self.frequency_penalties.mirror_to_gpu(), - self.frequency_penalties.gpu, - self.presence_penalties.mirror_to_gpu(), - self.presence_penalties.gpu, - self.repetition_penalties.mirror_to_gpu(), - self.repetition_penalties.gpu, - num_warps=1, - num_stages=1, - ) - no_penalties = not (self.frequency_penalties_reqs - or self.presence_penalties_reqs - or self.repetition_penalties_reqs) + generators = {} + return SamplingMetadata( - temperature=self.temperature.gpu[:batch_size], - all_greedy=not self.random_reqs, - all_random=not self.greedy_reqs, + temperature=temperature, + all_greedy=all_greedy, + all_random=all_random, top_p=top_p, top_k=top_k, - frequency_penalties=self.frequency_penalties.gpu[:batch_size], - presence_penalties=self.presence_penalties.gpu[:batch_size], - repetition_penalties=self.repetition_penalties.gpu[:batch_size], + frequency_penalties=frequency_penalties, + presence_penalties=presence_penalties, + repetition_penalties=repetition_penalties, no_penalties=no_penalties, - # TODO - generators={}, + generators=generators, token_ids=None, num_tokens=None, num_prompt_tokens=None, @@ -330,6 +220,10 @@ class RequestState: logitsprocs=None, ) + def _copy_np_to_gpu(self, src: np.ndarray) -> torch.Tensor: + cpu_tensor = torch.from_numpy(src) + return cpu_tensor.to(device=self.device, non_blocking=True) + def make_spec_decode_metadata( self, query_start_loc: torch.Tensor, @@ -369,44 +263,6 @@ class RequestState: ) -@triton.jit -def _make_sampling_metadata_kernel( - batch_idx_to_req_idx, # [batch_size] - src_temperature, - dst_temperature, - src_top_p, - dst_top_p, - src_top_k, - dst_top_k, - src_frequency_penalties, - dst_frequency_penalties, - src_presence_penalties, - dst_presence_penalties, - src_repetition_penalties, - dst_repetition_penalties, -): - batch_idx = tl.program_id(0) - req_idx = tl.load(batch_idx_to_req_idx + batch_idx) - - temperature = tl.load(src_temperature + req_idx) - tl.store(dst_temperature + batch_idx, temperature) - - top_p = tl.load(src_top_p + req_idx) - tl.store(dst_top_p + batch_idx, top_p) - - top_k = tl.load(src_top_k + req_idx) - tl.store(dst_top_k + batch_idx, top_k) - - frequency_penalties = tl.load(src_frequency_penalties + req_idx) - tl.store(dst_frequency_penalties + batch_idx, frequency_penalties) - - presence_penalties = tl.load(src_presence_penalties + req_idx) - tl.store(dst_presence_penalties + batch_idx, presence_penalties) - - repetition_penalties = tl.load(src_repetition_penalties + req_idx) - tl.store(dst_repetition_penalties + batch_idx, repetition_penalties) - - @triton.jit def _prepare_spec_decode_kernel( query_start_loc, # [B + 1]