GPU Model Runner V2 (#25266)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-11-21 08:20:55 -08:00 committed by GitHub
parent 1f400c58b8
commit 30b44a1598
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 2639 additions and 12 deletions

3
.github/CODEOWNERS vendored
View File

@ -35,6 +35,9 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
/vllm/v1/kv_cache_interface.py @heheda12345 /vllm/v1/kv_cache_interface.py @heheda12345
/vllm/v1/offloading @ApostaC /vllm/v1/offloading @ApostaC
# Model runner V2
/vllm/v1/worker/gpu @WoosukKwon
# Test ownership # Test ownership
/.buildkite/lm-eval-harness @mgoin /.buildkite/lm-eval-harness @mgoin
/tests/distributed/test_multi_node_assignment.py @youkaichao /tests/distributed/test_multi_node_assignment.py @youkaichao

View File

@ -231,6 +231,7 @@ if TYPE_CHECKING:
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256 VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
VLLM_USE_V2_MODEL_RUNNER: bool = False
def get_default_cache_root(): def get_default_cache_root():
@ -1522,6 +1523,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices( "VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices(
"VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"] "VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"]
), ),
# Flag to enable v2 model runner.
"VLLM_USE_V2_MODEL_RUNNER": lambda: bool(
int(os.getenv("VLLM_USE_V2_MODEL_RUNNER", "0"))
),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]

View File

@ -593,6 +593,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
) )
return self._workspace_buffer return self._workspace_buffer
def set_workspace_buffer(self, workspace_buffer: torch.Tensor):
self._workspace_buffer = workspace_buffer
def _get_prefill_wrapper( def _get_prefill_wrapper(
self, self,
) -> BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper: ) -> BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper:

View File

@ -44,11 +44,15 @@ class NewRequestData:
lora_request: LoRARequest | None lora_request: LoRARequest | None
prompt_embeds: "torch.Tensor | None" = None prompt_embeds: "torch.Tensor | None" = None
# Only used for v2 model runner.
prefill_token_ids: list[int] | None = None
@classmethod @classmethod
def from_request( def from_request(
cls, cls,
request: Request, request: Request,
block_ids: tuple[list[int], ...], block_ids: tuple[list[int], ...],
prefill_token_ids: list[int] | None = None,
) -> "NewRequestData": ) -> "NewRequestData":
return cls( return cls(
req_id=request.request_id, req_id=request.request_id,
@ -60,6 +64,7 @@ class NewRequestData:
num_computed_tokens=request.num_computed_tokens, num_computed_tokens=request.num_computed_tokens,
lora_request=request.lora_request, lora_request=request.lora_request,
prompt_embeds=request.prompt_embeds, prompt_embeds=request.prompt_embeds,
prefill_token_ids=prefill_token_ids,
) )
def __repr__(self) -> str: def __repr__(self) -> str:
@ -68,6 +73,7 @@ class NewRequestData:
f"NewRequestData(" f"NewRequestData("
f"req_id={self.req_id}," f"req_id={self.req_id},"
f"prompt_token_ids={self.prompt_token_ids}," f"prompt_token_ids={self.prompt_token_ids},"
f"prefill_token_ids={self.prefill_token_ids},"
f"mm_features={self.mm_features}," f"mm_features={self.mm_features},"
f"sampling_params={self.sampling_params}," f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids}," f"block_ids={self.block_ids},"
@ -183,6 +189,10 @@ class SchedulerOutput:
# freed from the encoder cache. # freed from the encoder cache.
free_encoder_mm_hashes: list[str] free_encoder_mm_hashes: list[str]
# Request IDs that are preempted in this step.
# Only used for v2 model runner.
preempted_req_ids: set[str] | None = None
# Whether the scheduled requests have all the output tokens they # Whether the scheduled requests have all the output tokens they
# need to perform grammar bitmask computation. # need to perform grammar bitmask computation.
pending_structured_output_tokens: bool = False pending_structured_output_tokens: bool = False
@ -193,6 +203,20 @@ class SchedulerOutput:
# EC Cache Connector metadata # EC Cache Connector metadata
ec_connector_metadata: ECConnectorMetadata | None = None ec_connector_metadata: ECConnectorMetadata | None = None
@classmethod
def make_empty(cls) -> "SchedulerOutput":
return cls(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
)
@dataclass @dataclass
class GrammarOutput: class GrammarOutput:

View File

@ -6,6 +6,7 @@ from collections import defaultdict
from collections.abc import Iterable from collections.abc import Iterable
from typing import Any from typing import Any
from vllm import envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.ec_transfer.ec_connector.base import ( from vllm.distributed.ec_transfer.ec_connector.base import (
ECConnectorMetadata, ECConnectorMetadata,
@ -187,6 +188,7 @@ class Scheduler(SchedulerInterface):
pcp_world_size=self.pcp_world_size, pcp_world_size=self.pcp_world_size,
) )
self.use_pp = self.parallel_config.pipeline_parallel_size > 1 self.use_pp = self.parallel_config.pipeline_parallel_size > 1
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
def schedule(self) -> SchedulerOutput: def schedule(self) -> SchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm: # NOTE(woosuk) on the scheduling algorithm:
@ -658,12 +660,25 @@ class Scheduler(SchedulerInterface):
) )
# Construct the scheduler output. # Construct the scheduler output.
if self.use_v2_model_runner:
scheduled_new_reqs = scheduled_new_reqs + scheduled_resumed_reqs
scheduled_resumed_reqs = []
new_reqs_data = [
NewRequestData.from_request(
req,
req_to_new_blocks[req.request_id].get_block_ids(),
req._all_token_ids,
)
for req in scheduled_new_reqs
]
else:
new_reqs_data = [ new_reqs_data = [
NewRequestData.from_request( NewRequestData.from_request(
req, req_to_new_blocks[req.request_id].get_block_ids() req, req_to_new_blocks[req.request_id].get_block_ids()
) )
for req in scheduled_new_reqs for req in scheduled_new_reqs
] ]
with record_function_or_nullcontext("schedule: make_cached_request_data"): with record_function_or_nullcontext("schedule: make_cached_request_data"):
cached_reqs_data = self._make_cached_request_data( cached_reqs_data = self._make_cached_request_data(
scheduled_running_reqs, scheduled_running_reqs,
@ -685,6 +700,7 @@ class Scheduler(SchedulerInterface):
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
scheduled_encoder_inputs=scheduled_encoder_inputs, scheduled_encoder_inputs=scheduled_encoder_inputs,
num_common_prefix_blocks=num_common_prefix_blocks, num_common_prefix_blocks=num_common_prefix_blocks,
preempted_req_ids={req.request_id for req in preempted_reqs},
# finished_req_ids is an existing state in the scheduler, # finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step. # instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between # It contains the request IDs that are finished in between

View File

@ -0,0 +1,4 @@
# [Experimental] Model Runner V2
This directory contains the new model runner which is under active development.
Ping [Woosuk Kwon](https://github.com/WoosukKwon) for any changes.

View File

View File

@ -0,0 +1,89 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
import numpy as np
import torch
from vllm.v1.outputs import (
AsyncModelRunnerOutput,
ModelRunnerOutput,
SamplerOutput,
)
class AsyncOutput(AsyncModelRunnerOutput):
def __init__(
self,
model_runner_output: ModelRunnerOutput,
sampler_output: SamplerOutput,
num_sampled_tokens: np.ndarray,
copy_stream: torch.cuda.Stream,
copy_event: torch.cuda.Event,
):
self.model_runner_output = model_runner_output
self.sampler_output = sampler_output
self.num_sampled_tokens = num_sampled_tokens
self.copy_stream = copy_stream
self.copy_event = copy_event
default_stream = torch.cuda.current_stream()
with torch.cuda.stream(self.copy_stream):
self.copy_stream.wait_stream(default_stream)
# NOTE(woosuk): We must ensure that CPU tensors are not freed
# before the device-to-host copy is fully completed. For instance,
# operations like
# self.sampled_token_np = ...to("cpu", non_blocking=True).numpy()
# are unsafe because the underlying CPU tensor can be prematurely freed and
# reused by other tensors before the asynchronous copy finishes, potentially
# causing race conditions. To prevent this, we delay freeing by holding
# references until the copy event signals completion.
# Likewise, we also need to keep the reference to the GPU tensors.
# This is done by keeping the reference to sampler_output and
# model_runner_output.
self.sampled_token_ids = sampler_output.sampled_token_ids.to(
"cpu", non_blocking=True
)
if sampler_output.logprobs_tensors is not None:
self.logprobs_tensors = (
sampler_output.logprobs_tensors.to_cpu_nonblocking()
)
else:
self.logprobs_tensors = None
self.prompt_logprobs_dict = {}
if self.model_runner_output.prompt_logprobs_dict:
for k, v in self.model_runner_output.prompt_logprobs_dict.items():
self.prompt_logprobs_dict[k] = v.to_cpu_nonblocking()
self.copy_event.record(self.copy_stream)
def get_output(self) -> ModelRunnerOutput:
self.copy_event.synchronize()
# NOTE(woosuk): The following code is to ensure compatibility with
# the existing model runner.
# Going forward, we should keep the data structures as NumPy arrays
# rather than Python lists.
sampled_token_ids_np = self.sampled_token_ids.numpy()
num_reqs = sampled_token_ids_np.shape[0]
sampled_token_ids: list[np.ndarray] = [
sampled_token_ids_np[i, : self.num_sampled_tokens[i]]
for i in range(num_reqs)
]
self.model_runner_output.sampled_token_ids = sampled_token_ids
if self.logprobs_tensors is not None:
self.model_runner_output.logprobs = self.logprobs_tensors.tolists()
self.model_runner_output.prompt_logprobs_dict = self.prompt_logprobs_dict
return self.model_runner_output
@contextmanager
def async_barrier(event: torch.cuda.Event | None):
if event is not None:
event.synchronize()
try:
yield
finally:
if event is not None:
event.record()

View File

@ -0,0 +1,187 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import Any
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.kv_cache_interface import (
KVCacheConfig,
KVCacheSpec,
)
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.utils import bind_kv_cache
def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
kv_cache_spec: dict[str, KVCacheSpec] = {}
attn_layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase)
for layer_name, attn_module in attn_layers.items():
# Skip modules that don't need KV cache (eg encoder-only attention)
if spec := attn_module.get_kv_cache_spec(vllm_config):
kv_cache_spec[layer_name] = spec
return kv_cache_spec
def init_attn_backend(
kv_cache_config: KVCacheConfig,
vllm_config: VllmConfig,
device: torch.device,
):
attn_backends: dict[str, AttentionBackend] = {}
attn_metadata_builders: list[AttentionMetadataBuilder] = []
flashinfer_workspace: torch.Tensor | None = None
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
layer_names = kv_cache_group_spec.layer_names
any_layer_name = next(iter(layer_names))
attn_layers = get_layers_from_vllm_config(
vllm_config, AttentionLayerBase, layer_names
)
attn_backend = attn_layers[any_layer_name].get_attn_backend()
for layer_name in layer_names:
attn_backends[layer_name] = attn_backend
attn_metadata_builder = attn_backend.get_builder_cls()(
kv_cache_group_spec.kv_cache_spec,
layer_names,
vllm_config,
device,
)
attn_metadata_builders.append(attn_metadata_builder) # type: ignore
if "FLASHINFER" in attn_backend.get_name():
if flashinfer_workspace is None:
flashinfer_workspace = attn_metadata_builder._get_workspace_buffer()
else:
attn_metadata_builder.set_workspace_buffer(flashinfer_workspace)
return attn_backends, attn_metadata_builders
def _allocate_kv_cache(
kv_cache_config: KVCacheConfig,
device: torch.device,
):
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=device)
for layer_name in kv_cache_tensor.shared_by:
kv_cache_raw_tensors[layer_name] = tensor
layer_names = set()
for group in kv_cache_config.kv_cache_groups:
for layer_name in group.layer_names:
layer_names.add(layer_name)
assert layer_names == set(kv_cache_raw_tensors.keys()), (
"Some layers are not correctly initialized"
)
return kv_cache_raw_tensors
def _reshape_kv_cache(
kv_cache_config: KVCacheConfig,
kv_cache_raw_tensors: dict[str, torch.Tensor],
attn_backends: dict[str, AttentionBackend],
) -> dict[str, torch.Tensor]:
kv_caches: dict[str, torch.Tensor] = {}
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
for layer_name in kv_cache_group_spec.layer_names:
raw_tensor = kv_cache_raw_tensors[layer_name]
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
attn_backend = attn_backends[layer_name]
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks,
kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
)
# FIXME(woosuk): Add kv_cache_stride_order to all attention backends.
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
assert len(kv_cache_stride_order) == len(kv_cache_shape)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
inv_order = [
kv_cache_stride_order.index(i)
for i in range(len(kv_cache_stride_order))
]
dtype = kv_cache_spec.dtype
raw_tensor = raw_tensor.view(dtype)
raw_tensor = raw_tensor.view(kv_cache_shape)
kv_caches[layer_name] = raw_tensor.permute(*inv_order)
return kv_caches
def init_kv_cache(
runner_kv_caches: list[torch.Tensor],
forward_context: dict[str, Any],
kv_cache_config: KVCacheConfig,
attn_backends: dict[str, AttentionBackend],
device: torch.device,
) -> None:
kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device)
kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors, attn_backends)
bind_kv_cache(kv_caches, forward_context, runner_kv_caches)
def build_attn_metadata(
attn_metadata_builders: list[AttentionMetadataBuilder],
num_reqs: int,
num_tokens: int,
query_start_loc: CpuGpuBuffer,
seq_lens: CpuGpuBuffer,
num_computed_tokens_cpu: torch.Tensor,
block_tables: Sequence[torch.Tensor],
slot_mappings: torch.Tensor,
kv_cache_config: KVCacheConfig,
) -> dict[str, Any]:
query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
query_start_loc_cpu = query_start_loc.cpu[: num_reqs + 1]
max_query_len = int(query_start_loc.np[: num_reqs + 1].max())
seq_lens_gpu = seq_lens.gpu[:num_reqs]
seq_lens_cpu = seq_lens.cpu[:num_reqs]
max_seq_len = int(seq_lens.np[:num_reqs].max())
attn_metadata: dict[str, Any] = {}
kv_cache_groups = kv_cache_config.kv_cache_groups
for i, kv_cache_spec in enumerate(kv_cache_groups):
block_table = block_tables[i]
slot_mapping = slot_mappings[i]
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc_gpu,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens_gpu,
seq_lens_cpu=seq_lens_cpu,
max_seq_len=max_seq_len,
num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
block_table_tensor=block_table,
slot_mapping=slot_mapping,
causal=True,
)
attn_metadata_builder = attn_metadata_builders[i]
metadata = attn_metadata_builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
for layer_name in kv_cache_spec.layer_names:
attn_metadata[layer_name] = metadata
return attn_metadata

View File

@ -0,0 +1,315 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
import torch
import triton
import triton.language as tl
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.utils.math_utils import cdiv
from vllm.v1.utils import CpuGpuBuffer
class BlockTables:
def __init__(
self,
block_sizes: list[int],
max_num_reqs: int,
max_num_batched_tokens: int,
max_model_len: int,
device: torch.device,
pin_memory: bool,
):
self.block_sizes = block_sizes
self.max_num_reqs = max_num_reqs
self.max_num_batched_tokens = max_num_batched_tokens
self.max_model_len = max_model_len
self.device = device
self.pin_memory = pin_memory
self.num_kv_cache_groups = len(self.block_sizes)
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
self.block_tables: list[torch.Tensor] = []
for i in range(self.num_kv_cache_groups):
block_size = self.block_sizes[i]
max_num_blocks = cdiv(self.max_model_len, block_size)
block_table = torch.zeros(
self.max_num_reqs,
max_num_blocks,
dtype=torch.int32,
device=self.device,
)
self.block_tables.append(block_table)
self.block_table_ptrs = self._make_ptr_tensor(self.block_tables)
# Block tables used for model's forward pass.
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
self.input_block_tables: list[torch.Tensor] = [
torch.zeros_like(block_table) for block_table in self.block_tables
]
self.input_block_table_ptrs = self._make_ptr_tensor(self.input_block_tables)
self.block_table_strides = torch.tensor(
[b.stride(0) for b in self.block_tables],
dtype=torch.int64,
device=self.device,
)
self.block_sizes_tensor = torch.tensor(
self.block_sizes, dtype=torch.int32, device=self.device
)
self.num_blocks = torch.zeros(
self.num_kv_cache_groups,
self.max_num_reqs,
dtype=torch.int32,
device=self.device,
)
self.slot_mappings = torch.zeros(
self.num_kv_cache_groups,
self.max_num_batched_tokens,
dtype=torch.int64,
device=self.device,
)
# Misc buffers.
self.req_indices = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
self.overwrite = self._make_buffer(self.max_num_reqs, dtype=torch.bool)
self.cu_num_new_blocks = self._make_buffer(
self.num_kv_cache_groups, self.max_num_reqs + 1, dtype=torch.int32
)
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(
*args, dtype=dtype, pin_memory=self.pin_memory, device=self.device
)
def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor:
# NOTE(woosuk): Use uint64 instead of int64 to cover all possible addresses.
ptrs_tensor_cpu = torch.tensor(
[t.data_ptr() for t in x],
dtype=torch.uint64,
device="cpu",
pin_memory=self.pin_memory,
)
return ptrs_tensor_cpu.to(self.device, non_blocking=True)
def append_block_ids(
self,
# [num_reqs]
req_indices: list[int],
# [num_kv_cache_groups, num_reqs + 1]
cu_num_new_blocks: tuple[list[int], ...],
# [num_kv_cache_groups, num_new_blocks]
new_block_ids: tuple[list[int], ...],
# [num_reqs]
overwrite: list[bool],
) -> None:
num_reqs = len(req_indices)
self.req_indices.np[:num_reqs] = req_indices
self.overwrite.np[:num_reqs] = overwrite
for i in range(self.num_kv_cache_groups):
self.cu_num_new_blocks.np[i, : num_reqs + 1] = cu_num_new_blocks[i]
# NOTE(woosuk): Here, we cannot use a fixed-size buffer because there's
# no clear upper bound to the number of new blocks in a single step.
# NOTE(woosuk): The buffer has to be cached, because otherwise we cannot
# guarantee that the buffer is not freed before the copy is completed.
self.new_block_ids_cpu = torch.empty(
self.num_kv_cache_groups,
max(len(x) for x in new_block_ids),
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory,
)
new_block_ids_np = self.new_block_ids_cpu.numpy()
for i in range(self.num_kv_cache_groups):
new_block_ids_np[i, : len(new_block_ids[i])] = new_block_ids[i]
new_block_ids_gpu = self.new_block_ids_cpu.to(self.device, non_blocking=True)
_append_block_ids_kernel[(self.num_kv_cache_groups, num_reqs)](
self.req_indices.copy_to_gpu(num_reqs),
self.cu_num_new_blocks.copy_to_gpu(),
self.cu_num_new_blocks.gpu.stride(0),
new_block_ids_gpu,
new_block_ids_gpu.stride(0),
self.overwrite.copy_to_gpu(num_reqs),
self.block_table_strides,
self.block_table_ptrs,
self.num_blocks,
self.num_blocks.stride(0),
BLOCK_SIZE=1024, # type: ignore
)
def gather_block_tables(
self,
idx_mapping: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
num_reqs = idx_mapping.shape[0]
_gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs)](
idx_mapping,
self.block_table_ptrs,
self.input_block_table_ptrs,
self.block_table_strides,
self.num_blocks,
self.num_blocks.stride(0),
BLOCK_SIZE=1024, # type: ignore
)
return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)
def get_dummy_block_tables(self, num_reqs: int) -> tuple[torch.Tensor, ...]:
return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)
def compute_slot_mappings(
self,
query_start_loc: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
num_reqs = query_start_loc.shape[0] - 1
num_tokens = positions.shape[0]
num_groups = self.num_kv_cache_groups
_compute_slot_mappings_kernel[(num_groups, num_reqs + 1)](
num_tokens,
self.max_num_batched_tokens,
query_start_loc,
positions,
self.input_block_table_ptrs,
self.block_table_strides,
self.block_sizes_tensor,
self.slot_mappings,
self.slot_mappings.stride(0),
PAD_ID=PAD_SLOT_ID,
BLOCK_SIZE=1024, # type: ignore
)
return self.slot_mappings[:, :num_tokens]
def get_dummy_slot_mappings(self, num_tokens: int) -> torch.Tensor:
self.slot_mappings.fill_(PAD_SLOT_ID)
return self.slot_mappings[:, :num_tokens]
@triton.jit
def _append_block_ids_kernel(
# Inputs
req_indices, # [num_reqs]
cu_num_new_blocks_ptr, # [num_kv_cache_groups, num_reqs + 1]
cu_num_new_blocks_stride,
new_block_ids_ptr, # [num_kv_cache_groups, num_new_blocks]
new_block_ids_stride,
overwrite, # [num_reqs]
block_table_strides, # [num_kv_cache_groups]
# Outputs
block_table_ptrs, # [num_kv_cache_groups]
num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
num_blocks_stride,
# Constants
BLOCK_SIZE: tl.constexpr,
):
group_id = tl.program_id(0)
batch_idx = tl.program_id(1)
req_idx = tl.load(req_indices + batch_idx)
do_overwrite = tl.load(overwrite + batch_idx)
group_new_blocks_ptr = cu_num_new_blocks_ptr + group_id * cu_num_new_blocks_stride
start_idx = tl.load(group_new_blocks_ptr + batch_idx)
end_idx = tl.load(group_new_blocks_ptr + batch_idx + 1)
num_new_blocks = end_idx - start_idx
group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
dst_start_idx = tl.load(group_num_blocks_ptr + req_idx) if not do_overwrite else 0
dst_end_idx = dst_start_idx + num_new_blocks
tl.store(group_num_blocks_ptr + req_idx, dst_end_idx)
# Destination
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
block_table_stride = tl.load(block_table_strides + group_id)
row_ptr = block_table_ptr + req_idx * block_table_stride
group_new_block_ids_ptr = new_block_ids_ptr + group_id * new_block_ids_stride
for i in range(0, num_new_blocks, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
block_ids = tl.load(
group_new_block_ids_ptr + start_idx + offset, mask=offset < num_new_blocks
)
tl.store(
row_ptr + dst_start_idx + offset, block_ids, mask=offset < num_new_blocks
)
@triton.jit
def _gather_block_tables_kernel(
batch_idx_to_req_idx, # [batch_size]
src_block_table_ptrs, # [num_kv_cache_groups]
dst_block_table_ptrs, # [num_kv_cache_groups]
block_table_strides, # [num_kv_cache_groups]
num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
num_blocks_stride,
BLOCK_SIZE: tl.constexpr,
):
# kv cache group id
group_id = tl.program_id(0)
batch_idx = tl.program_id(1)
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
num_blocks = tl.load(group_num_blocks_ptr + req_idx)
stride = tl.load(block_table_strides + group_id)
src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32)
src_row_ptr = src_block_table_ptr + req_idx * stride
dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32)
dst_row_ptr = dst_block_table_ptr + batch_idx * stride
for i in tl.range(0, num_blocks, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
block_ids = tl.load(src_row_ptr + offset, mask=offset < num_blocks)
tl.store(dst_row_ptr + offset, block_ids, mask=offset < num_blocks)
@triton.jit
def _compute_slot_mappings_kernel(
num_tokens,
max_num_tokens,
cu_num_tokens, # [num_reqs + 1]
pos, # [num_tokens]
block_table_ptrs, # [num_kv_cache_groups]
block_table_strides, # [num_kv_cache_groups]
page_sizes, # [num_kv_cache_groups]
slot_mappings_ptr, # [num_kv_cache_groups, max_num_tokens]
slot_mappings_stride,
PAD_ID: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
# kv cache group id
group_id = tl.program_id(0)
req_idx = tl.program_id(1)
slot_mapping_ptr = slot_mappings_ptr + group_id * slot_mappings_stride
if req_idx == tl.num_programs(1) - 1:
# Pad remaining slots to -1. This is needed for CUDA graphs.
for i in range(num_tokens, max_num_tokens, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
tl.store(slot_mapping_ptr + offset, PAD_ID, mask=offset < max_num_tokens)
return
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
block_table_stride = tl.load(block_table_strides + group_id)
page_size = tl.load(page_sizes + group_id)
start_idx = tl.load(cu_num_tokens + req_idx)
end_idx = tl.load(cu_num_tokens + req_idx + 1)
for i in range(start_idx, end_idx, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
block_indices = positions // page_size
block_numbers = tl.load(
block_table_ptr + req_idx * block_table_stride + block_indices
)
slot_ids = block_numbers * page_size + positions % page_size
tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)
@triton.jit
def _load_ptr(ptr_to_ptr, elem_dtype):
ptr = tl.load(ptr_to_ptr)
ptr = tl.cast(ptr, tl.pointer_type(elem_dtype))
return tl.multiple_of(ptr, 16)

View File

@ -0,0 +1,198 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
from contextlib import contextmanager
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
from vllm.forward_context import set_forward_context
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.input_batch import InputBuffers
class CudaGraphManager:
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
self.vllm_config = vllm_config
self.device = device
self.max_model_len = vllm_config.model_config.max_model_len
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.compilation_config = vllm_config.compilation_config
assert self.compilation_config is not None
self.cudagraph_mode = self.compilation_config.cudagraph_mode
self.cudagraph_sizes = sorted(self.compilation_config.cudagraph_capture_sizes)
self.padded_sizes = self._init_padded_sizes()
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
self.pool = torch.cuda.graph_pool_handle()
self.hidden_states: torch.Tensor | None = None
def _init_padded_sizes(self) -> dict[int, int]:
if not self.cudagraph_mode.has_full_cudagraphs():
# Full cuda graphs are not used.
return {}
padded_sizes: dict[int, int] = {}
assert len(self.cudagraph_sizes) > 0
for i in range(1, self.cudagraph_sizes[-1] + 1):
for x in self.cudagraph_sizes:
if i <= x:
padded_sizes[i] = x
break
return padded_sizes
def needs_capture(self) -> bool:
return len(self.padded_sizes) > 0
def get_cudagraph_size(
self,
scheduler_output: SchedulerOutput,
num_tokens_after_padding: int,
) -> int | None:
if not self.cudagraph_mode.has_full_cudagraphs():
return None
if self.cudagraph_mode != CUDAGraphMode.FULL:
# TODO(woosuk): Support uniform decode with multiple tokens (spec decoding).
all_decode = all(
x == 1 for x in scheduler_output.num_scheduled_tokens.values()
)
if not all_decode:
# Prefill is included.
return None
return self.padded_sizes.get(num_tokens_after_padding)
def capture_graph(
self,
batch_size: int,
model: nn.Module,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
) -> None:
assert batch_size not in self.graphs
# Prepare dummy inputs.
input_ids = input_buffers.input_ids.gpu[:batch_size]
positions = input_buffers.positions.gpu[:batch_size]
input_buffers.query_start_loc.np[: batch_size + 1] = np.arange(batch_size + 1)
input_buffers.query_start_loc.np[batch_size:] = batch_size
input_buffers.query_start_loc.copy_to_gpu()
input_buffers.seq_lens.np[:batch_size] = self.max_model_len
input_buffers.seq_lens.np[batch_size:] = 0
input_buffers.seq_lens.copy_to_gpu()
input_block_tables = [x[:batch_size] for x in block_tables.input_block_tables]
slot_mappings = block_tables.slot_mappings[:, :batch_size]
attn_metadata = build_attn_metadata(
attn_metadata_builders=attn_metadata_builders,
num_reqs=batch_size,
num_tokens=batch_size,
query_start_loc=input_buffers.query_start_loc,
seq_lens=input_buffers.seq_lens,
num_computed_tokens_cpu=None, # FIXME
block_tables=input_block_tables,
slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config,
)
if self.dp_size > 1:
num_tokens_across_dp = torch.full(
(self.dp_size,),
batch_size,
dtype=torch.int32,
device="cpu",
)
else:
num_tokens_across_dp = None
# Warm up.
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=batch_size,
num_tokens_across_dp=num_tokens_across_dp,
):
hidden_states = model(
input_ids=input_ids,
positions=positions,
)
if self.hidden_states is None:
self.hidden_states = torch.empty_like(hidden_states)
torch.cuda.synchronize()
# Capture the graph.
graph = torch.cuda.CUDAGraph()
with (
set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=batch_size,
num_tokens_across_dp=num_tokens_across_dp,
),
torch.cuda.graph(graph, self.pool),
):
hidden_states = model(
input_ids=input_ids,
positions=positions,
)
self.hidden_states[:batch_size] = hidden_states
self.graphs[batch_size] = graph
@torch.inference_mode()
def capture(
self,
model: nn.Module,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
) -> None:
assert self.needs_capture()
# Capture larger graphs first.
sizes_to_capture = sorted(self.cudagraph_sizes, reverse=True)
if is_global_first_rank():
sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs")
with freeze_gc(), graph_capture(device=self.device):
for batch_size in sizes_to_capture:
self.capture_graph(
batch_size,
model,
input_buffers,
block_tables,
attn_metadata_builders,
kv_cache_config,
)
def run(self, batch_size: int) -> torch.Tensor:
assert batch_size in self.graphs
self.graphs[batch_size].replay()
assert self.hidden_states is not None
return self.hidden_states[:batch_size]
@contextmanager
def freeze_gc():
gc.collect()
gc.freeze()
try:
yield
finally:
gc.unfreeze()

View File

@ -0,0 +1,22 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.distributed as dist
from vllm.distributed.parallel_state import get_dp_group
def get_batch_metadata_across_dp(
num_tokens: int,
cudagraph_size: int,
dp_size: int,
dp_rank: int,
) -> tuple[torch.Tensor, torch.Tensor]:
assert dp_size > 1
# Use CPU group to avoid CPU-GPU synchronization.
group = get_dp_group().cpu_group
tensor = torch.zeros(2, dp_size, dtype=torch.int32, device="cpu")
tensor[0][dp_rank] = num_tokens
tensor[1][dp_rank] = cudagraph_size
dist.all_reduce(tensor, group=group)
return tensor[0], tensor[1]

View File

@ -0,0 +1,265 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any
import numba
import numba.types as types
import numpy as np
import torch
import triton
import triton.language as tl
from vllm.utils import random_uuid
from vllm.utils.math_utils import cdiv
from vllm.v1.utils import CpuGpuBuffer
class InputBuffers:
def __init__(
self,
max_num_reqs: int,
max_num_tokens: int,
hidden_size: int,
vocab_size: int,
dtype: torch.dtype,
device: torch.device,
pin_memory: bool,
):
self.max_num_reqs = max_num_reqs
self.max_num_tokens = max_num_tokens
self.device = device
self.pin_memory = pin_memory
self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32)
self.input_ids = self._make_buffer(max_num_tokens, dtype=torch.int32)
self.positions = self._make_buffer(max_num_tokens, dtype=torch.int64)
self.query_start_loc = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
self.seq_lens = self._make_buffer(max_num_reqs, dtype=torch.int32)
# Structured outputs.
self.bitmask_indices = self._make_buffer(max_num_reqs, dtype=torch.int32)
self.grammar_bitmask = self._make_buffer(
max_num_reqs, cdiv(vocab_size, 32), dtype=torch.int32
)
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(
*args, dtype=dtype, pin_memory=self.pin_memory, device=self.device
)
@dataclass
class InputBatch:
# batch_idx -> req_id
req_ids: list[str]
num_reqs: int
# batch_idx -> req_state_idx
idx_mapping: torch.Tensor
idx_mapping_np: np.ndarray
# [num_reqs]
# batch_idx -> num_scheduled_tokens
num_scheduled_tokens: np.ndarray
# sum(num_scheduled_tokens)
num_tokens: int
num_tokens_after_padding: int
# [num_reqs + 1]
query_start_loc: torch.Tensor
query_start_loc_np: np.ndarray
# [num_reqs]
seq_lens: torch.Tensor
seq_lens_np: np.ndarray
# [num_tokens_after_padding]
input_ids: torch.Tensor
# [num_tokens_after_padding]
positions: torch.Tensor
# layer_name -> Metadata
attn_metadata: dict[str, Any]
# [num_reqs]
logits_indices: torch.Tensor
@classmethod
def make_dummy(
cls,
num_reqs: int,
num_tokens: int,
input_buffers: InputBuffers,
device: torch.device,
) -> "InputBatch":
assert 0 < num_reqs <= num_tokens
req_ids = [f"req_{i}_{random_uuid()}" for i in range(num_reqs)]
idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
num_scheduled_tokens = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32)
num_scheduled_tokens[-1] += num_tokens % num_reqs
assert int(num_scheduled_tokens.sum()) == num_tokens
input_buffers.query_start_loc.np[0] = 0
input_buffers.query_start_loc.np[1 : num_reqs + 1] = np.cumsum(
num_scheduled_tokens
)
input_buffers.query_start_loc.np[num_reqs + 1 :] = num_tokens
query_start_loc_np = input_buffers.query_start_loc.np[: num_reqs + 1]
query_start_loc = input_buffers.query_start_loc.copy_to_gpu()[: num_reqs + 1]
# seq_len equals to query_len
input_buffers.seq_lens.np[:num_reqs] = num_scheduled_tokens
input_buffers.seq_lens.np[num_reqs:] = 0
seq_lens_np = input_buffers.seq_lens.np[:num_reqs]
seq_lens = input_buffers.seq_lens.copy_to_gpu()[:num_reqs]
input_ids = input_buffers.input_ids.copy_to_gpu(num_tokens)
positions = input_buffers.positions.copy_to_gpu(num_tokens)
# attn_metadata = defaultdict(lambda: None)
logits_indices = query_start_loc[1:] - 1
return cls(
req_ids=req_ids,
num_reqs=num_reqs,
idx_mapping=idx_mapping,
idx_mapping_np=idx_mapping_np,
num_scheduled_tokens=num_scheduled_tokens,
num_tokens=num_tokens,
num_tokens_after_padding=num_tokens,
query_start_loc=query_start_loc,
query_start_loc_np=query_start_loc_np,
seq_lens=seq_lens,
seq_lens_np=seq_lens_np,
input_ids=input_ids,
positions=positions,
attn_metadata=None, # type: ignore
logits_indices=logits_indices,
)
# NOTE: With the type annotations, this function is pre-compiled
# before the first call.
@numba.jit(
[
types.none(
types.int32[:], # idx_mapping
types.int32[:, :], # token_ids
types.int32[:], # num_computed_tokens
types.int32[:], # num_scheduled_tokens
types.int32[:], # input_ids
types.int64[:], # positions
types.int32[:], # query_start_loc
types.int32[:], # seq_lens
)
],
nopython=True,
cache=True,
)
def _prepare_inputs(
idx_mapping: np.ndarray, # batch_idx -> req_idx
token_ids: np.ndarray, # [N, max_model_len]
num_computed_tokens: np.ndarray, # [N]
num_scheduled_tokens: np.ndarray, # [B]
input_ids: np.ndarray, # [num_input_tokens]
positions: np.ndarray, # [num_input_tokens]
query_start_loc: np.ndarray, # [B + 1]
seq_lens: np.ndarray, # [B]
) -> None:
num_reqs = num_scheduled_tokens.shape[0]
query_start_loc[0] = 0
cu_num_tokens = 0
for i in range(num_reqs):
req_idx = idx_mapping[i]
query_len = num_scheduled_tokens[i]
start = num_computed_tokens[req_idx]
end = start + query_len
seq_lens[i] = end
start_idx = cu_num_tokens
end_idx = start_idx + query_len
input_ids[start_idx:end_idx] = token_ids[req_idx, start:end]
positions[start_idx:end_idx] = np.arange(start, end, dtype=np.int64)
cu_num_tokens = end_idx
query_start_loc[i + 1] = cu_num_tokens
# Pad the inputs for CUDA graphs.
# Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that
query_start_loc[num_reqs + 1 :].fill(cu_num_tokens)
# Fill unused with 0 for full cuda graph mode.
seq_lens[num_reqs:].fill(0)
def prepare_inputs(
idx_mapping: np.ndarray,
prefill_token_ids: np.ndarray,
num_computed_tokens: np.ndarray,
num_scheduled_tokens: np.ndarray,
input_ids: CpuGpuBuffer,
positions: CpuGpuBuffer,
query_start_loc: CpuGpuBuffer,
seq_lens: CpuGpuBuffer,
num_tokens: int,
) -> None:
_prepare_inputs(
idx_mapping,
prefill_token_ids,
num_computed_tokens,
num_scheduled_tokens,
input_ids.np,
positions.np,
query_start_loc.np,
seq_lens.np,
)
input_ids.copy_to_gpu(num_tokens)
positions.copy_to_gpu(num_tokens)
# NOTE(woosuk): We should copy the whole query_start_loc and seq_lens
# tensors from CPU to GPU, because they may include paddings needed
# for full CUDA graph mode.
query_start_loc.copy_to_gpu()
seq_lens.copy_to_gpu()
return
@triton.jit
def _combine_last_token_ids_kernel(
input_ids_ptr,
idx_mapping_ptr,
last_token_ids_ptr,
query_start_loc_ptr,
seq_lens_ptr,
prefill_len_ptr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
seq_len = tl.load(seq_lens_ptr + batch_idx)
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
if seq_len <= prefill_len:
# Handling prefill tokens.
return
last_token_id = tl.load(last_token_ids_ptr + req_state_idx)
end = tl.load(query_start_loc_ptr + batch_idx + 1)
tl.store(input_ids_ptr + end - 1, last_token_id)
def combine_last_token_ids(
input_ids: torch.Tensor,
idx_mapping: torch.Tensor,
last_token_ids: torch.Tensor,
query_start_loc: torch.Tensor,
seq_lens: torch.Tensor,
prefill_len: torch.Tensor,
) -> torch.Tensor:
num_reqs = seq_lens.shape[0]
_combine_last_token_ids_kernel[(num_reqs,)](
input_ids,
idx_mapping,
last_token_ids,
query_start_loc,
seq_lens,
prefill_len,
)
return input_ids

View File

@ -0,0 +1,814 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
import time
from copy import deepcopy
from typing import Any
import numpy as np
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_loader
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import DeviceMemoryProfiler
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import (
EMPTY_MODEL_RUNNER_OUTPUT,
LogprobsTensors,
ModelRunnerOutput,
)
from vllm.v1.sample.sampler import SamplerOutput
from vllm.v1.worker.gpu.async_utils import AsyncOutput, async_barrier
from vllm.v1.worker.gpu.attn_utils import (
build_attn_metadata,
get_kv_cache_spec,
init_attn_backend,
init_kv_cache,
)
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
from vllm.v1.worker.gpu.dp_utils import get_batch_metadata_across_dp
from vllm.v1.worker.gpu.input_batch import (
InputBatch,
InputBuffers,
combine_last_token_ids,
prepare_inputs,
)
from vllm.v1.worker.gpu.sampler import Sampler, compute_prompt_logprobs
from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata
from vllm.v1.worker.gpu.structured_outputs import apply_grammar_bitmask
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
logger = init_logger(__name__)
class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.compilation_config = vllm_config.compilation_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.speculative_config = vllm_config.speculative_config
self.observability_config = vllm_config.observability_config
self.device = device
self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype
self.kv_cache_dtype = self.dtype
if self.cache_config.cache_dtype != "auto":
# Quantized KV cache.
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
self.cache_config.cache_dtype
]
self.is_pooling_model = False
self.vocab_size = self.model_config.get_vocab_size()
self.max_model_len = self.model_config.max_model_len
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.hidden_size = self.model_config.get_hidden_size()
self.dp_size = self.parallel_config.data_parallel_size
self.dp_rank = self.parallel_config.data_parallel_rank
self.use_async_scheduling = self.scheduler_config.async_scheduling
self.output_copy_stream = torch.cuda.Stream(self.device)
self.output_copy_event = torch.cuda.Event()
if self.use_async_scheduling:
self.input_prep_event = torch.cuda.Event()
self.structured_outputs_event = torch.cuda.Event()
else:
self.input_prep_event = None
self.structured_outputs_event = None
self.req_states = RequestState(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
vocab_size=self.vocab_size,
device=self.device,
pin_memory=self.pin_memory,
)
self.input_buffers = InputBuffers(
max_num_reqs=self.max_num_reqs,
max_num_tokens=self.max_num_tokens,
hidden_size=self.hidden_size,
vocab_size=self.vocab_size,
dtype=self.dtype,
device=self.device,
pin_memory=self.pin_memory,
)
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
# CUDA graphs.
self.cudagraph_manager = CudaGraphManager(
vllm_config=self.vllm_config,
device=self.device,
)
def get_supported_tasks(self) -> tuple[str]:
return ("generate",)
def load_model(self, *args, **kwargs) -> None:
time_before_load = time.perf_counter()
with DeviceMemoryProfiler() as m:
model_loader = get_model_loader(self.vllm_config.load_config)
logger.info("Loading model from scratch...")
self.model = model_loader.load_model(
vllm_config=self.vllm_config,
model_config=self.vllm_config.model_config,
)
if self.lora_config:
self.model = self.load_lora_model(
self.model,
self.vllm_config,
self.device,
)
time_after_load = time.perf_counter()
self.model_memory_usage = m.consumed_memory
logger.info(
"Model loading took %.4f GiB and %.6f seconds",
m.consumed_memory / GiB_bytes,
time_after_load - time_before_load,
)
def get_model(self) -> nn.Module:
return self.model
def get_kv_cache_spec(self):
return get_kv_cache_spec(self.vllm_config)
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
kv_cache_config = deepcopy(kv_cache_config)
self.kv_cache_config = kv_cache_config
block_sizes = [
kv_cache_group.kv_cache_spec.block_size
for kv_cache_group in kv_cache_config.kv_cache_groups
]
self.block_tables = BlockTables(
block_sizes=block_sizes,
max_num_reqs=self.max_num_reqs,
max_num_batched_tokens=self.max_num_tokens,
max_model_len=self.max_model_len,
device=self.device,
pin_memory=self.pin_memory,
)
self.attn_backends, self.attn_metadata_builders = init_attn_backend(
self.kv_cache_config,
self.vllm_config,
self.device,
)
self.kv_caches: list[torch.Tensor] = []
init_kv_cache(
self.kv_caches,
self.compilation_config.static_forward_context,
self.kv_cache_config,
self.attn_backends,
self.device,
)
# Attention groups are not supported.
self.attn_groups = [] # type: ignore
def prepare_dummy_attn_metadata(self, input_batch: InputBatch) -> None:
block_tables = self.block_tables.get_dummy_block_tables(input_batch.num_reqs)
slot_mappings = self.block_tables.get_dummy_slot_mappings(
input_batch.num_tokens
)
num_computed_tokens_cpu = torch.zeros(
input_batch.num_reqs, dtype=torch.int32, device="cpu"
)
attn_metadata = build_attn_metadata(
attn_metadata_builders=self.attn_metadata_builders,
num_reqs=input_batch.num_reqs,
num_tokens=input_batch.num_tokens,
query_start_loc=self.input_buffers.query_start_loc,
seq_lens=self.input_buffers.seq_lens,
num_computed_tokens_cpu=num_computed_tokens_cpu,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
)
input_batch.attn_metadata = attn_metadata
@torch.inference_mode()
def _dummy_run(
self,
num_tokens: int,
*args,
skip_attn: bool = True,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = min(num_tokens, self.max_num_reqs)
input_batch = InputBatch.make_dummy(
num_reqs=num_reqs,
num_tokens=num_tokens,
input_buffers=self.input_buffers,
device=self.device,
)
if not skip_attn:
self.prepare_dummy_attn_metadata(input_batch)
if self.dp_size == 1:
num_tokens_across_dp: torch.Tensor | None = None
else:
num_tokens_across_dp = torch.full(
(self.dp_size,), num_tokens, dtype=torch.int32, device="cpu"
)
num_sampled_tokens = np.ones(input_batch.num_reqs, dtype=np.int32)
with (
self.maybe_dummy_run_with_lora(
self.lora_config,
input_batch.num_scheduled_tokens,
num_sampled_tokens,
),
set_forward_context(
input_batch.attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp,
),
):
hidden_states = self.model(
input_ids=input_batch.input_ids,
positions=input_batch.positions,
)
sample_hidden_states = hidden_states[input_batch.logits_indices]
return hidden_states, sample_hidden_states
@torch.inference_mode()
def _dummy_sampler_run(
self,
hidden_states: torch.Tensor,
) -> None:
num_reqs = hidden_states.shape[0]
sampling_metadata = SamplingMetadata.make_dummy(
num_reqs=num_reqs,
device=self.device,
)
logits = self.model.compute_logits(hidden_states)
self.sampler(logits, sampling_metadata)
@torch.inference_mode()
def profile_run(self) -> None:
hidden_states, sample_hidden_states = self._dummy_run(
self.max_num_tokens,
skip_attn=True,
)
self._dummy_sampler_run(sample_hidden_states)
torch.cuda.synchronize()
del hidden_states, sample_hidden_states
gc.collect()
def reset_mm_cache(self) -> None:
pass
def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int:
# SP is not supported yet.
return num_scheduled_tokens
@torch.inference_mode()
def capture_model(self) -> int:
if not self.cudagraph_manager.needs_capture():
logger.warning(
"Skipping CUDA graph capture. To turn on CUDA graph capture, "
"ensure `cudagraph_mode` was not manually set to `NONE`"
)
return 0
start_time = time.perf_counter()
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
with self.maybe_setup_dummy_loras(self.lora_config):
self.cudagraph_manager.capture(
model=self.model,
input_buffers=self.input_buffers,
block_tables=self.block_tables,
attn_metadata_builders=self.attn_metadata_builders,
kv_cache_config=self.kv_cache_config,
)
end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
elapsed_time = end_time - start_time
cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
# This usually takes 5~20 seconds.
logger.info(
"Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time,
cuda_graph_size / (1 << 30),
)
return cuda_graph_size
def warmup_for_prefill(self) -> None:
# For FlashInfer, we would like to execute a dummy prefill run
# to trigger JIT compilation.
if all("FLASHINFER" in b.get_name() for b in self.attn_backends.values()):
self._dummy_run(self.max_num_tokens, skip_attn=False)
torch.cuda.synchronize()
def update_states(self, scheduler_output: SchedulerOutput) -> None:
for req_id in scheduler_output.preempted_req_ids:
self.req_states.remove_request(req_id)
for req_id in scheduler_output.finished_req_ids:
self.req_states.remove_request(req_id)
# TODO(woosuk): Change SchedulerOutput.
req_indices: list[int] = []
cu_num_new_blocks = tuple(
[0] for _ in range(self.block_tables.num_kv_cache_groups)
)
new_block_ids: tuple[list[int], ...] = tuple(
[] for _ in range(self.block_tables.num_kv_cache_groups)
)
overwrite: list[bool] = []
# Add new requests.
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
self.req_states.add_request(
req_id=req_id,
prompt_len=len(new_req_data.prompt_token_ids),
prefill_token_ids=new_req_data.prefill_token_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
sampling_params=new_req_data.sampling_params,
lora_request=new_req_data.lora_request,
)
req_index = self.req_states.req_id_to_index[req_id]
req_indices.append(req_index)
for i, block_ids in enumerate(new_req_data.block_ids):
x = cu_num_new_blocks[i][-1]
cu_num_new_blocks[i].append(x + len(block_ids))
new_block_ids[i].extend(block_ids)
overwrite.append(True)
# Add new blocks for the existing requests.
cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids):
req_index = self.req_states.req_id_to_index[req_id]
req_new_block_ids = cached_reqs.new_block_ids[i]
if req_new_block_ids is not None:
req_indices.append(req_index)
for group_id, block_ids in enumerate(req_new_block_ids):
x = cu_num_new_blocks[group_id][-1]
cu_num_new_blocks[group_id].append(x + len(block_ids))
new_block_ids[group_id].extend(block_ids)
overwrite.append(False)
if req_indices:
self.block_tables.append_block_ids(
req_indices=req_indices,
cu_num_new_blocks=cu_num_new_blocks,
new_block_ids=new_block_ids,
overwrite=overwrite,
)
def prepare_inputs(
self,
scheduler_output: SchedulerOutput,
num_tokens_after_padding: int,
) -> InputBatch:
num_tokens = scheduler_output.total_num_scheduled_tokens
assert num_tokens > 0
num_reqs = len(scheduler_output.num_scheduled_tokens)
# Decode first, then prefill.
# batch_idx -> req_id
req_ids = sorted(
scheduler_output.num_scheduled_tokens,
key=scheduler_output.num_scheduled_tokens.get,
)
num_scheduled_tokens = np.array(
[scheduler_output.num_scheduled_tokens[i] for i in req_ids], dtype=np.int32
)
idx_mapping_list = [
self.req_states.req_id_to_index[req_id] for req_id in req_ids
]
idx_mapping = self.input_buffers.idx_mapping
idx_mapping.np[:num_reqs] = idx_mapping_list
idx_mapping_np = idx_mapping.np[:num_reqs]
idx_mapping = idx_mapping.copy_to_gpu(num_reqs)
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables = self.block_tables.gather_block_tables(idx_mapping)
prepare_inputs(
idx_mapping_np,
self.req_states.prefill_token_ids,
self.req_states.num_computed_tokens,
num_scheduled_tokens,
self.input_buffers.input_ids,
self.input_buffers.positions,
self.input_buffers.query_start_loc,
self.input_buffers.seq_lens,
num_tokens,
)
query_start_loc = self.input_buffers.query_start_loc
query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
query_start_loc_np = query_start_loc.np[: num_reqs + 1]
seq_lens_gpu = self.input_buffers.seq_lens.gpu[:num_reqs]
seq_lens_np = self.input_buffers.seq_lens.np[:num_reqs]
# Some input token ids are directly read from the last sampled tokens.
combine_last_token_ids(
self.input_buffers.input_ids.gpu,
idx_mapping,
self.req_states.last_sampled_tokens,
query_start_loc_gpu,
seq_lens_gpu,
self.req_states.prefill_len.copy_to_gpu(),
)
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
slot_mappings = self.block_tables.compute_slot_mappings(
query_start_loc_gpu, self.input_buffers.positions.gpu[:num_tokens]
)
num_computed_tokens_cpu = torch.from_numpy(
self.req_states.num_computed_tokens[idx_mapping_np]
)
# Logits indices to sample next token from.
logits_indices = query_start_loc_gpu[1:] - 1
# Layer name -> attention metadata.
attn_metadata = build_attn_metadata(
attn_metadata_builders=self.attn_metadata_builders,
num_reqs=num_reqs,
num_tokens=num_tokens,
query_start_loc=self.input_buffers.query_start_loc,
seq_lens=self.input_buffers.seq_lens,
num_computed_tokens_cpu=num_computed_tokens_cpu,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
)
input_ids = self.input_buffers.input_ids.gpu[:num_tokens_after_padding]
positions = self.input_buffers.positions.gpu[:num_tokens_after_padding]
return InputBatch(
req_ids=req_ids,
num_reqs=num_reqs,
idx_mapping=idx_mapping,
idx_mapping_np=idx_mapping_np,
num_scheduled_tokens=num_scheduled_tokens,
num_tokens=num_tokens,
num_tokens_after_padding=num_tokens_after_padding,
query_start_loc=query_start_loc_gpu,
query_start_loc_np=query_start_loc_np,
seq_lens=seq_lens_gpu,
seq_lens_np=seq_lens_np,
input_ids=input_ids,
positions=positions,
attn_metadata=attn_metadata,
logits_indices=logits_indices,
)
def sample(
self,
hidden_states: torch.Tensor,
input_batch: InputBatch,
sampling_metadata: SamplingMetadata,
grammar_output: GrammarOutput | None,
) -> SamplerOutput:
sample_hidden_states = hidden_states[input_batch.logits_indices]
logits = self.model.compute_logits(sample_hidden_states)
if grammar_output is not None:
# Apply grammar bitmask to the logits in-place.
with async_barrier(self.structured_outputs_event):
apply_grammar_bitmask(
logits,
input_batch.req_ids,
grammar_output.structured_output_request_ids,
grammar_output.grammar_bitmask,
self.input_buffers,
)
sampler_output = self.sampler(logits, sampling_metadata)
return sampler_output
def compute_prompt_logprobs(
self,
hidden_states: torch.Tensor,
input_batch: InputBatch,
) -> dict[str, LogprobsTensors]:
idx_mapping_np = input_batch.idx_mapping_np
needs_prompt_logprobs = self.req_states.needs_prompt_logprobs[idx_mapping_np]
if not np.any(needs_prompt_logprobs):
# No request asks for prompt logprobs.
return {}
num_computed_tokens = self.req_states.num_computed_tokens[idx_mapping_np]
prompt_lens = self.req_states.prompt_len[idx_mapping_np]
# NOTE(woosuk): -1 because the last prompt token's hidden state is not
# needed for prompt logprobs.
includes_prompt = num_computed_tokens < prompt_lens - 1
# NOTE(woosuk): If the request was resumed after preemption, its prompt
# logprobs must have been computed before preemption. Skip.
resumed_after_prompt = (
prompt_lens < self.req_states.prefill_len.np[idx_mapping_np]
)
needs_prompt_logprobs &= includes_prompt & ~resumed_after_prompt
if not np.any(needs_prompt_logprobs):
return {}
# Just to be safe, clone the input ids.
n = input_batch.num_tokens
# Shift the input ids by one.
token_ids = torch.empty_like(input_batch.input_ids[:n])
token_ids[: n - 1] = input_batch.input_ids[1:n]
# To avoid out-of-bound access, set the last token id to 0.
token_ids[n - 1] = 0
# Handle chunked prompts.
seq_lens = self.input_buffers.seq_lens.np[: input_batch.num_reqs]
is_prompt_chunked = seq_lens < prompt_lens
prefill_token_ids = self.req_states.prefill_token_ids
query_start_loc = self.input_buffers.query_start_loc.np
for i, req_id in enumerate(input_batch.req_ids):
if not needs_prompt_logprobs[i]:
continue
if not is_prompt_chunked[i]:
continue
# The prompt is chunked. Get the next prompt token.
req_idx = input_batch.idx_mapping_np[i]
next_prompt_token = int(prefill_token_ids[req_idx, seq_lens[i]])
idx = int(query_start_loc[i + 1] - 1)
# Set the next prompt token.
# NOTE(woosuk): This triggers a GPU operation.
token_ids[idx] = next_prompt_token
# NOTE(woosuk): We mask out logprobs for negative tokens.
prompt_logprobs, prompt_ranks = compute_prompt_logprobs(
token_ids,
hidden_states[:n],
self.model.compute_logits,
)
prompt_token_ids = token_ids.unsqueeze(-1)
prompt_logprobs_dict: dict[str, LogprobsTensors] = {}
for i, req_id in enumerate(input_batch.req_ids):
if not needs_prompt_logprobs[i]:
continue
start_idx = query_start_loc[i]
end_idx = query_start_loc[i + 1]
assert start_idx < end_idx, (
f"start_idx ({start_idx}) >= end_idx ({end_idx})"
)
logprobs = LogprobsTensors(
logprob_token_ids=prompt_token_ids[start_idx:end_idx],
logprobs=prompt_logprobs[start_idx:end_idx],
selected_token_ranks=prompt_ranks[start_idx:end_idx],
)
req_extra_data = self.req_states.extra_data[req_id]
prompt_logprobs_list = req_extra_data.in_progress_prompt_logprobs
if is_prompt_chunked[i]:
# Prompt is chunked. Do not return the logprobs yet.
prompt_logprobs_list.append(logprobs)
continue
if prompt_logprobs_list:
# Merge the in-progress logprobs.
prompt_logprobs_list.append(logprobs)
logprobs = LogprobsTensors(
logprob_token_ids=torch.cat(
[x.logprob_token_ids for x in prompt_logprobs_list]
),
logprobs=torch.cat([x.logprobs for x in prompt_logprobs_list]),
selected_token_ranks=torch.cat(
[x.selected_token_ranks for x in prompt_logprobs_list]
),
)
prompt_logprobs_list.clear()
prompt_logprobs_dict[req_id] = logprobs
return prompt_logprobs_dict
def postprocess(
self,
sampler_output: SamplerOutput,
prompt_logprobs_dict: dict[str, LogprobsTensors],
input_batch: InputBatch,
) -> AsyncOutput | ModelRunnerOutput:
# Store the last sampled token ids.
self.req_states.last_sampled_tokens[input_batch.idx_mapping] = (
sampler_output.sampled_token_ids
)
# Get the number of sampled tokens.
# 0 if chunked-prefilling, 1 if not.
idx_mapping_np = input_batch.idx_mapping_np
is_chunked_prefilling = (
input_batch.seq_lens_np < self.req_states.num_tokens[idx_mapping_np]
)
num_sampled_tokens = (~is_chunked_prefilling).astype(np.int32)
# Increment the number of tokens.
self.req_states.num_tokens[idx_mapping_np] += num_sampled_tokens
# Increment the number of computed tokens.
self.req_states.num_computed_tokens[idx_mapping_np] += (
input_batch.num_scheduled_tokens
)
model_runner_output = ModelRunnerOutput(
req_ids=input_batch.req_ids,
req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
sampled_token_ids=None,
logprobs=None,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
kv_connector_output=None,
num_nans_in_logits=None,
)
async_output = AsyncOutput(
model_runner_output=model_runner_output,
sampler_output=sampler_output,
num_sampled_tokens=num_sampled_tokens,
copy_stream=self.output_copy_stream,
copy_event=self.output_copy_event,
)
if self.use_async_scheduling:
return async_output
return async_output.get_output()
def get_cudagraph_and_dp_padding(
self,
scheduler_output: SchedulerOutput,
) -> tuple[CUDAGraphMode, int, torch.Tensor | None]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if self.dp_size == 1:
# No DP. Only consider CUDA graphs.
if total_num_scheduled_tokens == 0:
# Special case: no tokens to run.
return CUDAGraphMode.NONE, 0, None
cudagraph_size = self.cudagraph_manager.get_cudagraph_size(
scheduler_output, total_num_scheduled_tokens
)
if cudagraph_size is not None:
# Use full CUDA graph.
return CUDAGraphMode.FULL, cudagraph_size, None
# Fall back to eager mode.
# TODO(woosuk): Support piecewise CUDA graphs.
return CUDAGraphMode.NONE, total_num_scheduled_tokens, None
# Consider DP padding and CUDA graph.
if total_num_scheduled_tokens == 0:
# Special handling is needed for 0.
cudagraph_size_before_dp: int | None = 0
else:
cudagraph_size_before_dp = self.cudagraph_manager.get_cudagraph_size(
scheduler_output, total_num_scheduled_tokens
)
if cudagraph_size_before_dp is None:
cudagraph_size_before_dp = -1
assert cudagraph_size_before_dp is not None
num_tokens_across_dp, cudagraph_size_across_dp = get_batch_metadata_across_dp(
total_num_scheduled_tokens,
cudagraph_size_before_dp,
self.dp_size,
self.dp_rank,
)
if all(cudagraph_size_across_dp >= 0):
# If all ranks can use CUDA graph, pad to the maximum number of tokens
# across DP and use CUDA graph.
num_tokens_after_padding = int(cudagraph_size_across_dp.max().item())
cudagraph_mode = CUDAGraphMode.FULL
else:
# If any of the ranks cannot use CUDA graph, use eager mode for all ranks.
# No padding is needed except for ranks that have no tokens to run.
num_tokens_across_dp = torch.clamp(num_tokens_across_dp, min=1)
num_tokens_after_padding = num_tokens_across_dp[self.dp_rank]
cudagraph_mode = CUDAGraphMode.NONE
return cudagraph_mode, num_tokens_after_padding, num_tokens_across_dp
@torch.inference_mode()
def execute_model(
self,
scheduler_output: SchedulerOutput,
intermediate_tensors: Any | None = None,
dummy_run: bool = False,
) -> ModelRunnerOutput | None:
assert intermediate_tensors is None
if scheduler_output.total_num_scheduled_tokens == 0 and not dummy_run:
# No need to run the model.
with async_barrier(self.input_prep_event):
self.update_states(scheduler_output)
return EMPTY_MODEL_RUNNER_OUTPUT
# NOTE: Call this before the async barrier so CPU all-reduce and
# GPU execution can overlap.
cudagraph_mode, num_tokens_after_padding, num_tokens_across_dp = (
self.get_cudagraph_and_dp_padding(scheduler_output)
)
with async_barrier(self.input_prep_event):
self.update_states(scheduler_output)
if num_tokens_after_padding == 0:
# All DP ranks have zero tokens to run.
return EMPTY_MODEL_RUNNER_OUTPUT
if not dummy_run:
# Common case.
# Prepare all the inputs and copy to the input buffers.
input_batch = self.prepare_inputs(
scheduler_output,
num_tokens_after_padding,
)
# NOTE(woosuk): Sampling metadata should be built under the async
# barrier to avoid race conditions.
pos = input_batch.positions[input_batch.logits_indices]
sampling_metadata = self.req_states.make_sampling_metadata(
input_batch.idx_mapping_np, pos
)
if self.lora_config:
# Activate LoRA adapters.
lora_inputs = self.req_states.make_lora_inputs(
input_batch.req_ids,
input_batch.idx_mapping_np,
input_batch.num_scheduled_tokens,
)
self._set_active_loras(*lora_inputs)
else:
# No actual tokens to run. A dummy run for DP.
num_reqs = min(num_tokens_after_padding, self.max_num_reqs)
input_batch = InputBatch.make_dummy(
num_reqs=num_reqs,
num_tokens=num_tokens_after_padding,
input_buffers=self.input_buffers,
device=self.device,
)
self.prepare_dummy_attn_metadata(input_batch)
sampling_metadata = None
# Run model.
if cudagraph_mode == CUDAGraphMode.FULL:
# Run CUDA graph.
# NOTE(woosuk): Here, we don't need to pass the input tensors,
# because they are already copied to the CUDA graph input buffers.
hidden_states = self.cudagraph_manager.run(
input_batch.num_tokens_after_padding
)
else:
# Run PyTorch model in eager mode.
with set_forward_context(
input_batch.attn_metadata,
self.vllm_config,
num_tokens=input_batch.num_tokens_after_padding,
cudagraph_runtime_mode=cudagraph_mode,
num_tokens_across_dp=num_tokens_across_dp,
):
hidden_states = self.model(
input_ids=input_batch.input_ids,
positions=input_batch.positions,
)
self.execute_model_state = hidden_states, input_batch, sampling_metadata
return None
@torch.inference_mode()
def sample_tokens(
self,
grammar_output: GrammarOutput | None,
) -> AsyncOutput | ModelRunnerOutput:
assert self.execute_model_state is not None
hidden_states, input_batch, sampling_metadata = self.execute_model_state
self.execute_model_state = None # type: ignore
assert sampling_metadata is not None
sampler_output = self.sample(
hidden_states, input_batch, sampling_metadata, grammar_output
)
prompt_logprobs_dict = self.compute_prompt_logprobs(hidden_states, input_batch)
output = self.postprocess(
sampler_output,
prompt_logprobs_dict,
input_batch,
)
return output

View File

@ -0,0 +1,327 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
import triton
import triton.language as tl
from vllm.config.model import LogprobsMode
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
class Sampler:
def __init__(
self,
logprobs_mode: LogprobsMode = "raw_logprobs",
):
if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]:
raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}")
self.logprobs_mode = logprobs_mode
def __call__(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
if sampling_metadata.max_num_logprobs is not None:
if self.logprobs_mode == "processed_logprobs":
sampled, logits = self.sample(
logits, sampling_metadata, return_logits=True
)
else:
assert self.logprobs_mode == "raw_logprobs"
sampled, _ = self.sample(logits, sampling_metadata, return_logits=False)
logprobs_tensors = compute_topk_logprobs(
logits,
sampling_metadata.max_num_logprobs,
sampled,
)
else:
sampled, _ = self.sample(logits, sampling_metadata, return_logits=False)
logprobs_tensors = None
# These are GPU tensors.
sampler_output = SamplerOutput(
# The sampled tokens are expanded to 2D tensor with shape
# [num_requests, 1], where each row represents one generated
# token per request.
sampled_token_ids=sampled.view(-1, 1),
logprobs_tensors=logprobs_tensors,
)
return sampler_output
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
return_logits: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None]:
is_greedy = sampling_metadata.temperature == 0
temp = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
logits = logits / temp.view(-1, 1)
logits = apply_top_k_top_p(
logits, sampling_metadata.top_k, sampling_metadata.top_p
)
sampled = gumbel_sample(
logits,
is_greedy,
sampling_metadata.seeds,
sampling_metadata.pos,
)
return sampled, logits if return_logits else None
@triton.jit
def _gumbel_sample_kernel(
sampled_ptr,
logits_ptr,
logits_stride,
seeds_ptr,
pos_ptr,
is_greedy_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
is_greedy = tl.load(is_greedy_ptr + req_idx)
if is_greedy:
# Greedy sampling. Don't apply gumbel noise.
max_val = float("-inf")
max_idx = 0
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + req_idx * logits_stride + block,
mask=mask,
other=float("-inf"),
)
idx = tl.argmax(logits, axis=0)
value = tl.max(logits, axis=0)
is_greater = value > max_val
max_val = tl.where(is_greater, value, max_val)
max_idx = tl.where(is_greater, i + idx, max_idx)
tl.store(sampled_ptr + req_idx, max_idx)
return
# Random sampling.
# Calculate gumbel seed.
seed = tl.load(seeds_ptr + req_idx)
pos = tl.load(pos_ptr + req_idx)
gumbel_seed = tl.randint(seed, pos)
max_val = float("-inf")
max_idx = 0
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
# Generate gumbel noise.
r = tl.rand(gumbel_seed, block).to(tl.float64)
gumbel_noise = -tl.log(-tl.log(r + 1e-20) + 1e-20)
gumbel_noise = gumbel_noise.to(tl.float32)
# Apply gumbel noise.
logits = tl.load(logits_ptr + req_idx * logits_stride + block, mask=mask)
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
# Argmax to get the sampled token.
idx = tl.argmax(logits, axis=0)
value = tl.max(logits, axis=0)
is_greater = value > max_val
max_val = tl.where(is_greater, value, max_val)
max_idx = tl.where(is_greater, i + idx, max_idx)
tl.store(sampled_ptr + req_idx, max_idx)
def gumbel_sample(
logits: torch.Tensor, # [num_reqs, vocab_size]
is_greedy: torch.Tensor, # [num_reqs]
seed: torch.Tensor, # [num_reqs]
pos: torch.Tensor, # [num_reqs]
) -> torch.Tensor:
num_reqs, vocab_size = logits.shape
# NOTE(woosuk): Use int64 for later indexing.
sampled = torch.empty(
num_reqs,
dtype=torch.int64,
device=logits.device,
)
_gumbel_sample_kernel[(num_reqs,)](
sampled,
logits,
logits.stride(0),
seed,
pos,
is_greedy,
vocab_size,
num_warps=8,
BLOCK_SIZE=16384, # type: ignore
)
return sampled
@triton.jit
def _topk_log_softmax_kernel(
output_ptr,
logits_ptr,
logits_stride,
topk_ids_ptr,
topk,
vocab_size,
BLOCK_SIZE: tl.constexpr,
PADDED_TOPK: tl.constexpr,
):
req_idx = tl.program_id(0)
row_ptr = logits_ptr + req_idx * logits_stride
max_val = float("-inf")
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
max_val = tl.max(tl.maximum(logits, max_val))
max_val = max_val.to(tl.float32) # type: ignore
se = 0.0
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0)
# NOTE(woosuk): Make sure that logits and all following operations use FP32.
logits = logits.to(tl.float32)
e = tl.exp(logits - max_val)
e = tl.where(block < vocab_size, e, 0.0)
se += tl.sum(e)
lse = tl.log(se)
k_offset = tl.arange(0, PADDED_TOPK)
k_mask = k_offset < topk
topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask, other=0)
logits = tl.load(row_ptr + topk_ids, mask=k_mask)
logits = logits.to(tl.float32)
o = logits - max_val - lse
tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask)
@triton.jit
def _ranks_kernel(
output_ptr,
logits_ptr,
logits_stride,
token_ids_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
row_ptr = logits_ptr + req_idx * logits_stride
token_id = tl.load(token_ids_ptr + req_idx)
x = tl.load(row_ptr + token_id)
n = 0
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
n += tl.sum((logits > x).to(tl.int32))
tl.store(output_ptr + req_idx, n)
def compute_token_logprobs(
logits: torch.Tensor,
token_ids: torch.Tensor,
) -> torch.Tensor:
batch_size = logits.shape[0]
vocab_size = logits.shape[1]
token_ids = token_ids.to(torch.int64)
num_logprobs = token_ids.shape[1]
logprobs = torch.empty(
batch_size,
num_logprobs,
dtype=torch.float32,
device=logits.device,
)
_topk_log_softmax_kernel[(batch_size,)](
logprobs,
logits,
logits.stride(0),
token_ids,
num_logprobs,
vocab_size,
BLOCK_SIZE=1024, # type: ignore
PADDED_TOPK=triton.next_power_of_2(num_logprobs),
)
return logprobs
def compute_topk_logprobs(
logits: torch.Tensor,
num_logprobs: int,
sampled_token_ids: torch.Tensor,
) -> LogprobsTensors:
assert num_logprobs >= 0
batch_size, vocab_size = logits.shape
if num_logprobs == 0:
logprob_token_ids = sampled_token_ids.unsqueeze(-1)
else:
topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices
logprob_token_ids = torch.cat(
(sampled_token_ids.unsqueeze(-1), topk_indices), dim=1
)
# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
# logprobs tensor. Instead, we only compute and return the logprobs of
# the topk + 1 tokens.
logprobs = compute_token_logprobs(logits, logprob_token_ids)
token_ranks = torch.empty(
batch_size,
dtype=torch.int64,
device=logits.device,
)
_ranks_kernel[(batch_size,)](
token_ranks,
logits,
logits.stride(0),
sampled_token_ids,
vocab_size,
BLOCK_SIZE=8192, # type: ignore
)
return LogprobsTensors(
logprob_token_ids=logprob_token_ids,
logprobs=logprobs,
selected_token_ranks=token_ranks,
)
def compute_prompt_logprobs(
prompt_token_ids: torch.Tensor,
prompt_hidden_states: torch.Tensor,
logits_fn: Callable[[torch.Tensor], torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
# Since materializing the full prompt logits can take too much memory,
# we compute it in chunks.
CHUNK_SIZE = 1024
logprobs = []
ranks = []
prompt_token_ids = prompt_token_ids.to(torch.int64)
for start_idx in range(0, prompt_token_ids.shape[0], CHUNK_SIZE):
end_idx = start_idx + CHUNK_SIZE
# NOTE(woosuk): logits_fn can be slow because it involves all-gather.
prompt_logits = logits_fn(prompt_hidden_states[start_idx:end_idx])
prompt_logprobs = compute_topk_logprobs(
prompt_logits,
0, # num_logprobs
prompt_token_ids[start_idx:end_idx],
)
logprobs.append(prompt_logprobs.logprobs)
ranks.append(prompt_logprobs.selected_token_ranks)
logprobs = torch.cat(logprobs, dim=0) if len(logprobs) > 1 else logprobs[0]
ranks = torch.cat(ranks, dim=0) if len(ranks) > 1 else ranks[0]
return logprobs, ranks

View File

@ -0,0 +1,265 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
import numpy as np
import torch
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.utils import CpuGpuBuffer
_NP_INT64_MIN = np.iinfo(np.int64).min
_NP_INT64_MAX = np.iinfo(np.int64).max
NO_LORA_ID = 0
@dataclass
class SamplingMetadata:
temperature: torch.Tensor
top_p: torch.Tensor | None
top_k: torch.Tensor | None
seeds: torch.Tensor
pos: torch.Tensor
# None means no logprobs, 0 means sampled token logprobs only
max_num_logprobs: int | None
@classmethod
def make_dummy(
cls,
num_reqs: int,
device: torch.device,
) -> "SamplingMetadata":
assert num_reqs > 0
temperature = torch.zeros(num_reqs, dtype=torch.float32, device=device)
temperature[0] = 0.5
# TODO(woosuk): Use top-p and top-k for dummy sampler.
# Currently, they are disabled because of memory usage.
# top_p = torch.full((num_reqs,), 0.95, dtype=torch.float32, device=device)
# top_k = torch.full((num_reqs,), 20, dtype=torch.int32, device=device)
top_p = None
top_k = None
seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device)
pos = torch.zeros(num_reqs, dtype=torch.int64, device=device)
max_num_logprobs = 20
return cls(
temperature=temperature,
top_p=top_p,
top_k=top_k,
seeds=seeds,
pos=pos,
max_num_logprobs=max_num_logprobs,
)
class RequestState:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
vocab_size: int,
device: torch.device,
pin_memory: bool,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_batched_tokens = max_num_batched_tokens
self.vocab_size = vocab_size
self.device = device
self.pin_memory = pin_memory
self.req_id_to_index: dict[str, int] = {}
self.index_to_req_id: dict[int, str] = {}
self.free_indices = list(range(max_num_reqs))
self.extra_data: dict[str, ExtraData] = {}
self.prompt_len = np.zeros(self.max_num_reqs, dtype=np.int32)
self.prefill_token_ids = np.zeros(
(self.max_num_reqs, self.max_model_len),
dtype=np.int32,
)
self.prefill_len = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
self.num_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
self.num_computed_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
# Last sampled tokens.
self.last_sampled_tokens = torch.zeros(
self.max_num_reqs,
1,
dtype=torch.int64,
device=device,
)
# LoRA.
self.lora_ids = np.zeros(self.max_num_reqs, dtype=np.int32)
self.lora_ids.fill(NO_LORA_ID)
# Sampling parameters.
self.temperature = self._make_param(self.max_num_reqs, torch.float32)
self.top_p = self._make_param(self.max_num_reqs, torch.float32)
self.top_k = self._make_param(self.max_num_reqs, torch.int32)
self.seeds = self._make_param(self.max_num_reqs, torch.int64)
self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32)
# -1 means no logprobs are requested.
self.num_logprobs.fill(-1)
self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
def _make_param(self, size: int, dtype: torch.dtype) -> "Param":
return Param(size, dtype=dtype, device=self.device, pin_memory=self.pin_memory)
def _make_buffer(self, size: int, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(
size, dtype=dtype, device=self.device, pin_memory=self.pin_memory
)
@property
def num_reqs(self) -> int:
return len(self.req_id_to_index)
def add_request(
self,
req_id: str,
prompt_len: int,
prefill_token_ids: list[int],
num_computed_tokens: int,
sampling_params: SamplingParams,
lora_request: LoRARequest | None,
) -> None:
assert len(self.free_indices) > 0, "No free indices"
req_idx = self.free_indices.pop()
self.req_id_to_index[req_id] = req_idx
self.index_to_req_id[req_idx] = req_id
self.extra_data[req_id] = ExtraData(lora_request)
self.prompt_len[req_idx] = prompt_len
prefill_len = len(prefill_token_ids)
assert prefill_len >= prompt_len, (
f"prefill_len {prefill_len} < prompt_len {prompt_len}"
)
self.prefill_len.np[req_idx] = prefill_len
self.prefill_token_ids[req_idx, :prefill_len] = prefill_token_ids
self.num_tokens[req_idx] = prefill_len
self.num_computed_tokens[req_idx] = num_computed_tokens
if lora_request is not None:
self.lora_ids[req_idx] = lora_request.lora_int_id
else:
self.lora_ids[req_idx] = NO_LORA_ID
self.temperature.np[req_idx] = sampling_params.temperature
self.top_p.np[req_idx] = sampling_params.top_p
if 0 < sampling_params.top_k < self.vocab_size:
top_k = sampling_params.top_k
else:
top_k = self.vocab_size
self.top_k.np[req_idx] = top_k
if sampling_params.seed is not None:
seed = sampling_params.seed
else:
seed = np.random.randint(_NP_INT64_MIN, _NP_INT64_MAX)
self.seeds.np[req_idx] = seed
if sampling_params.logprobs is not None:
num_logprobs = sampling_params.logprobs
else:
num_logprobs = -1
self.num_logprobs[req_idx] = num_logprobs
# For now, only support prompt logprobs for the prompt tokens.
needs_prompt_logprobs = sampling_params.prompt_logprobs is not None
self.needs_prompt_logprobs[req_idx] = needs_prompt_logprobs
def remove_request(self, req_id: str) -> None:
self.extra_data.pop(req_id, None)
req_idx = self.req_id_to_index.pop(req_id, None)
if req_idx is None:
# Request not found.
return
self.index_to_req_id.pop(req_idx, None)
self.free_indices.append(req_idx)
def make_sampling_metadata(
self,
idx_mapping: np.ndarray,
pos: torch.Tensor,
) -> SamplingMetadata:
temperature = self.temperature.np[idx_mapping]
temperature = self.temperature.copy_np_to_gpu(temperature)
top_p = self.top_p.np[idx_mapping]
no_top_p = np.all(top_p == 1.0)
top_p = self.top_p.copy_np_to_gpu(top_p) if not no_top_p else None
top_k = self.top_k.np[idx_mapping]
no_top_k = np.all(top_k == self.vocab_size)
top_k = self.top_k.copy_np_to_gpu(top_k) if not no_top_k else None
seeds = self.seeds.np[idx_mapping]
seeds = self.seeds.copy_np_to_gpu(seeds)
num_logprobs = self.num_logprobs[idx_mapping]
max_num_logprobs: int | None = int(np.max(num_logprobs))
if max_num_logprobs == -1:
max_num_logprobs = None
return SamplingMetadata(
temperature=temperature,
top_p=top_p,
top_k=top_k,
seeds=seeds,
pos=pos,
max_num_logprobs=max_num_logprobs,
)
def make_lora_inputs(
self,
req_ids: list[str],
idx_mapping: np.ndarray,
num_scheduled_tokens: np.ndarray,
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
lora_ids = self.lora_ids[idx_mapping]
prompt_lora_mapping = tuple(lora_ids)
token_lora_mapping = tuple(lora_ids.repeat(num_scheduled_tokens))
active_lora_requests: set[LoRARequest] = set()
for req_id in req_ids:
lora_request = self.extra_data[req_id].lora_request
if lora_request is not None:
active_lora_requests.add(lora_request)
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
class Param:
def __init__(
self,
size: int,
dtype: torch.dtype,
device: torch.device,
pin_memory: bool,
):
self.buffer = CpuGpuBuffer(
size,
dtype=dtype,
device=device,
pin_memory=pin_memory,
)
self.np = np.zeros_like(self.buffer.np)
def copy_np_to_gpu(self, x: np.ndarray) -> torch.Tensor:
n = x.shape[0]
self.buffer.np[:n] = x
return self.buffer.copy_to_gpu(n)
@dataclass
class ExtraData:
lora_request: LoRARequest | None
in_progress_prompt_logprobs: list[LogprobsTensors] = field(default_factory=list)

View File

@ -0,0 +1,76 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu.input_batch import InputBuffers
def apply_grammar_bitmask(
logits: torch.Tensor,
req_ids: list[str],
grammar_req_ids: list[str],
grammar_bitmask: np.ndarray,
input_buffers: InputBuffers,
) -> None:
input_buffers.grammar_bitmask.np[: grammar_bitmask.shape[0]] = grammar_bitmask
input_buffers.grammar_bitmask.copy_to_gpu(grammar_bitmask.shape[0])
batch_size = logits.shape[0]
grammar_req_id_to_idx = {req_id: i for i, req_id in enumerate(grammar_req_ids)}
# logits -> bitmask mapping
mapping = [grammar_req_id_to_idx.get(req_id, -1) for req_id in req_ids]
input_buffers.bitmask_indices.np[:batch_size] = mapping
input_buffers.bitmask_indices.copy_to_gpu(batch_size)
vocab_size = logits.shape[-1]
BLOCK_SIZE = 8192
grid = (batch_size, triton.cdiv(vocab_size, BLOCK_SIZE))
_apply_grammar_bitmask_kernel[grid](
logits,
logits.stride(0),
input_buffers.grammar_bitmask.gpu,
input_buffers.grammar_bitmask.gpu.stride(0),
input_buffers.bitmask_indices.gpu,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
)
# Adapted from
# https://github.com/mlc-ai/xgrammar/blob/main/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py
@triton.jit
def _apply_grammar_bitmask_kernel(
logits_ptr,
logits_stride,
bitmask_ptr,
bitmask_stride,
bitmask_indices_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
logits_idx = tl.program_id(0)
bitmask_idx = tl.load(bitmask_indices_ptr + logits_idx)
if bitmask_idx == -1:
# No bitmask to apply.
return
# Load the bitmask.
block_id = tl.program_id(1)
bitmask_offset = (block_id * BLOCK_SIZE) // 32 + tl.arange(0, BLOCK_SIZE // 32)
packed_bitmask = tl.load(
bitmask_ptr + bitmask_idx * bitmask_stride + bitmask_offset,
mask=bitmask_offset < bitmask_stride,
)
# Unpack the bitmask.
bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0
bitmask = bitmask.reshape(BLOCK_SIZE)
# Apply the bitmask to the logits.
block_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
tl.store(
logits_ptr + logits_idx * logits_stride + block_offset,
-float("inf"),
mask=bitmask & (block_offset < vocab_size),
)

View File

@ -41,7 +41,7 @@ from vllm.sequence import IntermediateTensors
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
from vllm.utils.mem_constants import GiB_bytes from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import MemorySnapshot, memory_profiling from vllm.utils.mem_utils import MemorySnapshot, memory_profiling
from vllm.v1.core.sched.output import GrammarOutput from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ( from vllm.v1.outputs import (
@ -58,7 +58,6 @@ logger = init_logger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.v1.core.sched.output import SchedulerOutput
class Worker(WorkerBase): class Worker(WorkerBase):
@ -101,6 +100,8 @@ class Worker(WorkerBase):
else: else:
self.profiler = None self.profiler = None
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
def sleep(self, level: int = 1) -> None: def sleep(self, level: int = 1) -> None:
from vllm.device_allocator.cumem import CuMemAllocator from vllm.device_allocator.cumem import CuMemAllocator
@ -237,9 +238,17 @@ class Worker(WorkerBase):
raise RuntimeError(f"Not support device type: {self.device_config.device}") raise RuntimeError(f"Not support device type: {self.device_config.device}")
# Construct the model runner # Construct the model runner
self.model_runner: GPUModelRunner = GPUModelRunner( if self.use_v2_model_runner:
from vllm.v1.worker.gpu.model_runner import (
GPUModelRunner as GPUModelRunnerV2,
)
# HACK(woosuk): This is a temporary fix to avoid type errors.
self.model_runner: GPUModelRunner = GPUModelRunnerV2( # type: ignore
self.vllm_config, self.device self.vllm_config, self.device
) )
else:
self.model_runner = GPUModelRunner(self.vllm_config, self.device)
if self.rank == 0: if self.rank == 0:
# If usage stat is enabled, collect relevant info. # If usage stat is enabled, collect relevant info.
@ -573,6 +582,11 @@ class Worker(WorkerBase):
self.profiler.stop() self.profiler.stop()
def execute_dummy_batch(self) -> None: def execute_dummy_batch(self) -> None:
if self.use_v2_model_runner:
self.model_runner.execute_model(
SchedulerOutput.make_empty(), dummy_run=True
)
else:
self.model_runner._dummy_run(1, uniform_decode=True) self.model_runner._dummy_run(1, uniform_decode=True)
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool: