diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 4600e6315a369..a4834ef5d975d 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -307,6 +307,7 @@ class GPUModelRunner: # Slot mappings: [num_kv_cache_groups, num_tokens] slot_mappings = self.block_tables.compute_slot_mappings( query_start_loc_gpu, positions.gpu[:num_tokens]) + logits_indices = query_start_loc_gpu[1:] - 1 num_logits_indices = logits_indices.size(0) @@ -366,6 +367,7 @@ class GPUModelRunner: # TODO(woosuk): Support DP sampler + CUDA graphs. sample_hidden_states = hidden_states[input_batch.logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) + pos = input_batch.positions[input_batch.logits_indices] sampling_metadata = self.req_states.make_sampling_metadata( input_batch.idx_mapping_np, pos) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 50480bbd3a4aa..3ee2160a42ffe 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -8,44 +8,79 @@ from collections import defaultdict from collections.abc import Iterator from contextlib import contextmanager from copy import deepcopy -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union, cast import numpy as np import torch import torch.distributed +import torch.nn as nn from tqdm import tqdm from typing_extensions import TypeAlias import vllm.envs as envs +from vllm.attention import Attention, AttentionType from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention from vllm.compilation.counter import compilation_counter +from vllm.compilation.cuda_graph import CUDAGraphWrapper from vllm.compilation.monitor import set_cudagraph_capturing_enabled -from vllm.config import CUDAGraphMode, VllmConfig, get_layers_from_vllm_config +from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig, + get_layers_from_vllm_config, update_config) from vllm.distributed.eplb.eplb_state import EplbState -from vllm.distributed.kv_transfer import has_kv_transfer_group -from vllm.distributed.parallel_state import (get_pp_group, get_tp_group, - graph_capture, - is_global_first_rank) -from vllm.forward_context import BatchDescriptor, set_forward_context +from vllm.distributed.kv_transfer import (get_kv_transfer_group, + has_kv_transfer_group) +from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks +from vllm.distributed.parallel_state import ( + get_pp_group, get_tp_group, graph_capture, is_global_first_rank, + prepare_communication_buffer_for_model) +from vllm.forward_context import (BatchDescriptor, DPMetadata, + set_forward_context) from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.layers.mamba.abstract import MambaBase +from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding +from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader +from vllm.model_executor.models.interfaces import (is_mixture_of_experts, + supports_eagle3, + supports_mrope, + supports_transcription) +from vllm.model_executor.models.interfaces_base import ( + VllmModelForPooling, is_pooling_model, is_text_generation_model) from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.sequence import IntermediateTensors -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, check_use_alibi, - is_pin_memory_available, round_up) +from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem, + PlaceholderRange) +from vllm.multimodal.utils import group_mm_kwargs_by_modality +from vllm.pooling_params import PoolingParams +from vllm.sampling_params import SamplingType +from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.tasks import GenerationTask, PoolingTask, SupportedTask +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, + GiB_bytes, check_use_alibi, get_dtype_size, + is_pin_memory_available, + length_from_prompt_token_ids_or_embeds, round_up, + supports_dynamo) +from vllm.v1.attention.backends.flash_attn import AttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder -from vllm.v1.attention.backends.utils import (AttentionCGSupport, - CommonAttentionMetadata) +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, + create_fast_prefill_custom_backend, + reorder_batch_to_split_decodes_and_prefills, split_attn_metadata) from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher # yapf conflicts with isort for this block # yapf: disable -from vllm.v1.kv_cache_interface import (EncoderOnlyAttentionSpec, - KVCacheConfig, KVCacheSpec) +from vllm.v1.kv_cache_interface import (AttentionSpec, + ChunkedLocalAttentionSpec, + CrossAttentionSpec, + EncoderOnlyAttentionSpec, + FullAttentionSpec, KVCacheConfig, + KVCacheGroupSpec, KVCacheSpec, + MambaSpec, SlidingWindowSpec) # yapf: enable from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, - LogprobsLists, LogprobsTensors, ModelRunnerOutput, - SamplerOutput) -from vllm.v1.sample.logits_processor import LogitsProcessors + DraftTokenIds, LogprobsLists, LogprobsTensors, + ModelRunnerOutput, SamplerOutput) +from vllm.v1.pool.metadata import PoolingMetadata +from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.sampler import Sampler @@ -53,20 +88,24 @@ from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer -from vllm.v1.utils import record_function_or_nullcontext -from vllm.v1.worker.gpu_input_batch import InputBatch, prepare_inputs -from vllm.v1.worker.gpu_worker_states import RequestState +from vllm.v1.structured_output.utils import apply_grammar_bitmask +from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch +from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper from vllm.v1.worker.kv_connector_model_runner_mixin import ( KVConnectorModelRunnerMixin) from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin -from vllm.v1.worker.ubatch_splitting import get_dp_padding_ubatch +from vllm.v1.worker.ubatch_splitting import get_dp_padding_ubatch, ubatch_split from vllm.v1.worker.ubatch_utils import UBatchSlice, UBatchSlices from vllm.v1.worker.utils import is_residual_scattered_for_sp -from .utils import AttentionGroup, MultiModalBudget, bind_kv_cache +from .utils import (AttentionGroup, MultiModalBudget, + add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache, + gather_mm_placeholders, sanity_check_mm_encoder_outputs, + scatter_mm_placeholders) if TYPE_CHECKING: - + from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.v1.core.sched.output import SchedulerOutput logger = init_logger(__name__) @@ -160,6 +199,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): cache_config.cache_dtype] self.is_pooling_model = (model_config.runner_type == 'pooling') + self.enable_prompt_embeds = model_config.enable_prompt_embeds self.is_multimodal_raw_input_only_model = ( model_config.is_multimodal_raw_input_only_model) @@ -244,17 +284,34 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.rejection_sampler = RejectionSampler() # Request states. - self.max_num_cached_reqs = 2 * self.max_num_reqs - self.req_states = RequestState( + self.requests: dict[str, CachedRequestState] = {} + self.comm_stream = torch.cuda.Stream() + + # Input Batch + # NOTE(Chen): Ideally, we should initialize the input batch inside + # `initialize_kv_cache` based on the kv cache config. However, as in + # https://github.com/vllm-project/vllm/pull/18298, due to some unknown + # reasons, we have to initialize the input batch before `load_model`, + # quantization + weight offloading will fail otherwise. As a temporary + # solution, we initialize the input batch here, and re-initialize it + # in `initialize_kv_cache` if the block_sizes here is different from + # the block_sizes in the kv cache config. + self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, # We need to use the encoder length for encoder-decoer # because of KV cache for cross-attention. max_model_len=max(self.max_model_len, self.max_encoder_len), max_num_batched_tokens=self.max_num_tokens, - max_num_cached_reqs=self.max_num_cached_reqs, device=self.device, pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), + block_sizes=[self.cache_config.block_size], + is_spec_decode=bool(self.vllm_config.speculative_config), + logitsprocs=build_logitsprocs( + self.vllm_config, self.device, self.pin_memory, + self.is_pooling_model, + self.vllm_config.model_config.logits_processors), + is_pooling_model=self.is_pooling_model, ) self.use_async_scheduling = self.scheduler_config.async_scheduling @@ -288,6 +345,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.hidden_size, dtype=self.dtype, numpy=False) + self.is_token_ids = self._make_buffer(self.max_num_tokens, + dtype=torch.bool) self.discard_request_indices = self._make_buffer(self.max_num_reqs, dtype=torch.int64) self.num_discarded_requests = 0 @@ -323,6 +382,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # None in the first PP rank. The rest are set after load_model. self.intermediate_tensors: Optional[IntermediateTensors] = None + # OPTIMIZATION: Cache the tensors rather than creating them every step. + # Keep in int64 to avoid overflow with long context + self.arange_np = np.arange(max(self.max_num_reqs + 1, + self.max_model_len, + self.max_num_tokens), + dtype=np.int64) + # Layer pairings for cross-layer KV sharing. # If an Attention layer `layer_name` is in the keys of this dict, it # means this layer will perform attention using the keys and values @@ -347,6 +413,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.mm_registry, ) if self.supports_mm_inputs else None + self.reorder_batch_threshold: Optional[int] = None + # Attention layers that are only in the KVCacheConfig of the runner # (e.g., KV sharing, encoder-only attention), but not in the # KVCacheConfig of the scheduler. @@ -362,74 +430,647 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): device="cpu", pin_memory=self.pin_memory) - def _prepare_inputs( + def _make_buffer(self, + *size: Union[int, torch.SymInt], + dtype: torch.dtype, + numpy: bool = True) -> CpuGpuBuffer: + return CpuGpuBuffer(*size, + dtype=dtype, + device=self.device, + pin_memory=self.pin_memory, + with_numpy=numpy) + + def _init_model_kwargs(self, num_tokens: int): + model_kwargs = dict[str, Any]() + + if not self.is_pooling_model: + return model_kwargs + + num_reqs = self.input_batch.num_reqs + pooling_params = self.input_batch.get_pooling_params() + + token_type_id_requests = dict[int, Any]() + for i, param in enumerate(pooling_params): + if param.extra_kwargs is not None and \ + (token_types := param.extra_kwargs.get( + "compressed_token_type_ids")) is not None: + token_type_id_requests[i] = token_types + + if len(token_type_id_requests) == 0: + return model_kwargs + + seq_lens = self.seq_lens.gpu[:num_reqs] + token_type_ids = [] + + for i in range(num_reqs): + pos = token_type_id_requests.get(i, seq_lens[i]) + ids = (torch.arange(seq_lens[i]) >= pos).int() + token_type_ids.append(ids) + + model_kwargs["token_type_ids"] = torch.concat(token_type_ids).to( + device=self.device) + return model_kwargs + + def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: + """ + Update the order of requests in the batch based on the attention + backend's needs. For example, some attention backends (namely MLA) may + want to separate requests based on if the attention computation will be + compute-bound or memory-bound. + + Args: + scheduler_output: The scheduler output. + """ + # Attention free models have zero kv_cache_goups, however models + # like Mamba are also attention free but use the kv_cache for + # keeping its internal state. This is why we check the number + # of kv_cache groups instead of solely checking + # for self.model_config.is_attention_free. + if len(self.kv_cache_config.kv_cache_groups) == 0: + return + + if self.reorder_batch_threshold is not None: + # NOTE(lucas): currently no backend supports the custom masking + # required for DCP with q_len > 1, so we assert here. Remove this + # assert once the custom mask is support is added to FA3. + if self.dcp_world_size > 1: + assert self.reorder_batch_threshold == 1, \ + "DCP not support reorder_batch_threshold > 1 now." + reorder_batch_to_split_decodes_and_prefills( + self.input_batch, + scheduler_output, + decode_threshold=self.reorder_batch_threshold) + + # Note: used for model runner override. + def _init_device_properties(self) -> None: + """Initialize attributes from torch.cuda.get_device_properties + """ + self.device_properties = torch.cuda.get_device_properties(self.device) + self.num_sms = self.device_properties.multi_processor_count + + # Note: used for model runner override. + def _sync_device(self) -> None: + torch.cuda.synchronize() + + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: + """Update the cached states and the persistent batch with the scheduler + output. + + The updated states are used by the `_prepare_inputs` function to create + the input GPU tensors for the model. + + The SamplingMetadata is updated and copied to the GPU if there is a + new/resumed/paused/finished request in the batch. + """ + # Remove finished requests from the cached states. + for req_id in scheduler_output.finished_req_ids: + self.requests.pop(req_id, None) + # Remove the finished requests from the persistent batch. + # NOTE(woosuk): There could be an edge case where finished_req_ids and + # scheduled_req_ids overlap. This happens when a request is aborted and + # then resubmitted with the same ID. In this case, we treat them as two + # distinct requests - clearing the cached states for the first request + # and handling the second as a new request. + for req_id in scheduler_output.finished_req_ids: + self.input_batch.remove_request(req_id) + + # Free the cached encoder outputs. + for mm_hash in scheduler_output.free_encoder_mm_hashes: + self.encoder_cache.pop(mm_hash, None) + + # Remove the unscheduled requests from the persistent batch. + # NOTE(woosuk): The unscheduled requests are either preempted requests + # or running requests that are not scheduled in this step. We remove + # them from the persistent batch but keep their cached states since + # they will be scheduled again sometime in the future. + scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys() + cached_req_ids = self.input_batch.req_id_to_index.keys() + unscheduled_req_ids = cached_req_ids - scheduled_req_ids + # NOTE(woosuk): The persistent batch optimization assumes that + # consecutive batches contain mostly the same requests. If batches + # have low request overlap (e.g., alternating between two distinct + # sets of requests), this optimization becomes very inefficient. + for req_id in unscheduled_req_ids: + self.input_batch.remove_request(req_id) + + reqs_to_add: list[CachedRequestState] = [] + # Add new requests to the cached states. + for new_req_data in scheduler_output.scheduled_new_reqs: + req_id = new_req_data.req_id + sampling_params = new_req_data.sampling_params + pooling_params = new_req_data.pooling_params + + if sampling_params and \ + sampling_params.sampling_type == SamplingType.RANDOM_SEED: + generator = torch.Generator(device=self.device) + generator.manual_seed(sampling_params.seed) + else: + generator = None + + if self.is_pooling_model: + assert pooling_params is not None + task = pooling_params.task + assert task is not None, "You did not set `task` in the API" + + model = cast(VllmModelForPooling, self.get_model()) + to_update = model.pooler.get_pooling_updates(task) + to_update.apply(pooling_params) + + req_state = CachedRequestState( + req_id=req_id, + prompt_token_ids=new_req_data.prompt_token_ids, + prompt_embeds=new_req_data.prompt_embeds, + mm_features=new_req_data.mm_features, + sampling_params=sampling_params, + pooling_params=pooling_params, + generator=generator, + block_ids=new_req_data.block_ids, + num_computed_tokens=new_req_data.num_computed_tokens, + output_token_ids=[], + lora_request=new_req_data.lora_request, + ) + self.requests[req_id] = req_state + + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.uses_mrope: + self._init_mrope_positions(req_state) + + reqs_to_add.append(req_state) + + # Update the states of the running/resumed requests. + is_last_rank = get_pp_group().is_last_rank + req_data = scheduler_output.scheduled_cached_reqs + for i, req_id in enumerate(req_data.req_ids): + req_state = self.requests[req_id] + num_computed_tokens = req_data.num_computed_tokens[i] + new_block_ids = req_data.new_block_ids[i] + resumed_from_preemption = req_data.resumed_from_preemption[i] + + # Update the cached states. + req_state.num_computed_tokens = num_computed_tokens + + if not is_last_rank: + # When using PP, the scheduler sends the sampled tokens back, + # because there's no direct communication between the first- + # stage worker and the last-stage worker. + new_token_ids = req_data.new_token_ids[i] + # Add the sampled token(s) from the previous step (if any). + # This doesn't include "unverified" tokens like spec tokens. + num_new_tokens = (num_computed_tokens + len(new_token_ids) - + req_state.num_tokens) + if num_new_tokens == 1: + # Avoid slicing list in most common case. + req_state.output_token_ids.append(new_token_ids[-1]) + elif num_new_tokens > 0: + req_state.output_token_ids.extend( + new_token_ids[-num_new_tokens:]) + + # Update the block IDs. + if not resumed_from_preemption: + if new_block_ids is not None: + # Append the new blocks to the existing block IDs. + for block_ids, new_ids in zip(req_state.block_ids, + new_block_ids): + block_ids.extend(new_ids) + else: + assert new_block_ids is not None + # The request is resumed from preemption. + # Replace the existing block IDs with the new ones. + req_state.block_ids = new_block_ids + + req_index = self.input_batch.req_id_to_index.get(req_id) + if req_index is None: + # The request is not in the persistent batch. + # The request was either preempted and resumed later, or was not + # scheduled in the previous step and needs to be added again. + reqs_to_add.append(req_state) + continue + + # Update the persistent batch. + self.input_batch.num_computed_tokens_cpu[req_index] = ( + num_computed_tokens) + if new_block_ids is not None: + self.input_batch.block_table.append_row( + new_block_ids, req_index) + + # For the last rank, we don't need to update the token_ids_cpu + # because the sampled tokens are already cached. + if not is_last_rank: + # Add new_token_ids to token_ids_cpu. + start_token_index = num_computed_tokens + end_token_index = num_computed_tokens + len(new_token_ids) + self.input_batch.token_ids_cpu[ + req_index, + start_token_index:end_token_index] = new_token_ids + self.input_batch.num_tokens_no_spec[ + req_index] = end_token_index + self.input_batch.num_tokens[req_index] = end_token_index + + # Add spec_token_ids to token_ids_cpu. + spec_token_ids = ( + scheduler_output.scheduled_spec_decode_tokens.get(req_id, ())) + if spec_token_ids: + num_spec_tokens = len(spec_token_ids) + start_index = self.input_batch.num_tokens_no_spec[req_index] + end_token_index = start_index + num_spec_tokens + self.input_batch.token_ids_cpu[ + req_index, start_index:end_token_index] = spec_token_ids + # NOTE(woosuk): `num_tokens` here may include spec tokens. + self.input_batch.num_tokens[req_index] += num_spec_tokens + + # Add the new or resumed requests to the persistent batch. + # The smaller empty indices are filled first. + for request in reqs_to_add: + self.input_batch.add_request(request) + + # Condense the batched states if there are gaps left by removed requests + self.input_batch.condense() + # Allow attention backend to reorder the batch, potentially + self._may_reorder_batch(scheduler_output) + # Refresh batch metadata with any pending updates. + self.input_batch.refresh_metadata() + + def _update_states_after_model_execute( + self, output_token_ids: torch.Tensor) -> None: + """Update the cached states after model execution. + + This is used for MTP/EAGLE for hybrid models, as in linear attention, + only the last token's state is kept. In MTP/EAGLE, for draft tokens + the state are kept util we decide how many tokens are accepted for + each sequence, and a shifting is done during the next iteration + based on the number of accepted tokens. + """ + if not self.model_config.is_hybrid or not self.speculative_config: + return + + # Find the number of accepted tokens for each sequence. + num_accepted_tokens = (torch.cat( + [ + output_token_ids, + torch.full((output_token_ids.size(0), 1), + -1, + device=output_token_ids.device), + ], + dim=1) == -1).int().argmax(-1).cpu().numpy() + for i, num_tokens in enumerate(num_accepted_tokens): + self.input_batch.num_accepted_tokens_cpu[i] = num_tokens + + def _init_mrope_positions(self, req_state: CachedRequestState): + image_grid_thw = [] + video_grid_thw = [] + second_per_grid_ts = [] + audio_feature_lengths = [] + use_audio_in_video = False + for mm_feature in req_state.mm_features: + mm_item = mm_feature.data + if mm_item is None: + continue + mm_input = mm_item.get_data() + if (t := mm_input.get("image_grid_thw")) is not None: + image_grid_thw.append(t.tolist()) + if (t := mm_input.get("video_grid_thw")) is not None: + video_grid_thw.append(t.tolist()) + if (t := mm_input.get("second_per_grid_ts")) is not None: + second_per_grid_ts.append(t) + if (t := mm_input.get("audio_feature_lengths")) is not None: + audio_feature_lengths.append(t) + if mm_input.get("use_audio_in_video") is True: + use_audio_in_video = True + + if supports_mrope(self.model): + req_state.mrope_positions, req_state.mrope_position_delta = \ + self.model.get_mrope_input_positions( + req_state.prompt_token_ids, + hf_config=self.model_config.hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + else: + req_state.mrope_positions, req_state.mrope_position_delta = \ + MRotaryEmbedding.get_input_positions_tensor( + req_state.prompt_token_ids, + hf_config=self.model_config.hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + + def _extract_mm_kwargs( self, scheduler_output: "SchedulerOutput", - ) -> InputBatch: + ) -> BatchedTensorInputs: + if not scheduler_output or not self.is_multimodal_raw_input_only_model: + return {} + + mm_kwargs = list[MultiModalKwargsItem]() + for req in scheduler_output.scheduled_new_reqs: + for feature in req.mm_features: + if feature.data is not None: + mm_kwargs.append(feature.data) + + # Input all modalities at once + mm_kwargs_combined: BatchedTensorInputs = {} + for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + ): + mm_kwargs_combined.update(mm_kwargs_group) + + return mm_kwargs_combined + + def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs: + if not self.is_multimodal_raw_input_only_model: + return {} + + mm_budget = self.mm_budget + assert mm_budget is not None + + dummy_modality = mm_budget.get_modality_with_max_tokens() + return self._get_mm_dummy_batch(dummy_modality, num_seqs) + + def _get_cumsum_and_arange( + self, + num_tokens: np.ndarray, + cumsum_dtype: Optional[np.dtype] = None, + ) -> tuple[np.ndarray, np.ndarray]: + """Get the cumulative sum and batched arange of the given array. + # E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]) + # Equivalent to but faster than: + # np.concatenate([np.arange(n) for n in num_tokens]) + """ + # Step 1. [2, 5, 3] -> [2, 7, 10] + cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype) + total_num_tokens = cu_num_tokens[-1] + # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7] + cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens) + # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + arange = self.arange_np[:total_num_tokens] - cumsums_offsets + + return cu_num_tokens, arange + + def _prepare_input_ids(self, total_num_scheduled_tokens: int, + cu_num_tokens: np.ndarray) -> None: + """Prepare the input IDs for the current batch. + + Carefully handles the `prev_sampled_token_ids` which can be cached + from the previous engine iteration, in which case those tokens on the + GPU need to be copied into the corresponding slots into input_ids.""" + + if self.input_batch.prev_sampled_token_ids is None: + # Normal scheduling case + self.input_ids.copy_to_gpu(total_num_scheduled_tokens) + self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens) + self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens) + return + + # Async scheduling case, where some decode requests from the previous + # iteration won't have entries in input_ids_cpu and need to be copied + # on the GPU from prev_sampled_token_ids. + prev_req_id_to_index = self.input_batch.prev_req_id_to_index + assert prev_req_id_to_index is not None + flattened_indices = [] + prev_common_req_indices = [] + indices_match = True + max_flattened_index = -1 + for req_id, cur_index in self.input_batch.req_id_to_index.items(): + if (prev_index := prev_req_id_to_index.get(req_id)) is not None: + prev_common_req_indices.append(prev_index) + # We need to compute the flattened input_ids index of the + # last token in each common request. + flattened_index = cu_num_tokens[cur_index].item() - 1 + flattened_indices.append(flattened_index) + indices_match &= (prev_index == flattened_index) + max_flattened_index = max(max_flattened_index, flattened_index) + num_commmon_tokens = len(flattened_indices) + if num_commmon_tokens < total_num_scheduled_tokens: + # If not all requests are decodes from the last iteration, + # We need to copy the input_ids_cpu to the GPU first. + self.input_ids.copy_to_gpu(total_num_scheduled_tokens) + self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens) + self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens) + if num_commmon_tokens == 0: + # No requests in common with the previous iteration + # So input_ids_cpu will have all the input ids. + return + if indices_match and max_flattened_index == (num_commmon_tokens - 1): + # Common-case optimization: the batch is unchanged + # and no reordering happened. + # The indices are both the same permutation of 0..N-1 so + # we can copy directly using a single slice. + self.input_ids.gpu[:num_commmon_tokens].copy_( + self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, + 0], + non_blocking=True) + self.is_token_ids.gpu[:num_commmon_tokens] = True + return + # Upload the index tensors asynchronously + # so the scatter can be non-blocking. + input_ids_index_tensor = torch.tensor(flattened_indices, + dtype=torch.int64, + pin_memory=self.pin_memory).to( + self.device, + non_blocking=True) + prev_common_req_indices_tensor = torch.tensor( + prev_common_req_indices, + dtype=torch.int64, + pin_memory=self.pin_memory).to(self.device, non_blocking=True) + self.input_ids.gpu.scatter_( + dim=0, + index=input_ids_index_tensor, + src=self.input_batch.prev_sampled_token_ids[ + prev_common_req_indices_tensor, 0]) + + def _get_encoder_seq_lens( + self, + scheduler_output: "SchedulerOutput", + kv_cache_spec: KVCacheSpec, + num_reqs: int, + ) -> Optional[np.ndarray]: + if not isinstance(kv_cache_spec, CrossAttentionSpec): + return None + + # Build encoder_seq_lens array mapping request indices to + # encoder lengths for inputs scheduled in this batch + encoder_seq_lens = np.zeros(num_reqs, dtype=np.int32) + for req_id in scheduler_output.scheduled_encoder_inputs: + req_index = self.input_batch.req_id_to_index[req_id] + encoder_seq_lens[req_index] = self.max_encoder_len + + return encoder_seq_lens + + def _prepare_inputs( + self, scheduler_output: "SchedulerOutput" + ) -> tuple[PerLayerAttnMetadata, torch.Tensor, + Optional[SpecDecodeMetadata], np.ndarray, + Optional[CommonAttentionMetadata], int, Optional[UBatchSlices], + Optional[torch.Tensor]]: + """ + :return: tuple[ + attn_metadata: layer-to-attention_metadata mapping, + logits_indices, spec_decode_metadata + ] + """ total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 - num_reqs = len(scheduler_output.num_scheduled_tokens) - - # batch_idx -> req_id - req_ids = sorted(scheduler_output.num_scheduled_tokens, - key=scheduler_output.num_scheduled_tokens.get) - num_scheduled_tokens = np.array( - [scheduler_output.num_scheduled_tokens[i] for i in req_ids], - dtype=np.int32) - - # batch_idx -> req_idx - idx_mapping_list = [ - self.req_states.req_id_to_index[req_id] for req_id in req_ids - ] - self.idx_mapping.np[:num_reqs] = idx_mapping_list - idx_mapping_np = self.idx_mapping.np[:num_reqs] - idx_mapping = self.idx_mapping.copy_to_gpu(num_reqs) - # req_id -> batch_idx - req_id_to_batch_idx = {req_id: i for i, req_id in enumerate(req_ids)} + num_reqs = self.input_batch.num_reqs + assert num_reqs > 0 # OPTIMIZATION: Start copying the block table first. # This way, we can overlap the copy with the following CPU operations. - block_tables = self.block_tables.compute_block_tables(idx_mapping) + self.input_batch.block_table.commit_block_table(num_reqs) # Get the number of scheduled tokens for each request. - num_scheduled_tokens = np.array( - [scheduler_output.num_scheduled_tokens[i] for i in req_ids], - dtype=np.int32) + req_ids = self.input_batch.req_ids + tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] + num_scheduled_tokens = np.array(tokens, dtype=np.int32) + max_num_scheduled_tokens = max(tokens) - prepare_inputs( - idx_mapping_np, - self.req_states.token_ids.np, - self.req_states.num_computed_tokens.np, - num_scheduled_tokens, - self.input_ids.np, - self.query_start_loc.np, - self.seq_lens.np, - self.positions.np, - ) - self.input_ids.copy_to_gpu(total_num_scheduled_tokens) - self.positions.copy_to_gpu(total_num_scheduled_tokens) + # Get request indices. + # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] + req_indices = np.repeat(self.arange_np[:num_reqs], + num_scheduled_tokens) - # NOTE(woosuk): We should copy the whole query_start_loc and seq_lens - # tensors from CPU to GPU, because they may include paddings needed - # for full CUDA graph mode. - self.query_start_loc.copy_to_gpu() - self.seq_lens.copy_to_gpu() - query_start_loc = self.query_start_loc.gpu[:num_reqs + 1] - max_query_len = int(num_scheduled_tokens.max()) - seq_lens = self.seq_lens.gpu[:num_reqs] - max_seq_len = int(self.seq_lens.np[:num_reqs].max()) + # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] + # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + cu_num_tokens, arange = self._get_cumsum_and_arange( + num_scheduled_tokens) - # Compute the slot mappings on GPUs. - slot_mappings = self.block_tables.compute_slot_mappings( - query_start_loc, self.positions.gpu[:total_num_scheduled_tokens]) + # Get positions. + positions_np = self.positions.np[:total_num_scheduled_tokens] + np.add(self.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np) # Calculate M-RoPE positions. # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: - self._calc_mrope_positions(req_ids, num_scheduled_tokens) + self._calc_mrope_positions(scheduler_output) + + # Get token indices. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] + # where M is the max_model_len. + token_indices = (positions_np + + req_indices * self.input_batch.token_ids_cpu.shape[1]) + token_indices_tensor = torch.from_numpy(token_indices) + + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. + torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), + 0, + token_indices_tensor, + out=self.input_ids.cpu[:total_num_scheduled_tokens]) + is_token_ids = self.input_batch.is_token_ids.flatten() + torch.index_select( + is_token_ids, + 0, + token_indices_tensor, + out=self.is_token_ids.cpu[:total_num_scheduled_tokens]) + + # Because we did not pre-allocate a massive prompt_embeds CPU tensor on + # the InputBatch, we need to fill in the prompt embeds into the expected + # spots in the GpuModelRunner's pre-allocated prompt_embeds tensor. + if self.input_batch.req_prompt_embeds: + output_idx = 0 + for req_idx in range(num_reqs): + num_sched = num_scheduled_tokens[req_idx] + + # Skip if this request doesn't have embeddings + if req_idx not in self.input_batch.req_prompt_embeds: + output_idx += num_sched + continue + + # Skip if no tokens scheduled + if num_sched <= 0: + output_idx += num_sched + continue + + req_embeds = self.input_batch.req_prompt_embeds[req_idx] + start_pos = self.input_batch.num_computed_tokens_cpu[req_idx] + + # Skip if trying to read beyond available embeddings + if start_pos >= req_embeds.shape[0]: + output_idx += num_sched + continue + + # Copy available embeddings + end_pos = start_pos + num_sched + actual_end = min(end_pos, req_embeds.shape[0]) + actual_num_sched = actual_end - start_pos + + if actual_num_sched > 0: + self.inputs_embeds.cpu[output_idx:output_idx + + actual_num_sched].copy_( + req_embeds[start_pos:actual_end] + ) + + output_idx += num_sched + + self.input_batch.block_table.compute_slot_mapping( + req_indices, positions_np) + self.input_batch.block_table.commit_slot_mapping( + total_num_scheduled_tokens) + + # Prepare the attention metadata. + self.query_start_loc.np[0] = 0 + self.query_start_loc.np[1:num_reqs + 1] = cu_num_tokens + # Note: pad query_start_loc to be non-decreasing, as kernels + # like FlashAttention requires that + self.query_start_loc.np[num_reqs + 1:].fill(cu_num_tokens[-1]) + self.query_start_loc.copy_to_gpu() + query_start_loc = self.query_start_loc.gpu[:num_reqs + 1] + + num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens + num_tokens_padded = num_tokens_unpadded + self.get_local_padding( + num_tokens_unpadded) + ubatch_slices, num_tokens_after_padding = \ + ubatch_split(max_num_scheduled_tokens, + num_tokens_unpadded, + num_tokens_padded, + self.vllm_config) + + self.seq_lens.np[:num_reqs] = ( + self.input_batch.num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens) + # Fill unused with 0 for full cuda graph mode. + self.seq_lens.np[num_reqs:].fill(0) + self.seq_lens.copy_to_gpu() + seq_lens = self.seq_lens.gpu[:num_reqs] + max_seq_len = self.seq_lens.np[:num_reqs].max().item() + + num_tokens = [ + self.requests[r].num_tokens for r in self.input_batch.req_ids + ] + num_tokens_np = np.array(num_tokens, dtype=np.int32) + + # Record the index of requests that should not be sampled, + # so that we could clear the sampled tokens before returning + discard_requests_mask = self.seq_lens.np[:num_reqs] < num_tokens_np + discard_request_indices = np.nonzero(discard_requests_mask)[0] + self.num_discarded_requests = len(discard_request_indices) + self.discard_request_indices.np[:self.num_discarded_requests] = ( + discard_request_indices) + + self.discard_request_indices.copy_to_gpu(self.num_discarded_requests) + + # Copy the tensors to the GPU. + self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens) + + if self.uses_mrope: + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_( self.mrope_positions.cpu[:, :total_num_scheduled_tokens], non_blocking=True) + else: + # Common case (1D positions) + self.positions.copy_to_gpu(total_num_scheduled_tokens) use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 @@ -438,16 +1079,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # partial requests. While we should not sample any token # from these partial requests, we do so for simplicity. # We will ignore the sampled tokens from the partial requests. + # TODO: Support prompt logprobs. logits_indices = query_start_loc[1:] - 1 num_draft_tokens = None spec_decode_metadata = None else: # Get the number of draft tokens for each request. - spec_decode_metadata = self._prepare_spec_decode_metadata( - req_ids, - scheduler_output.scheduled_spec_decode_tokens, - query_start_loc, - ) + # Iterate over the dictionary rather than all requests since not all + # requests have draft tokens. + num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) + for req_id, draft_token_ids in ( + scheduler_output.scheduled_spec_decode_tokens.items()): + req_idx = self.input_batch.req_id_to_index[req_id] + num_draft_tokens[req_idx] = len(draft_token_ids) + + spec_decode_metadata = self._calc_spec_decode_metadata( + num_draft_tokens, cu_num_tokens) logits_indices = spec_decode_metadata.logits_indices self.num_draft_tokens.np[:num_reqs] = num_draft_tokens self.num_draft_tokens.np[num_reqs:].fill(0) @@ -458,12 +1105,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): logits_indices_padded = self._prepare_kv_sharing_fast_prefill( logits_indices) + attn_metadata: PerLayerAttnMetadata = {} + if ubatch_slices is not None: + attn_metadata = [dict() for _ in range(len(ubatch_slices))] + # Used in the below loop. query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1] seq_lens_cpu = self.seq_lens.cpu[:num_reqs] - num_computed_tokens_np = self.req_states.num_computed_tokens.np[ - idx_mapping_np] - num_computed_tokens_cpu = torch.from_numpy(num_computed_tokens_np) + num_computed_tokens_cpu = ( + self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) spec_decode_common_attn_metadata = None if use_spec_decode: self.num_accepted_tokens.np[:num_reqs] = ( @@ -471,7 +1121,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.num_accepted_tokens.np[num_reqs:].fill(1) self.num_accepted_tokens.copy_to_gpu() - attn_metadata: dict[str, Any] = {} # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( @@ -495,8 +1144,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) num_common_prefix_blocks = 0 else: - blk_table_tensor = block_tables[kv_cache_group_id] - slot_mapping = slot_mappings[kv_cache_group_id] + blk_table = self.input_batch.block_table[kv_cache_group_id] + blk_table_tensor = blk_table.get_device_tensor(num_reqs) + slot_mapping = blk_table.slot_mapping.gpu[: + total_num_scheduled_tokens] + + # Fill unused with -1. Needed for reshape_and_cache in full cuda + # graph mode. + blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_( + -1) num_common_prefix_blocks = ( scheduler_output. num_common_prefix_blocks[kv_cache_group_id]) @@ -509,7 +1165,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_computed_tokens_cpu=num_computed_tokens_cpu, num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_query_len, + max_query_len=max_num_scheduled_tokens, max_seq_len=max_seq_len, block_table_tensor=blk_table_tensor, slot_mapping=slot_mapping, @@ -570,20 +1226,616 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) - return InputBatch( - req_ids=req_ids, - num_scheduled_tokens=num_scheduled_tokens, - req_id_to_batch_idx=req_id_to_batch_idx, - idx_mapping=idx_mapping, - idx_mapping_np=idx_mapping_np, - num_reqs=num_reqs, - total_num_tokens=total_num_scheduled_tokens, - max_query_len=max_query_len, - attn_metadata=attn_metadata, - spec_decode_metadata=spec_decode_metadata, - spec_decode_common_attn_metadata=spec_decode_common_attn_metadata, + return (attn_metadata, logits_indices, spec_decode_metadata, + num_scheduled_tokens, spec_decode_common_attn_metadata, + max_num_scheduled_tokens, ubatch_slices, + num_tokens_after_padding) + + def _compute_cascade_attn_prefix_len( + self, + num_scheduled_tokens: np.ndarray, + num_common_prefix_blocks: int, + kv_cache_spec: KVCacheSpec, + attn_metadata_builder: AttentionMetadataBuilder, + ) -> int: + """Compute the length of the common prefix for cascade attention. + + NOTE(woosuk): The common prefix length returned by this function + represents the length used specifically for cascade attention, not the + actual number of tokens shared between requests. When cascade attention + is disabled (use_cascade=False), this function returns 0 even if + requests share common tokens. Additionally, the common prefix length is + truncated to a multiple of the block size and may be further truncated + due to implementation details explained below. + + Args: + num_scheduled_tokens: Number of tokens scheduled per request. + num_common_prefix_blocks: Number of shared KV cache blocks. + + Returns: + int: Length of common prefix in tokens. + """ + common_prefix_len = num_common_prefix_blocks * kv_cache_spec.block_size + if common_prefix_len == 0: + # Common case. + return 0 + + # NOTE(woosuk): Cascade attention uses two attention kernels: one + # for the common prefix and the other for the rest. For the first + # kernel, we concatenate all the query tokens (possibly from + # different requests) and treat them as if they are from the same + # request. Then, we use bi-directional attention to process the + # common prefix in the KV cache. Importantly, this means that the + # first kernel does not do any masking. + + # Consider the following example: + # Request 1's input query: [D, E, X] + # Request 1's kv cache: [A, B, C, D, E, X] + # Request 1's num_computed_tokens: 3 (i.e., [A, B, C]) + # Request 2's input query: [E, Y] + # Request 2's kv cache: [A, B, C, D, E, Y] + # Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D]) + + # If we use [A, B, C, D, E] as the common prefix, then the + # first kernel will compute the bi-directional attention between + # input query [D, E, X, E, Y] and common prefix [A, B, C, D, E]. + # However, this is wrong because D in Request 1 should not attend to + # E in the common prefix (i.e., we need masking). + # To avoid this, [A, B, C, D] should be the common prefix. + # That is, the common prefix should be capped by the minimum + # num_computed_tokens among the requests, and plus one to include + # the first token of the query. + + # In practice, we use [A, B, C] as the common prefix, instead of + # [A, B, C, D] (i.e., the common prefix is capped by the minimum + # num_computed_tokens, without plus one). + # This is because of an implementation detail: We want to always + # use two kernels for cascade attention. Let's imagine: + # Request 3's input query: [D] + # Request 3's kv cache: [A, B, C, D] + # Request 3's num_computed_tokens: 3 (i.e., [A, B, C]) + # If we use [A, B, C, D] as the common prefix for Request 1-3, + # then Request 3 will be processed only by the first kernel, + # and the second kernel will get an empty input. While this is not + # a fundamental problem, our current implementation does not support + # this case. + num_reqs = len(num_scheduled_tokens) + common_prefix_len = min( + common_prefix_len, + self.input_batch.num_computed_tokens_cpu[:num_reqs].min()) + # common_prefix_len should be a multiple of the block size. + common_prefix_len = (common_prefix_len // kv_cache_spec.block_size * + kv_cache_spec.block_size) + use_sliding_window = (isinstance(kv_cache_spec, SlidingWindowSpec) or + (isinstance(kv_cache_spec, FullAttentionSpec) + and kv_cache_spec.sliding_window is not None)) + use_local_attention = ( + isinstance(kv_cache_spec, ChunkedLocalAttentionSpec) + or (isinstance(kv_cache_spec, FullAttentionSpec) + and kv_cache_spec.attention_chunk_size is not None)) + assert isinstance(kv_cache_spec, AttentionSpec) + use_cascade = attn_metadata_builder.use_cascade_attention( + common_prefix_len=common_prefix_len, + query_lens=num_scheduled_tokens, + num_query_heads=self.num_query_heads, + num_kv_heads=kv_cache_spec.num_kv_heads, + use_alibi=self.use_alibi, + use_sliding_window=use_sliding_window, + use_local_attention=use_local_attention, + num_sms=self.num_sms, + ) + return common_prefix_len if use_cascade else 0 + + def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): + mrope_pos_ptr = 0 + for index, req_id in enumerate(self.input_batch.req_ids): + req = self.requests[req_id] + assert req.mrope_positions is not None + + num_computed_tokens = \ + self.input_batch.num_computed_tokens_cpu[index] + num_scheduled_tokens = \ + scheduler_output.num_scheduled_tokens[req_id] + num_prompt_tokens = length_from_prompt_token_ids_or_embeds( + req.prompt_token_ids, req.prompt_embeds) + + if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens: + prompt_part_len = max(0, + num_prompt_tokens - num_computed_tokens) + completion_part_len = max( + 0, num_scheduled_tokens - prompt_part_len) + else: + prompt_part_len = num_scheduled_tokens + completion_part_len = 0 + + assert num_scheduled_tokens == prompt_part_len + completion_part_len + + if prompt_part_len > 0: + # prompt's mrope_positions are pre-computed + dst_start = mrope_pos_ptr + dst_end = mrope_pos_ptr + prompt_part_len + src_start = num_computed_tokens + src_end = num_computed_tokens + prompt_part_len + + self.mrope_positions.cpu[:, dst_start:dst_end] = ( + req.mrope_positions[:, src_start:src_end]) + mrope_pos_ptr += prompt_part_len + + if completion_part_len > 0: + # compute completion's mrope_positions on-the-fly + dst_start = mrope_pos_ptr + dst_end = mrope_pos_ptr + completion_part_len + + MRotaryEmbedding.get_next_input_positions_tensor( + out=self.mrope_positions.np, + out_offset=dst_start, + mrope_position_delta=req.mrope_position_delta, + context_len=num_computed_tokens + prompt_part_len, + num_new_tokens=completion_part_len, + ) + + mrope_pos_ptr += completion_part_len + + def _calc_spec_decode_metadata( + self, + num_draft_tokens: np.ndarray, + cu_num_scheduled_tokens: np.ndarray, + ) -> SpecDecodeMetadata: + # Inputs: + # cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209] + # num_draft_tokens: [ 3, 0, 2, 0, 1] + # Outputs: + # cu_num_draft_tokens: [ 3, 3, 5, 5, 6] + # logits_indices: [ 0, 1, 2, 3, 103, 104, 105, 106, + # 206, 207, 208] + # target_logits_indices: [ 0, 1, 2, 5, 6, 9] + # bonus_logits_indices: [ 3, 4, 7, 8, 10] + + # Compute the logits indices. + # [4, 1, 3, 1, 2] + num_sampled_tokens = num_draft_tokens + 1 + + # Step 1. cu_num_sampled_tokens: [4, 5, 8, 9, 11] + # arange: [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] + cu_num_sampled_tokens, arange = self._get_cumsum_and_arange( + num_sampled_tokens, cumsum_dtype=np.int32) + # Step 2. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] + logits_indices = np.repeat( + cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens) + # Step 3. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] + logits_indices += arange + + # Compute the bonus logits indices. + bonus_logits_indices = cu_num_sampled_tokens - 1 + + # Compute the draft logits indices. + # cu_num_draft_tokens: [3, 3, 5, 5, 6] + # arange: [0, 1, 2, 0, 1, 0] + cu_num_draft_tokens, arange = self._get_cumsum_and_arange( + num_draft_tokens, cumsum_dtype=np.int32) + # [0, 0, 0, 5, 5, 9] + target_logits_indices = np.repeat( + cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens) + # [0, 1, 2, 5, 6, 9] + target_logits_indices += arange + + # TODO: Optimize the CPU -> GPU copy. + cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to( + self.device, non_blocking=True) + logits_indices = torch.from_numpy(logits_indices).to(self.device, + non_blocking=True) + target_logits_indices = torch.from_numpy(target_logits_indices).to( + self.device, non_blocking=True) + bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to( + self.device, non_blocking=True) + + # Compute the draft token ids. + # draft_token_indices: [ 1, 2, 3, 105, 106, 208] + draft_token_ids = self.input_ids.gpu[logits_indices] + draft_token_ids = draft_token_ids[target_logits_indices + 1] + + metadata = SpecDecodeMetadata( + draft_token_ids=draft_token_ids, + num_draft_tokens=num_draft_tokens.tolist(), + cu_num_draft_tokens=cu_num_draft_tokens, + target_logits_indices=target_logits_indices, + bonus_logits_indices=bonus_logits_indices, logits_indices=logits_indices, ) + return metadata + + def _prepare_kv_sharing_fast_prefill( + self, + logits_indices: torch.Tensor, + ) -> torch.Tensor: + assert self.kv_sharing_fast_prefill_logits_indices is not None + num_logits = logits_indices.shape[0] + assert num_logits > 0 + self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_( + logits_indices) + # There might have leftover indices in logits_indices[num_logits:] + # from previous iterations, whose values may be greater than the + # batch size in the current iteration. To ensure indices are always + # valid, we fill the padded indices with the last index. + self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_( + logits_indices[-1].item()) + if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and num_logits <= self.cudagraph_batch_sizes[-1]): + # Use piecewise CUDA graphs. + # Add padding to the batch size. + num_logits_padded = self.vllm_config.pad_for_cudagraph(num_logits) + else: + num_logits_padded = num_logits + logits_indices_padded = ( + self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded]) + return logits_indices_padded + + def _batch_mm_kwargs_from_scheduler( + self, + scheduler_output: "SchedulerOutput", + ) -> tuple[list[MultiModalKwargsItem], list[tuple[str, PlaceholderRange]]]: + """Batch multimodal kwargs from scheduled encoder inputs. + + Args: + scheduler_output: The scheduler output containing scheduled encoder + inputs. + + Returns: + A tuple of (mm_kwargs, req_ids_pos) where: + - mm_kwargs: List of multimodal kwargs items to be batched + - mm_hashes_pos: List of (mm_hash, position_info) tuples + """ + scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs + if not scheduled_encoder_inputs: + return [], [] + # Batch the multi-modal inputs. + mm_kwargs = list[MultiModalKwargsItem]() + # list of tuple (mm_hash, position_info) + mm_hashes_pos = list[tuple[str, PlaceholderRange]]() + for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): + req_state = self.requests[req_id] + + for mm_input_id in encoder_input_ids: + mm_feature = req_state.mm_features[mm_input_id] + mm_hash = mm_feature.identifier + mm_kwargs.append(mm_feature.data) + mm_hashes_pos.append((mm_hash, mm_feature.mm_position)) + + return mm_kwargs, mm_hashes_pos + + def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): + # Batch the multi-modal inputs using the helper method. + mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler( + scheduler_output) + + if not mm_kwargs: + return + + # Batch mm inputs as much as we can: if a request in the batch has + # multiple modalities or a different modality than the previous one, + # we process it separately to preserve item order. + # FIXME(ywang96): This is a hacky way to deal with multiple modalities + # in the same batch while still being able to benefit from batching + # multimodal inputs. The proper solution should be reordering the + # encoder outputs. + encoder_outputs = [] + for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + ): + # Run the encoder. + # `curr_group_outputs` is either of the following: + # 1. A tensor of shape (num_items, feature_size, hidden_size) + # in case feature_size is fixed across all multimodal items. + # 2. A list or tuple (length: num_items) of tensors, each of shape + # (feature_size, hidden_size) in case the feature size is dynamic + # depending on the input multimodal items. + curr_group_outputs = self.model.get_multimodal_embeddings( + **mm_kwargs_group) + + sanity_check_mm_encoder_outputs( + curr_group_outputs, + expected_num_items=num_items, + ) + + for output in curr_group_outputs: + encoder_outputs.append(output) + + # Cache the encoder outputs by mm_hash + for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs): + self.encoder_cache[mm_hash] = scatter_mm_placeholders( + output, + is_embed=pos_info.is_embed, + ) + + def _gather_mm_embeddings( + self, + scheduler_output: "SchedulerOutput", + shift_computed_tokens: int = 0, + ) -> list[torch.Tensor]: + mm_embeds: list[torch.Tensor] = [] + for req_id in self.input_batch.req_ids: + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ + req_id] + req_state = self.requests[req_id] + num_computed_tokens = \ + req_state.num_computed_tokens + shift_computed_tokens + for mm_feature in req_state.mm_features: + pos_info = mm_feature.mm_position + start_pos = pos_info.offset + num_encoder_tokens = pos_info.length + + # The encoder output is needed if the two ranges overlap: + # [num_computed_tokens, + # num_computed_tokens + num_scheduled_tokens) and + # [start_pos, start_pos + num_encoder_tokens) + if start_pos >= num_computed_tokens + num_scheduled_tokens: + # The encoder output is not needed in this step. + break + if start_pos + num_encoder_tokens <= num_computed_tokens: + # The encoder output is already processed and stored + # in the decoder's KV cache. + continue + + start_idx = max(num_computed_tokens - start_pos, 0) + end_idx = min( + num_computed_tokens - start_pos + num_scheduled_tokens, + num_encoder_tokens, + ) + assert start_idx < end_idx + + mm_hash = mm_feature.identifier + encoder_output = self.encoder_cache.get(mm_hash, None) + assert encoder_output is not None,\ + f"Encoder cache miss for {mm_hash}." + + if (is_embed := pos_info.is_embed) is not None: + is_embed = is_embed[start_idx:end_idx] + + mm_embeds_item = gather_mm_placeholders( + encoder_output[start_idx:end_idx], + is_embed=is_embed, + ) + mm_embeds.append(mm_embeds_item) + return mm_embeds + + def _extract_encoder_inputs( + self, + scheduler_output: "SchedulerOutput", + ) -> dict[str, torch.Tensor]: + """Extract encoder inputs for encoder-decoder models. + + This method extracts multimodal input features from scheduled encoder + inputs and formats them for the encoder-decoder model forward pass. + """ + # Batch the multi-modal inputs using the helper method. + mm_kwargs, _ = self._batch_mm_kwargs_from_scheduler(scheduler_output) + + if not mm_kwargs: + return {} + + # Group MM kwargs by modality and extract features + encoder_features = {} + for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + ): + # Add the grouped features to encoder_features dict + # This allows the model to receive them as kwargs (e.g., + # input_features=...) + encoder_features.update(mm_kwargs_group) + + return encoder_features + + def get_model(self) -> nn.Module: + # get raw model out of the cudagraph wrapper. + if isinstance(self.model, (CUDAGraphWrapper, UBatchWrapper)): + return self.model.unwrap() + return self.model + + def get_supported_generation_tasks(self) -> list[GenerationTask]: + model = self.get_model() + supported_tasks = list[GenerationTask]() + + if is_text_generation_model(model): + supported_tasks.append("generate") + + if supports_transcription(model): + if model.supports_transcription_only: + return ["transcription"] + + supported_tasks.append("transcription") + + return supported_tasks + + def get_supported_pooling_tasks(self) -> list[PoolingTask]: + model = self.get_model() + if not is_pooling_model(model): + return [] + + supported_tasks = list(model.pooler.get_supported_tasks()) + + if (self.scheduler_config.chunked_prefill_enabled + and "encode" in supported_tasks): + supported_tasks.remove("encode") + + logger.debug_once("Chunked prefill is not supported with " + "encode task which using ALL pooling. " + "Please turn off chunked prefill by " + "`--no-enable-chunked-prefill` before using it.") + + if "score" in supported_tasks: + num_labels = getattr(self.model_config.hf_config, "num_labels", 0) + if num_labels != 1: + supported_tasks.remove("score") + logger.debug_once( + "Score API is only enabled for num_labels == 1.") + + return supported_tasks + + def get_supported_tasks(self) -> tuple[SupportedTask, ...]: + tasks = list[SupportedTask]() + + if self.model_config.runner_type == "generate": + tasks.extend(self.get_supported_generation_tasks()) + if self.model_config.runner_type == "pooling": + tasks.extend(self.get_supported_pooling_tasks()) + + return tuple(tasks) + + def sync_and_slice_intermediate_tensors( + self, num_tokens: int, intermediate_tensors: IntermediateTensors, + sync_self: bool) -> IntermediateTensors: + + assert self.intermediate_tensors is not None + + tp = self.vllm_config.parallel_config.tensor_parallel_size + is_rs = is_residual_scattered_for_sp(self.vllm_config, num_tokens) + + # When sequence parallelism is enabled, the "residual" tensor is sharded + # across tensor parallel ranks, so each rank only needs its own slice. + if sync_self: + assert intermediate_tensors is not None + for k, v in intermediate_tensors.items(): + is_scattered = k == "residual" and is_rs + copy_len = num_tokens // tp if is_scattered else \ + num_tokens + self.intermediate_tensors[k][:copy_len].copy_( + v[:copy_len], non_blocking=True) + + return IntermediateTensors({ + k: + v[:num_tokens // + tp] if k == "residual" and is_rs else v[:num_tokens] + for k, v in self.intermediate_tensors.items() + }) + + def eplb_step(self, + is_dummy: bool = False, + is_profile: bool = False) -> None: + """ + Step for the EPLB (Expert Parallelism Load Balancing) state. + """ + if not self.parallel_config.enable_eplb: + return + + assert self.eplb_state is not None + model = self.get_model() + assert is_mixture_of_experts(model) + self.eplb_state.step( + model, + is_dummy, + is_profile, + log_stats=self.parallel_config.eplb_config.log_balancedness, + ) + + def get_dp_padding(self, + num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: + """ + Determines the total number of tokens that each rank will run. + All ranks will be padded out so that they run with the same number + of tokens + + Returns: tuple[ + num_pad_tokens: The number of tokens that will be added to the batch + num_tokens_after_padding: A tensor containing the total number of + tokens for each DP rank including padding. + ] + """ + dp_size = self.vllm_config.parallel_config.data_parallel_size + dp_rank = self.vllm_config.parallel_config.data_parallel_rank + + # For DP: Don't pad when setting enforce_eager. + # This lets us set enforce_eager on the prefiller in a P/D setup and + # still use CUDA graphs (enabled by this padding) on the decoder. + # + # TODO(tms) : There are many cases where padding is enabled for + # prefills, causing unnecessary and excessive padding of activations. + + if dp_size == 1 or self.vllm_config.model_config.enforce_eager: + # Early exit. + return 0, None + + num_tokens_across_dp = DPMetadata.num_tokens_across_dp( + num_tokens, dp_size, dp_rank) + max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item() + num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] * + dp_size, + device="cpu", + dtype=torch.int32) + return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding + + def get_local_padding(self, num_tokens_unpadded: int) -> int: + + num_tokens_padded = num_tokens_unpadded + + if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]): + # Use piecewise CUDA graphs. + # Add padding to the batch size. + num_tokens_padded = self.vllm_config.pad_for_cudagraph( + num_tokens_unpadded) + else: + # Eager mode. + # Pad tokens to multiple of tensor_parallel_size when + # enabled collective fusion for SP + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + if self.vllm_config.compilation_config.pass_config. \ + enable_sequence_parallelism and tp_size > 1: + num_tokens_padded = round_up(num_tokens_unpadded, tp_size) + + num_pad_tokens = num_tokens_padded - num_tokens_unpadded + return num_pad_tokens + + # This is where the second ubatch is adjusted to account for the padding. + # Should be called after attention metadata creation. This just pads + # the second ubatch slice out to the total number of tokens + # (num_tokens + padding) + def pad_out_ubatch_slice(self, ubatch_slices: UBatchSlices, + num_total_tokens: int): + padded_second_ubatch_slice = slice(ubatch_slices[1].token_slice.start, + num_total_tokens) + ubatch_slices[1] = UBatchSlice(padded_second_ubatch_slice, + padded_second_ubatch_slice) + + def _pool( + self, + hidden_states: torch.Tensor, + num_scheduled_tokens: int, + num_scheduled_tokens_np: np.ndarray, + ) -> ModelRunnerOutput: + assert self.input_batch.num_reqs ==\ + len(self.input_batch.pooling_params), \ + "Either all or none of the requests in" \ + " a batch must be pooling request" + + hidden_states = hidden_states[:num_scheduled_tokens] + pooling_metadata = self.input_batch.get_pooling_metadata() + pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(), + device=hidden_states.device) + seq_lens_cpu = self.seq_lens.cpu[:self.input_batch.num_reqs] + + # Pooling models D2H & synchronize occurs in pooler.py:build_output + raw_pooler_output = self.model.pooler( + hidden_states=hidden_states, pooling_metadata=pooling_metadata) + + pooler_output: list[Optional[torch.Tensor]] = [] + for raw_output, seq_len, prompt_len in zip( + raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens): + + output = raw_output.data if seq_len == prompt_len else None + pooler_output.append(output) + + return ModelRunnerOutput( + req_ids=self.input_batch.req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=[], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=pooler_output, + ) def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int: if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE @@ -607,7 +1859,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def _preprocess( self, scheduler_output: "SchedulerOutput", - input_batch: InputBatch, intermediate_tensors: Optional[IntermediateTensors] = None, ubatch_slices: Optional[UBatchSlices] = None, num_tokens_after_padding: Optional[torch.Tensor] = None, @@ -652,6 +1903,32 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): **self._init_model_kwargs(num_scheduled_tokens), **self._extract_mm_kwargs(scheduler_output), } + elif (self.enable_prompt_embeds and get_pp_group().is_first_rank): + # Get the input embeddings for the tokens that are not input embeds, + # then put them into the appropriate positions. + # TODO(qthequartermasterman): Since even when prompt embeds are + # enabled, (a) not all requests will use prompt embeds, and (b) + # after the initial prompt is processed, the rest of the generated + # tokens will be token ids, it is not desirable to have the + # embedding layer outside of the CUDA graph all the time. The v0 + # engine avoids this by "double compiling" the CUDA graph, once + # with input_ids and again with inputs_embeds, for all num_tokens. + # If a batch only has token ids, then including the embedding layer + # in the CUDA graph will be more performant (like in the else case + # below). + token_ids_idx = self.is_token_ids.gpu[:num_scheduled_tokens] \ + .nonzero(as_tuple=False) \ + .squeeze(1) + # Some tokens ids may need to become embeds + if token_ids_idx.numel() > 0: + token_ids = self.input_ids.gpu[token_ids_idx] + tokens_to_embeds = self.model.get_input_embeddings( + input_ids=token_ids) + self.inputs_embeds.gpu[token_ids_idx] = tokens_to_embeds + + inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] + model_kwargs = self._init_model_kwargs(num_input_tokens) + input_ids = None else: # For text-only models, we use token ids as input. # While it is possible to use embeddings as input just like the @@ -688,13 +1965,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) def _sample( - self, - logits: Optional[torch.Tensor], - input_batch: InputBatch, + self, logits: Optional[torch.Tensor], + spec_decode_metadata: Optional[SpecDecodeMetadata] ) -> SamplerOutput: # Sample the next token and get logprobs if needed. - sampling_metadata = input_batch.sampling_metadata - spec_decode_metadata = input_batch.spec_decode_metadata + sampling_metadata = self.input_batch.sampling_metadata if spec_decode_metadata is None: sampler_output = self.sampler( logits=logits, @@ -730,12 +2005,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return sampler_output def _bookkeeping_sync( - self, - scheduler_output: "SchedulerOutput", - sampler_output: SamplerOutput, - logits: Optional[torch.Tensor], - hidden_states: torch.Tensor, - num_scheduled_tokens: int, + self, scheduler_output: "SchedulerOutput", + sampler_output: SamplerOutput, logits: Optional[torch.Tensor], + hidden_states: torch.Tensor, num_scheduled_tokens: int ) -> tuple[ dict[str, int], Optional[LogprobsLists], @@ -835,6 +2107,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids + self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True self.input_batch.num_tokens_no_spec[req_idx] = end_idx self.input_batch.num_tokens[req_idx] = end_idx @@ -895,8 +2168,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): positions, intermediate_tensors, model_kwargs, - ) = self._preprocess(scheduler_output, input_batch, - intermediate_tensors) + ) = self._preprocess(scheduler_output, intermediate_tensors, + ubatch_slices, num_tokens_after_padding) + + if ubatch_slices is not None: + num_input_tokens = num_input_tokens // 2 uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( @@ -990,7 +2266,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): logits, self.device) with record_function_or_nullcontext("Sample"): - sampler_output = self._sample(logits, input_batch) + sampler_output = self._sample(logits, spec_decode_metadata) def propose_draft_token_ids(sampled_token_ids): assert spec_decode_common_attn_metadata is not None @@ -1027,19 +2303,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): logits, hidden_states, num_scheduled_tokens) - if self.speculative_config: - assert spec_decode_common_attn_metadata is not None - with record_function_or_nullcontext("Draft"): - self._draft_token_ids = self.propose_draft_token_ids( - scheduler_output, - valid_sampled_token_ids, - input_batch.sampling_metadata, - hidden_states, - sample_hidden_states, - aux_hidden_states, - spec_decode_metadata, - spec_decode_common_attn_metadata, - ) + if self.speculative_config and not use_padded_batch_for_eagle: + # ngram and other speculative decoding methods use the sampled + # tokens on the CPU, so they are run after bookkeeping. + propose_draft_token_ids(valid_sampled_token_ids) with record_function_or_nullcontext("EPLB"): self.eplb_step() @@ -1065,6 +2332,309 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): async_output_copy_stream=self.async_output_copy_stream, ) + def take_draft_token_ids(self) -> Optional[DraftTokenIds]: + if self._draft_token_ids is None: + return None + req_ids = self.input_batch.req_ids + if isinstance(self._draft_token_ids, torch.Tensor): + draft_token_ids = self._draft_token_ids.tolist() + else: + draft_token_ids = self._draft_token_ids + self._draft_token_ids = None + return DraftTokenIds(req_ids, draft_token_ids) + + def propose_draft_token_ids( + self, + scheduler_output: "SchedulerOutput", + sampled_token_ids: Union[torch.Tensor, list[list[int]]], + sampling_metadata: SamplingMetadata, + hidden_states: torch.Tensor, + sample_hidden_states: torch.Tensor, + aux_hidden_states: Optional[torch.Tensor], + spec_decode_metadata: Optional[SpecDecodeMetadata], + common_attn_metadata: CommonAttentionMetadata, + ) -> Union[list[list[int]], torch.Tensor]: + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + if self.speculative_config.method == "ngram": + assert isinstance(sampled_token_ids, list) + assert isinstance(self.drafter, NgramProposer) + draft_token_ids = self.propose_ngram_draft_token_ids( + sampled_token_ids) + elif self.speculative_config.method == "medusa": + assert isinstance(sampled_token_ids, list) + assert isinstance(self.drafter, MedusaProposer) + + if sample_hidden_states.shape[0] == len(sampled_token_ids): + # The input to the target model does not include draft tokens. + hidden_states = sample_hidden_states + else: + indices = [] + offset = 0 + for num_draft, tokens in zip( + spec_decode_metadata.num_draft_tokens, + sampled_token_ids): + indices.append(offset + len(tokens) - 1) + offset += num_draft + 1 + indices = torch.tensor(indices, device=self.device) + hidden_states = sample_hidden_states[indices] + + draft_token_ids = self.drafter.propose( + target_hidden_states=hidden_states, + sampling_metadata=sampling_metadata, + ) + elif self.speculative_config.use_eagle(): + assert isinstance(self.drafter, EagleProposer) + + if self.speculative_config.disable_padded_drafter_batch: + # When padded-batch is disabled, the sampled_token_ids should be + # the cpu-side list[list[int]] of valid sampled tokens for each + # request, with invalid requests having empty lists. + assert isinstance(sampled_token_ids, list), \ + "sampled_token_ids should be a python list when" \ + "padded-batch is disabled." + next_token_ids = self.drafter.prepare_next_token_ids_cpu( + sampled_token_ids, self.requests, self.input_batch, + scheduler_output.num_scheduled_tokens) + else: + # When using padded-batch, the sampled_token_ids should be + # the gpu tensor of sampled tokens for each request, of shape + # (num_reqs, num_spec_tokens + 1) with rejected tokens having + # value -1. + assert isinstance(sampled_token_ids, torch.Tensor), \ + "sampled_token_ids should be a torch.Tensor when" \ + "padded-batch is enabled." + next_token_ids, valid_sampled_tokens_count = \ + self.drafter.prepare_next_token_ids_padded( + common_attn_metadata, + sampled_token_ids, + self.requests, + self.input_batch, + self.discard_request_indices.gpu, + self.num_discarded_requests + ) + + if spec_decode_metadata is None: + token_indices_to_sample = None + # input_ids can be None for multimodal models. + target_token_ids = self.input_ids.gpu[:num_scheduled_tokens] + # TODO(woosuk): Support M-RoPE. + target_positions = self.positions.gpu[:num_scheduled_tokens] + if self.use_aux_hidden_state_outputs: + target_hidden_states = torch.cat( + [h[:num_scheduled_tokens] for h in aux_hidden_states], + dim=-1) + else: + target_hidden_states = hidden_states[:num_scheduled_tokens] + else: + if self.speculative_config.disable_padded_drafter_batch: + token_indices_to_sample = None + common_attn_metadata, token_indices =\ + self.drafter.prepare_inputs( + common_attn_metadata, + sampled_token_ids, + spec_decode_metadata.num_draft_tokens) + else: + common_attn_metadata, token_indices, \ + token_indices_to_sample =\ + self.drafter.prepare_inputs_padded( + common_attn_metadata, + spec_decode_metadata, + valid_sampled_tokens_count) + + target_token_ids = self.input_ids.gpu[token_indices] + # TODO(woosuk): Support M-RoPE. + target_positions = self.positions.gpu[token_indices] + if self.use_aux_hidden_state_outputs: + target_hidden_states = torch.cat( + [h[token_indices] for h in aux_hidden_states], dim=-1) + else: + target_hidden_states = hidden_states[token_indices] + mm_embeds = None + if self.supports_mm_inputs: + mm_embeds = self._gather_mm_embeddings(scheduler_output, + shift_computed_tokens=1) + + draft_token_ids = self.drafter.propose( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + next_token_ids=next_token_ids, + last_token_indices=token_indices_to_sample, + sampling_metadata=sampling_metadata, + common_attn_metadata=common_attn_metadata, + mm_embeds=mm_embeds, + ) + return draft_token_ids + + def propose_ngram_draft_token_ids( + self, + sampled_token_ids: list[list[int]], + ) -> list[list[int]]: + # TODO(woosuk): Optimize. + req_ids = self.input_batch.req_ids + draft_token_ids: list[list[int]] = [] + for i, sampled_ids in enumerate(sampled_token_ids): + num_sampled_ids = len(sampled_ids) + if not num_sampled_ids: + # Skip speculative decoding. + draft_token_ids.append([]) + continue + + # Skip requests that require sampling parameters that are not + # supported with speculative decoding. + req_id = req_ids[i] + if req_id in self.input_batch.spec_decode_unsupported_reqs: + draft_token_ids.append([]) + continue + + num_tokens = self.input_batch.num_tokens_no_spec[i] + if num_tokens >= self.max_model_len: + # Skip requests that have already reached the max model length. + draft_token_ids.append([]) + continue + + drafter_output = self.drafter.propose( + self.input_batch.token_ids_cpu[i, :num_tokens]) + if drafter_output is None or len(drafter_output) == 0: + draft_token_ids.append([]) + else: + draft_token_ids.append(drafter_output.tolist()) + return draft_token_ids + + def update_config(self, overrides: dict[str, Any]) -> None: + allowed_config_names = {"load_config", "model_config"} + for config_name, config_overrides in overrides.items(): + assert config_name in allowed_config_names, \ + f"Config `{config_name}` not supported. " \ + f"Allowed configs: {allowed_config_names}" + config = getattr(self, config_name) + new_config = update_config(config, config_overrides) + setattr(self, config_name, new_config) + + def load_model(self, eep_scale_up: bool = False) -> None: + """ + Args: + eep_scale_up: the model loading is for elastic EP scale up. + """ + logger.info("Starting to load model %s...", self.model_config.model) + if eep_scale_up: + from vllm.distributed.parallel_state import get_ep_group + num_local_physical_experts = torch.empty(1, + dtype=torch.int32, + device="cpu") + torch.distributed.broadcast(num_local_physical_experts, + group=get_ep_group().cpu_group, + group_src=0) + num_local_physical_experts = int(num_local_physical_experts.item()) + new_ep_size = get_ep_group().world_size + global_expert_load, old_global_expert_indices = ( + EplbState.recv_state()) + num_logical_experts = global_expert_load.shape[1] + self.parallel_config.eplb_config.num_redundant_experts = ( + num_local_physical_experts * new_ep_size - num_logical_experts) + assert old_global_expert_indices.shape[ + 1] % num_local_physical_experts == 0 + old_ep_size = old_global_expert_indices.shape[ + 1] // num_local_physical_experts + rank_mapping = { + old_ep_rank: old_ep_rank + for old_ep_rank in range(old_ep_size) + } + else: + global_expert_load = None + old_global_expert_indices = None + rank_mapping = None + + with DeviceMemoryProfiler() as m: + time_before_load = time.perf_counter() + model_loader = get_model_loader(self.load_config) + logger.info("Loading model from scratch...") + self.model = model_loader.load_model( + vllm_config=self.vllm_config, model_config=self.model_config) + if self.lora_config: + self.model = self.load_lora_model(self.model, + self.model_config, + self.scheduler_config, + self.lora_config, + self.device) + if hasattr(self, "drafter"): + logger.info("Loading drafter model...") + self.drafter.load_model(self.model) + if self.use_aux_hidden_state_outputs: + if supports_eagle3(self.model): + self.model.set_aux_hidden_state_layers( + self.model.get_eagle3_aux_hidden_state_layers()) + else: + raise RuntimeError( + "Model does not support EAGLE3 interface but " + "aux_hidden_state_outputs was requested") + time_after_load = time.perf_counter() + self.model_memory_usage = m.consumed_memory + logger.info("Model loading took %.4f GiB and %.6f seconds", + self.model_memory_usage / GiB_bytes, + time_after_load - time_before_load) + prepare_communication_buffer_for_model(self.model) + + if is_mixture_of_experts( + self.model) and self.parallel_config.enable_eplb: + logger.info("EPLB is enabled for model %s.", + self.model_config.model) + self.eplb_state = EplbState.build( + self.model, + self.device, + self.parallel_config, + global_expert_load, + old_global_expert_indices, + rank_mapping, + ) + + if ( + self.vllm_config.compilation_config.level == \ + CompilationLevel.DYNAMO_AS_IS and supports_dynamo() + ): + backend = self.vllm_config.compilation_config.init_backend( + self.vllm_config) + compilation_counter.dynamo_as_is_count += 1 + self.model.compile( + fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + backend=backend) + return + # for other compilation levels, cudagraph behavior is controlled by + # CudagraphWraper and CudagraphDispatcher of vllm. + + # wrap the model with full cudagraph wrapper if needed. + if self.compilation_config.cudagraph_mode.has_full_cudagraphs() \ + and not self.parallel_config.enable_dbo: + self.model = CUDAGraphWrapper(self.model, + self.vllm_config, + runtime_mode=CUDAGraphMode.FULL) + elif self.parallel_config.enable_dbo: + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): + self.model = UBatchWrapper(self.model, self.vllm_config, + CUDAGraphMode.FULL, self.device) + else: + self.model = UBatchWrapper(self.model, self.vllm_config, + CUDAGraphMode.NONE, self.device) + + def reload_weights(self) -> None: + assert getattr(self, "model", None) is not None, \ + "Cannot reload weights before model is loaded." + model_loader = get_model_loader(self.load_config) + logger.info("Reloading weights inplace...") + model = self.get_model() + model_loader.load_weights(model, model_config=self.model_config) + + def save_tensorized_model( + self, + tensorizer_config: "TensorizerConfig", + ) -> None: + model = self.get_model() + TensorizerLoader.save_model( + model, + tensorizer_config=tensorizer_config, + model_config=self.model_config, + ) + def _get_prompt_logprobs_dict( self, hidden_states: torch.Tensor, @@ -1085,6 +2655,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Get metadata for this request. request = self.requests[req_id] + if request.prompt_token_ids is None: + # Prompt logprobs is incompatible with prompt embeddings + continue + num_prompt_tokens = len(request.prompt_token_ids) prompt_token_ids = torch.tensor(request.prompt_token_ids).to( self.device, non_blocking=True) @@ -1159,6 +2733,82 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return prompt_logprobs_dict + def _get_nans_in_logits( + self, + logits: Optional[torch.Tensor], + ) -> dict[str, int]: + try: + if logits is None: + return {req_id: 0 for req_id in self.input_batch.req_ids} + + num_nans_in_logits = {} + num_nans_for_index = logits.isnan().sum(dim=-1).cpu().numpy() + for req_id in self.input_batch.req_ids: + req_index = self.input_batch.req_id_to_index[req_id] + num_nans_in_logits[req_id] = ( + int(num_nans_for_index[req_index]) + if num_nans_for_index is not None + and req_index < logits.shape[0] else 0) + return num_nans_in_logits + except IndexError: + return {} + + @contextmanager + def maybe_randomize_inputs(self, input_ids: torch.Tensor): + """ + Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set. + This is to help balance expert-selection + - during profile_run + - during DP rank dummy run + """ + dp_size = self.vllm_config.parallel_config.data_parallel_size + randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1 + if not randomize_inputs: + yield + else: + import functools + + @functools.cache + def rand_input_ids() -> torch.Tensor: + return torch.randint_like( + self.input_ids.gpu, + low=0, + high=self.model_config.get_vocab_size(), + dtype=input_ids.dtype) + + logger.debug_once("Randomizing dummy data for DP Rank") + input_ids.copy_(rand_input_ids()[:input_ids.size(0)], + non_blocking=True) + yield + input_ids.fill_(0) + + def _get_mm_dummy_batch( + self, + modality: str, + max_items_per_batch: int, + ) -> BatchedTensorInputs: + """Dummy data for profiling and precompiling multimodal models.""" + assert self.mm_budget is not None + + dummy_decoder_data = self.mm_registry.get_decoder_dummy_data( + model_config=self.model_config, + seq_len=self.max_num_tokens, + mm_counts={modality: 1}, + cache=self.mm_budget.cache, + ) + dummy_mm_data = dummy_decoder_data.multi_modal_data + + # Result in the maximum GPU consumption of the model + dummy_mm_item = dummy_mm_data[modality][0] + dummy_mm_items = [dummy_mm_item] * max_items_per_batch + + return next(mm_kwargs_group + for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( + dummy_mm_items, + device=self.device, + pin_memory=self.pin_memory, + )) + @torch.inference_mode() def _dummy_run( self, @@ -1361,6 +3011,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): **model_kwargs, **self._dummy_mm_kwargs(num_reqs), } + elif self.enable_prompt_embeds: + input_ids = None + inputs_embeds = self.inputs_embeds.gpu[:num_tokens] + model_kwargs = self._init_model_kwargs(num_tokens) else: input_ids = self.input_ids.gpu[:num_tokens] inputs_embeds = None @@ -1510,12 +3164,135 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) return sampler_output + def _dummy_pooler_run_task( + self, + hidden_states: torch.Tensor, + task: PoolingTask, + ) -> PoolerOutput: + num_tokens = hidden_states.shape[0] + max_num_reqs = self.scheduler_config.max_num_seqs + num_reqs = min(num_tokens, max_num_reqs) + min_tokens_per_req = num_tokens // num_reqs + num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + assert sum(num_scheduled_tokens_list) == num_tokens + assert len(num_scheduled_tokens_list) == num_reqs + + req_num_tokens = num_tokens // num_reqs + + dummy_prompt_lens = torch.tensor( + num_scheduled_tokens_list, + device="cpu", + ) + dummy_token_ids = torch.zeros((num_reqs, req_num_tokens), + dtype=torch.int32, + device=self.device) + + model = cast(VllmModelForPooling, self.get_model()) + dummy_pooling_params = PoolingParams(task=task) + dummy_pooling_params.verify(task=task, model_config=self.model_config) + to_update = model.pooler.get_pooling_updates(task) + to_update.apply(dummy_pooling_params) + + dummy_metadata = PoolingMetadata( + prompt_lens=dummy_prompt_lens, + prompt_token_ids=dummy_token_ids, + pooling_params=[dummy_pooling_params] * num_reqs, + ) + + dummy_metadata.build_pooling_cursor(num_scheduled_tokens_list, + device=hidden_states.device) + + try: + return model.pooler(hidden_states=hidden_states, + pooling_metadata=dummy_metadata) + except RuntimeError as e: + if 'out of memory' in str(e): + raise RuntimeError( + "CUDA out of memory occurred when warming up pooler " + f"({task=}) with {num_reqs} dummy requests. Please try " + "lowering `max_num_seqs` or `gpu_memory_utilization` when " + "initializing the engine.") from e + else: + raise e + + @torch.inference_mode() + def _dummy_pooler_run( + self, + hidden_states: torch.Tensor, + ) -> PoolerOutput: + # Find the task that has the largest output for subsequent steps + output_size = dict[PoolingTask, float]() + for task in self.get_supported_pooling_tasks(): + # Run a full batch with each task to ensure none of them OOMs + output = self._dummy_pooler_run_task(hidden_states, task) + output_size[task] = output.get_data_nbytes() + del output # Allow GC + + max_task = max(output_size.items(), key=lambda x: x[1])[0] + return self._dummy_pooler_run_task(hidden_states, max_task) + def profile_run(self) -> None: + # Profile with multimodal encoder & encoder cache. + if self.supports_mm_inputs: + if self.model_config.multimodal_config.skip_mm_profiling: + logger.info( + "Skipping memory profiling for multimodal encoder and " + "encoder cache.") + else: + mm_budget = self.mm_budget + assert mm_budget is not None + + if (encoder_budget := mm_budget.get_encoder_budget()) > 0: + # NOTE: Currently model is profiled with a single non-text + # modality with the max possible input tokens even when + # it supports multiple. + dummy_modality = mm_budget.get_modality_with_max_tokens() + max_mm_items_per_batch = mm_budget \ + .max_items_per_batch_by_modality[dummy_modality] + + logger.info( + "Encoder cache will be initialized with a budget of " + "%s tokens, and profiled with %s %s items of the " + "maximum feature size.", + encoder_budget, + max_mm_items_per_batch, + dummy_modality, + ) + + # Create dummy batch of multimodal inputs. + batched_dummy_mm_inputs = self._get_mm_dummy_batch( + dummy_modality, + max_mm_items_per_batch, + ) + + # Run multimodal encoder. + dummy_encoder_outputs = \ + self.model.get_multimodal_embeddings( + **batched_dummy_mm_inputs) + + sanity_check_mm_encoder_outputs( + dummy_encoder_outputs, + expected_num_items=max_mm_items_per_batch, + ) + + # Cache the dummy encoder outputs. + self.encoder_cache["tmp"] = dict( + enumerate(dummy_encoder_outputs)) + # Add `is_profile` here to pre-allocate communication buffers hidden_states, last_hidden_states \ = self._dummy_run(self.max_num_tokens, is_profile=True) - output = self._dummy_sampler_run(last_hidden_states) + if get_pp_group().is_last_rank: + if self.is_pooling_model: + output = self._dummy_pooler_run(hidden_states) + else: + output = self._dummy_sampler_run(last_hidden_states) + else: + output = None + self._sync_device() del hidden_states, output + self.encoder_cache.clear() gc.collect() def capture_model(self) -> int: @@ -1687,6 +3464,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): for layer_name in layer_names: attn_backend = layers[layer_name].get_attn_backend() + if layer_name in self.kv_sharing_fast_prefill_eligible_layers: + attn_backend = create_fast_prefill_custom_backend( + "FastPrefill", + attn_backend, + ) + key = attn_backend.full_cls_name() attn_backends[key] = attn_backend attn_backend_layers[key].append(layer_name) @@ -1729,6 +3512,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.attn_groups.append( create_attn_groups(attn_backends, kv_cache_spec)) + # Calculate reorder batch threshold (if needed) + self.calculate_reorder_batch_threshold() + def initialize_cudagraph_capture(self) -> None: min_cg_support = AttentionCGSupport.ALWAYS min_cg_builder_name = None @@ -1797,6 +3583,65 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.compilation_config.cudagraph_mode, self.uniform_decode_query_len) + def calculate_reorder_batch_threshold(self) -> None: + """ + Check that if any backends reorder batches; that the reordering + is compatible (e.g., decode threshold is the same) + """ + for group in self._attn_group_iterator(): + attn_metadata_builder_i = group.get_metadata_builder() + + # check that if any backends reorder batches; that the reordering + # is compatible (e.g., decode threshold is the same) + reorder_batch_threshold_i = ( + attn_metadata_builder_i.reorder_batch_threshold) + if reorder_batch_threshold_i is not None: + if self.reorder_batch_threshold is not None: + if reorder_batch_threshold_i != \ + self.reorder_batch_threshold: + raise ValueError( + f"Attention backend reorders decodes with " + f"threshold {reorder_batch_threshold_i} but other " + f"backend uses threshold " + f"{self.reorder_batch_threshold}") + else: + self.reorder_batch_threshold = reorder_batch_threshold_i + + def may_reinitialize_input_batch(self, + kv_cache_config: KVCacheConfig) -> None: + """ + Re-initialize the input batch if the block sizes are different from + `[self.cache_config.block_size]`. This usually happens when there + are multiple KV cache groups. + + Args: + kv_cache_config: The KV cache configuration. + """ + block_sizes = [ + kv_cache_group.kv_cache_spec.block_size + for kv_cache_group in kv_cache_config.kv_cache_groups + ] + if block_sizes != [self.cache_config.block_size]: + assert self.cache_config.cpu_offload_gb == 0, ( + "Cannot re-initialize the input batch when CPU weight " + "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501 + "for more details.") + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=max(self.max_model_len, self.max_encoder_len), + max_num_batched_tokens=self.max_num_tokens, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=self.model_config.get_vocab_size(), + block_sizes=block_sizes, + is_spec_decode=bool(self.vllm_config.speculative_config), + logitsprocs=self.input_batch.logitsprocs, + is_pooling_model=self.is_pooling_model, + num_speculative_tokens=( + self.vllm_config.speculative_config.num_speculative_tokens + if self.vllm_config.speculative_config else 0), + ) + def _allocate_kv_cache_tensors( self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: """ @@ -1856,6 +3701,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): corresponding memory buffer for KV cache. """ kv_caches: dict[str, torch.Tensor] = {} + has_attn, has_mamba = False, False for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator(): attn_backend = group.backend for layer_name in group.layer_names: @@ -1865,34 +3711,91 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 num_blocks = (raw_tensor.numel() // kv_cache_spec.page_size_bytes) + if isinstance(kv_cache_spec, AttentionSpec): + has_attn = True + kv_cache_shape = attn_backend.get_kv_cache_shape( + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + dtype = kv_cache_spec.dtype + try: + kv_cache_stride_order = \ + attn_backend.get_kv_cache_stride_order() + assert len(kv_cache_stride_order) == len( + kv_cache_shape) + except (AttributeError, NotImplementedError): + kv_cache_stride_order = tuple( + range(len(kv_cache_shape))) + # The allocation respects the backend-defined stride order + # to ensure the semantic remains consistent for each + # backend. We first obtain the generic kv cache shape and + # then permute it according to the stride order which could + # result in a non-contiguous tensor. + kv_cache_shape = tuple(kv_cache_shape[i] + for i in kv_cache_stride_order) + # Maintain original KV shape view. + inv_order = [ + kv_cache_stride_order.index(i) + for i in range(len(kv_cache_stride_order)) + ] + kv_caches[layer_name] = kv_cache_raw_tensors[ + layer_name].view(dtype).view(kv_cache_shape).permute( + *inv_order) + elif isinstance(kv_cache_spec, MambaSpec): + has_mamba = True + raw_tensor = kv_cache_raw_tensors[layer_name] + state_tensors = [] + storage_offset_bytes = 0 + for (shape, dtype) in zip(kv_cache_spec.shapes, + kv_cache_spec.dtypes): + dtype_size = get_dtype_size(dtype) + num_element_per_page = ( + kv_cache_spec.page_size_bytes // dtype_size) + target_shape = (num_blocks, *shape) + stride = torch.empty(target_shape).stride() + target_stride = (num_element_per_page, *stride[1:]) + assert storage_offset_bytes % dtype_size == 0 + tensor = torch.as_strided( + raw_tensor.view(dtype), + size=target_shape, + stride=target_stride, + storage_offset=storage_offset_bytes // dtype_size, + ) + state_tensors.append(tensor) + storage_offset_bytes += stride[0] * dtype_size - kv_cache_shape = attn_backend.get_kv_cache_shape( - num_blocks, kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) - dtype = kv_cache_spec.dtype - try: - kv_cache_stride_order = \ - attn_backend.get_kv_cache_stride_order() - assert len(kv_cache_stride_order) == len(kv_cache_shape) - except (AttributeError, NotImplementedError): - kv_cache_stride_order = tuple(range(len(kv_cache_shape))) - # The allocation respects the backend-defined stride order - # to ensure the semantic remains consistent for each - # backend. We first obtain the generic kv cache shape and - # then permute it according to the stride order which could - # result in a non-contiguous tensor. - kv_cache_shape = tuple(kv_cache_shape[i] - for i in kv_cache_stride_order) - # Maintain original KV shape view. - inv_order = [ - kv_cache_stride_order.index(i) - for i in range(len(kv_cache_stride_order)) - ] - kv_caches[layer_name] = kv_cache_raw_tensors[layer_name].view( - dtype).view(kv_cache_shape).permute(*inv_order) + kv_caches[layer_name] = state_tensors + else: + raise NotImplementedError + + if has_attn and has_mamba: + self._update_hybrid_attention_mamba_layout(kv_caches) return kv_caches + def _update_hybrid_attention_mamba_layout( + self, kv_caches: dict[str, torch.Tensor]) -> None: + """ + Update the layout of attention layers from (2, num_blocks, ...) to + (num_blocks, 2, ...). + + Args: + kv_caches: The KV cache buffer of each layer. + """ + + for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator(): + for layer_name in group.layer_names: + kv_cache = kv_caches[layer_name] + if (isinstance(kv_cache_spec, AttentionSpec) + and kv_cache.shape[0] == 2): + assert kv_cache.shape[1] != 2, \ + "Fail to determine whether the layout is " \ + "(2, num_blocks, ...) or (num_blocks, 2, ...) for " \ + f"a tensor of shape {kv_cache.shape}" + hidden_size = kv_cache.shape[2:].numel() + kv_cache.as_strided_(size=kv_cache.shape, + stride=(hidden_size, 2 * hidden_size, + *kv_cache.stride()[2:])) + def initialize_kv_cache_tensors( self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: """ @@ -1910,11 +3813,47 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): kv_caches = self._reshape_kv_cache_tensors(kv_cache_config, kv_cache_raw_tensors) + # Set up cross-layer KV cache sharing + for layer_name, target_layer_name in self.shared_kv_cache_layers.items( + ): + logger.debug("%s reuses KV cache of %s", layer_name, + target_layer_name) + kv_caches[layer_name] = kv_caches[target_layer_name] + bind_kv_cache(kv_caches, self.compilation_config.static_forward_context, self.kv_caches) return kv_caches + def maybe_add_kv_sharing_layers_to_kv_cache_groups( + self, kv_cache_config: KVCacheConfig) -> None: + """ + Add layers that re-use KV cache to KV cache group of its target layer. + Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()` + """ + if not self.shared_kv_cache_layers: + # No cross-layer KV sharing, return + return + + add_kv_sharing_layers_to_kv_cache_groups( + self.shared_kv_cache_layers, + kv_cache_config.kv_cache_groups, + self.runner_only_attn_layers, + ) + + if self.cache_config.kv_sharing_fast_prefill: + # In You Only Cache Once (https://arxiv.org/abs/2405.05254) or other + # similar KV sharing setups, only the layers that generate KV caches + # are involved in the prefill phase, enabling prefill to early exit. + attn_layers = get_layers_from_vllm_config(self.vllm_config, + Attention) + for layer_name in reversed(attn_layers): + if layer_name in self.shared_kv_cache_layers: + self.kv_sharing_fast_prefill_eligible_layers.add( + layer_name) + else: + break + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. @@ -1924,9 +3863,162 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): """ kv_cache_config = deepcopy(kv_cache_config) self.kv_cache_config = kv_cache_config + self.may_reinitialize_input_batch(kv_cache_config) + self.may_add_encoder_only_layers_to_kv_cache_config() + self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) self.initialize_attn_backend(kv_cache_config) kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) + if self.speculative_config and self.speculative_config.use_eagle(): + assert isinstance(self.drafter, EagleProposer) + # validate all draft model layers belong to the same kv cache + # group + self.drafter.validate_same_kv_cache_group(kv_cache_config) + + if has_kv_transfer_group(): + get_kv_transfer_group().register_kv_caches(kv_caches) + if self.device.type == 'xpu': + get_kv_transfer_group().set_host_xfer_buffer_ops( + copy_kv_blocks) + + if self.dcp_world_size > 1: + layer_names = self.attn_groups[0][0].layer_names + layers = get_layers_from_vllm_config(self.vllm_config, + AttentionLayerBase, + layer_names) + for layer in layers.values(): + assert layer.impl.need_to_return_lse_for_decode, ( + "DCP requires attention impls to return" + " the softmax lse for decode, but the impl " + f"{layer.impl.__class__.__name__} " + "does not return the softmax lse for decode.") + + def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: + """ + Add encoder-only layers to the KV cache config. + """ + block_size = self.vllm_config.cache_config.block_size + use_mla = self.vllm_config.model_config.use_mla + encoder_only_attn_specs: dict[AttentionSpec, + list[str]] = defaultdict(list) + attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) + for layer_name, attn_module in attn_layers.items(): + if attn_module.attn_type == AttentionType.ENCODER_ONLY: + attn_spec: AttentionSpec = EncoderOnlyAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + use_mla=use_mla) + encoder_only_attn_specs[attn_spec].append(layer_name) + self.runner_only_attn_layers.add(layer_name) + if len(encoder_only_attn_specs) > 0: + assert len( + encoder_only_attn_specs + ) == 1, "Only support one encoder-only attention spec now" + spec, layer_names = encoder_only_attn_specs.popitem() + self.kv_cache_config.kv_cache_groups.append( + KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec)) + + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: + """ + Generates the KVCacheSpec by parsing the kv cache format from each + Attention module in the static forward context. + Returns: + KVCacheSpec: A dictionary mapping layer names to their KV cache + format. Layers that do not need KV cache are not included. + """ + + block_size = self.vllm_config.cache_config.block_size + use_mla = self.vllm_config.model_config.use_mla + kv_cache_spec: dict[str, KVCacheSpec] = {} + attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) + for layer_name, attn_module in attn_layers.items(): + if (kv_tgt_layer := + attn_module.kv_sharing_target_layer_name) is not None: + # The layer doesn't need its own KV cache and will use that of + # the target layer. We skip creating a KVCacheSpec for it, so + # that KV cache management logic will act as this layer does + # not exist, and doesn't allocate KV cache for the layer. This + # enables the memory saving of cross-layer kv sharing, allowing + # a given amount of memory to accommodate longer context lengths + # or enable more requests to be processed simultaneously. + self.shared_kv_cache_layers[layer_name] = kv_tgt_layer + continue + + # TODO(lucas): move the attention specs into the model layers like + # the attention backends + if attn_module.attn_type == AttentionType.DECODER: + if attn_module.sliding_window is not None: + kv_cache_spec[layer_name] = SlidingWindowSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + sliding_window=attn_module.sliding_window, + use_mla=use_mla) + elif self.attention_chunk_size is not None \ + and isinstance(attn_module, ChunkedLocalAttention): + kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + attention_chunk_size=self.attention_chunk_size, + use_mla=use_mla) + else: + kv_cache_spec[layer_name] = FullAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + use_mla=use_mla) + elif attn_module.attn_type == AttentionType.ENCODER_DECODER: + kv_cache_spec[layer_name] = CrossAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + use_mla=use_mla) + elif attn_module.attn_type in (AttentionType.ENCODER, + AttentionType.ENCODER_ONLY): + # encoder-only attention does not need KV cache. + continue + else: + raise ValueError( + f"Unknown attention type: {attn_module.attn_type}") + + mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase) + if len(mamba_layers) > 0: + if (self.vllm_config.speculative_config is not None + and self.vllm_config.model_config.hf_config.model_type + not in ["qwen3_next"]): + raise NotImplementedError( + "Mamba with speculative decoding is not supported yet.") + if self.vllm_config.cache_config.enable_prefix_caching: + raise NotImplementedError( + "Prefix caching is not supported for Mamba yet.") + max_model_len = self.vllm_config.model_config.max_model_len + + page_size_padded = ( + self.vllm_config.cache_config.mamba_page_size_padded) + + # Set block_size to max_model_len, so that mamba model will always + # have only one block in the KV cache. + for layer_name, mamba_module in mamba_layers.items(): + kv_cache_spec[layer_name] = MambaSpec( + shapes=mamba_module.get_state_shape(), + dtypes=mamba_module.get_state_dtype(), + block_size=max_model_len, + page_size_padded=page_size_padded, + mamba_type=mamba_module.mamba_type, + num_speculative_blocks=( + self.speculative_config.num_speculative_tokens + if self.speculative_config else 0), + ) + + return kv_cache_spec + def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: # This is a short term mitigation for issue mentioned in # https://github.com/vllm-project/vllm/issues/22754.