import gc import time from typing import TYPE_CHECKING, Dict, List, Tuple, cast import numpy as np import torch import torch.distributed import torch.nn as nn from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention from vllm.config import CompilationLevel, VllmConfig from vllm.distributed.parallel_state import graph_capture from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.sampling_params import SamplingType from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, LayerBlockType, cdiv, is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, FlashAttentionMetadata) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.engine.mm_input_mapper import MMInputMapperClient from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch if TYPE_CHECKING: from vllm.v1.core.scheduler import SchedulerOutput logger = init_logger(__name__) class GPUModelRunner: def __init__( self, vllm_config: VllmConfig, device: torch.device, ): self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config self.lora_config = vllm_config.lora_config self.load_config = vllm_config.load_config self.parallel_config = vllm_config.parallel_config self.scheduler_config = vllm_config.scheduler_config self.speculative_config = vllm_config.speculative_config self.prompt_adapter_config = vllm_config.prompt_adapter_config self.observability_config = vllm_config.observability_config model_config = self.model_config cache_config = self.cache_config scheduler_config = self.scheduler_config parallel_config = self.parallel_config self.device = device self.pin_memory = is_pin_memory_available() self.dtype = self.model_config.dtype if cache_config.cache_dtype == "auto": self.kv_cache_dtype = self.dtype else: self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ cache_config.cache_dtype] self.is_multimodal_model = model_config.is_multimodal_model self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size self.max_model_len = model_config.max_model_len self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs # Model-related. self.num_attn_layers = model_config.get_num_layers_by_block_type( parallel_config, LayerBlockType.attention) self.num_query_heads = model_config.get_num_attention_heads( parallel_config) self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.head_size = model_config.get_head_size() self.hidden_size = model_config.get_hidden_size() # Multi-modal data support self.input_registry = INPUT_REGISTRY self.mm_registry = MULTIMODAL_REGISTRY # NOTE: Initialized input mapper is only used for processing dummy # multimodal data into multimodal kwargs for GPU memory profiling. self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config) self.mm_input_mapper_profiling.use_cache = False encoder_compute_budget, encoder_cache_size = compute_encoder_budget( model_config=model_config, scheduler_config=scheduler_config, ) self.max_num_encoder_input_tokens = encoder_compute_budget self.encoder_cache_size = encoder_cache_size # Lazy initialization # self.model: nn.Module # Set after load_model self.kv_caches: List[torch.Tensor] = [] # req_id -> (input_id -> encoder_output) self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {} # Request states. self.requests: Dict[str, CachedRequestState] = {} # Persistent batch. self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, max_model_len=self.max_model_len, max_num_blocks_per_req=self.max_num_blocks_per_req, device=self.device, pin_memory=self.pin_memory, vocab_size=model_config.get_vocab_size(), ) self.use_cuda_graph = (self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager) # TODO(woosuk): Provide an option to tune the max cudagraph batch size. # The convention is different. # self.cudagraph_batch_sizes sorts in ascending order. # The batch sizes in the config are in descending order. self.cudagraph_batch_sizes = list( reversed(self.vllm_config.compilation_config.capture_sizes)) # Cache the device properties. self.device_properties = torch.cuda.get_device_properties(self.device) self.num_sms = self.device_properties.multi_processor_count # Persistent buffers for CUDA graphs. self.input_ids = torch.zeros(self.max_num_tokens, dtype=torch.int32, device=self.device) self.positions = torch.zeros(self.max_num_tokens, dtype=torch.int64, device=self.device) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.model_config.uses_mrope: # NOTE: `mrope_positions` is implemented as a permuted tensor to # satisfy the following properties to allow `torch.compile` to work # properly: # - shape: (3, ) # - stride: (1, 3) # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1921022256 # NOTE: When M-RoPE is enabled, position ids are 3D regardless of # the modality of inputs. For text-only inputs, each dimension has # identical position IDs, making M-RoPE functionally equivalent to # 1D-RoPE. # See page 5 of https://arxiv.org/abs/2409.12191 self.mrope_positions = torch.zeros((self.max_num_tokens, 3), dtype=torch.int64, device=self.device) self.mrope_positions_cpu = torch.zeros((self.max_num_tokens, 3), dtype=torch.int64, device="cpu", pin_memory=self.pin_memory) self.mrope_positions = self.mrope_positions.permute((1, 0)) self.mrope_positions_cpu = self.mrope_positions_cpu.permute((1, 0)) self.inputs_embeds = torch.zeros( (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=self.device) # OPTIMIZATION: Cache the tensors rather than creating them every step. self.arange_np = np.arange(max(self.max_num_reqs + 1, self.max_model_len), dtype=np.int32) # NOTE(woosuk): These tensors are "stateless", i.e., they are literally # a faster version of creating a new tensor every time. Thus, we should # not make any assumptions about the values in these tensors. self.input_ids_cpu = torch.zeros(self.max_num_tokens, dtype=torch.int32, device="cpu", pin_memory=self.pin_memory) self.input_ids_np = self.input_ids_cpu.numpy() self.positions_cpu = torch.zeros(self.max_num_tokens, dtype=torch.int64, device="cpu", pin_memory=self.pin_memory) self.positions_np = self.positions_cpu.numpy() self.slot_mapping_cpu = torch.zeros(self.max_num_tokens, dtype=torch.int32, device="cpu", pin_memory=self.pin_memory) self.slot_mapping_np = self.slot_mapping_cpu.numpy() self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, dtype=torch.int32, device="cpu", pin_memory=self.pin_memory) self.query_start_loc_np = self.query_start_loc_cpu.numpy() self.seq_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, dtype=torch.int32, device="cpu", pin_memory=self.pin_memory) self.seq_start_loc_np = self.seq_start_loc_cpu.numpy() def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Remove stopped requests from the cached states. # Keep the states of the pre-empted requests. for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) self.encoder_cache.pop(req_id, None) # Free the cached encoder outputs. for req_id, input_id in scheduler_output.free_encoder_input_ids: encoder_outputs = self.encoder_cache.get(req_id) if encoder_outputs is not None: encoder_outputs.pop(input_id, None) if not encoder_outputs: self.encoder_cache.pop(req_id, None) # Remove the requests from the persistent batch. stopped_req_ids = set().union( scheduler_output.preempted_req_ids, scheduler_output.finished_req_ids, ) removed_req_indices: List[int] = [] for req_id in stopped_req_ids: req_index = self.input_batch.remove_request(req_id) if req_index is not None: removed_req_indices.append(req_index) # Update the states of the running requests. for req_data in scheduler_output.scheduled_running_reqs: req_id = req_data.req_id req_state = self.requests[req_id] req_index = self.input_batch.req_id_to_index[req_id] # Update the num_computed_tokens. req_state.num_computed_tokens = req_data.num_computed_tokens self.input_batch.num_computed_tokens_cpu[req_index] = ( req_data.num_computed_tokens) # Update the block table. num_new_blocks = len(req_data.new_block_ids) if num_new_blocks == 0: continue start_index = len(req_state.block_ids) req_state.block_ids.extend(req_data.new_block_ids) self.input_batch.block_table.append_row(req_index, start_index, req_data.new_block_ids) req_ids_to_add: List[str] = [] # Add new requests to the cached states. for new_req_data in scheduler_output.scheduled_new_reqs: req_id = new_req_data.req_id sampling_params = new_req_data.sampling_params if sampling_params.sampling_type == SamplingType.RANDOM_SEED: generator = torch.Generator(device=self.device) generator.manual_seed(sampling_params.seed) else: generator = None self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, prompt=new_req_data.prompt, mm_inputs=new_req_data.mm_inputs, mm_positions=new_req_data.mm_positions, sampling_params=sampling_params, generator=generator, block_ids=new_req_data.block_ids, num_computed_tokens=new_req_data.num_computed_tokens, output_token_ids=[], ) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.model_config.uses_mrope: image_grid_thw = [] video_grid_thw = [] for mm_input in self.requests[req_id].mm_inputs: if mm_input.get("image_grid_thw") is not None: image_grid_thw.extend( mm_input["image_grid_thw"].tolist()) if mm_input.get("video_grid_thw") is not None: video_grid_thw.extend( mm_input["video_grid_thw"].tolist()) hf_config = self.model_config.hf_config self.requests[req_id].mrope_positions, \ self.requests[req_id].mrope_position_delta = \ MRotaryEmbedding.get_input_positions_tensor( self.requests[req_id].prompt_token_ids, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, image_token_id=hf_config.image_token_id, video_token_id=hf_config.video_token_id, vision_start_token_id=hf_config.vision_start_token_id, vision_end_token_id=hf_config.vision_end_token_id, spatial_merge_size=hf_config.vision_config. spatial_merge_size, ) req_ids_to_add.append(req_id) # Update the cached states of the resumed requests. for res_req_data in scheduler_output.scheduled_resumed_reqs: req_id = res_req_data.req_id req_state = self.requests[req_id] req_state.block_ids = res_req_data.block_ids req_state.num_computed_tokens = res_req_data.num_computed_tokens req_ids_to_add.append(req_id) # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. removed_req_indices = sorted(removed_req_indices, reverse=True) for req_id in req_ids_to_add: req_state = self.requests[req_id] if removed_req_indices: # Fill the empty index. req_index = removed_req_indices.pop() else: # Append to the end. req_index = None self.input_batch.add_request(req_state, req_index) # Condense the batched states if there are empty indices. if removed_req_indices: self.input_batch.condense(removed_req_indices) def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs assert num_reqs > 0 # OPTIMIZATION: Start copying the block table first. # This way, we can overlap the copy with the following CPU operations. self.input_batch.block_table.commit(num_reqs) # Get the number of scheduled tokens for each request. # TODO: The Python loop can be slow. Optimize. num_scheduled_tokens = [] max_num_scheduled_tokens = 0 for req_id in self.input_batch.req_ids[:num_reqs]: assert req_id is not None num_tokens = scheduler_output.num_scheduled_tokens[req_id] num_scheduled_tokens.append(num_tokens) max_num_scheduled_tokens = max(max_num_scheduled_tokens, num_tokens) num_scheduled_tokens = np.array(num_scheduled_tokens, dtype=np.int32) assert max_num_scheduled_tokens > 0 # Get request indices. # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) # Get batched arange. # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] arange = np.concatenate( [self.arange_np[:n] for n in num_scheduled_tokens]) # Get positions. positions_np = self.positions_np[:total_num_scheduled_tokens] np.add(self.input_batch.num_computed_tokens_cpu[req_indices], arange, out=positions_np) # Calculate M-RoPE positions. # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.model_config.uses_mrope: self._calc_mrope_positions(scheduler_output) # Get token indices. # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] # where M is the max_model_len. token_indices = (positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1]) # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), 0, torch.from_numpy(token_indices), out=self.input_ids_cpu[:total_num_scheduled_tokens]) # Calculate the slot mapping. # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] # where K is the max_num_blocks_per_req and the block size is 2. # NOTE(woosuk): We can't simply use `token_indices // block_size` here # because M (max_model_len) is not necessarily divisible by block_size. block_table_indices = (req_indices * self.max_num_blocks_per_req + positions_np // self.block_size) # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. block_table_cpu = self.input_batch.block_table.get_cpu_tensor() block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() block_offsets = positions_np % self.block_size np.add(block_numbers * self.block_size, block_offsets, out=self.slot_mapping_np[:total_num_scheduled_tokens]) # Prepare the attention metadata. self.query_start_loc_np[0] = 0 np.cumsum(num_scheduled_tokens, out=self.query_start_loc_np[1:num_reqs + 1]) seq_lens = (self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens) max_seq_len = seq_lens.max() self.seq_start_loc_np[0] = 0 np.cumsum(seq_lens, out=self.seq_start_loc_np[1:num_reqs + 1]) # Copy the tensors to the GPU. self.input_ids[:total_num_scheduled_tokens].copy_( self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) if self.model_config.uses_mrope: # Only relevant for models using M-RoPE (e.g, Qwen2-VL) self.mrope_positions[:, :total_num_scheduled_tokens].copy_( self.mrope_positions_cpu[:, :total_num_scheduled_tokens], non_blocking=True) else: # Common case (1D positions) self.positions[:total_num_scheduled_tokens].copy_( self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to( self.device, non_blocking=True) seq_start_loc = self.seq_start_loc_cpu[:num_reqs + 1].to( self.device, non_blocking=True) slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to( self.device, non_blocking=True).long() # Prepare for cascade attention if needed. common_prefix_len = (scheduler_output.num_common_prefix_blocks * self.block_size) if common_prefix_len == 0: # Common case. use_cascade = False else: # NOTE(woosuk): Cascade attention uses two attention kernels: one # for the common prefix and the other for the rest. For the first # kernel, we concatenate all the query tokens (possibly from # different requests) and treat them as if they are from the same # request. Then, we use bi-directional attention to process the # common prefix in the KV cache. Importantly, this means that the # first kernel does not do any masking. # Consider the following example: # Request 1's input query: [D, E, X] # Request 1's kv cache: [A, B, C, D, E, X] # Request 1's num_computed_tokens: 3 (i.e., [A, B, C]) # Request 2's input query: [E, Y] # Request 2's kv cache: [A, B, C, D, E, Y] # Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D]) # If we use [A, B, C, D, E] as the common prefix, then the # first kernel will compute the bi-directional attention between # input query [D, E, X, E, Y] and common prefix [A, B, C, D, E]. # However, this is wrong because D in Request 1 should not attend to # E in the common prefix (i.e., we need masking). # To avoid this, [A, B, C, D] should be the common prefix. # That is, the common prefix should be capped by the minimum # num_computed_tokens among the requests, and plus one to include # the first token of the query. # In practice, we use [A, B, C] as the common prefix, instead of # [A, B, C, D] (i.e., the common prefix is capped by the minimum # num_computed_tokens, without plus one). # This is because of an implementation detail: We want to always # use two kernels for cascade attention. Let's imagine: # Request 3's input query: [D] # Request 3's kv cache: [A, B, C, D] # Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D]) # If we use [A, B, C, D] as the common prefix for Request 1-3, # then Request 3 will be processed only by the first kernel, # and the second kernel will get an empty input. While this is not # a fundamental problem, our current implementation does not support # this case. common_prefix_len = min( common_prefix_len, self.input_batch.num_computed_tokens_cpu[:num_reqs].min()) # common_prefix_len should be a multiple of the block size. common_prefix_len = (common_prefix_len // self.block_size * self.block_size) use_cascade = FlashAttentionBackend.use_cascade_attention( common_prefix_len=common_prefix_len, query_lens=num_scheduled_tokens, num_query_heads=self.num_query_heads, num_kv_heads=self.num_kv_heads, use_alibi=False, # FIXME use_sliding_window=self.sliding_window is not None, num_sms=self.num_sms, ) if use_cascade: # TODO: Optimize. cu_prefix_query_lens = torch.tensor( [0, total_num_scheduled_tokens], dtype=torch.int32, device=self.device) cu_prefix_kv_lens = torch.tensor([0, common_prefix_len], dtype=torch.int32, device=self.device) cu_suffix_kv_lens = ( self.seq_start_loc_np[:num_reqs + 1] - self.arange_np[:num_reqs + 1] * common_prefix_len) cu_suffix_kv_lens = torch.from_numpy(cu_suffix_kv_lens).to( self.device) else: cu_prefix_query_lens = None cu_prefix_kv_lens = None cu_suffix_kv_lens = None attn_metadata = FlashAttentionMetadata( num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens, query_start_loc=query_start_loc, max_seq_len=max_seq_len, seq_start_loc=seq_start_loc, block_table=( self.input_batch.block_table.get_device_tensor()[:num_reqs]), slot_mapping=slot_mapping, use_cascade=use_cascade, common_prefix_len=common_prefix_len, cu_prefix_query_lens=cu_prefix_query_lens, cu_prefix_kv_lens=cu_prefix_kv_lens, cu_suffix_kv_lens=cu_suffix_kv_lens, ) # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial # request in the batch. While we should not sample any token from this # partial request, we do so for simplicity. We will ignore the sampled # token from the partial request. # TODO: Support prompt logprobs. logits_indices = query_start_loc[1:] - 1 return attn_metadata, logits_indices def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): mrope_pos_ptr = 0 num_reqs = self.input_batch.num_reqs for index, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): assert req_id is not None req = self.requests[req_id] assert req.mrope_positions is not None num_computed_tokens = \ self.input_batch.num_computed_tokens_cpu[index] num_scheduled_tokens = \ scheduler_output.num_scheduled_tokens[req_id] num_prompt_tokens = len(req.prompt_token_ids) if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens: prompt_part_len = max(0, num_prompt_tokens - num_computed_tokens) completion_part_len = max( 0, num_scheduled_tokens - prompt_part_len) else: prompt_part_len = num_scheduled_tokens completion_part_len = 0 assert num_scheduled_tokens == prompt_part_len + completion_part_len if prompt_part_len > 0: # prompt's mrope_positions are pre-computed dst_start = mrope_pos_ptr dst_end = mrope_pos_ptr + prompt_part_len src_start = num_computed_tokens src_end = num_computed_tokens + prompt_part_len self.mrope_positions_cpu[:, dst_start:dst_end] = \ req.mrope_positions[:,src_start:src_end] mrope_pos_ptr += prompt_part_len if completion_part_len > 0: # compute completion's mrope_positions on-the-fly dst_start = mrope_pos_ptr dst_end = mrope_pos_ptr + completion_part_len self.mrope_positions_cpu[:, dst_start:dst_end] = \ MRotaryEmbedding.get_next_input_positions_tensor( req.mrope_position_delta, context_len=num_computed_tokens + prompt_part_len, seq_len=num_computed_tokens + prompt_part_len + completion_part_len, ) mrope_pos_ptr += completion_part_len def _prepare_sampling( self, scheduler_output: "SchedulerOutput", ) -> SamplingMetadata: skip_copy = True if (scheduler_output.finished_req_ids or scheduler_output.preempted_req_ids): skip_copy = False if (scheduler_output.scheduled_new_reqs or scheduler_output.scheduled_resumed_reqs): skip_copy = False # Create the sampling metadata. req_id_output_token_ids: Dict[str, List[int]] = \ {req_id: req.output_token_ids \ for req_id, req in self.requests.items()} sampling_metadata = self.input_batch.make_sampling_metadata( req_id_output_token_ids, skip_copy) return sampling_metadata def _execute_encoder(self, scheduler_output: "SchedulerOutput"): scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs if not scheduled_encoder_inputs: return # Batch the multi-modal inputs. mm_inputs: List[MultiModalKwargs] = [] req_input_ids: List[Tuple[str, int]] = [] for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): req_state = self.requests[req_id] for input_id in encoder_input_ids: mm_inputs.append(req_state.mm_inputs[input_id]) req_input_ids.append((req_id, input_id)) # Batch mm inputs as much as we can: if a request in the batch has # multiple modalities or a different modality than the previous one, # we process it separately to preserve item order. # FIXME(ywang96): This is a hacky way to deal with multiple modalities # in the same batch while still being able to benefit from batching # multimodal inputs. The proper solution should be reordering the # encoder outputs. grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs) encoder_outputs = [] for grouped_mm_inputs in grouped_mm_inputs_list: batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs) batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs, device=self.device) # Run the encoder. # `curr_group_outputs` is either of the following: # 1. A tensor of shape (num_items, feature_size, hidden_size) # in case feature_size is fixed across all multimodal items. # 2. A list or tuple (length: num_items) of tensors, each of shape # (feature_size, hidden_size) in case the feature size is dynamic # depending on the input multimodal items. curr_group_outputs = self.model.get_multimodal_embeddings( **batched_mm_inputs) for output in curr_group_outputs: encoder_outputs.append(output) # Cache the encoder outputs. for (req_id, input_id), output in zip(req_input_ids, encoder_outputs): if req_id not in self.encoder_cache: self.encoder_cache[req_id] = {} self.encoder_cache[req_id][input_id] = output def _gather_encoder_outputs( self, scheduler_output: "SchedulerOutput", ) -> List[torch.Tensor]: encoder_outputs: List[torch.Tensor] = [] num_reqs = self.input_batch.num_reqs for req_id in self.input_batch.req_ids[:num_reqs]: assert req_id is not None num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ req_id] req_state = self.requests[req_id] num_computed_tokens = req_state.num_computed_tokens mm_positions = req_state.mm_positions for i, pos_info in enumerate(mm_positions): start_pos = pos_info["offset"] num_encoder_tokens = pos_info["length"] # The encoder output is needed if the two ranges overlap: # [num_computed_tokens, # num_computed_tokens + num_scheduled_tokens) and # [start_pos, start_pos + num_encoder_tokens) if start_pos >= num_computed_tokens + num_scheduled_tokens: # The encoder output is not needed in this step. break if start_pos + num_encoder_tokens <= num_computed_tokens: # The encoder output is already processed and stored # in the decoder's KV cache. continue start_idx = max(num_computed_tokens - start_pos, 0) end_idx = min( num_computed_tokens - start_pos + num_scheduled_tokens, num_encoder_tokens) assert start_idx < end_idx assert req_id in self.encoder_cache assert i in self.encoder_cache[req_id] encoder_output = self.encoder_cache[req_id][i] encoder_outputs.append(encoder_output[start_idx:end_idx]) return encoder_outputs def get_model(self) -> nn.Module: return self.model @torch.inference_mode() def execute_model( self, scheduler_output: "SchedulerOutput", ) -> ModelRunnerOutput: self._update_states(scheduler_output) if self.is_multimodal_model: # Run the multimodal encoder if any. self._execute_encoder(scheduler_output) encoder_outputs = self._gather_encoder_outputs(scheduler_output) else: encoder_outputs = [] # Prepare the decoder inputs. attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.use_cuda_graph and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): # Use piecewise CUDA graphs. # Add padding to the batch size. num_input_tokens = self.vllm_config.pad_for_cudagraph( num_scheduled_tokens) else: # Eager mode. num_input_tokens = num_scheduled_tokens attn_metadata.num_input_tokens = num_input_tokens if self.is_multimodal_model: # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. input_ids = self.input_ids[:num_scheduled_tokens] if encoder_outputs: inputs_embeds = self.model.get_input_embeddings( input_ids, encoder_outputs) else: inputs_embeds = self.model.get_input_embeddings(input_ids) # TODO(woosuk): Avoid the copy. Optimize. self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds) inputs_embeds = self.inputs_embeds[:num_input_tokens] input_ids = None else: # For text-only models, we use token ids as input. # While it is possible to use embeddings as input just like the # multimodal models, it is not desirable for performance since # then the embedding layer is not included in the CUDA graph. input_ids = self.input_ids[:num_input_tokens] inputs_embeds = None # Run the decoder. # Use persistent buffers for CUDA graphs. with set_forward_context(attn_metadata, self.vllm_config): positions = self.mrope_positions[:, :num_input_tokens] \ if self.model_config.uses_mrope \ else self.positions[:num_input_tokens] hidden_states = self.model( input_ids=input_ids, positions=positions, kv_caches=self.kv_caches, attn_metadata=None, inputs_embeds=inputs_embeds, ) hidden_states = hidden_states[:num_scheduled_tokens] hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(hidden_states, None) # Sample the next token and get logprobs if needed. sampling_metadata = self._prepare_sampling(scheduler_output) sampler_output = self.model.sample( logits=logits, sampling_metadata=sampling_metadata, ) sampled_token_ids = sampler_output.sampled_token_ids # TODO(woosuk): The following loop can be slow since it iterates over # the requests one by one. Optimize. num_reqs = self.input_batch.num_reqs for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): assert req_id is not None req_state = self.requests[req_id] seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]) assert seq_len <= req_state.num_tokens if seq_len == req_state.num_tokens: # Append the sampled token to the output token ids. token_id = sampled_token_ids[i] self.input_batch.token_ids_cpu[i, seq_len] = token_id self.input_batch.num_tokens[i] += 1 req_state.output_token_ids.append(token_id) else: # Ignore the sampled token from the partial request. # Rewind the generator state as if the token was not sampled. generator = self.input_batch.generators.get(i) if generator is not None: # This relies on cuda-specific torch-internal impl details generator.set_offset(generator.get_offset() - 4) if sampler_output.logprob_token_ids is None: logprob_token_ids = None else: logprob_token_ids = sampler_output.logprob_token_ids.cpu() if sampler_output.logprobs is None: logprobs = None else: logprobs = sampler_output.logprobs.cpu() # num_reqs entries should be non-None assert all( req_id is not None for req_id in self.input_batch.req_ids[:num_reqs]), "req_ids contains None" req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs]) model_runner_output = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=sampled_token_ids, logprob_token_ids_cpu=logprob_token_ids, logprobs_cpu=logprobs, ) return model_runner_output def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: # noqa: SIM117 self.model = get_model(vllm_config=self.vllm_config) self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took %.4f GB", self.model_memory_usage / float(2**30)) @torch.inference_mode() def _dummy_run( self, model: nn.Module, num_tokens: int, kv_caches: List[torch.Tensor], ) -> torch.Tensor: if self.is_multimodal_model: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] else: input_ids = self.input_ids[:num_tokens] inputs_embeds = None with set_forward_context(None, self.vllm_config): positions = self.mrope_positions[:, :num_tokens] \ if self.model_config.uses_mrope \ else self.positions[:num_tokens] hidden_states = model( input_ids=input_ids, positions=positions, kv_caches=kv_caches, attn_metadata=None, inputs_embeds=inputs_embeds, ) return hidden_states def profile_run(self) -> None: # use an empty tensor instead of `None`` to force Dynamo to pass # it by reference, rather by specializing on the value `None`. # the `dtype` argument does not matter, and we use `float32` as # a placeholder (it has wide hardware support). # it is important to create tensors inside the loop, rather than # multiplying the list, to avoid Dynamo from treating them as # tensor aliasing. dummy_kv_caches = [ torch.tensor([], dtype=torch.float32, device=self.device) for _ in range(self.num_attn_layers) ] # Profile with multimodal encoder & encoder cache. # TODO: handle encoder-decoder models once we support them. if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0 and self.encoder_cache_size > 0): # NOTE: Currently model is profiled with a single non-text # modality with the max possible input tokens even when # it supports multiple. max_tokens_by_modality_dict = MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality( # noqa: E501 self.model_config) dummy_data_modality, max_tokens_per_mm_item = max( max_tokens_by_modality_dict.items(), key=lambda item: item[1]) # Check how many items of this modality can be supported by # the encoder budget. encoder_budget = min(self.max_num_encoder_input_tokens, self.encoder_cache_size) max_num_mm_items_encoder_budget = cdiv(encoder_budget, max_tokens_per_mm_item) # Check how many items of this modality can be supported by # the decoder budget. max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt( self.model_config)[dummy_data_modality] # NOTE: We do not consider max_num_batched_tokens on purpose # because the multimodal embeddings can be generated in advance # and chunked prefilled. max_num_mm_items_decoder_budget = self.max_num_reqs * \ max_mm_items_per_req max_num_mm_items = min(max_num_mm_items_encoder_budget, max_num_mm_items_decoder_budget) logger.info( "Encoder cache will be initialized with a budget of %s tokens," " and profiled with %s %s items of the maximum feature size.", encoder_budget, max_num_mm_items, dummy_data_modality) # Create dummy batch of multimodal inputs. dummy_request_data = self.input_registry.dummy_data_for_profiling( model_config=self.model_config, seq_len=self.max_num_tokens, mm_registry=self.mm_registry, ) dummy_mm_data = dummy_request_data.multi_modal_data # Dummy data definition in V0 may contain multiple multimodal items # (e.g, multiple images) for a single request, therefore here we # always replicate first item by max_num_mm_items times since in V1 # they are scheduled to be processed separately. # Case when models have a merged processor, their dummy data is # already batched `MultiModalKwargs`, therefore we take the first # `MultiModalKwargsItem` from the desired modality to profile on. if isinstance(dummy_mm_data, MultiModalKwargs): dummy_mm_item = dummy_mm_data.get_item( modality=dummy_data_modality, item_index=0) dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item]) # Case when models have dummy data explicitly defined as # `MultiModalDataDict`, so they need to be processed through input # mapper. # TODO (ywang96): deprecate this path once merged processor is # supported on all models. else: mm_kwargs_list = self.mm_input_mapper_profiling.process_inputs( mm_data=dummy_mm_data, mm_hashes=None, mm_processor_kwargs=None, precomputed_mm_inputs=None) dummy_mm_kwargs = mm_kwargs_list[0] batched_dummy_mm_inputs = MultiModalKwargs.batch( [dummy_mm_kwargs] * max_num_mm_items) batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs( batched_dummy_mm_inputs, device=self.device) # Run multimodal encoder. dummy_encoder_outputs = self.model.get_multimodal_embeddings( **batched_dummy_mm_inputs) assert len(dummy_encoder_outputs) == max_num_mm_items, ( "Expected dimension 0 of encoder outputs to match the number " f"of multimodal data items: {max_num_mm_items}, got " f"{len(dummy_encoder_outputs)=} instead. This is most likely " "due to the 'get_multimodal_embeddings' method of the model " "not implemented correctly.") # Cache the dummy encoder outputs. self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) # Trigger compilation for general shape. hidden_states = self._dummy_run(self.model, self.max_num_tokens, dummy_kv_caches) logits = self.model.compute_logits(hidden_states, None) logits = logits[:self.max_num_tokens] # TODO(woosuk): Consider the memory usage of the sampler. torch.cuda.synchronize() del hidden_states, logits self.encoder_cache.clear() gc.collect() def capture_model(self) -> None: if not self.use_cuda_graph: logger.warning( "Skipping CUDA graph capture. Please add " "-O %s to use CUDA graphs.", CompilationLevel.PIECEWISE) return start_time = time.perf_counter() start_free_gpu_memory = torch.cuda.mem_get_info()[0] # Trigger CUDA graph capture for specific shapes. # Capture the large shapes first so that the smaller shapes # can reuse the memory pool allocated for the large shapes. with graph_capture(device=self.device): for num_tokens in reversed(self.cudagraph_batch_sizes): for _ in range(self.vllm_config.compilation_config. cudagraph_num_of_warmups): self._dummy_run(self.model, num_tokens, self.kv_caches) self._dummy_run(self.model, num_tokens, self.kv_caches) end_time = time.perf_counter() end_free_gpu_memory = torch.cuda.mem_get_info()[0] elapsed_time = end_time - start_time cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory # This usually takes 5~20 seconds. logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, cuda_graph_size / (1 << 30)) def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. Args: kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ if len(kv_cache_config.groups) > 1: raise NotImplementedError( "Hybrid models with more than one KV cache type are not " "supported yet.") kv_caches: Dict[str, torch.Tensor] = {} for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items(): tensor_config = kv_cache_config.tensors[layer_name] assert tensor_config.size % layer_spec.page_size_bytes == 0 num_blocks = tensor_config.size // layer_spec.page_size_bytes if isinstance(layer_spec, FullAttentionSpec): kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape( num_blocks, layer_spec.block_size, layer_spec.num_kv_heads, layer_spec.head_size) dtype = layer_spec.dtype kv_caches[layer_name] = torch.zeros(kv_cache_shape, dtype=dtype, device=self.device) else: raise NotImplementedError bind_kv_cache( kv_caches, self.vllm_config.compilation_config.static_forward_context, self.kv_caches) def get_kv_cache_spec(self) -> KVCacheSpec: """ Generates the KVCacheSpec by parsing the kv cache format from each Attention module in the static forward context. Returns: KVCacheSpec: A dictionary mapping layer names to their KV cache format. Layers that do not need KV cache are not included. """ forward_ctx = self.vllm_config.compilation_config.static_forward_context block_size = self.vllm_config.cache_config.block_size kv_cache_spec: KVCacheSpec = {} for layer_name, attn_module in forward_ctx.items(): # TODO: Support other attention modules, e.g., sliding window, # cross-attention, MLA. assert isinstance(attn_module, Attention) if attn_module.attn_type == AttentionType.DECODER: kv_cache_spec[layer_name] = FullAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=attn_module.dtype, ) elif attn_module.attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY): # encoder-only attention does not need KV cache. continue elif attn_module.attn_type == AttentionType.ENCODER_DECODER: raise NotImplementedError else: raise ValueError( f"Unknown attention type: {attn_module.attn_type}") return kv_cache_spec