diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 6e178bb690c5..0e834c057c40 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -35,6 +35,9 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson /vllm/v1/kv_cache_interface.py @heheda12345 /vllm/v1/offloading @ApostaC +# Model runner V2 +/vllm/v1/worker/gpu @WoosukKwon + # Test ownership /.buildkite/lm-eval-harness @mgoin /tests/distributed/test_multi_node_assignment.py @youkaichao diff --git a/vllm/envs.py b/vllm/envs.py index 888a09cf6d3e..d2d691740342 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -231,6 +231,7 @@ if TYPE_CHECKING: VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256 VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" + VLLM_USE_V2_MODEL_RUNNER: bool = False def get_default_cache_root(): @@ -1522,6 +1523,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices( "VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"] ), + # Flag to enable v2 model runner. + "VLLM_USE_V2_MODEL_RUNNER": lambda: bool( + int(os.getenv("VLLM_USE_V2_MODEL_RUNNER", "0")) + ), } # --8<-- [end:env-vars-definition] diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 3ad7e8c52fc1..e3f499216d7f 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -593,6 +593,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ) return self._workspace_buffer + def set_workspace_buffer(self, workspace_buffer: torch.Tensor): + self._workspace_buffer = workspace_buffer + def _get_prefill_wrapper( self, ) -> BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper: diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 20fdb3446404..7902513dce49 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -44,11 +44,15 @@ class NewRequestData: lora_request: LoRARequest | None prompt_embeds: "torch.Tensor | None" = None + # Only used for v2 model runner. + prefill_token_ids: list[int] | None = None + @classmethod def from_request( cls, request: Request, block_ids: tuple[list[int], ...], + prefill_token_ids: list[int] | None = None, ) -> "NewRequestData": return cls( req_id=request.request_id, @@ -60,6 +64,7 @@ class NewRequestData: num_computed_tokens=request.num_computed_tokens, lora_request=request.lora_request, prompt_embeds=request.prompt_embeds, + prefill_token_ids=prefill_token_ids, ) def __repr__(self) -> str: @@ -68,6 +73,7 @@ class NewRequestData: f"NewRequestData(" f"req_id={self.req_id}," f"prompt_token_ids={self.prompt_token_ids}," + f"prefill_token_ids={self.prefill_token_ids}," f"mm_features={self.mm_features}," f"sampling_params={self.sampling_params}," f"block_ids={self.block_ids}," @@ -183,6 +189,10 @@ class SchedulerOutput: # freed from the encoder cache. free_encoder_mm_hashes: list[str] + # Request IDs that are preempted in this step. + # Only used for v2 model runner. + preempted_req_ids: set[str] | None = None + # Whether the scheduled requests have all the output tokens they # need to perform grammar bitmask computation. pending_structured_output_tokens: bool = False @@ -193,6 +203,20 @@ class SchedulerOutput: # EC Cache Connector metadata ec_connector_metadata: ECConnectorMetadata | None = None + @classmethod + def make_empty(cls) -> "SchedulerOutput": + return cls( + scheduled_new_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), + num_scheduled_tokens={}, + total_num_scheduled_tokens=0, + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=[], + finished_req_ids=set(), + free_encoder_mm_hashes=[], + ) + @dataclass class GrammarOutput: diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 1ac8520a8ed2..9195b112d869 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -6,6 +6,7 @@ from collections import defaultdict from collections.abc import Iterable from typing import Any +from vllm import envs from vllm.config import VllmConfig from vllm.distributed.ec_transfer.ec_connector.base import ( ECConnectorMetadata, @@ -187,6 +188,7 @@ class Scheduler(SchedulerInterface): pcp_world_size=self.pcp_world_size, ) self.use_pp = self.parallel_config.pipeline_parallel_size > 1 + self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER def schedule(self) -> SchedulerOutput: # NOTE(woosuk) on the scheduling algorithm: @@ -658,12 +660,25 @@ class Scheduler(SchedulerInterface): ) # Construct the scheduler output. - new_reqs_data = [ - NewRequestData.from_request( - req, req_to_new_blocks[req.request_id].get_block_ids() - ) - for req in scheduled_new_reqs - ] + if self.use_v2_model_runner: + scheduled_new_reqs = scheduled_new_reqs + scheduled_resumed_reqs + scheduled_resumed_reqs = [] + new_reqs_data = [ + NewRequestData.from_request( + req, + req_to_new_blocks[req.request_id].get_block_ids(), + req._all_token_ids, + ) + for req in scheduled_new_reqs + ] + else: + new_reqs_data = [ + NewRequestData.from_request( + req, req_to_new_blocks[req.request_id].get_block_ids() + ) + for req in scheduled_new_reqs + ] + with record_function_or_nullcontext("schedule: make_cached_request_data"): cached_reqs_data = self._make_cached_request_data( scheduled_running_reqs, @@ -685,6 +700,7 @@ class Scheduler(SchedulerInterface): scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, scheduled_encoder_inputs=scheduled_encoder_inputs, num_common_prefix_blocks=num_common_prefix_blocks, + preempted_req_ids={req.request_id for req in preempted_reqs}, # finished_req_ids is an existing state in the scheduler, # instead of being newly scheduled in this step. # It contains the request IDs that are finished in between diff --git a/vllm/v1/worker/gpu/README.md b/vllm/v1/worker/gpu/README.md new file mode 100644 index 000000000000..093f524b3250 --- /dev/null +++ b/vllm/v1/worker/gpu/README.md @@ -0,0 +1,4 @@ +# [Experimental] Model Runner V2 + +This directory contains the new model runner which is under active development. +Ping [Woosuk Kwon](https://github.com/WoosukKwon) for any changes. diff --git a/vllm/v1/worker/gpu/__init__.py b/vllm/v1/worker/gpu/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/v1/worker/gpu/async_utils.py b/vllm/v1/worker/gpu/async_utils.py new file mode 100644 index 000000000000..638ec6fb0b08 --- /dev/null +++ b/vllm/v1/worker/gpu/async_utils.py @@ -0,0 +1,89 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from contextlib import contextmanager + +import numpy as np +import torch + +from vllm.v1.outputs import ( + AsyncModelRunnerOutput, + ModelRunnerOutput, + SamplerOutput, +) + + +class AsyncOutput(AsyncModelRunnerOutput): + def __init__( + self, + model_runner_output: ModelRunnerOutput, + sampler_output: SamplerOutput, + num_sampled_tokens: np.ndarray, + copy_stream: torch.cuda.Stream, + copy_event: torch.cuda.Event, + ): + self.model_runner_output = model_runner_output + self.sampler_output = sampler_output + self.num_sampled_tokens = num_sampled_tokens + self.copy_stream = copy_stream + self.copy_event = copy_event + + default_stream = torch.cuda.current_stream() + with torch.cuda.stream(self.copy_stream): + self.copy_stream.wait_stream(default_stream) + + # NOTE(woosuk): We must ensure that CPU tensors are not freed + # before the device-to-host copy is fully completed. For instance, + # operations like + # self.sampled_token_np = ...to("cpu", non_blocking=True).numpy() + # are unsafe because the underlying CPU tensor can be prematurely freed and + # reused by other tensors before the asynchronous copy finishes, potentially + # causing race conditions. To prevent this, we delay freeing by holding + # references until the copy event signals completion. + # Likewise, we also need to keep the reference to the GPU tensors. + # This is done by keeping the reference to sampler_output and + # model_runner_output. + self.sampled_token_ids = sampler_output.sampled_token_ids.to( + "cpu", non_blocking=True + ) + if sampler_output.logprobs_tensors is not None: + self.logprobs_tensors = ( + sampler_output.logprobs_tensors.to_cpu_nonblocking() + ) + else: + self.logprobs_tensors = None + self.prompt_logprobs_dict = {} + if self.model_runner_output.prompt_logprobs_dict: + for k, v in self.model_runner_output.prompt_logprobs_dict.items(): + self.prompt_logprobs_dict[k] = v.to_cpu_nonblocking() + self.copy_event.record(self.copy_stream) + + def get_output(self) -> ModelRunnerOutput: + self.copy_event.synchronize() + + # NOTE(woosuk): The following code is to ensure compatibility with + # the existing model runner. + # Going forward, we should keep the data structures as NumPy arrays + # rather than Python lists. + sampled_token_ids_np = self.sampled_token_ids.numpy() + num_reqs = sampled_token_ids_np.shape[0] + sampled_token_ids: list[np.ndarray] = [ + sampled_token_ids_np[i, : self.num_sampled_tokens[i]] + for i in range(num_reqs) + ] + self.model_runner_output.sampled_token_ids = sampled_token_ids + + if self.logprobs_tensors is not None: + self.model_runner_output.logprobs = self.logprobs_tensors.tolists() + self.model_runner_output.prompt_logprobs_dict = self.prompt_logprobs_dict + return self.model_runner_output + + +@contextmanager +def async_barrier(event: torch.cuda.Event | None): + if event is not None: + event.synchronize() + try: + yield + finally: + if event is not None: + event.record() diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py new file mode 100644 index 000000000000..8850c1809229 --- /dev/null +++ b/vllm/v1/worker/gpu/attn_utils.py @@ -0,0 +1,187 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Sequence +from typing import Any + +import torch + +from vllm.attention.backends.abstract import AttentionBackend +from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, + CommonAttentionMetadata, +) +from vllm.v1.kv_cache_interface import ( + KVCacheConfig, + KVCacheSpec, +) +from vllm.v1.utils import CpuGpuBuffer +from vllm.v1.worker.utils import bind_kv_cache + + +def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]: + kv_cache_spec: dict[str, KVCacheSpec] = {} + attn_layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase) + for layer_name, attn_module in attn_layers.items(): + # Skip modules that don't need KV cache (eg encoder-only attention) + if spec := attn_module.get_kv_cache_spec(vllm_config): + kv_cache_spec[layer_name] = spec + return kv_cache_spec + + +def init_attn_backend( + kv_cache_config: KVCacheConfig, + vllm_config: VllmConfig, + device: torch.device, +): + attn_backends: dict[str, AttentionBackend] = {} + attn_metadata_builders: list[AttentionMetadataBuilder] = [] + flashinfer_workspace: torch.Tensor | None = None + for kv_cache_group_spec in kv_cache_config.kv_cache_groups: + layer_names = kv_cache_group_spec.layer_names + any_layer_name = next(iter(layer_names)) + + attn_layers = get_layers_from_vllm_config( + vllm_config, AttentionLayerBase, layer_names + ) + attn_backend = attn_layers[any_layer_name].get_attn_backend() + for layer_name in layer_names: + attn_backends[layer_name] = attn_backend + + attn_metadata_builder = attn_backend.get_builder_cls()( + kv_cache_group_spec.kv_cache_spec, + layer_names, + vllm_config, + device, + ) + attn_metadata_builders.append(attn_metadata_builder) # type: ignore + + if "FLASHINFER" in attn_backend.get_name(): + if flashinfer_workspace is None: + flashinfer_workspace = attn_metadata_builder._get_workspace_buffer() + else: + attn_metadata_builder.set_workspace_buffer(flashinfer_workspace) + return attn_backends, attn_metadata_builders + + +def _allocate_kv_cache( + kv_cache_config: KVCacheConfig, + device: torch.device, +): + kv_cache_raw_tensors: dict[str, torch.Tensor] = {} + for kv_cache_tensor in kv_cache_config.kv_cache_tensors: + tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=device) + for layer_name in kv_cache_tensor.shared_by: + kv_cache_raw_tensors[layer_name] = tensor + + layer_names = set() + for group in kv_cache_config.kv_cache_groups: + for layer_name in group.layer_names: + layer_names.add(layer_name) + assert layer_names == set(kv_cache_raw_tensors.keys()), ( + "Some layers are not correctly initialized" + ) + return kv_cache_raw_tensors + + +def _reshape_kv_cache( + kv_cache_config: KVCacheConfig, + kv_cache_raw_tensors: dict[str, torch.Tensor], + attn_backends: dict[str, AttentionBackend], +) -> dict[str, torch.Tensor]: + kv_caches: dict[str, torch.Tensor] = {} + for kv_cache_group_spec in kv_cache_config.kv_cache_groups: + kv_cache_spec = kv_cache_group_spec.kv_cache_spec + for layer_name in kv_cache_group_spec.layer_names: + raw_tensor = kv_cache_raw_tensors[layer_name] + assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 + num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes + + attn_backend = attn_backends[layer_name] + kv_cache_shape = attn_backend.get_kv_cache_shape( + num_blocks, + kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + ) + + # FIXME(woosuk): Add kv_cache_stride_order to all attention backends. + try: + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() + assert len(kv_cache_stride_order) == len(kv_cache_shape) + except (AttributeError, NotImplementedError): + kv_cache_stride_order = tuple(range(len(kv_cache_shape))) + + kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) + inv_order = [ + kv_cache_stride_order.index(i) + for i in range(len(kv_cache_stride_order)) + ] + + dtype = kv_cache_spec.dtype + raw_tensor = raw_tensor.view(dtype) + raw_tensor = raw_tensor.view(kv_cache_shape) + kv_caches[layer_name] = raw_tensor.permute(*inv_order) + return kv_caches + + +def init_kv_cache( + runner_kv_caches: list[torch.Tensor], + forward_context: dict[str, Any], + kv_cache_config: KVCacheConfig, + attn_backends: dict[str, AttentionBackend], + device: torch.device, +) -> None: + kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device) + kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors, attn_backends) + bind_kv_cache(kv_caches, forward_context, runner_kv_caches) + + +def build_attn_metadata( + attn_metadata_builders: list[AttentionMetadataBuilder], + num_reqs: int, + num_tokens: int, + query_start_loc: CpuGpuBuffer, + seq_lens: CpuGpuBuffer, + num_computed_tokens_cpu: torch.Tensor, + block_tables: Sequence[torch.Tensor], + slot_mappings: torch.Tensor, + kv_cache_config: KVCacheConfig, +) -> dict[str, Any]: + query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1] + query_start_loc_cpu = query_start_loc.cpu[: num_reqs + 1] + max_query_len = int(query_start_loc.np[: num_reqs + 1].max()) + seq_lens_gpu = seq_lens.gpu[:num_reqs] + seq_lens_cpu = seq_lens.cpu[:num_reqs] + max_seq_len = int(seq_lens.np[:num_reqs].max()) + + attn_metadata: dict[str, Any] = {} + kv_cache_groups = kv_cache_config.kv_cache_groups + for i, kv_cache_spec in enumerate(kv_cache_groups): + block_table = block_tables[i] + slot_mapping = slot_mappings[i] + + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=query_start_loc_gpu, + query_start_loc_cpu=query_start_loc_cpu, + seq_lens=seq_lens_gpu, + seq_lens_cpu=seq_lens_cpu, + max_seq_len=max_seq_len, + num_computed_tokens_cpu=num_computed_tokens_cpu, + num_reqs=num_reqs, + num_actual_tokens=num_tokens, + max_query_len=max_query_len, + block_table_tensor=block_table, + slot_mapping=slot_mapping, + causal=True, + ) + + attn_metadata_builder = attn_metadata_builders[i] + metadata = attn_metadata_builder.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) + for layer_name in kv_cache_spec.layer_names: + attn_metadata[layer_name] = metadata + return attn_metadata diff --git a/vllm/v1/worker/gpu/block_table.py b/vllm/v1/worker/gpu/block_table.py new file mode 100644 index 000000000000..ff24e88ede2c --- /dev/null +++ b/vllm/v1/worker/gpu/block_table.py @@ -0,0 +1,315 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable + +import torch +import triton +import triton.language as tl + +from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.utils.math_utils import cdiv +from vllm.v1.utils import CpuGpuBuffer + + +class BlockTables: + def __init__( + self, + block_sizes: list[int], + max_num_reqs: int, + max_num_batched_tokens: int, + max_model_len: int, + device: torch.device, + pin_memory: bool, + ): + self.block_sizes = block_sizes + self.max_num_reqs = max_num_reqs + self.max_num_batched_tokens = max_num_batched_tokens + self.max_model_len = max_model_len + self.device = device + self.pin_memory = pin_memory + + self.num_kv_cache_groups = len(self.block_sizes) + # num_kv_cache_groups x [max_num_reqs, max_num_blocks] + self.block_tables: list[torch.Tensor] = [] + for i in range(self.num_kv_cache_groups): + block_size = self.block_sizes[i] + max_num_blocks = cdiv(self.max_model_len, block_size) + block_table = torch.zeros( + self.max_num_reqs, + max_num_blocks, + dtype=torch.int32, + device=self.device, + ) + self.block_tables.append(block_table) + self.block_table_ptrs = self._make_ptr_tensor(self.block_tables) + + # Block tables used for model's forward pass. + # num_kv_cache_groups x [max_num_reqs, max_num_blocks] + self.input_block_tables: list[torch.Tensor] = [ + torch.zeros_like(block_table) for block_table in self.block_tables + ] + self.input_block_table_ptrs = self._make_ptr_tensor(self.input_block_tables) + + self.block_table_strides = torch.tensor( + [b.stride(0) for b in self.block_tables], + dtype=torch.int64, + device=self.device, + ) + self.block_sizes_tensor = torch.tensor( + self.block_sizes, dtype=torch.int32, device=self.device + ) + self.num_blocks = torch.zeros( + self.num_kv_cache_groups, + self.max_num_reqs, + dtype=torch.int32, + device=self.device, + ) + self.slot_mappings = torch.zeros( + self.num_kv_cache_groups, + self.max_num_batched_tokens, + dtype=torch.int64, + device=self.device, + ) + + # Misc buffers. + self.req_indices = self._make_buffer(self.max_num_reqs, dtype=torch.int32) + self.overwrite = self._make_buffer(self.max_num_reqs, dtype=torch.bool) + self.cu_num_new_blocks = self._make_buffer( + self.num_kv_cache_groups, self.max_num_reqs + 1, dtype=torch.int32 + ) + + def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer: + return CpuGpuBuffer( + *args, dtype=dtype, pin_memory=self.pin_memory, device=self.device + ) + + def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor: + # NOTE(woosuk): Use uint64 instead of int64 to cover all possible addresses. + ptrs_tensor_cpu = torch.tensor( + [t.data_ptr() for t in x], + dtype=torch.uint64, + device="cpu", + pin_memory=self.pin_memory, + ) + return ptrs_tensor_cpu.to(self.device, non_blocking=True) + + def append_block_ids( + self, + # [num_reqs] + req_indices: list[int], + # [num_kv_cache_groups, num_reqs + 1] + cu_num_new_blocks: tuple[list[int], ...], + # [num_kv_cache_groups, num_new_blocks] + new_block_ids: tuple[list[int], ...], + # [num_reqs] + overwrite: list[bool], + ) -> None: + num_reqs = len(req_indices) + self.req_indices.np[:num_reqs] = req_indices + self.overwrite.np[:num_reqs] = overwrite + for i in range(self.num_kv_cache_groups): + self.cu_num_new_blocks.np[i, : num_reqs + 1] = cu_num_new_blocks[i] + + # NOTE(woosuk): Here, we cannot use a fixed-size buffer because there's + # no clear upper bound to the number of new blocks in a single step. + # NOTE(woosuk): The buffer has to be cached, because otherwise we cannot + # guarantee that the buffer is not freed before the copy is completed. + self.new_block_ids_cpu = torch.empty( + self.num_kv_cache_groups, + max(len(x) for x in new_block_ids), + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory, + ) + new_block_ids_np = self.new_block_ids_cpu.numpy() + for i in range(self.num_kv_cache_groups): + new_block_ids_np[i, : len(new_block_ids[i])] = new_block_ids[i] + new_block_ids_gpu = self.new_block_ids_cpu.to(self.device, non_blocking=True) + + _append_block_ids_kernel[(self.num_kv_cache_groups, num_reqs)]( + self.req_indices.copy_to_gpu(num_reqs), + self.cu_num_new_blocks.copy_to_gpu(), + self.cu_num_new_blocks.gpu.stride(0), + new_block_ids_gpu, + new_block_ids_gpu.stride(0), + self.overwrite.copy_to_gpu(num_reqs), + self.block_table_strides, + self.block_table_ptrs, + self.num_blocks, + self.num_blocks.stride(0), + BLOCK_SIZE=1024, # type: ignore + ) + + def gather_block_tables( + self, + idx_mapping: torch.Tensor, + ) -> tuple[torch.Tensor, ...]: + num_reqs = idx_mapping.shape[0] + _gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs)]( + idx_mapping, + self.block_table_ptrs, + self.input_block_table_ptrs, + self.block_table_strides, + self.num_blocks, + self.num_blocks.stride(0), + BLOCK_SIZE=1024, # type: ignore + ) + return tuple(block_table[:num_reqs] for block_table in self.input_block_tables) + + def get_dummy_block_tables(self, num_reqs: int) -> tuple[torch.Tensor, ...]: + return tuple(block_table[:num_reqs] for block_table in self.input_block_tables) + + def compute_slot_mappings( + self, + query_start_loc: torch.Tensor, + positions: torch.Tensor, + ) -> torch.Tensor: + num_reqs = query_start_loc.shape[0] - 1 + num_tokens = positions.shape[0] + num_groups = self.num_kv_cache_groups + _compute_slot_mappings_kernel[(num_groups, num_reqs + 1)]( + num_tokens, + self.max_num_batched_tokens, + query_start_loc, + positions, + self.input_block_table_ptrs, + self.block_table_strides, + self.block_sizes_tensor, + self.slot_mappings, + self.slot_mappings.stride(0), + PAD_ID=PAD_SLOT_ID, + BLOCK_SIZE=1024, # type: ignore + ) + return self.slot_mappings[:, :num_tokens] + + def get_dummy_slot_mappings(self, num_tokens: int) -> torch.Tensor: + self.slot_mappings.fill_(PAD_SLOT_ID) + return self.slot_mappings[:, :num_tokens] + + +@triton.jit +def _append_block_ids_kernel( + # Inputs + req_indices, # [num_reqs] + cu_num_new_blocks_ptr, # [num_kv_cache_groups, num_reqs + 1] + cu_num_new_blocks_stride, + new_block_ids_ptr, # [num_kv_cache_groups, num_new_blocks] + new_block_ids_stride, + overwrite, # [num_reqs] + block_table_strides, # [num_kv_cache_groups] + # Outputs + block_table_ptrs, # [num_kv_cache_groups] + num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs] + num_blocks_stride, + # Constants + BLOCK_SIZE: tl.constexpr, +): + group_id = tl.program_id(0) + batch_idx = tl.program_id(1) + req_idx = tl.load(req_indices + batch_idx) + do_overwrite = tl.load(overwrite + batch_idx) + + group_new_blocks_ptr = cu_num_new_blocks_ptr + group_id * cu_num_new_blocks_stride + start_idx = tl.load(group_new_blocks_ptr + batch_idx) + end_idx = tl.load(group_new_blocks_ptr + batch_idx + 1) + num_new_blocks = end_idx - start_idx + + group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride + dst_start_idx = tl.load(group_num_blocks_ptr + req_idx) if not do_overwrite else 0 + dst_end_idx = dst_start_idx + num_new_blocks + tl.store(group_num_blocks_ptr + req_idx, dst_end_idx) + + # Destination + block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32) + block_table_stride = tl.load(block_table_strides + group_id) + row_ptr = block_table_ptr + req_idx * block_table_stride + + group_new_block_ids_ptr = new_block_ids_ptr + group_id * new_block_ids_stride + for i in range(0, num_new_blocks, BLOCK_SIZE): + offset = i + tl.arange(0, BLOCK_SIZE) + block_ids = tl.load( + group_new_block_ids_ptr + start_idx + offset, mask=offset < num_new_blocks + ) + tl.store( + row_ptr + dst_start_idx + offset, block_ids, mask=offset < num_new_blocks + ) + + +@triton.jit +def _gather_block_tables_kernel( + batch_idx_to_req_idx, # [batch_size] + src_block_table_ptrs, # [num_kv_cache_groups] + dst_block_table_ptrs, # [num_kv_cache_groups] + block_table_strides, # [num_kv_cache_groups] + num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs] + num_blocks_stride, + BLOCK_SIZE: tl.constexpr, +): + # kv cache group id + group_id = tl.program_id(0) + batch_idx = tl.program_id(1) + req_idx = tl.load(batch_idx_to_req_idx + batch_idx) + + group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride + num_blocks = tl.load(group_num_blocks_ptr + req_idx) + + stride = tl.load(block_table_strides + group_id) + src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32) + src_row_ptr = src_block_table_ptr + req_idx * stride + dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32) + dst_row_ptr = dst_block_table_ptr + batch_idx * stride + + for i in tl.range(0, num_blocks, BLOCK_SIZE): + offset = i + tl.arange(0, BLOCK_SIZE) + block_ids = tl.load(src_row_ptr + offset, mask=offset < num_blocks) + tl.store(dst_row_ptr + offset, block_ids, mask=offset < num_blocks) + + +@triton.jit +def _compute_slot_mappings_kernel( + num_tokens, + max_num_tokens, + cu_num_tokens, # [num_reqs + 1] + pos, # [num_tokens] + block_table_ptrs, # [num_kv_cache_groups] + block_table_strides, # [num_kv_cache_groups] + page_sizes, # [num_kv_cache_groups] + slot_mappings_ptr, # [num_kv_cache_groups, max_num_tokens] + slot_mappings_stride, + PAD_ID: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + # kv cache group id + group_id = tl.program_id(0) + req_idx = tl.program_id(1) + slot_mapping_ptr = slot_mappings_ptr + group_id * slot_mappings_stride + + if req_idx == tl.num_programs(1) - 1: + # Pad remaining slots to -1. This is needed for CUDA graphs. + for i in range(num_tokens, max_num_tokens, BLOCK_SIZE): + offset = i + tl.arange(0, BLOCK_SIZE) + tl.store(slot_mapping_ptr + offset, PAD_ID, mask=offset < max_num_tokens) + return + + block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32) + block_table_stride = tl.load(block_table_strides + group_id) + page_size = tl.load(page_sizes + group_id) + + start_idx = tl.load(cu_num_tokens + req_idx) + end_idx = tl.load(cu_num_tokens + req_idx + 1) + for i in range(start_idx, end_idx, BLOCK_SIZE): + offset = i + tl.arange(0, BLOCK_SIZE) + positions = tl.load(pos + offset, mask=offset < end_idx, other=0) + block_indices = positions // page_size + block_numbers = tl.load( + block_table_ptr + req_idx * block_table_stride + block_indices + ) + slot_ids = block_numbers * page_size + positions % page_size + tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx) + + +@triton.jit +def _load_ptr(ptr_to_ptr, elem_dtype): + ptr = tl.load(ptr_to_ptr) + ptr = tl.cast(ptr, tl.pointer_type(elem_dtype)) + return tl.multiple_of(ptr, 16) diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py new file mode 100644 index 000000000000..7fd1f76669f4 --- /dev/null +++ b/vllm/v1/worker/gpu/cudagraph_utils.py @@ -0,0 +1,198 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import gc +from contextlib import contextmanager + +import numpy as np +import torch +import torch.nn as nn +from tqdm import tqdm + +from vllm.config import VllmConfig +from vllm.config.compilation import CUDAGraphMode +from vllm.distributed.parallel_state import graph_capture, is_global_first_rank +from vllm.forward_context import set_forward_context +from vllm.v1.attention.backends.utils import AttentionMetadataBuilder +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.worker.gpu.attn_utils import build_attn_metadata +from vllm.v1.worker.gpu.block_table import BlockTables +from vllm.v1.worker.gpu.input_batch import InputBuffers + + +class CudaGraphManager: + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + ): + self.vllm_config = vllm_config + self.device = device + + self.max_model_len = vllm_config.model_config.max_model_len + self.dp_size = vllm_config.parallel_config.data_parallel_size + self.compilation_config = vllm_config.compilation_config + assert self.compilation_config is not None + + self.cudagraph_mode = self.compilation_config.cudagraph_mode + self.cudagraph_sizes = sorted(self.compilation_config.cudagraph_capture_sizes) + self.padded_sizes = self._init_padded_sizes() + + self.graphs: dict[int, torch.cuda.CUDAGraph] = {} + self.pool = torch.cuda.graph_pool_handle() + self.hidden_states: torch.Tensor | None = None + + def _init_padded_sizes(self) -> dict[int, int]: + if not self.cudagraph_mode.has_full_cudagraphs(): + # Full cuda graphs are not used. + return {} + + padded_sizes: dict[int, int] = {} + assert len(self.cudagraph_sizes) > 0 + for i in range(1, self.cudagraph_sizes[-1] + 1): + for x in self.cudagraph_sizes: + if i <= x: + padded_sizes[i] = x + break + return padded_sizes + + def needs_capture(self) -> bool: + return len(self.padded_sizes) > 0 + + def get_cudagraph_size( + self, + scheduler_output: SchedulerOutput, + num_tokens_after_padding: int, + ) -> int | None: + if not self.cudagraph_mode.has_full_cudagraphs(): + return None + if self.cudagraph_mode != CUDAGraphMode.FULL: + # TODO(woosuk): Support uniform decode with multiple tokens (spec decoding). + all_decode = all( + x == 1 for x in scheduler_output.num_scheduled_tokens.values() + ) + if not all_decode: + # Prefill is included. + return None + return self.padded_sizes.get(num_tokens_after_padding) + + def capture_graph( + self, + batch_size: int, + model: nn.Module, + input_buffers: InputBuffers, + block_tables: BlockTables, + attn_metadata_builders: list[AttentionMetadataBuilder], + kv_cache_config: KVCacheConfig, + ) -> None: + assert batch_size not in self.graphs + + # Prepare dummy inputs. + input_ids = input_buffers.input_ids.gpu[:batch_size] + positions = input_buffers.positions.gpu[:batch_size] + + input_buffers.query_start_loc.np[: batch_size + 1] = np.arange(batch_size + 1) + input_buffers.query_start_loc.np[batch_size:] = batch_size + input_buffers.query_start_loc.copy_to_gpu() + input_buffers.seq_lens.np[:batch_size] = self.max_model_len + input_buffers.seq_lens.np[batch_size:] = 0 + input_buffers.seq_lens.copy_to_gpu() + + input_block_tables = [x[:batch_size] for x in block_tables.input_block_tables] + slot_mappings = block_tables.slot_mappings[:, :batch_size] + + attn_metadata = build_attn_metadata( + attn_metadata_builders=attn_metadata_builders, + num_reqs=batch_size, + num_tokens=batch_size, + query_start_loc=input_buffers.query_start_loc, + seq_lens=input_buffers.seq_lens, + num_computed_tokens_cpu=None, # FIXME + block_tables=input_block_tables, + slot_mappings=slot_mappings, + kv_cache_config=kv_cache_config, + ) + if self.dp_size > 1: + num_tokens_across_dp = torch.full( + (self.dp_size,), + batch_size, + dtype=torch.int32, + device="cpu", + ) + else: + num_tokens_across_dp = None + + # Warm up. + with set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=batch_size, + num_tokens_across_dp=num_tokens_across_dp, + ): + hidden_states = model( + input_ids=input_ids, + positions=positions, + ) + if self.hidden_states is None: + self.hidden_states = torch.empty_like(hidden_states) + torch.cuda.synchronize() + + # Capture the graph. + graph = torch.cuda.CUDAGraph() + with ( + set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=batch_size, + num_tokens_across_dp=num_tokens_across_dp, + ), + torch.cuda.graph(graph, self.pool), + ): + hidden_states = model( + input_ids=input_ids, + positions=positions, + ) + self.hidden_states[:batch_size] = hidden_states + self.graphs[batch_size] = graph + + @torch.inference_mode() + def capture( + self, + model: nn.Module, + input_buffers: InputBuffers, + block_tables: BlockTables, + attn_metadata_builders: list[AttentionMetadataBuilder], + kv_cache_config: KVCacheConfig, + ) -> None: + assert self.needs_capture() + # Capture larger graphs first. + sizes_to_capture = sorted(self.cudagraph_sizes, reverse=True) + if is_global_first_rank(): + sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs") + + with freeze_gc(), graph_capture(device=self.device): + for batch_size in sizes_to_capture: + self.capture_graph( + batch_size, + model, + input_buffers, + block_tables, + attn_metadata_builders, + kv_cache_config, + ) + + def run(self, batch_size: int) -> torch.Tensor: + assert batch_size in self.graphs + self.graphs[batch_size].replay() + assert self.hidden_states is not None + return self.hidden_states[:batch_size] + + +@contextmanager +def freeze_gc(): + gc.collect() + gc.freeze() + try: + yield + finally: + gc.unfreeze() diff --git a/vllm/v1/worker/gpu/dp_utils.py b/vllm/v1/worker/gpu/dp_utils.py new file mode 100644 index 000000000000..9bfc7f25bef3 --- /dev/null +++ b/vllm/v1/worker/gpu/dp_utils.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch +import torch.distributed as dist + +from vllm.distributed.parallel_state import get_dp_group + + +def get_batch_metadata_across_dp( + num_tokens: int, + cudagraph_size: int, + dp_size: int, + dp_rank: int, +) -> tuple[torch.Tensor, torch.Tensor]: + assert dp_size > 1 + # Use CPU group to avoid CPU-GPU synchronization. + group = get_dp_group().cpu_group + tensor = torch.zeros(2, dp_size, dtype=torch.int32, device="cpu") + tensor[0][dp_rank] = num_tokens + tensor[1][dp_rank] = cudagraph_size + dist.all_reduce(tensor, group=group) + return tensor[0], tensor[1] diff --git a/vllm/v1/worker/gpu/input_batch.py b/vllm/v1/worker/gpu/input_batch.py new file mode 100644 index 000000000000..89f375649146 --- /dev/null +++ b/vllm/v1/worker/gpu/input_batch.py @@ -0,0 +1,265 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import Any + +import numba +import numba.types as types +import numpy as np +import torch +import triton +import triton.language as tl + +from vllm.utils import random_uuid +from vllm.utils.math_utils import cdiv +from vllm.v1.utils import CpuGpuBuffer + + +class InputBuffers: + def __init__( + self, + max_num_reqs: int, + max_num_tokens: int, + hidden_size: int, + vocab_size: int, + dtype: torch.dtype, + device: torch.device, + pin_memory: bool, + ): + self.max_num_reqs = max_num_reqs + self.max_num_tokens = max_num_tokens + self.device = device + self.pin_memory = pin_memory + + self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32) + self.input_ids = self._make_buffer(max_num_tokens, dtype=torch.int32) + self.positions = self._make_buffer(max_num_tokens, dtype=torch.int64) + self.query_start_loc = self._make_buffer(max_num_reqs + 1, dtype=torch.int32) + self.seq_lens = self._make_buffer(max_num_reqs, dtype=torch.int32) + + # Structured outputs. + self.bitmask_indices = self._make_buffer(max_num_reqs, dtype=torch.int32) + self.grammar_bitmask = self._make_buffer( + max_num_reqs, cdiv(vocab_size, 32), dtype=torch.int32 + ) + + def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer: + return CpuGpuBuffer( + *args, dtype=dtype, pin_memory=self.pin_memory, device=self.device + ) + + +@dataclass +class InputBatch: + # batch_idx -> req_id + req_ids: list[str] + num_reqs: int + + # batch_idx -> req_state_idx + idx_mapping: torch.Tensor + idx_mapping_np: np.ndarray + + # [num_reqs] + # batch_idx -> num_scheduled_tokens + num_scheduled_tokens: np.ndarray + # sum(num_scheduled_tokens) + num_tokens: int + num_tokens_after_padding: int + + # [num_reqs + 1] + query_start_loc: torch.Tensor + query_start_loc_np: np.ndarray + # [num_reqs] + seq_lens: torch.Tensor + seq_lens_np: np.ndarray + + # [num_tokens_after_padding] + input_ids: torch.Tensor + # [num_tokens_after_padding] + positions: torch.Tensor + + # layer_name -> Metadata + attn_metadata: dict[str, Any] + + # [num_reqs] + logits_indices: torch.Tensor + + @classmethod + def make_dummy( + cls, + num_reqs: int, + num_tokens: int, + input_buffers: InputBuffers, + device: torch.device, + ) -> "InputBatch": + assert 0 < num_reqs <= num_tokens + req_ids = [f"req_{i}_{random_uuid()}" for i in range(num_reqs)] + idx_mapping_np = np.arange(num_reqs, dtype=np.int32) + idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device) + num_scheduled_tokens = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32) + num_scheduled_tokens[-1] += num_tokens % num_reqs + assert int(num_scheduled_tokens.sum()) == num_tokens + + input_buffers.query_start_loc.np[0] = 0 + input_buffers.query_start_loc.np[1 : num_reqs + 1] = np.cumsum( + num_scheduled_tokens + ) + input_buffers.query_start_loc.np[num_reqs + 1 :] = num_tokens + query_start_loc_np = input_buffers.query_start_loc.np[: num_reqs + 1] + query_start_loc = input_buffers.query_start_loc.copy_to_gpu()[: num_reqs + 1] + # seq_len equals to query_len + input_buffers.seq_lens.np[:num_reqs] = num_scheduled_tokens + input_buffers.seq_lens.np[num_reqs:] = 0 + seq_lens_np = input_buffers.seq_lens.np[:num_reqs] + seq_lens = input_buffers.seq_lens.copy_to_gpu()[:num_reqs] + + input_ids = input_buffers.input_ids.copy_to_gpu(num_tokens) + positions = input_buffers.positions.copy_to_gpu(num_tokens) + # attn_metadata = defaultdict(lambda: None) + logits_indices = query_start_loc[1:] - 1 + return cls( + req_ids=req_ids, + num_reqs=num_reqs, + idx_mapping=idx_mapping, + idx_mapping_np=idx_mapping_np, + num_scheduled_tokens=num_scheduled_tokens, + num_tokens=num_tokens, + num_tokens_after_padding=num_tokens, + query_start_loc=query_start_loc, + query_start_loc_np=query_start_loc_np, + seq_lens=seq_lens, + seq_lens_np=seq_lens_np, + input_ids=input_ids, + positions=positions, + attn_metadata=None, # type: ignore + logits_indices=logits_indices, + ) + + +# NOTE: With the type annotations, this function is pre-compiled +# before the first call. +@numba.jit( + [ + types.none( + types.int32[:], # idx_mapping + types.int32[:, :], # token_ids + types.int32[:], # num_computed_tokens + types.int32[:], # num_scheduled_tokens + types.int32[:], # input_ids + types.int64[:], # positions + types.int32[:], # query_start_loc + types.int32[:], # seq_lens + ) + ], + nopython=True, + cache=True, +) +def _prepare_inputs( + idx_mapping: np.ndarray, # batch_idx -> req_idx + token_ids: np.ndarray, # [N, max_model_len] + num_computed_tokens: np.ndarray, # [N] + num_scheduled_tokens: np.ndarray, # [B] + input_ids: np.ndarray, # [num_input_tokens] + positions: np.ndarray, # [num_input_tokens] + query_start_loc: np.ndarray, # [B + 1] + seq_lens: np.ndarray, # [B] +) -> None: + num_reqs = num_scheduled_tokens.shape[0] + query_start_loc[0] = 0 + + cu_num_tokens = 0 + for i in range(num_reqs): + req_idx = idx_mapping[i] + query_len = num_scheduled_tokens[i] + start = num_computed_tokens[req_idx] + end = start + query_len + seq_lens[i] = end + + start_idx = cu_num_tokens + end_idx = start_idx + query_len + input_ids[start_idx:end_idx] = token_ids[req_idx, start:end] + positions[start_idx:end_idx] = np.arange(start, end, dtype=np.int64) + + cu_num_tokens = end_idx + query_start_loc[i + 1] = cu_num_tokens + + # Pad the inputs for CUDA graphs. + # Note: pad query_start_loc to be non-decreasing, as kernels + # like FlashAttention requires that + query_start_loc[num_reqs + 1 :].fill(cu_num_tokens) + # Fill unused with 0 for full cuda graph mode. + seq_lens[num_reqs:].fill(0) + + +def prepare_inputs( + idx_mapping: np.ndarray, + prefill_token_ids: np.ndarray, + num_computed_tokens: np.ndarray, + num_scheduled_tokens: np.ndarray, + input_ids: CpuGpuBuffer, + positions: CpuGpuBuffer, + query_start_loc: CpuGpuBuffer, + seq_lens: CpuGpuBuffer, + num_tokens: int, +) -> None: + _prepare_inputs( + idx_mapping, + prefill_token_ids, + num_computed_tokens, + num_scheduled_tokens, + input_ids.np, + positions.np, + query_start_loc.np, + seq_lens.np, + ) + input_ids.copy_to_gpu(num_tokens) + positions.copy_to_gpu(num_tokens) + # NOTE(woosuk): We should copy the whole query_start_loc and seq_lens + # tensors from CPU to GPU, because they may include paddings needed + # for full CUDA graph mode. + query_start_loc.copy_to_gpu() + seq_lens.copy_to_gpu() + return + + +@triton.jit +def _combine_last_token_ids_kernel( + input_ids_ptr, + idx_mapping_ptr, + last_token_ids_ptr, + query_start_loc_ptr, + seq_lens_ptr, + prefill_len_ptr, +): + batch_idx = tl.program_id(0) + req_state_idx = tl.load(idx_mapping_ptr + batch_idx) + + seq_len = tl.load(seq_lens_ptr + batch_idx) + prefill_len = tl.load(prefill_len_ptr + req_state_idx) + if seq_len <= prefill_len: + # Handling prefill tokens. + return + + last_token_id = tl.load(last_token_ids_ptr + req_state_idx) + end = tl.load(query_start_loc_ptr + batch_idx + 1) + tl.store(input_ids_ptr + end - 1, last_token_id) + + +def combine_last_token_ids( + input_ids: torch.Tensor, + idx_mapping: torch.Tensor, + last_token_ids: torch.Tensor, + query_start_loc: torch.Tensor, + seq_lens: torch.Tensor, + prefill_len: torch.Tensor, +) -> torch.Tensor: + num_reqs = seq_lens.shape[0] + _combine_last_token_ids_kernel[(num_reqs,)]( + input_ids, + idx_mapping, + last_token_ids, + query_start_loc, + seq_lens, + prefill_len, + ) + return input_ids diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py new file mode 100644 index 000000000000..08aad9ddd06b --- /dev/null +++ b/vllm/v1/worker/gpu/model_runner.py @@ -0,0 +1,814 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import gc +import time +from copy import deepcopy +from typing import Any + +import numpy as np +import torch +import torch.nn as nn + +from vllm.config import VllmConfig +from vllm.config.compilation import CUDAGraphMode +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model_loader +from vllm.utils.mem_constants import GiB_bytes +from vllm.utils.mem_utils import DeviceMemoryProfiler +from vllm.utils.platform_utils import is_pin_memory_available +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.outputs import ( + EMPTY_MODEL_RUNNER_OUTPUT, + LogprobsTensors, + ModelRunnerOutput, +) +from vllm.v1.sample.sampler import SamplerOutput +from vllm.v1.worker.gpu.async_utils import AsyncOutput, async_barrier +from vllm.v1.worker.gpu.attn_utils import ( + build_attn_metadata, + get_kv_cache_spec, + init_attn_backend, + init_kv_cache, +) +from vllm.v1.worker.gpu.block_table import BlockTables +from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager +from vllm.v1.worker.gpu.dp_utils import get_batch_metadata_across_dp +from vllm.v1.worker.gpu.input_batch import ( + InputBatch, + InputBuffers, + combine_last_token_ids, + prepare_inputs, +) +from vllm.v1.worker.gpu.sampler import Sampler, compute_prompt_logprobs +from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata +from vllm.v1.worker.gpu.structured_outputs import apply_grammar_bitmask +from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin +from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin + +logger = init_logger(__name__) + + +class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + ): + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.compilation_config = vllm_config.compilation_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.speculative_config = vllm_config.speculative_config + self.observability_config = vllm_config.observability_config + + self.device = device + self.pin_memory = is_pin_memory_available() + self.dtype = self.model_config.dtype + self.kv_cache_dtype = self.dtype + if self.cache_config.cache_dtype != "auto": + # Quantized KV cache. + self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ + self.cache_config.cache_dtype + ] + self.is_pooling_model = False + + self.vocab_size = self.model_config.get_vocab_size() + self.max_model_len = self.model_config.max_model_len + self.max_num_tokens = self.scheduler_config.max_num_batched_tokens + self.max_num_reqs = self.scheduler_config.max_num_seqs + self.hidden_size = self.model_config.get_hidden_size() + + self.dp_size = self.parallel_config.data_parallel_size + self.dp_rank = self.parallel_config.data_parallel_rank + + self.use_async_scheduling = self.scheduler_config.async_scheduling + self.output_copy_stream = torch.cuda.Stream(self.device) + self.output_copy_event = torch.cuda.Event() + if self.use_async_scheduling: + self.input_prep_event = torch.cuda.Event() + self.structured_outputs_event = torch.cuda.Event() + else: + self.input_prep_event = None + self.structured_outputs_event = None + + self.req_states = RequestState( + max_num_reqs=self.max_num_reqs, + max_model_len=self.max_model_len, + max_num_batched_tokens=self.max_num_tokens, + vocab_size=self.vocab_size, + device=self.device, + pin_memory=self.pin_memory, + ) + self.input_buffers = InputBuffers( + max_num_reqs=self.max_num_reqs, + max_num_tokens=self.max_num_tokens, + hidden_size=self.hidden_size, + vocab_size=self.vocab_size, + dtype=self.dtype, + device=self.device, + pin_memory=self.pin_memory, + ) + self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) + + # CUDA graphs. + self.cudagraph_manager = CudaGraphManager( + vllm_config=self.vllm_config, + device=self.device, + ) + + def get_supported_tasks(self) -> tuple[str]: + return ("generate",) + + def load_model(self, *args, **kwargs) -> None: + time_before_load = time.perf_counter() + with DeviceMemoryProfiler() as m: + model_loader = get_model_loader(self.vllm_config.load_config) + logger.info("Loading model from scratch...") + + self.model = model_loader.load_model( + vllm_config=self.vllm_config, + model_config=self.vllm_config.model_config, + ) + if self.lora_config: + self.model = self.load_lora_model( + self.model, + self.vllm_config, + self.device, + ) + time_after_load = time.perf_counter() + + self.model_memory_usage = m.consumed_memory + logger.info( + "Model loading took %.4f GiB and %.6f seconds", + m.consumed_memory / GiB_bytes, + time_after_load - time_before_load, + ) + + def get_model(self) -> nn.Module: + return self.model + + def get_kv_cache_spec(self): + return get_kv_cache_spec(self.vllm_config) + + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: + kv_cache_config = deepcopy(kv_cache_config) + self.kv_cache_config = kv_cache_config + block_sizes = [ + kv_cache_group.kv_cache_spec.block_size + for kv_cache_group in kv_cache_config.kv_cache_groups + ] + + self.block_tables = BlockTables( + block_sizes=block_sizes, + max_num_reqs=self.max_num_reqs, + max_num_batched_tokens=self.max_num_tokens, + max_model_len=self.max_model_len, + device=self.device, + pin_memory=self.pin_memory, + ) + + self.attn_backends, self.attn_metadata_builders = init_attn_backend( + self.kv_cache_config, + self.vllm_config, + self.device, + ) + + self.kv_caches: list[torch.Tensor] = [] + init_kv_cache( + self.kv_caches, + self.compilation_config.static_forward_context, + self.kv_cache_config, + self.attn_backends, + self.device, + ) + # Attention groups are not supported. + self.attn_groups = [] # type: ignore + + def prepare_dummy_attn_metadata(self, input_batch: InputBatch) -> None: + block_tables = self.block_tables.get_dummy_block_tables(input_batch.num_reqs) + slot_mappings = self.block_tables.get_dummy_slot_mappings( + input_batch.num_tokens + ) + num_computed_tokens_cpu = torch.zeros( + input_batch.num_reqs, dtype=torch.int32, device="cpu" + ) + attn_metadata = build_attn_metadata( + attn_metadata_builders=self.attn_metadata_builders, + num_reqs=input_batch.num_reqs, + num_tokens=input_batch.num_tokens, + query_start_loc=self.input_buffers.query_start_loc, + seq_lens=self.input_buffers.seq_lens, + num_computed_tokens_cpu=num_computed_tokens_cpu, + block_tables=block_tables, + slot_mappings=slot_mappings, + kv_cache_config=self.kv_cache_config, + ) + input_batch.attn_metadata = attn_metadata + + @torch.inference_mode() + def _dummy_run( + self, + num_tokens: int, + *args, + skip_attn: bool = True, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + num_reqs = min(num_tokens, self.max_num_reqs) + input_batch = InputBatch.make_dummy( + num_reqs=num_reqs, + num_tokens=num_tokens, + input_buffers=self.input_buffers, + device=self.device, + ) + if not skip_attn: + self.prepare_dummy_attn_metadata(input_batch) + + if self.dp_size == 1: + num_tokens_across_dp: torch.Tensor | None = None + else: + num_tokens_across_dp = torch.full( + (self.dp_size,), num_tokens, dtype=torch.int32, device="cpu" + ) + num_sampled_tokens = np.ones(input_batch.num_reqs, dtype=np.int32) + with ( + self.maybe_dummy_run_with_lora( + self.lora_config, + input_batch.num_scheduled_tokens, + num_sampled_tokens, + ), + set_forward_context( + input_batch.attn_metadata, + self.vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp, + ), + ): + hidden_states = self.model( + input_ids=input_batch.input_ids, + positions=input_batch.positions, + ) + sample_hidden_states = hidden_states[input_batch.logits_indices] + return hidden_states, sample_hidden_states + + @torch.inference_mode() + def _dummy_sampler_run( + self, + hidden_states: torch.Tensor, + ) -> None: + num_reqs = hidden_states.shape[0] + sampling_metadata = SamplingMetadata.make_dummy( + num_reqs=num_reqs, + device=self.device, + ) + logits = self.model.compute_logits(hidden_states) + self.sampler(logits, sampling_metadata) + + @torch.inference_mode() + def profile_run(self) -> None: + hidden_states, sample_hidden_states = self._dummy_run( + self.max_num_tokens, + skip_attn=True, + ) + self._dummy_sampler_run(sample_hidden_states) + torch.cuda.synchronize() + del hidden_states, sample_hidden_states + gc.collect() + + def reset_mm_cache(self) -> None: + pass + + def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int: + # SP is not supported yet. + return num_scheduled_tokens + + @torch.inference_mode() + def capture_model(self) -> int: + if not self.cudagraph_manager.needs_capture(): + logger.warning( + "Skipping CUDA graph capture. To turn on CUDA graph capture, " + "ensure `cudagraph_mode` was not manually set to `NONE`" + ) + return 0 + + start_time = time.perf_counter() + start_free_gpu_memory = torch.cuda.mem_get_info()[0] + + with self.maybe_setup_dummy_loras(self.lora_config): + self.cudagraph_manager.capture( + model=self.model, + input_buffers=self.input_buffers, + block_tables=self.block_tables, + attn_metadata_builders=self.attn_metadata_builders, + kv_cache_config=self.kv_cache_config, + ) + + end_time = time.perf_counter() + end_free_gpu_memory = torch.cuda.mem_get_info()[0] + elapsed_time = end_time - start_time + cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory + # This usually takes 5~20 seconds. + logger.info( + "Graph capturing finished in %.0f secs, took %.2f GiB", + elapsed_time, + cuda_graph_size / (1 << 30), + ) + return cuda_graph_size + + def warmup_for_prefill(self) -> None: + # For FlashInfer, we would like to execute a dummy prefill run + # to trigger JIT compilation. + if all("FLASHINFER" in b.get_name() for b in self.attn_backends.values()): + self._dummy_run(self.max_num_tokens, skip_attn=False) + torch.cuda.synchronize() + + def update_states(self, scheduler_output: SchedulerOutput) -> None: + for req_id in scheduler_output.preempted_req_ids: + self.req_states.remove_request(req_id) + for req_id in scheduler_output.finished_req_ids: + self.req_states.remove_request(req_id) + + # TODO(woosuk): Change SchedulerOutput. + req_indices: list[int] = [] + cu_num_new_blocks = tuple( + [0] for _ in range(self.block_tables.num_kv_cache_groups) + ) + new_block_ids: tuple[list[int], ...] = tuple( + [] for _ in range(self.block_tables.num_kv_cache_groups) + ) + overwrite: list[bool] = [] + + # Add new requests. + for new_req_data in scheduler_output.scheduled_new_reqs: + req_id = new_req_data.req_id + self.req_states.add_request( + req_id=req_id, + prompt_len=len(new_req_data.prompt_token_ids), + prefill_token_ids=new_req_data.prefill_token_ids, + num_computed_tokens=new_req_data.num_computed_tokens, + sampling_params=new_req_data.sampling_params, + lora_request=new_req_data.lora_request, + ) + + req_index = self.req_states.req_id_to_index[req_id] + req_indices.append(req_index) + for i, block_ids in enumerate(new_req_data.block_ids): + x = cu_num_new_blocks[i][-1] + cu_num_new_blocks[i].append(x + len(block_ids)) + new_block_ids[i].extend(block_ids) + overwrite.append(True) + + # Add new blocks for the existing requests. + cached_reqs = scheduler_output.scheduled_cached_reqs + for i, req_id in enumerate(cached_reqs.req_ids): + req_index = self.req_states.req_id_to_index[req_id] + + req_new_block_ids = cached_reqs.new_block_ids[i] + if req_new_block_ids is not None: + req_indices.append(req_index) + for group_id, block_ids in enumerate(req_new_block_ids): + x = cu_num_new_blocks[group_id][-1] + cu_num_new_blocks[group_id].append(x + len(block_ids)) + new_block_ids[group_id].extend(block_ids) + overwrite.append(False) + + if req_indices: + self.block_tables.append_block_ids( + req_indices=req_indices, + cu_num_new_blocks=cu_num_new_blocks, + new_block_ids=new_block_ids, + overwrite=overwrite, + ) + + def prepare_inputs( + self, + scheduler_output: SchedulerOutput, + num_tokens_after_padding: int, + ) -> InputBatch: + num_tokens = scheduler_output.total_num_scheduled_tokens + assert num_tokens > 0 + num_reqs = len(scheduler_output.num_scheduled_tokens) + + # Decode first, then prefill. + # batch_idx -> req_id + req_ids = sorted( + scheduler_output.num_scheduled_tokens, + key=scheduler_output.num_scheduled_tokens.get, + ) + num_scheduled_tokens = np.array( + [scheduler_output.num_scheduled_tokens[i] for i in req_ids], dtype=np.int32 + ) + + idx_mapping_list = [ + self.req_states.req_id_to_index[req_id] for req_id in req_ids + ] + idx_mapping = self.input_buffers.idx_mapping + idx_mapping.np[:num_reqs] = idx_mapping_list + idx_mapping_np = idx_mapping.np[:num_reqs] + idx_mapping = idx_mapping.copy_to_gpu(num_reqs) + + # Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks] + block_tables = self.block_tables.gather_block_tables(idx_mapping) + + prepare_inputs( + idx_mapping_np, + self.req_states.prefill_token_ids, + self.req_states.num_computed_tokens, + num_scheduled_tokens, + self.input_buffers.input_ids, + self.input_buffers.positions, + self.input_buffers.query_start_loc, + self.input_buffers.seq_lens, + num_tokens, + ) + + query_start_loc = self.input_buffers.query_start_loc + query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1] + query_start_loc_np = query_start_loc.np[: num_reqs + 1] + seq_lens_gpu = self.input_buffers.seq_lens.gpu[:num_reqs] + seq_lens_np = self.input_buffers.seq_lens.np[:num_reqs] + + # Some input token ids are directly read from the last sampled tokens. + combine_last_token_ids( + self.input_buffers.input_ids.gpu, + idx_mapping, + self.req_states.last_sampled_tokens, + query_start_loc_gpu, + seq_lens_gpu, + self.req_states.prefill_len.copy_to_gpu(), + ) + + # Compute slot mappings: [num_kv_cache_groups, num_tokens] + slot_mappings = self.block_tables.compute_slot_mappings( + query_start_loc_gpu, self.input_buffers.positions.gpu[:num_tokens] + ) + + num_computed_tokens_cpu = torch.from_numpy( + self.req_states.num_computed_tokens[idx_mapping_np] + ) + + # Logits indices to sample next token from. + logits_indices = query_start_loc_gpu[1:] - 1 + + # Layer name -> attention metadata. + attn_metadata = build_attn_metadata( + attn_metadata_builders=self.attn_metadata_builders, + num_reqs=num_reqs, + num_tokens=num_tokens, + query_start_loc=self.input_buffers.query_start_loc, + seq_lens=self.input_buffers.seq_lens, + num_computed_tokens_cpu=num_computed_tokens_cpu, + block_tables=block_tables, + slot_mappings=slot_mappings, + kv_cache_config=self.kv_cache_config, + ) + + input_ids = self.input_buffers.input_ids.gpu[:num_tokens_after_padding] + positions = self.input_buffers.positions.gpu[:num_tokens_after_padding] + return InputBatch( + req_ids=req_ids, + num_reqs=num_reqs, + idx_mapping=idx_mapping, + idx_mapping_np=idx_mapping_np, + num_scheduled_tokens=num_scheduled_tokens, + num_tokens=num_tokens, + num_tokens_after_padding=num_tokens_after_padding, + query_start_loc=query_start_loc_gpu, + query_start_loc_np=query_start_loc_np, + seq_lens=seq_lens_gpu, + seq_lens_np=seq_lens_np, + input_ids=input_ids, + positions=positions, + attn_metadata=attn_metadata, + logits_indices=logits_indices, + ) + + def sample( + self, + hidden_states: torch.Tensor, + input_batch: InputBatch, + sampling_metadata: SamplingMetadata, + grammar_output: GrammarOutput | None, + ) -> SamplerOutput: + sample_hidden_states = hidden_states[input_batch.logits_indices] + logits = self.model.compute_logits(sample_hidden_states) + if grammar_output is not None: + # Apply grammar bitmask to the logits in-place. + with async_barrier(self.structured_outputs_event): + apply_grammar_bitmask( + logits, + input_batch.req_ids, + grammar_output.structured_output_request_ids, + grammar_output.grammar_bitmask, + self.input_buffers, + ) + sampler_output = self.sampler(logits, sampling_metadata) + return sampler_output + + def compute_prompt_logprobs( + self, + hidden_states: torch.Tensor, + input_batch: InputBatch, + ) -> dict[str, LogprobsTensors]: + idx_mapping_np = input_batch.idx_mapping_np + needs_prompt_logprobs = self.req_states.needs_prompt_logprobs[idx_mapping_np] + if not np.any(needs_prompt_logprobs): + # No request asks for prompt logprobs. + return {} + + num_computed_tokens = self.req_states.num_computed_tokens[idx_mapping_np] + prompt_lens = self.req_states.prompt_len[idx_mapping_np] + # NOTE(woosuk): -1 because the last prompt token's hidden state is not + # needed for prompt logprobs. + includes_prompt = num_computed_tokens < prompt_lens - 1 + # NOTE(woosuk): If the request was resumed after preemption, its prompt + # logprobs must have been computed before preemption. Skip. + resumed_after_prompt = ( + prompt_lens < self.req_states.prefill_len.np[idx_mapping_np] + ) + needs_prompt_logprobs &= includes_prompt & ~resumed_after_prompt + if not np.any(needs_prompt_logprobs): + return {} + + # Just to be safe, clone the input ids. + n = input_batch.num_tokens + # Shift the input ids by one. + token_ids = torch.empty_like(input_batch.input_ids[:n]) + token_ids[: n - 1] = input_batch.input_ids[1:n] + # To avoid out-of-bound access, set the last token id to 0. + token_ids[n - 1] = 0 + + # Handle chunked prompts. + seq_lens = self.input_buffers.seq_lens.np[: input_batch.num_reqs] + is_prompt_chunked = seq_lens < prompt_lens + prefill_token_ids = self.req_states.prefill_token_ids + query_start_loc = self.input_buffers.query_start_loc.np + for i, req_id in enumerate(input_batch.req_ids): + if not needs_prompt_logprobs[i]: + continue + if not is_prompt_chunked[i]: + continue + # The prompt is chunked. Get the next prompt token. + req_idx = input_batch.idx_mapping_np[i] + next_prompt_token = int(prefill_token_ids[req_idx, seq_lens[i]]) + idx = int(query_start_loc[i + 1] - 1) + # Set the next prompt token. + # NOTE(woosuk): This triggers a GPU operation. + token_ids[idx] = next_prompt_token + + # NOTE(woosuk): We mask out logprobs for negative tokens. + prompt_logprobs, prompt_ranks = compute_prompt_logprobs( + token_ids, + hidden_states[:n], + self.model.compute_logits, + ) + + prompt_token_ids = token_ids.unsqueeze(-1) + prompt_logprobs_dict: dict[str, LogprobsTensors] = {} + for i, req_id in enumerate(input_batch.req_ids): + if not needs_prompt_logprobs[i]: + continue + + start_idx = query_start_loc[i] + end_idx = query_start_loc[i + 1] + assert start_idx < end_idx, ( + f"start_idx ({start_idx}) >= end_idx ({end_idx})" + ) + logprobs = LogprobsTensors( + logprob_token_ids=prompt_token_ids[start_idx:end_idx], + logprobs=prompt_logprobs[start_idx:end_idx], + selected_token_ranks=prompt_ranks[start_idx:end_idx], + ) + + req_extra_data = self.req_states.extra_data[req_id] + prompt_logprobs_list = req_extra_data.in_progress_prompt_logprobs + if is_prompt_chunked[i]: + # Prompt is chunked. Do not return the logprobs yet. + prompt_logprobs_list.append(logprobs) + continue + + if prompt_logprobs_list: + # Merge the in-progress logprobs. + prompt_logprobs_list.append(logprobs) + logprobs = LogprobsTensors( + logprob_token_ids=torch.cat( + [x.logprob_token_ids for x in prompt_logprobs_list] + ), + logprobs=torch.cat([x.logprobs for x in prompt_logprobs_list]), + selected_token_ranks=torch.cat( + [x.selected_token_ranks for x in prompt_logprobs_list] + ), + ) + prompt_logprobs_list.clear() + + prompt_logprobs_dict[req_id] = logprobs + return prompt_logprobs_dict + + def postprocess( + self, + sampler_output: SamplerOutput, + prompt_logprobs_dict: dict[str, LogprobsTensors], + input_batch: InputBatch, + ) -> AsyncOutput | ModelRunnerOutput: + # Store the last sampled token ids. + self.req_states.last_sampled_tokens[input_batch.idx_mapping] = ( + sampler_output.sampled_token_ids + ) + # Get the number of sampled tokens. + # 0 if chunked-prefilling, 1 if not. + idx_mapping_np = input_batch.idx_mapping_np + is_chunked_prefilling = ( + input_batch.seq_lens_np < self.req_states.num_tokens[idx_mapping_np] + ) + num_sampled_tokens = (~is_chunked_prefilling).astype(np.int32) + # Increment the number of tokens. + self.req_states.num_tokens[idx_mapping_np] += num_sampled_tokens + # Increment the number of computed tokens. + self.req_states.num_computed_tokens[idx_mapping_np] += ( + input_batch.num_scheduled_tokens + ) + + model_runner_output = ModelRunnerOutput( + req_ids=input_batch.req_ids, + req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)}, + sampled_token_ids=None, + logprobs=None, + prompt_logprobs_dict=prompt_logprobs_dict, + pooler_output=[], + kv_connector_output=None, + num_nans_in_logits=None, + ) + async_output = AsyncOutput( + model_runner_output=model_runner_output, + sampler_output=sampler_output, + num_sampled_tokens=num_sampled_tokens, + copy_stream=self.output_copy_stream, + copy_event=self.output_copy_event, + ) + if self.use_async_scheduling: + return async_output + return async_output.get_output() + + def get_cudagraph_and_dp_padding( + self, + scheduler_output: SchedulerOutput, + ) -> tuple[CUDAGraphMode, int, torch.Tensor | None]: + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + if self.dp_size == 1: + # No DP. Only consider CUDA graphs. + if total_num_scheduled_tokens == 0: + # Special case: no tokens to run. + return CUDAGraphMode.NONE, 0, None + + cudagraph_size = self.cudagraph_manager.get_cudagraph_size( + scheduler_output, total_num_scheduled_tokens + ) + if cudagraph_size is not None: + # Use full CUDA graph. + return CUDAGraphMode.FULL, cudagraph_size, None + # Fall back to eager mode. + # TODO(woosuk): Support piecewise CUDA graphs. + return CUDAGraphMode.NONE, total_num_scheduled_tokens, None + + # Consider DP padding and CUDA graph. + if total_num_scheduled_tokens == 0: + # Special handling is needed for 0. + cudagraph_size_before_dp: int | None = 0 + else: + cudagraph_size_before_dp = self.cudagraph_manager.get_cudagraph_size( + scheduler_output, total_num_scheduled_tokens + ) + if cudagraph_size_before_dp is None: + cudagraph_size_before_dp = -1 + + assert cudagraph_size_before_dp is not None + num_tokens_across_dp, cudagraph_size_across_dp = get_batch_metadata_across_dp( + total_num_scheduled_tokens, + cudagraph_size_before_dp, + self.dp_size, + self.dp_rank, + ) + if all(cudagraph_size_across_dp >= 0): + # If all ranks can use CUDA graph, pad to the maximum number of tokens + # across DP and use CUDA graph. + num_tokens_after_padding = int(cudagraph_size_across_dp.max().item()) + cudagraph_mode = CUDAGraphMode.FULL + else: + # If any of the ranks cannot use CUDA graph, use eager mode for all ranks. + # No padding is needed except for ranks that have no tokens to run. + num_tokens_across_dp = torch.clamp(num_tokens_across_dp, min=1) + num_tokens_after_padding = num_tokens_across_dp[self.dp_rank] + cudagraph_mode = CUDAGraphMode.NONE + return cudagraph_mode, num_tokens_after_padding, num_tokens_across_dp + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: SchedulerOutput, + intermediate_tensors: Any | None = None, + dummy_run: bool = False, + ) -> ModelRunnerOutput | None: + assert intermediate_tensors is None + if scheduler_output.total_num_scheduled_tokens == 0 and not dummy_run: + # No need to run the model. + with async_barrier(self.input_prep_event): + self.update_states(scheduler_output) + return EMPTY_MODEL_RUNNER_OUTPUT + + # NOTE: Call this before the async barrier so CPU all-reduce and + # GPU execution can overlap. + cudagraph_mode, num_tokens_after_padding, num_tokens_across_dp = ( + self.get_cudagraph_and_dp_padding(scheduler_output) + ) + with async_barrier(self.input_prep_event): + self.update_states(scheduler_output) + if num_tokens_after_padding == 0: + # All DP ranks have zero tokens to run. + return EMPTY_MODEL_RUNNER_OUTPUT + + if not dummy_run: + # Common case. + # Prepare all the inputs and copy to the input buffers. + input_batch = self.prepare_inputs( + scheduler_output, + num_tokens_after_padding, + ) + + # NOTE(woosuk): Sampling metadata should be built under the async + # barrier to avoid race conditions. + pos = input_batch.positions[input_batch.logits_indices] + sampling_metadata = self.req_states.make_sampling_metadata( + input_batch.idx_mapping_np, pos + ) + + if self.lora_config: + # Activate LoRA adapters. + lora_inputs = self.req_states.make_lora_inputs( + input_batch.req_ids, + input_batch.idx_mapping_np, + input_batch.num_scheduled_tokens, + ) + self._set_active_loras(*lora_inputs) + else: + # No actual tokens to run. A dummy run for DP. + num_reqs = min(num_tokens_after_padding, self.max_num_reqs) + input_batch = InputBatch.make_dummy( + num_reqs=num_reqs, + num_tokens=num_tokens_after_padding, + input_buffers=self.input_buffers, + device=self.device, + ) + self.prepare_dummy_attn_metadata(input_batch) + sampling_metadata = None + + # Run model. + if cudagraph_mode == CUDAGraphMode.FULL: + # Run CUDA graph. + # NOTE(woosuk): Here, we don't need to pass the input tensors, + # because they are already copied to the CUDA graph input buffers. + hidden_states = self.cudagraph_manager.run( + input_batch.num_tokens_after_padding + ) + else: + # Run PyTorch model in eager mode. + with set_forward_context( + input_batch.attn_metadata, + self.vllm_config, + num_tokens=input_batch.num_tokens_after_padding, + cudagraph_runtime_mode=cudagraph_mode, + num_tokens_across_dp=num_tokens_across_dp, + ): + hidden_states = self.model( + input_ids=input_batch.input_ids, + positions=input_batch.positions, + ) + + self.execute_model_state = hidden_states, input_batch, sampling_metadata + return None + + @torch.inference_mode() + def sample_tokens( + self, + grammar_output: GrammarOutput | None, + ) -> AsyncOutput | ModelRunnerOutput: + assert self.execute_model_state is not None + hidden_states, input_batch, sampling_metadata = self.execute_model_state + self.execute_model_state = None # type: ignore + assert sampling_metadata is not None + + sampler_output = self.sample( + hidden_states, input_batch, sampling_metadata, grammar_output + ) + prompt_logprobs_dict = self.compute_prompt_logprobs(hidden_states, input_batch) + output = self.postprocess( + sampler_output, + prompt_logprobs_dict, + input_batch, + ) + return output diff --git a/vllm/v1/worker/gpu/sampler.py b/vllm/v1/worker/gpu/sampler.py new file mode 100644 index 000000000000..e916aadb6b5a --- /dev/null +++ b/vllm/v1/worker/gpu/sampler.py @@ -0,0 +1,327 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable + +import torch +import triton +import triton.language as tl + +from vllm.config.model import LogprobsMode +from vllm.v1.outputs import LogprobsTensors, SamplerOutput +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p + + +class Sampler: + def __init__( + self, + logprobs_mode: LogprobsMode = "raw_logprobs", + ): + if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]: + raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}") + self.logprobs_mode = logprobs_mode + + def __call__( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput: + if sampling_metadata.max_num_logprobs is not None: + if self.logprobs_mode == "processed_logprobs": + sampled, logits = self.sample( + logits, sampling_metadata, return_logits=True + ) + else: + assert self.logprobs_mode == "raw_logprobs" + sampled, _ = self.sample(logits, sampling_metadata, return_logits=False) + + logprobs_tensors = compute_topk_logprobs( + logits, + sampling_metadata.max_num_logprobs, + sampled, + ) + else: + sampled, _ = self.sample(logits, sampling_metadata, return_logits=False) + logprobs_tensors = None + + # These are GPU tensors. + sampler_output = SamplerOutput( + # The sampled tokens are expanded to 2D tensor with shape + # [num_requests, 1], where each row represents one generated + # token per request. + sampled_token_ids=sampled.view(-1, 1), + logprobs_tensors=logprobs_tensors, + ) + return sampler_output + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + return_logits: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + is_greedy = sampling_metadata.temperature == 0 + temp = torch.where(is_greedy, 1.0, sampling_metadata.temperature) + logits = logits / temp.view(-1, 1) + logits = apply_top_k_top_p( + logits, sampling_metadata.top_k, sampling_metadata.top_p + ) + + sampled = gumbel_sample( + logits, + is_greedy, + sampling_metadata.seeds, + sampling_metadata.pos, + ) + return sampled, logits if return_logits else None + + +@triton.jit +def _gumbel_sample_kernel( + sampled_ptr, + logits_ptr, + logits_stride, + seeds_ptr, + pos_ptr, + is_greedy_ptr, + vocab_size, + BLOCK_SIZE: tl.constexpr, +): + req_idx = tl.program_id(0) + is_greedy = tl.load(is_greedy_ptr + req_idx) + + if is_greedy: + # Greedy sampling. Don't apply gumbel noise. + max_val = float("-inf") + max_idx = 0 + for i in range(0, vocab_size, BLOCK_SIZE): + block = i + tl.arange(0, BLOCK_SIZE) + mask = block < vocab_size + logits = tl.load( + logits_ptr + req_idx * logits_stride + block, + mask=mask, + other=float("-inf"), + ) + + idx = tl.argmax(logits, axis=0) + value = tl.max(logits, axis=0) + is_greater = value > max_val + max_val = tl.where(is_greater, value, max_val) + max_idx = tl.where(is_greater, i + idx, max_idx) + tl.store(sampled_ptr + req_idx, max_idx) + return + + # Random sampling. + # Calculate gumbel seed. + seed = tl.load(seeds_ptr + req_idx) + pos = tl.load(pos_ptr + req_idx) + gumbel_seed = tl.randint(seed, pos) + + max_val = float("-inf") + max_idx = 0 + for i in range(0, vocab_size, BLOCK_SIZE): + block = i + tl.arange(0, BLOCK_SIZE) + mask = block < vocab_size + + # Generate gumbel noise. + r = tl.rand(gumbel_seed, block).to(tl.float64) + gumbel_noise = -tl.log(-tl.log(r + 1e-20) + 1e-20) + gumbel_noise = gumbel_noise.to(tl.float32) + + # Apply gumbel noise. + logits = tl.load(logits_ptr + req_idx * logits_stride + block, mask=mask) + logits = tl.where(mask, logits + gumbel_noise, float("-inf")) + + # Argmax to get the sampled token. + idx = tl.argmax(logits, axis=0) + value = tl.max(logits, axis=0) + is_greater = value > max_val + max_val = tl.where(is_greater, value, max_val) + max_idx = tl.where(is_greater, i + idx, max_idx) + tl.store(sampled_ptr + req_idx, max_idx) + + +def gumbel_sample( + logits: torch.Tensor, # [num_reqs, vocab_size] + is_greedy: torch.Tensor, # [num_reqs] + seed: torch.Tensor, # [num_reqs] + pos: torch.Tensor, # [num_reqs] +) -> torch.Tensor: + num_reqs, vocab_size = logits.shape + # NOTE(woosuk): Use int64 for later indexing. + sampled = torch.empty( + num_reqs, + dtype=torch.int64, + device=logits.device, + ) + _gumbel_sample_kernel[(num_reqs,)]( + sampled, + logits, + logits.stride(0), + seed, + pos, + is_greedy, + vocab_size, + num_warps=8, + BLOCK_SIZE=16384, # type: ignore + ) + return sampled + + +@triton.jit +def _topk_log_softmax_kernel( + output_ptr, + logits_ptr, + logits_stride, + topk_ids_ptr, + topk, + vocab_size, + BLOCK_SIZE: tl.constexpr, + PADDED_TOPK: tl.constexpr, +): + req_idx = tl.program_id(0) + row_ptr = logits_ptr + req_idx * logits_stride + + max_val = float("-inf") + for i in range(0, vocab_size, BLOCK_SIZE): + block = i + tl.arange(0, BLOCK_SIZE) + logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf")) + max_val = tl.max(tl.maximum(logits, max_val)) + max_val = max_val.to(tl.float32) # type: ignore + + se = 0.0 + for i in range(0, vocab_size, BLOCK_SIZE): + block = i + tl.arange(0, BLOCK_SIZE) + logits = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0) + # NOTE(woosuk): Make sure that logits and all following operations use FP32. + logits = logits.to(tl.float32) + e = tl.exp(logits - max_val) + e = tl.where(block < vocab_size, e, 0.0) + se += tl.sum(e) + lse = tl.log(se) + + k_offset = tl.arange(0, PADDED_TOPK) + k_mask = k_offset < topk + topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask, other=0) + + logits = tl.load(row_ptr + topk_ids, mask=k_mask) + logits = logits.to(tl.float32) + o = logits - max_val - lse + tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask) + + +@triton.jit +def _ranks_kernel( + output_ptr, + logits_ptr, + logits_stride, + token_ids_ptr, + vocab_size, + BLOCK_SIZE: tl.constexpr, +): + req_idx = tl.program_id(0) + row_ptr = logits_ptr + req_idx * logits_stride + + token_id = tl.load(token_ids_ptr + req_idx) + x = tl.load(row_ptr + token_id) + + n = 0 + for i in range(0, vocab_size, BLOCK_SIZE): + block = i + tl.arange(0, BLOCK_SIZE) + logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf")) + n += tl.sum((logits > x).to(tl.int32)) + tl.store(output_ptr + req_idx, n) + + +def compute_token_logprobs( + logits: torch.Tensor, + token_ids: torch.Tensor, +) -> torch.Tensor: + batch_size = logits.shape[0] + vocab_size = logits.shape[1] + token_ids = token_ids.to(torch.int64) + num_logprobs = token_ids.shape[1] + logprobs = torch.empty( + batch_size, + num_logprobs, + dtype=torch.float32, + device=logits.device, + ) + _topk_log_softmax_kernel[(batch_size,)]( + logprobs, + logits, + logits.stride(0), + token_ids, + num_logprobs, + vocab_size, + BLOCK_SIZE=1024, # type: ignore + PADDED_TOPK=triton.next_power_of_2(num_logprobs), + ) + return logprobs + + +def compute_topk_logprobs( + logits: torch.Tensor, + num_logprobs: int, + sampled_token_ids: torch.Tensor, +) -> LogprobsTensors: + assert num_logprobs >= 0 + batch_size, vocab_size = logits.shape + if num_logprobs == 0: + logprob_token_ids = sampled_token_ids.unsqueeze(-1) + else: + topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices + logprob_token_ids = torch.cat( + (sampled_token_ids.unsqueeze(-1), topk_indices), dim=1 + ) + + # NOTE(woosuk): Here, to save GPU memory, we do not materialize the full + # logprobs tensor. Instead, we only compute and return the logprobs of + # the topk + 1 tokens. + logprobs = compute_token_logprobs(logits, logprob_token_ids) + token_ranks = torch.empty( + batch_size, + dtype=torch.int64, + device=logits.device, + ) + _ranks_kernel[(batch_size,)]( + token_ranks, + logits, + logits.stride(0), + sampled_token_ids, + vocab_size, + BLOCK_SIZE=8192, # type: ignore + ) + return LogprobsTensors( + logprob_token_ids=logprob_token_ids, + logprobs=logprobs, + selected_token_ranks=token_ranks, + ) + + +def compute_prompt_logprobs( + prompt_token_ids: torch.Tensor, + prompt_hidden_states: torch.Tensor, + logits_fn: Callable[[torch.Tensor], torch.Tensor], +) -> tuple[torch.Tensor, torch.Tensor]: + # Since materializing the full prompt logits can take too much memory, + # we compute it in chunks. + CHUNK_SIZE = 1024 + logprobs = [] + ranks = [] + prompt_token_ids = prompt_token_ids.to(torch.int64) + for start_idx in range(0, prompt_token_ids.shape[0], CHUNK_SIZE): + end_idx = start_idx + CHUNK_SIZE + # NOTE(woosuk): logits_fn can be slow because it involves all-gather. + prompt_logits = logits_fn(prompt_hidden_states[start_idx:end_idx]) + prompt_logprobs = compute_topk_logprobs( + prompt_logits, + 0, # num_logprobs + prompt_token_ids[start_idx:end_idx], + ) + logprobs.append(prompt_logprobs.logprobs) + ranks.append(prompt_logprobs.selected_token_ranks) + + logprobs = torch.cat(logprobs, dim=0) if len(logprobs) > 1 else logprobs[0] + ranks = torch.cat(ranks, dim=0) if len(ranks) > 1 else ranks[0] + return logprobs, ranks diff --git a/vllm/v1/worker/gpu/states.py b/vllm/v1/worker/gpu/states.py new file mode 100644 index 000000000000..5d05c3f57790 --- /dev/null +++ b/vllm/v1/worker/gpu/states.py @@ -0,0 +1,265 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass, field + +import numpy as np +import torch + +from vllm.lora.request import LoRARequest +from vllm.sampling_params import SamplingParams +from vllm.v1.outputs import LogprobsTensors +from vllm.v1.utils import CpuGpuBuffer + +_NP_INT64_MIN = np.iinfo(np.int64).min +_NP_INT64_MAX = np.iinfo(np.int64).max +NO_LORA_ID = 0 + + +@dataclass +class SamplingMetadata: + temperature: torch.Tensor + + top_p: torch.Tensor | None + top_k: torch.Tensor | None + + seeds: torch.Tensor + pos: torch.Tensor + + # None means no logprobs, 0 means sampled token logprobs only + max_num_logprobs: int | None + + @classmethod + def make_dummy( + cls, + num_reqs: int, + device: torch.device, + ) -> "SamplingMetadata": + assert num_reqs > 0 + temperature = torch.zeros(num_reqs, dtype=torch.float32, device=device) + temperature[0] = 0.5 + # TODO(woosuk): Use top-p and top-k for dummy sampler. + # Currently, they are disabled because of memory usage. + # top_p = torch.full((num_reqs,), 0.95, dtype=torch.float32, device=device) + # top_k = torch.full((num_reqs,), 20, dtype=torch.int32, device=device) + top_p = None + top_k = None + seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device) + pos = torch.zeros(num_reqs, dtype=torch.int64, device=device) + max_num_logprobs = 20 + + return cls( + temperature=temperature, + top_p=top_p, + top_k=top_k, + seeds=seeds, + pos=pos, + max_num_logprobs=max_num_logprobs, + ) + + +class RequestState: + def __init__( + self, + max_num_reqs: int, + max_model_len: int, + max_num_batched_tokens: int, + vocab_size: int, + device: torch.device, + pin_memory: bool, + ): + self.max_num_reqs = max_num_reqs + self.max_model_len = max_model_len + self.max_num_batched_tokens = max_num_batched_tokens + self.vocab_size = vocab_size + self.device = device + self.pin_memory = pin_memory + + self.req_id_to_index: dict[str, int] = {} + self.index_to_req_id: dict[int, str] = {} + self.free_indices = list(range(max_num_reqs)) + self.extra_data: dict[str, ExtraData] = {} + + self.prompt_len = np.zeros(self.max_num_reqs, dtype=np.int32) + self.prefill_token_ids = np.zeros( + (self.max_num_reqs, self.max_model_len), + dtype=np.int32, + ) + self.prefill_len = self._make_buffer(self.max_num_reqs, dtype=torch.int32) + self.num_tokens = np.zeros(self.max_num_reqs, dtype=np.int32) + self.num_computed_tokens = np.zeros(self.max_num_reqs, dtype=np.int32) + + # Last sampled tokens. + self.last_sampled_tokens = torch.zeros( + self.max_num_reqs, + 1, + dtype=torch.int64, + device=device, + ) + + # LoRA. + self.lora_ids = np.zeros(self.max_num_reqs, dtype=np.int32) + self.lora_ids.fill(NO_LORA_ID) + + # Sampling parameters. + self.temperature = self._make_param(self.max_num_reqs, torch.float32) + self.top_p = self._make_param(self.max_num_reqs, torch.float32) + self.top_k = self._make_param(self.max_num_reqs, torch.int32) + self.seeds = self._make_param(self.max_num_reqs, torch.int64) + + self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32) + # -1 means no logprobs are requested. + self.num_logprobs.fill(-1) + self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool) + + def _make_param(self, size: int, dtype: torch.dtype) -> "Param": + return Param(size, dtype=dtype, device=self.device, pin_memory=self.pin_memory) + + def _make_buffer(self, size: int, dtype: torch.dtype) -> CpuGpuBuffer: + return CpuGpuBuffer( + size, dtype=dtype, device=self.device, pin_memory=self.pin_memory + ) + + @property + def num_reqs(self) -> int: + return len(self.req_id_to_index) + + def add_request( + self, + req_id: str, + prompt_len: int, + prefill_token_ids: list[int], + num_computed_tokens: int, + sampling_params: SamplingParams, + lora_request: LoRARequest | None, + ) -> None: + assert len(self.free_indices) > 0, "No free indices" + req_idx = self.free_indices.pop() + self.req_id_to_index[req_id] = req_idx + self.index_to_req_id[req_idx] = req_id + self.extra_data[req_id] = ExtraData(lora_request) + + self.prompt_len[req_idx] = prompt_len + prefill_len = len(prefill_token_ids) + assert prefill_len >= prompt_len, ( + f"prefill_len {prefill_len} < prompt_len {prompt_len}" + ) + self.prefill_len.np[req_idx] = prefill_len + self.prefill_token_ids[req_idx, :prefill_len] = prefill_token_ids + self.num_tokens[req_idx] = prefill_len + self.num_computed_tokens[req_idx] = num_computed_tokens + + if lora_request is not None: + self.lora_ids[req_idx] = lora_request.lora_int_id + else: + self.lora_ids[req_idx] = NO_LORA_ID + + self.temperature.np[req_idx] = sampling_params.temperature + self.top_p.np[req_idx] = sampling_params.top_p + if 0 < sampling_params.top_k < self.vocab_size: + top_k = sampling_params.top_k + else: + top_k = self.vocab_size + self.top_k.np[req_idx] = top_k + + if sampling_params.seed is not None: + seed = sampling_params.seed + else: + seed = np.random.randint(_NP_INT64_MIN, _NP_INT64_MAX) + self.seeds.np[req_idx] = seed + + if sampling_params.logprobs is not None: + num_logprobs = sampling_params.logprobs + else: + num_logprobs = -1 + self.num_logprobs[req_idx] = num_logprobs + + # For now, only support prompt logprobs for the prompt tokens. + needs_prompt_logprobs = sampling_params.prompt_logprobs is not None + self.needs_prompt_logprobs[req_idx] = needs_prompt_logprobs + + def remove_request(self, req_id: str) -> None: + self.extra_data.pop(req_id, None) + req_idx = self.req_id_to_index.pop(req_id, None) + if req_idx is None: + # Request not found. + return + self.index_to_req_id.pop(req_idx, None) + self.free_indices.append(req_idx) + + def make_sampling_metadata( + self, + idx_mapping: np.ndarray, + pos: torch.Tensor, + ) -> SamplingMetadata: + temperature = self.temperature.np[idx_mapping] + temperature = self.temperature.copy_np_to_gpu(temperature) + + top_p = self.top_p.np[idx_mapping] + no_top_p = np.all(top_p == 1.0) + top_p = self.top_p.copy_np_to_gpu(top_p) if not no_top_p else None + + top_k = self.top_k.np[idx_mapping] + no_top_k = np.all(top_k == self.vocab_size) + top_k = self.top_k.copy_np_to_gpu(top_k) if not no_top_k else None + + seeds = self.seeds.np[idx_mapping] + seeds = self.seeds.copy_np_to_gpu(seeds) + + num_logprobs = self.num_logprobs[idx_mapping] + max_num_logprobs: int | None = int(np.max(num_logprobs)) + if max_num_logprobs == -1: + max_num_logprobs = None + + return SamplingMetadata( + temperature=temperature, + top_p=top_p, + top_k=top_k, + seeds=seeds, + pos=pos, + max_num_logprobs=max_num_logprobs, + ) + + def make_lora_inputs( + self, + req_ids: list[str], + idx_mapping: np.ndarray, + num_scheduled_tokens: np.ndarray, + ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]: + lora_ids = self.lora_ids[idx_mapping] + prompt_lora_mapping = tuple(lora_ids) + token_lora_mapping = tuple(lora_ids.repeat(num_scheduled_tokens)) + + active_lora_requests: set[LoRARequest] = set() + for req_id in req_ids: + lora_request = self.extra_data[req_id].lora_request + if lora_request is not None: + active_lora_requests.add(lora_request) + return prompt_lora_mapping, token_lora_mapping, active_lora_requests + + +class Param: + def __init__( + self, + size: int, + dtype: torch.dtype, + device: torch.device, + pin_memory: bool, + ): + self.buffer = CpuGpuBuffer( + size, + dtype=dtype, + device=device, + pin_memory=pin_memory, + ) + self.np = np.zeros_like(self.buffer.np) + + def copy_np_to_gpu(self, x: np.ndarray) -> torch.Tensor: + n = x.shape[0] + self.buffer.np[:n] = x + return self.buffer.copy_to_gpu(n) + + +@dataclass +class ExtraData: + lora_request: LoRARequest | None + in_progress_prompt_logprobs: list[LogprobsTensors] = field(default_factory=list) diff --git a/vllm/v1/worker/gpu/structured_outputs.py b/vllm/v1/worker/gpu/structured_outputs.py new file mode 100644 index 000000000000..83051b0ed33f --- /dev/null +++ b/vllm/v1/worker/gpu/structured_outputs.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import numpy as np +import torch + +from vllm.triton_utils import tl, triton +from vllm.v1.worker.gpu.input_batch import InputBuffers + + +def apply_grammar_bitmask( + logits: torch.Tensor, + req_ids: list[str], + grammar_req_ids: list[str], + grammar_bitmask: np.ndarray, + input_buffers: InputBuffers, +) -> None: + input_buffers.grammar_bitmask.np[: grammar_bitmask.shape[0]] = grammar_bitmask + input_buffers.grammar_bitmask.copy_to_gpu(grammar_bitmask.shape[0]) + + batch_size = logits.shape[0] + grammar_req_id_to_idx = {req_id: i for i, req_id in enumerate(grammar_req_ids)} + # logits -> bitmask mapping + mapping = [grammar_req_id_to_idx.get(req_id, -1) for req_id in req_ids] + input_buffers.bitmask_indices.np[:batch_size] = mapping + input_buffers.bitmask_indices.copy_to_gpu(batch_size) + + vocab_size = logits.shape[-1] + BLOCK_SIZE = 8192 + grid = (batch_size, triton.cdiv(vocab_size, BLOCK_SIZE)) + _apply_grammar_bitmask_kernel[grid]( + logits, + logits.stride(0), + input_buffers.grammar_bitmask.gpu, + input_buffers.grammar_bitmask.gpu.stride(0), + input_buffers.bitmask_indices.gpu, + vocab_size, + BLOCK_SIZE=BLOCK_SIZE, + ) + + +# Adapted from +# https://github.com/mlc-ai/xgrammar/blob/main/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py +@triton.jit +def _apply_grammar_bitmask_kernel( + logits_ptr, + logits_stride, + bitmask_ptr, + bitmask_stride, + bitmask_indices_ptr, + vocab_size, + BLOCK_SIZE: tl.constexpr, +): + logits_idx = tl.program_id(0) + bitmask_idx = tl.load(bitmask_indices_ptr + logits_idx) + if bitmask_idx == -1: + # No bitmask to apply. + return + + # Load the bitmask. + block_id = tl.program_id(1) + bitmask_offset = (block_id * BLOCK_SIZE) // 32 + tl.arange(0, BLOCK_SIZE // 32) + packed_bitmask = tl.load( + bitmask_ptr + bitmask_idx * bitmask_stride + bitmask_offset, + mask=bitmask_offset < bitmask_stride, + ) + # Unpack the bitmask. + bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0 + bitmask = bitmask.reshape(BLOCK_SIZE) + + # Apply the bitmask to the logits. + block_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + tl.store( + logits_ptr + logits_idx * logits_stride + block_offset, + -float("inf"), + mask=bitmask & (block_offset < vocab_size), + ) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index f1fd5be966c3..6a4bfde5f972 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -41,7 +41,7 @@ from vllm.sequence import IntermediateTensors from vllm.tasks import SupportedTask from vllm.utils.mem_constants import GiB_bytes from vllm.utils.mem_utils import MemorySnapshot, memory_profiling -from vllm.v1.core.sched.output import GrammarOutput +from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ( @@ -58,7 +58,6 @@ logger = init_logger(__name__) if TYPE_CHECKING: from vllm.model_executor.model_loader.tensorizer import TensorizerConfig - from vllm.v1.core.sched.output import SchedulerOutput class Worker(WorkerBase): @@ -101,6 +100,8 @@ class Worker(WorkerBase): else: self.profiler = None + self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER + def sleep(self, level: int = 1) -> None: from vllm.device_allocator.cumem import CuMemAllocator @@ -237,9 +238,17 @@ class Worker(WorkerBase): raise RuntimeError(f"Not support device type: {self.device_config.device}") # Construct the model runner - self.model_runner: GPUModelRunner = GPUModelRunner( - self.vllm_config, self.device - ) + if self.use_v2_model_runner: + from vllm.v1.worker.gpu.model_runner import ( + GPUModelRunner as GPUModelRunnerV2, + ) + + # HACK(woosuk): This is a temporary fix to avoid type errors. + self.model_runner: GPUModelRunner = GPUModelRunnerV2( # type: ignore + self.vllm_config, self.device + ) + else: + self.model_runner = GPUModelRunner(self.vllm_config, self.device) if self.rank == 0: # If usage stat is enabled, collect relevant info. @@ -573,7 +582,12 @@ class Worker(WorkerBase): self.profiler.stop() def execute_dummy_batch(self) -> None: - self.model_runner._dummy_run(1, uniform_decode=True) + if self.use_v2_model_runner: + self.model_runner.execute_model( + SchedulerOutput.make_empty(), dummy_run=True + ) + else: + self.model_runner._dummy_run(1, uniform_decode=True) def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_runner.add_lora(lora_request)