mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:15:01 +08:00
GPU Model Runner V2 (#25266)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
1f400c58b8
commit
30b44a1598
3
.github/CODEOWNERS
vendored
3
.github/CODEOWNERS
vendored
@ -35,6 +35,9 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
/vllm/v1/kv_cache_interface.py @heheda12345
|
||||
/vllm/v1/offloading @ApostaC
|
||||
|
||||
# Model runner V2
|
||||
/vllm/v1/worker/gpu @WoosukKwon
|
||||
|
||||
# Test ownership
|
||||
/.buildkite/lm-eval-harness @mgoin
|
||||
/tests/distributed/test_multi_node_assignment.py @youkaichao
|
||||
|
||||
@ -231,6 +231,7 @@ if TYPE_CHECKING:
|
||||
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
|
||||
VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256
|
||||
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
|
||||
VLLM_USE_V2_MODEL_RUNNER: bool = False
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -1522,6 +1523,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices(
|
||||
"VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"]
|
||||
),
|
||||
# Flag to enable v2 model runner.
|
||||
"VLLM_USE_V2_MODEL_RUNNER": lambda: bool(
|
||||
int(os.getenv("VLLM_USE_V2_MODEL_RUNNER", "0"))
|
||||
),
|
||||
}
|
||||
|
||||
# --8<-- [end:env-vars-definition]
|
||||
|
||||
@ -593,6 +593,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
)
|
||||
return self._workspace_buffer
|
||||
|
||||
def set_workspace_buffer(self, workspace_buffer: torch.Tensor):
|
||||
self._workspace_buffer = workspace_buffer
|
||||
|
||||
def _get_prefill_wrapper(
|
||||
self,
|
||||
) -> BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper:
|
||||
|
||||
@ -44,11 +44,15 @@ class NewRequestData:
|
||||
lora_request: LoRARequest | None
|
||||
prompt_embeds: "torch.Tensor | None" = None
|
||||
|
||||
# Only used for v2 model runner.
|
||||
prefill_token_ids: list[int] | None = None
|
||||
|
||||
@classmethod
|
||||
def from_request(
|
||||
cls,
|
||||
request: Request,
|
||||
block_ids: tuple[list[int], ...],
|
||||
prefill_token_ids: list[int] | None = None,
|
||||
) -> "NewRequestData":
|
||||
return cls(
|
||||
req_id=request.request_id,
|
||||
@ -60,6 +64,7 @@ class NewRequestData:
|
||||
num_computed_tokens=request.num_computed_tokens,
|
||||
lora_request=request.lora_request,
|
||||
prompt_embeds=request.prompt_embeds,
|
||||
prefill_token_ids=prefill_token_ids,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@ -68,6 +73,7 @@ class NewRequestData:
|
||||
f"NewRequestData("
|
||||
f"req_id={self.req_id},"
|
||||
f"prompt_token_ids={self.prompt_token_ids},"
|
||||
f"prefill_token_ids={self.prefill_token_ids},"
|
||||
f"mm_features={self.mm_features},"
|
||||
f"sampling_params={self.sampling_params},"
|
||||
f"block_ids={self.block_ids},"
|
||||
@ -183,6 +189,10 @@ class SchedulerOutput:
|
||||
# freed from the encoder cache.
|
||||
free_encoder_mm_hashes: list[str]
|
||||
|
||||
# Request IDs that are preempted in this step.
|
||||
# Only used for v2 model runner.
|
||||
preempted_req_ids: set[str] | None = None
|
||||
|
||||
# Whether the scheduled requests have all the output tokens they
|
||||
# need to perform grammar bitmask computation.
|
||||
pending_structured_output_tokens: bool = False
|
||||
@ -193,6 +203,20 @@ class SchedulerOutput:
|
||||
# EC Cache Connector metadata
|
||||
ec_connector_metadata: ECConnectorMetadata | None = None
|
||||
|
||||
@classmethod
|
||||
def make_empty(cls) -> "SchedulerOutput":
|
||||
return cls(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
||||
num_scheduled_tokens={},
|
||||
total_num_scheduled_tokens=0,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GrammarOutput:
|
||||
|
||||
@ -6,6 +6,7 @@ from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.ec_transfer.ec_connector.base import (
|
||||
ECConnectorMetadata,
|
||||
@ -187,6 +188,7 @@ class Scheduler(SchedulerInterface):
|
||||
pcp_world_size=self.pcp_world_size,
|
||||
)
|
||||
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
|
||||
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
|
||||
|
||||
def schedule(self) -> SchedulerOutput:
|
||||
# NOTE(woosuk) on the scheduling algorithm:
|
||||
@ -658,12 +660,25 @@ class Scheduler(SchedulerInterface):
|
||||
)
|
||||
|
||||
# Construct the scheduler output.
|
||||
new_reqs_data = [
|
||||
NewRequestData.from_request(
|
||||
req, req_to_new_blocks[req.request_id].get_block_ids()
|
||||
)
|
||||
for req in scheduled_new_reqs
|
||||
]
|
||||
if self.use_v2_model_runner:
|
||||
scheduled_new_reqs = scheduled_new_reqs + scheduled_resumed_reqs
|
||||
scheduled_resumed_reqs = []
|
||||
new_reqs_data = [
|
||||
NewRequestData.from_request(
|
||||
req,
|
||||
req_to_new_blocks[req.request_id].get_block_ids(),
|
||||
req._all_token_ids,
|
||||
)
|
||||
for req in scheduled_new_reqs
|
||||
]
|
||||
else:
|
||||
new_reqs_data = [
|
||||
NewRequestData.from_request(
|
||||
req, req_to_new_blocks[req.request_id].get_block_ids()
|
||||
)
|
||||
for req in scheduled_new_reqs
|
||||
]
|
||||
|
||||
with record_function_or_nullcontext("schedule: make_cached_request_data"):
|
||||
cached_reqs_data = self._make_cached_request_data(
|
||||
scheduled_running_reqs,
|
||||
@ -685,6 +700,7 @@ class Scheduler(SchedulerInterface):
|
||||
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
|
||||
scheduled_encoder_inputs=scheduled_encoder_inputs,
|
||||
num_common_prefix_blocks=num_common_prefix_blocks,
|
||||
preempted_req_ids={req.request_id for req in preempted_reqs},
|
||||
# finished_req_ids is an existing state in the scheduler,
|
||||
# instead of being newly scheduled in this step.
|
||||
# It contains the request IDs that are finished in between
|
||||
|
||||
4
vllm/v1/worker/gpu/README.md
Normal file
4
vllm/v1/worker/gpu/README.md
Normal file
@ -0,0 +1,4 @@
|
||||
# [Experimental] Model Runner V2
|
||||
|
||||
This directory contains the new model runner which is under active development.
|
||||
Ping [Woosuk Kwon](https://github.com/WoosukKwon) for any changes.
|
||||
0
vllm/v1/worker/gpu/__init__.py
Normal file
0
vllm/v1/worker/gpu/__init__.py
Normal file
89
vllm/v1/worker/gpu/async_utils.py
Normal file
89
vllm/v1/worker/gpu/async_utils.py
Normal file
@ -0,0 +1,89 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from contextlib import contextmanager
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.v1.outputs import (
|
||||
AsyncModelRunnerOutput,
|
||||
ModelRunnerOutput,
|
||||
SamplerOutput,
|
||||
)
|
||||
|
||||
|
||||
class AsyncOutput(AsyncModelRunnerOutput):
|
||||
def __init__(
|
||||
self,
|
||||
model_runner_output: ModelRunnerOutput,
|
||||
sampler_output: SamplerOutput,
|
||||
num_sampled_tokens: np.ndarray,
|
||||
copy_stream: torch.cuda.Stream,
|
||||
copy_event: torch.cuda.Event,
|
||||
):
|
||||
self.model_runner_output = model_runner_output
|
||||
self.sampler_output = sampler_output
|
||||
self.num_sampled_tokens = num_sampled_tokens
|
||||
self.copy_stream = copy_stream
|
||||
self.copy_event = copy_event
|
||||
|
||||
default_stream = torch.cuda.current_stream()
|
||||
with torch.cuda.stream(self.copy_stream):
|
||||
self.copy_stream.wait_stream(default_stream)
|
||||
|
||||
# NOTE(woosuk): We must ensure that CPU tensors are not freed
|
||||
# before the device-to-host copy is fully completed. For instance,
|
||||
# operations like
|
||||
# self.sampled_token_np = ...to("cpu", non_blocking=True).numpy()
|
||||
# are unsafe because the underlying CPU tensor can be prematurely freed and
|
||||
# reused by other tensors before the asynchronous copy finishes, potentially
|
||||
# causing race conditions. To prevent this, we delay freeing by holding
|
||||
# references until the copy event signals completion.
|
||||
# Likewise, we also need to keep the reference to the GPU tensors.
|
||||
# This is done by keeping the reference to sampler_output and
|
||||
# model_runner_output.
|
||||
self.sampled_token_ids = sampler_output.sampled_token_ids.to(
|
||||
"cpu", non_blocking=True
|
||||
)
|
||||
if sampler_output.logprobs_tensors is not None:
|
||||
self.logprobs_tensors = (
|
||||
sampler_output.logprobs_tensors.to_cpu_nonblocking()
|
||||
)
|
||||
else:
|
||||
self.logprobs_tensors = None
|
||||
self.prompt_logprobs_dict = {}
|
||||
if self.model_runner_output.prompt_logprobs_dict:
|
||||
for k, v in self.model_runner_output.prompt_logprobs_dict.items():
|
||||
self.prompt_logprobs_dict[k] = v.to_cpu_nonblocking()
|
||||
self.copy_event.record(self.copy_stream)
|
||||
|
||||
def get_output(self) -> ModelRunnerOutput:
|
||||
self.copy_event.synchronize()
|
||||
|
||||
# NOTE(woosuk): The following code is to ensure compatibility with
|
||||
# the existing model runner.
|
||||
# Going forward, we should keep the data structures as NumPy arrays
|
||||
# rather than Python lists.
|
||||
sampled_token_ids_np = self.sampled_token_ids.numpy()
|
||||
num_reqs = sampled_token_ids_np.shape[0]
|
||||
sampled_token_ids: list[np.ndarray] = [
|
||||
sampled_token_ids_np[i, : self.num_sampled_tokens[i]]
|
||||
for i in range(num_reqs)
|
||||
]
|
||||
self.model_runner_output.sampled_token_ids = sampled_token_ids
|
||||
|
||||
if self.logprobs_tensors is not None:
|
||||
self.model_runner_output.logprobs = self.logprobs_tensors.tolists()
|
||||
self.model_runner_output.prompt_logprobs_dict = self.prompt_logprobs_dict
|
||||
return self.model_runner_output
|
||||
|
||||
|
||||
@contextmanager
|
||||
def async_barrier(event: torch.cuda.Event | None):
|
||||
if event is not None:
|
||||
event.synchronize()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if event is not None:
|
||||
event.record()
|
||||
187
vllm/v1/worker/gpu/attn_utils.py
Normal file
187
vllm/v1/worker/gpu/attn_utils.py
Normal file
@ -0,0 +1,187 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
KVCacheConfig,
|
||||
KVCacheSpec,
|
||||
)
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.utils import bind_kv_cache
|
||||
|
||||
|
||||
def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
attn_layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase)
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
# Skip modules that don't need KV cache (eg encoder-only attention)
|
||||
if spec := attn_module.get_kv_cache_spec(vllm_config):
|
||||
kv_cache_spec[layer_name] = spec
|
||||
return kv_cache_spec
|
||||
|
||||
|
||||
def init_attn_backend(
|
||||
kv_cache_config: KVCacheConfig,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
attn_backends: dict[str, AttentionBackend] = {}
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder] = []
|
||||
flashinfer_workspace: torch.Tensor | None = None
|
||||
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
|
||||
layer_names = kv_cache_group_spec.layer_names
|
||||
any_layer_name = next(iter(layer_names))
|
||||
|
||||
attn_layers = get_layers_from_vllm_config(
|
||||
vllm_config, AttentionLayerBase, layer_names
|
||||
)
|
||||
attn_backend = attn_layers[any_layer_name].get_attn_backend()
|
||||
for layer_name in layer_names:
|
||||
attn_backends[layer_name] = attn_backend
|
||||
|
||||
attn_metadata_builder = attn_backend.get_builder_cls()(
|
||||
kv_cache_group_spec.kv_cache_spec,
|
||||
layer_names,
|
||||
vllm_config,
|
||||
device,
|
||||
)
|
||||
attn_metadata_builders.append(attn_metadata_builder) # type: ignore
|
||||
|
||||
if "FLASHINFER" in attn_backend.get_name():
|
||||
if flashinfer_workspace is None:
|
||||
flashinfer_workspace = attn_metadata_builder._get_workspace_buffer()
|
||||
else:
|
||||
attn_metadata_builder.set_workspace_buffer(flashinfer_workspace)
|
||||
return attn_backends, attn_metadata_builders
|
||||
|
||||
|
||||
def _allocate_kv_cache(
|
||||
kv_cache_config: KVCacheConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
|
||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||
tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=device)
|
||||
for layer_name in kv_cache_tensor.shared_by:
|
||||
kv_cache_raw_tensors[layer_name] = tensor
|
||||
|
||||
layer_names = set()
|
||||
for group in kv_cache_config.kv_cache_groups:
|
||||
for layer_name in group.layer_names:
|
||||
layer_names.add(layer_name)
|
||||
assert layer_names == set(kv_cache_raw_tensors.keys()), (
|
||||
"Some layers are not correctly initialized"
|
||||
)
|
||||
return kv_cache_raw_tensors
|
||||
|
||||
|
||||
def _reshape_kv_cache(
|
||||
kv_cache_config: KVCacheConfig,
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor],
|
||||
attn_backends: dict[str, AttentionBackend],
|
||||
) -> dict[str, torch.Tensor]:
|
||||
kv_caches: dict[str, torch.Tensor] = {}
|
||||
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
|
||||
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
raw_tensor = kv_cache_raw_tensors[layer_name]
|
||||
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
|
||||
|
||||
attn_backend = attn_backends[layer_name]
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
num_blocks,
|
||||
kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size,
|
||||
)
|
||||
|
||||
# FIXME(woosuk): Add kv_cache_stride_order to all attention backends.
|
||||
try:
|
||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
|
||||
assert len(kv_cache_stride_order) == len(kv_cache_shape)
|
||||
except (AttributeError, NotImplementedError):
|
||||
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
|
||||
|
||||
kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
|
||||
inv_order = [
|
||||
kv_cache_stride_order.index(i)
|
||||
for i in range(len(kv_cache_stride_order))
|
||||
]
|
||||
|
||||
dtype = kv_cache_spec.dtype
|
||||
raw_tensor = raw_tensor.view(dtype)
|
||||
raw_tensor = raw_tensor.view(kv_cache_shape)
|
||||
kv_caches[layer_name] = raw_tensor.permute(*inv_order)
|
||||
return kv_caches
|
||||
|
||||
|
||||
def init_kv_cache(
|
||||
runner_kv_caches: list[torch.Tensor],
|
||||
forward_context: dict[str, Any],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
attn_backends: dict[str, AttentionBackend],
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device)
|
||||
kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors, attn_backends)
|
||||
bind_kv_cache(kv_caches, forward_context, runner_kv_caches)
|
||||
|
||||
|
||||
def build_attn_metadata(
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
query_start_loc: CpuGpuBuffer,
|
||||
seq_lens: CpuGpuBuffer,
|
||||
num_computed_tokens_cpu: torch.Tensor,
|
||||
block_tables: Sequence[torch.Tensor],
|
||||
slot_mappings: torch.Tensor,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> dict[str, Any]:
|
||||
query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
|
||||
query_start_loc_cpu = query_start_loc.cpu[: num_reqs + 1]
|
||||
max_query_len = int(query_start_loc.np[: num_reqs + 1].max())
|
||||
seq_lens_gpu = seq_lens.gpu[:num_reqs]
|
||||
seq_lens_cpu = seq_lens.cpu[:num_reqs]
|
||||
max_seq_len = int(seq_lens.np[:num_reqs].max())
|
||||
|
||||
attn_metadata: dict[str, Any] = {}
|
||||
kv_cache_groups = kv_cache_config.kv_cache_groups
|
||||
for i, kv_cache_spec in enumerate(kv_cache_groups):
|
||||
block_table = block_tables[i]
|
||||
slot_mapping = slot_mappings[i]
|
||||
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc_gpu,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens=seq_lens_gpu,
|
||||
seq_lens_cpu=seq_lens_cpu,
|
||||
max_seq_len=max_seq_len,
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=max_query_len,
|
||||
block_table_tensor=block_table,
|
||||
slot_mapping=slot_mapping,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
attn_metadata_builder = attn_metadata_builders[i]
|
||||
metadata = attn_metadata_builder.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
for layer_name in kv_cache_spec.layer_names:
|
||||
attn_metadata[layer_name] = metadata
|
||||
return attn_metadata
|
||||
315
vllm/v1/worker/gpu/block_table.py
Normal file
315
vllm/v1/worker/gpu/block_table.py
Normal file
@ -0,0 +1,315 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
|
||||
|
||||
class BlockTables:
|
||||
def __init__(
|
||||
self,
|
||||
block_sizes: list[int],
|
||||
max_num_reqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
max_model_len: int,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
):
|
||||
self.block_sizes = block_sizes
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.max_model_len = max_model_len
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
|
||||
self.num_kv_cache_groups = len(self.block_sizes)
|
||||
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
|
||||
self.block_tables: list[torch.Tensor] = []
|
||||
for i in range(self.num_kv_cache_groups):
|
||||
block_size = self.block_sizes[i]
|
||||
max_num_blocks = cdiv(self.max_model_len, block_size)
|
||||
block_table = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
max_num_blocks,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
self.block_tables.append(block_table)
|
||||
self.block_table_ptrs = self._make_ptr_tensor(self.block_tables)
|
||||
|
||||
# Block tables used for model's forward pass.
|
||||
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
|
||||
self.input_block_tables: list[torch.Tensor] = [
|
||||
torch.zeros_like(block_table) for block_table in self.block_tables
|
||||
]
|
||||
self.input_block_table_ptrs = self._make_ptr_tensor(self.input_block_tables)
|
||||
|
||||
self.block_table_strides = torch.tensor(
|
||||
[b.stride(0) for b in self.block_tables],
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
)
|
||||
self.block_sizes_tensor = torch.tensor(
|
||||
self.block_sizes, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self.num_blocks = torch.zeros(
|
||||
self.num_kv_cache_groups,
|
||||
self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
self.slot_mappings = torch.zeros(
|
||||
self.num_kv_cache_groups,
|
||||
self.max_num_batched_tokens,
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Misc buffers.
|
||||
self.req_indices = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
|
||||
self.overwrite = self._make_buffer(self.max_num_reqs, dtype=torch.bool)
|
||||
self.cu_num_new_blocks = self._make_buffer(
|
||||
self.num_kv_cache_groups, self.max_num_reqs + 1, dtype=torch.int32
|
||||
)
|
||||
|
||||
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
|
||||
return CpuGpuBuffer(
|
||||
*args, dtype=dtype, pin_memory=self.pin_memory, device=self.device
|
||||
)
|
||||
|
||||
def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor:
|
||||
# NOTE(woosuk): Use uint64 instead of int64 to cover all possible addresses.
|
||||
ptrs_tensor_cpu = torch.tensor(
|
||||
[t.data_ptr() for t in x],
|
||||
dtype=torch.uint64,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
return ptrs_tensor_cpu.to(self.device, non_blocking=True)
|
||||
|
||||
def append_block_ids(
|
||||
self,
|
||||
# [num_reqs]
|
||||
req_indices: list[int],
|
||||
# [num_kv_cache_groups, num_reqs + 1]
|
||||
cu_num_new_blocks: tuple[list[int], ...],
|
||||
# [num_kv_cache_groups, num_new_blocks]
|
||||
new_block_ids: tuple[list[int], ...],
|
||||
# [num_reqs]
|
||||
overwrite: list[bool],
|
||||
) -> None:
|
||||
num_reqs = len(req_indices)
|
||||
self.req_indices.np[:num_reqs] = req_indices
|
||||
self.overwrite.np[:num_reqs] = overwrite
|
||||
for i in range(self.num_kv_cache_groups):
|
||||
self.cu_num_new_blocks.np[i, : num_reqs + 1] = cu_num_new_blocks[i]
|
||||
|
||||
# NOTE(woosuk): Here, we cannot use a fixed-size buffer because there's
|
||||
# no clear upper bound to the number of new blocks in a single step.
|
||||
# NOTE(woosuk): The buffer has to be cached, because otherwise we cannot
|
||||
# guarantee that the buffer is not freed before the copy is completed.
|
||||
self.new_block_ids_cpu = torch.empty(
|
||||
self.num_kv_cache_groups,
|
||||
max(len(x) for x in new_block_ids),
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
new_block_ids_np = self.new_block_ids_cpu.numpy()
|
||||
for i in range(self.num_kv_cache_groups):
|
||||
new_block_ids_np[i, : len(new_block_ids[i])] = new_block_ids[i]
|
||||
new_block_ids_gpu = self.new_block_ids_cpu.to(self.device, non_blocking=True)
|
||||
|
||||
_append_block_ids_kernel[(self.num_kv_cache_groups, num_reqs)](
|
||||
self.req_indices.copy_to_gpu(num_reqs),
|
||||
self.cu_num_new_blocks.copy_to_gpu(),
|
||||
self.cu_num_new_blocks.gpu.stride(0),
|
||||
new_block_ids_gpu,
|
||||
new_block_ids_gpu.stride(0),
|
||||
self.overwrite.copy_to_gpu(num_reqs),
|
||||
self.block_table_strides,
|
||||
self.block_table_ptrs,
|
||||
self.num_blocks,
|
||||
self.num_blocks.stride(0),
|
||||
BLOCK_SIZE=1024, # type: ignore
|
||||
)
|
||||
|
||||
def gather_block_tables(
|
||||
self,
|
||||
idx_mapping: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
_gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs)](
|
||||
idx_mapping,
|
||||
self.block_table_ptrs,
|
||||
self.input_block_table_ptrs,
|
||||
self.block_table_strides,
|
||||
self.num_blocks,
|
||||
self.num_blocks.stride(0),
|
||||
BLOCK_SIZE=1024, # type: ignore
|
||||
)
|
||||
return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)
|
||||
|
||||
def get_dummy_block_tables(self, num_reqs: int) -> tuple[torch.Tensor, ...]:
|
||||
return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)
|
||||
|
||||
def compute_slot_mappings(
|
||||
self,
|
||||
query_start_loc: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
num_reqs = query_start_loc.shape[0] - 1
|
||||
num_tokens = positions.shape[0]
|
||||
num_groups = self.num_kv_cache_groups
|
||||
_compute_slot_mappings_kernel[(num_groups, num_reqs + 1)](
|
||||
num_tokens,
|
||||
self.max_num_batched_tokens,
|
||||
query_start_loc,
|
||||
positions,
|
||||
self.input_block_table_ptrs,
|
||||
self.block_table_strides,
|
||||
self.block_sizes_tensor,
|
||||
self.slot_mappings,
|
||||
self.slot_mappings.stride(0),
|
||||
PAD_ID=PAD_SLOT_ID,
|
||||
BLOCK_SIZE=1024, # type: ignore
|
||||
)
|
||||
return self.slot_mappings[:, :num_tokens]
|
||||
|
||||
def get_dummy_slot_mappings(self, num_tokens: int) -> torch.Tensor:
|
||||
self.slot_mappings.fill_(PAD_SLOT_ID)
|
||||
return self.slot_mappings[:, :num_tokens]
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _append_block_ids_kernel(
|
||||
# Inputs
|
||||
req_indices, # [num_reqs]
|
||||
cu_num_new_blocks_ptr, # [num_kv_cache_groups, num_reqs + 1]
|
||||
cu_num_new_blocks_stride,
|
||||
new_block_ids_ptr, # [num_kv_cache_groups, num_new_blocks]
|
||||
new_block_ids_stride,
|
||||
overwrite, # [num_reqs]
|
||||
block_table_strides, # [num_kv_cache_groups]
|
||||
# Outputs
|
||||
block_table_ptrs, # [num_kv_cache_groups]
|
||||
num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
|
||||
num_blocks_stride,
|
||||
# Constants
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
group_id = tl.program_id(0)
|
||||
batch_idx = tl.program_id(1)
|
||||
req_idx = tl.load(req_indices + batch_idx)
|
||||
do_overwrite = tl.load(overwrite + batch_idx)
|
||||
|
||||
group_new_blocks_ptr = cu_num_new_blocks_ptr + group_id * cu_num_new_blocks_stride
|
||||
start_idx = tl.load(group_new_blocks_ptr + batch_idx)
|
||||
end_idx = tl.load(group_new_blocks_ptr + batch_idx + 1)
|
||||
num_new_blocks = end_idx - start_idx
|
||||
|
||||
group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
|
||||
dst_start_idx = tl.load(group_num_blocks_ptr + req_idx) if not do_overwrite else 0
|
||||
dst_end_idx = dst_start_idx + num_new_blocks
|
||||
tl.store(group_num_blocks_ptr + req_idx, dst_end_idx)
|
||||
|
||||
# Destination
|
||||
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
|
||||
block_table_stride = tl.load(block_table_strides + group_id)
|
||||
row_ptr = block_table_ptr + req_idx * block_table_stride
|
||||
|
||||
group_new_block_ids_ptr = new_block_ids_ptr + group_id * new_block_ids_stride
|
||||
for i in range(0, num_new_blocks, BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, BLOCK_SIZE)
|
||||
block_ids = tl.load(
|
||||
group_new_block_ids_ptr + start_idx + offset, mask=offset < num_new_blocks
|
||||
)
|
||||
tl.store(
|
||||
row_ptr + dst_start_idx + offset, block_ids, mask=offset < num_new_blocks
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _gather_block_tables_kernel(
|
||||
batch_idx_to_req_idx, # [batch_size]
|
||||
src_block_table_ptrs, # [num_kv_cache_groups]
|
||||
dst_block_table_ptrs, # [num_kv_cache_groups]
|
||||
block_table_strides, # [num_kv_cache_groups]
|
||||
num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
|
||||
num_blocks_stride,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# kv cache group id
|
||||
group_id = tl.program_id(0)
|
||||
batch_idx = tl.program_id(1)
|
||||
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
|
||||
|
||||
group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
|
||||
num_blocks = tl.load(group_num_blocks_ptr + req_idx)
|
||||
|
||||
stride = tl.load(block_table_strides + group_id)
|
||||
src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32)
|
||||
src_row_ptr = src_block_table_ptr + req_idx * stride
|
||||
dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32)
|
||||
dst_row_ptr = dst_block_table_ptr + batch_idx * stride
|
||||
|
||||
for i in tl.range(0, num_blocks, BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, BLOCK_SIZE)
|
||||
block_ids = tl.load(src_row_ptr + offset, mask=offset < num_blocks)
|
||||
tl.store(dst_row_ptr + offset, block_ids, mask=offset < num_blocks)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _compute_slot_mappings_kernel(
|
||||
num_tokens,
|
||||
max_num_tokens,
|
||||
cu_num_tokens, # [num_reqs + 1]
|
||||
pos, # [num_tokens]
|
||||
block_table_ptrs, # [num_kv_cache_groups]
|
||||
block_table_strides, # [num_kv_cache_groups]
|
||||
page_sizes, # [num_kv_cache_groups]
|
||||
slot_mappings_ptr, # [num_kv_cache_groups, max_num_tokens]
|
||||
slot_mappings_stride,
|
||||
PAD_ID: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# kv cache group id
|
||||
group_id = tl.program_id(0)
|
||||
req_idx = tl.program_id(1)
|
||||
slot_mapping_ptr = slot_mappings_ptr + group_id * slot_mappings_stride
|
||||
|
||||
if req_idx == tl.num_programs(1) - 1:
|
||||
# Pad remaining slots to -1. This is needed for CUDA graphs.
|
||||
for i in range(num_tokens, max_num_tokens, BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, BLOCK_SIZE)
|
||||
tl.store(slot_mapping_ptr + offset, PAD_ID, mask=offset < max_num_tokens)
|
||||
return
|
||||
|
||||
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
|
||||
block_table_stride = tl.load(block_table_strides + group_id)
|
||||
page_size = tl.load(page_sizes + group_id)
|
||||
|
||||
start_idx = tl.load(cu_num_tokens + req_idx)
|
||||
end_idx = tl.load(cu_num_tokens + req_idx + 1)
|
||||
for i in range(start_idx, end_idx, BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, BLOCK_SIZE)
|
||||
positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
|
||||
block_indices = positions // page_size
|
||||
block_numbers = tl.load(
|
||||
block_table_ptr + req_idx * block_table_stride + block_indices
|
||||
)
|
||||
slot_ids = block_numbers * page_size + positions % page_size
|
||||
tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _load_ptr(ptr_to_ptr, elem_dtype):
|
||||
ptr = tl.load(ptr_to_ptr)
|
||||
ptr = tl.cast(ptr, tl.pointer_type(elem_dtype))
|
||||
return tl.multiple_of(ptr, 16)
|
||||
198
vllm/v1/worker/gpu/cudagraph_utils.py
Normal file
198
vllm/v1/worker/gpu/cudagraph_utils.py
Normal file
@ -0,0 +1,198 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import gc
|
||||
from contextlib import contextmanager
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.input_batch import InputBuffers
|
||||
|
||||
|
||||
class CudaGraphManager:
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.device = device
|
||||
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
assert self.compilation_config is not None
|
||||
|
||||
self.cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||
self.cudagraph_sizes = sorted(self.compilation_config.cudagraph_capture_sizes)
|
||||
self.padded_sizes = self._init_padded_sizes()
|
||||
|
||||
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
|
||||
self.pool = torch.cuda.graph_pool_handle()
|
||||
self.hidden_states: torch.Tensor | None = None
|
||||
|
||||
def _init_padded_sizes(self) -> dict[int, int]:
|
||||
if not self.cudagraph_mode.has_full_cudagraphs():
|
||||
# Full cuda graphs are not used.
|
||||
return {}
|
||||
|
||||
padded_sizes: dict[int, int] = {}
|
||||
assert len(self.cudagraph_sizes) > 0
|
||||
for i in range(1, self.cudagraph_sizes[-1] + 1):
|
||||
for x in self.cudagraph_sizes:
|
||||
if i <= x:
|
||||
padded_sizes[i] = x
|
||||
break
|
||||
return padded_sizes
|
||||
|
||||
def needs_capture(self) -> bool:
|
||||
return len(self.padded_sizes) > 0
|
||||
|
||||
def get_cudagraph_size(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
num_tokens_after_padding: int,
|
||||
) -> int | None:
|
||||
if not self.cudagraph_mode.has_full_cudagraphs():
|
||||
return None
|
||||
if self.cudagraph_mode != CUDAGraphMode.FULL:
|
||||
# TODO(woosuk): Support uniform decode with multiple tokens (spec decoding).
|
||||
all_decode = all(
|
||||
x == 1 for x in scheduler_output.num_scheduled_tokens.values()
|
||||
)
|
||||
if not all_decode:
|
||||
# Prefill is included.
|
||||
return None
|
||||
return self.padded_sizes.get(num_tokens_after_padding)
|
||||
|
||||
def capture_graph(
|
||||
self,
|
||||
batch_size: int,
|
||||
model: nn.Module,
|
||||
input_buffers: InputBuffers,
|
||||
block_tables: BlockTables,
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> None:
|
||||
assert batch_size not in self.graphs
|
||||
|
||||
# Prepare dummy inputs.
|
||||
input_ids = input_buffers.input_ids.gpu[:batch_size]
|
||||
positions = input_buffers.positions.gpu[:batch_size]
|
||||
|
||||
input_buffers.query_start_loc.np[: batch_size + 1] = np.arange(batch_size + 1)
|
||||
input_buffers.query_start_loc.np[batch_size:] = batch_size
|
||||
input_buffers.query_start_loc.copy_to_gpu()
|
||||
input_buffers.seq_lens.np[:batch_size] = self.max_model_len
|
||||
input_buffers.seq_lens.np[batch_size:] = 0
|
||||
input_buffers.seq_lens.copy_to_gpu()
|
||||
|
||||
input_block_tables = [x[:batch_size] for x in block_tables.input_block_tables]
|
||||
slot_mappings = block_tables.slot_mappings[:, :batch_size]
|
||||
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_metadata_builders=attn_metadata_builders,
|
||||
num_reqs=batch_size,
|
||||
num_tokens=batch_size,
|
||||
query_start_loc=input_buffers.query_start_loc,
|
||||
seq_lens=input_buffers.seq_lens,
|
||||
num_computed_tokens_cpu=None, # FIXME
|
||||
block_tables=input_block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
if self.dp_size > 1:
|
||||
num_tokens_across_dp = torch.full(
|
||||
(self.dp_size,),
|
||||
batch_size,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
else:
|
||||
num_tokens_across_dp = None
|
||||
|
||||
# Warm up.
|
||||
with set_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=batch_size,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
):
|
||||
hidden_states = model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
)
|
||||
if self.hidden_states is None:
|
||||
self.hidden_states = torch.empty_like(hidden_states)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Capture the graph.
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with (
|
||||
set_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=batch_size,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
),
|
||||
torch.cuda.graph(graph, self.pool),
|
||||
):
|
||||
hidden_states = model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
)
|
||||
self.hidden_states[:batch_size] = hidden_states
|
||||
self.graphs[batch_size] = graph
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture(
|
||||
self,
|
||||
model: nn.Module,
|
||||
input_buffers: InputBuffers,
|
||||
block_tables: BlockTables,
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> None:
|
||||
assert self.needs_capture()
|
||||
# Capture larger graphs first.
|
||||
sizes_to_capture = sorted(self.cudagraph_sizes, reverse=True)
|
||||
if is_global_first_rank():
|
||||
sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs")
|
||||
|
||||
with freeze_gc(), graph_capture(device=self.device):
|
||||
for batch_size in sizes_to_capture:
|
||||
self.capture_graph(
|
||||
batch_size,
|
||||
model,
|
||||
input_buffers,
|
||||
block_tables,
|
||||
attn_metadata_builders,
|
||||
kv_cache_config,
|
||||
)
|
||||
|
||||
def run(self, batch_size: int) -> torch.Tensor:
|
||||
assert batch_size in self.graphs
|
||||
self.graphs[batch_size].replay()
|
||||
assert self.hidden_states is not None
|
||||
return self.hidden_states[:batch_size]
|
||||
|
||||
|
||||
@contextmanager
|
||||
def freeze_gc():
|
||||
gc.collect()
|
||||
gc.freeze()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
gc.unfreeze()
|
||||
22
vllm/v1/worker/gpu/dp_utils.py
Normal file
22
vllm/v1/worker/gpu/dp_utils.py
Normal file
@ -0,0 +1,22 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
|
||||
|
||||
def get_batch_metadata_across_dp(
|
||||
num_tokens: int,
|
||||
cudagraph_size: int,
|
||||
dp_size: int,
|
||||
dp_rank: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert dp_size > 1
|
||||
# Use CPU group to avoid CPU-GPU synchronization.
|
||||
group = get_dp_group().cpu_group
|
||||
tensor = torch.zeros(2, dp_size, dtype=torch.int32, device="cpu")
|
||||
tensor[0][dp_rank] = num_tokens
|
||||
tensor[1][dp_rank] = cudagraph_size
|
||||
dist.all_reduce(tensor, group=group)
|
||||
return tensor[0], tensor[1]
|
||||
265
vllm/v1/worker/gpu/input_batch.py
Normal file
265
vllm/v1/worker/gpu/input_batch.py
Normal file
@ -0,0 +1,265 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numba
|
||||
import numba.types as types
|
||||
import numpy as np
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
|
||||
|
||||
class InputBuffers:
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_num_tokens: int,
|
||||
hidden_size: int,
|
||||
vocab_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
|
||||
self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32)
|
||||
self.input_ids = self._make_buffer(max_num_tokens, dtype=torch.int32)
|
||||
self.positions = self._make_buffer(max_num_tokens, dtype=torch.int64)
|
||||
self.query_start_loc = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
|
||||
self.seq_lens = self._make_buffer(max_num_reqs, dtype=torch.int32)
|
||||
|
||||
# Structured outputs.
|
||||
self.bitmask_indices = self._make_buffer(max_num_reqs, dtype=torch.int32)
|
||||
self.grammar_bitmask = self._make_buffer(
|
||||
max_num_reqs, cdiv(vocab_size, 32), dtype=torch.int32
|
||||
)
|
||||
|
||||
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
|
||||
return CpuGpuBuffer(
|
||||
*args, dtype=dtype, pin_memory=self.pin_memory, device=self.device
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputBatch:
|
||||
# batch_idx -> req_id
|
||||
req_ids: list[str]
|
||||
num_reqs: int
|
||||
|
||||
# batch_idx -> req_state_idx
|
||||
idx_mapping: torch.Tensor
|
||||
idx_mapping_np: np.ndarray
|
||||
|
||||
# [num_reqs]
|
||||
# batch_idx -> num_scheduled_tokens
|
||||
num_scheduled_tokens: np.ndarray
|
||||
# sum(num_scheduled_tokens)
|
||||
num_tokens: int
|
||||
num_tokens_after_padding: int
|
||||
|
||||
# [num_reqs + 1]
|
||||
query_start_loc: torch.Tensor
|
||||
query_start_loc_np: np.ndarray
|
||||
# [num_reqs]
|
||||
seq_lens: torch.Tensor
|
||||
seq_lens_np: np.ndarray
|
||||
|
||||
# [num_tokens_after_padding]
|
||||
input_ids: torch.Tensor
|
||||
# [num_tokens_after_padding]
|
||||
positions: torch.Tensor
|
||||
|
||||
# layer_name -> Metadata
|
||||
attn_metadata: dict[str, Any]
|
||||
|
||||
# [num_reqs]
|
||||
logits_indices: torch.Tensor
|
||||
|
||||
@classmethod
|
||||
def make_dummy(
|
||||
cls,
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
input_buffers: InputBuffers,
|
||||
device: torch.device,
|
||||
) -> "InputBatch":
|
||||
assert 0 < num_reqs <= num_tokens
|
||||
req_ids = [f"req_{i}_{random_uuid()}" for i in range(num_reqs)]
|
||||
idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
|
||||
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
|
||||
num_scheduled_tokens = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32)
|
||||
num_scheduled_tokens[-1] += num_tokens % num_reqs
|
||||
assert int(num_scheduled_tokens.sum()) == num_tokens
|
||||
|
||||
input_buffers.query_start_loc.np[0] = 0
|
||||
input_buffers.query_start_loc.np[1 : num_reqs + 1] = np.cumsum(
|
||||
num_scheduled_tokens
|
||||
)
|
||||
input_buffers.query_start_loc.np[num_reqs + 1 :] = num_tokens
|
||||
query_start_loc_np = input_buffers.query_start_loc.np[: num_reqs + 1]
|
||||
query_start_loc = input_buffers.query_start_loc.copy_to_gpu()[: num_reqs + 1]
|
||||
# seq_len equals to query_len
|
||||
input_buffers.seq_lens.np[:num_reqs] = num_scheduled_tokens
|
||||
input_buffers.seq_lens.np[num_reqs:] = 0
|
||||
seq_lens_np = input_buffers.seq_lens.np[:num_reqs]
|
||||
seq_lens = input_buffers.seq_lens.copy_to_gpu()[:num_reqs]
|
||||
|
||||
input_ids = input_buffers.input_ids.copy_to_gpu(num_tokens)
|
||||
positions = input_buffers.positions.copy_to_gpu(num_tokens)
|
||||
# attn_metadata = defaultdict(lambda: None)
|
||||
logits_indices = query_start_loc[1:] - 1
|
||||
return cls(
|
||||
req_ids=req_ids,
|
||||
num_reqs=num_reqs,
|
||||
idx_mapping=idx_mapping,
|
||||
idx_mapping_np=idx_mapping_np,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_after_padding=num_tokens,
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_np=query_start_loc_np,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_np=seq_lens_np,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
attn_metadata=None, # type: ignore
|
||||
logits_indices=logits_indices,
|
||||
)
|
||||
|
||||
|
||||
# NOTE: With the type annotations, this function is pre-compiled
|
||||
# before the first call.
|
||||
@numba.jit(
|
||||
[
|
||||
types.none(
|
||||
types.int32[:], # idx_mapping
|
||||
types.int32[:, :], # token_ids
|
||||
types.int32[:], # num_computed_tokens
|
||||
types.int32[:], # num_scheduled_tokens
|
||||
types.int32[:], # input_ids
|
||||
types.int64[:], # positions
|
||||
types.int32[:], # query_start_loc
|
||||
types.int32[:], # seq_lens
|
||||
)
|
||||
],
|
||||
nopython=True,
|
||||
cache=True,
|
||||
)
|
||||
def _prepare_inputs(
|
||||
idx_mapping: np.ndarray, # batch_idx -> req_idx
|
||||
token_ids: np.ndarray, # [N, max_model_len]
|
||||
num_computed_tokens: np.ndarray, # [N]
|
||||
num_scheduled_tokens: np.ndarray, # [B]
|
||||
input_ids: np.ndarray, # [num_input_tokens]
|
||||
positions: np.ndarray, # [num_input_tokens]
|
||||
query_start_loc: np.ndarray, # [B + 1]
|
||||
seq_lens: np.ndarray, # [B]
|
||||
) -> None:
|
||||
num_reqs = num_scheduled_tokens.shape[0]
|
||||
query_start_loc[0] = 0
|
||||
|
||||
cu_num_tokens = 0
|
||||
for i in range(num_reqs):
|
||||
req_idx = idx_mapping[i]
|
||||
query_len = num_scheduled_tokens[i]
|
||||
start = num_computed_tokens[req_idx]
|
||||
end = start + query_len
|
||||
seq_lens[i] = end
|
||||
|
||||
start_idx = cu_num_tokens
|
||||
end_idx = start_idx + query_len
|
||||
input_ids[start_idx:end_idx] = token_ids[req_idx, start:end]
|
||||
positions[start_idx:end_idx] = np.arange(start, end, dtype=np.int64)
|
||||
|
||||
cu_num_tokens = end_idx
|
||||
query_start_loc[i + 1] = cu_num_tokens
|
||||
|
||||
# Pad the inputs for CUDA graphs.
|
||||
# Note: pad query_start_loc to be non-decreasing, as kernels
|
||||
# like FlashAttention requires that
|
||||
query_start_loc[num_reqs + 1 :].fill(cu_num_tokens)
|
||||
# Fill unused with 0 for full cuda graph mode.
|
||||
seq_lens[num_reqs:].fill(0)
|
||||
|
||||
|
||||
def prepare_inputs(
|
||||
idx_mapping: np.ndarray,
|
||||
prefill_token_ids: np.ndarray,
|
||||
num_computed_tokens: np.ndarray,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
input_ids: CpuGpuBuffer,
|
||||
positions: CpuGpuBuffer,
|
||||
query_start_loc: CpuGpuBuffer,
|
||||
seq_lens: CpuGpuBuffer,
|
||||
num_tokens: int,
|
||||
) -> None:
|
||||
_prepare_inputs(
|
||||
idx_mapping,
|
||||
prefill_token_ids,
|
||||
num_computed_tokens,
|
||||
num_scheduled_tokens,
|
||||
input_ids.np,
|
||||
positions.np,
|
||||
query_start_loc.np,
|
||||
seq_lens.np,
|
||||
)
|
||||
input_ids.copy_to_gpu(num_tokens)
|
||||
positions.copy_to_gpu(num_tokens)
|
||||
# NOTE(woosuk): We should copy the whole query_start_loc and seq_lens
|
||||
# tensors from CPU to GPU, because they may include paddings needed
|
||||
# for full CUDA graph mode.
|
||||
query_start_loc.copy_to_gpu()
|
||||
seq_lens.copy_to_gpu()
|
||||
return
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _combine_last_token_ids_kernel(
|
||||
input_ids_ptr,
|
||||
idx_mapping_ptr,
|
||||
last_token_ids_ptr,
|
||||
query_start_loc_ptr,
|
||||
seq_lens_ptr,
|
||||
prefill_len_ptr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
|
||||
seq_len = tl.load(seq_lens_ptr + batch_idx)
|
||||
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
|
||||
if seq_len <= prefill_len:
|
||||
# Handling prefill tokens.
|
||||
return
|
||||
|
||||
last_token_id = tl.load(last_token_ids_ptr + req_state_idx)
|
||||
end = tl.load(query_start_loc_ptr + batch_idx + 1)
|
||||
tl.store(input_ids_ptr + end - 1, last_token_id)
|
||||
|
||||
|
||||
def combine_last_token_ids(
|
||||
input_ids: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
last_token_ids: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
prefill_len: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
num_reqs = seq_lens.shape[0]
|
||||
_combine_last_token_ids_kernel[(num_reqs,)](
|
||||
input_ids,
|
||||
idx_mapping,
|
||||
last_token_ids,
|
||||
query_start_loc,
|
||||
seq_lens,
|
||||
prefill_len,
|
||||
)
|
||||
return input_ids
|
||||
814
vllm/v1/worker/gpu/model_runner.py
Normal file
814
vllm/v1/worker/gpu/model_runner.py
Normal file
@ -0,0 +1,814 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import gc
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.utils.mem_constants import GiB_bytes
|
||||
from vllm.utils.mem_utils import DeviceMemoryProfiler
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.outputs import (
|
||||
EMPTY_MODEL_RUNNER_OUTPUT,
|
||||
LogprobsTensors,
|
||||
ModelRunnerOutput,
|
||||
)
|
||||
from vllm.v1.sample.sampler import SamplerOutput
|
||||
from vllm.v1.worker.gpu.async_utils import AsyncOutput, async_barrier
|
||||
from vllm.v1.worker.gpu.attn_utils import (
|
||||
build_attn_metadata,
|
||||
get_kv_cache_spec,
|
||||
init_attn_backend,
|
||||
init_kv_cache,
|
||||
)
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
|
||||
from vllm.v1.worker.gpu.dp_utils import get_batch_metadata_across_dp
|
||||
from vllm.v1.worker.gpu.input_batch import (
|
||||
InputBatch,
|
||||
InputBuffers,
|
||||
combine_last_token_ids,
|
||||
prepare_inputs,
|
||||
)
|
||||
from vllm.v1.worker.gpu.sampler import Sampler, compute_prompt_logprobs
|
||||
from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata
|
||||
from vllm.v1.worker.gpu.structured_outputs import apply_grammar_bitmask
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.load_config = vllm_config.load_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
|
||||
self.device = device
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
self.dtype = self.model_config.dtype
|
||||
self.kv_cache_dtype = self.dtype
|
||||
if self.cache_config.cache_dtype != "auto":
|
||||
# Quantized KV cache.
|
||||
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||
self.cache_config.cache_dtype
|
||||
]
|
||||
self.is_pooling_model = False
|
||||
|
||||
self.vocab_size = self.model_config.get_vocab_size()
|
||||
self.max_model_len = self.model_config.max_model_len
|
||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
self.hidden_size = self.model_config.get_hidden_size()
|
||||
|
||||
self.dp_size = self.parallel_config.data_parallel_size
|
||||
self.dp_rank = self.parallel_config.data_parallel_rank
|
||||
|
||||
self.use_async_scheduling = self.scheduler_config.async_scheduling
|
||||
self.output_copy_stream = torch.cuda.Stream(self.device)
|
||||
self.output_copy_event = torch.cuda.Event()
|
||||
if self.use_async_scheduling:
|
||||
self.input_prep_event = torch.cuda.Event()
|
||||
self.structured_outputs_event = torch.cuda.Event()
|
||||
else:
|
||||
self.input_prep_event = None
|
||||
self.structured_outputs_event = None
|
||||
|
||||
self.req_states = RequestState(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.max_model_len,
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
vocab_size=self.vocab_size,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
self.input_buffers = InputBuffers(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
hidden_size=self.hidden_size,
|
||||
vocab_size=self.vocab_size,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
|
||||
|
||||
# CUDA graphs.
|
||||
self.cudagraph_manager = CudaGraphManager(
|
||||
vllm_config=self.vllm_config,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def get_supported_tasks(self) -> tuple[str]:
|
||||
return ("generate",)
|
||||
|
||||
def load_model(self, *args, **kwargs) -> None:
|
||||
time_before_load = time.perf_counter()
|
||||
with DeviceMemoryProfiler() as m:
|
||||
model_loader = get_model_loader(self.vllm_config.load_config)
|
||||
logger.info("Loading model from scratch...")
|
||||
|
||||
self.model = model_loader.load_model(
|
||||
vllm_config=self.vllm_config,
|
||||
model_config=self.vllm_config.model_config,
|
||||
)
|
||||
if self.lora_config:
|
||||
self.model = self.load_lora_model(
|
||||
self.model,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
)
|
||||
time_after_load = time.perf_counter()
|
||||
|
||||
self.model_memory_usage = m.consumed_memory
|
||||
logger.info(
|
||||
"Model loading took %.4f GiB and %.6f seconds",
|
||||
m.consumed_memory / GiB_bytes,
|
||||
time_after_load - time_before_load,
|
||||
)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
def get_kv_cache_spec(self):
|
||||
return get_kv_cache_spec(self.vllm_config)
|
||||
|
||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
kv_cache_config = deepcopy(kv_cache_config)
|
||||
self.kv_cache_config = kv_cache_config
|
||||
block_sizes = [
|
||||
kv_cache_group.kv_cache_spec.block_size
|
||||
for kv_cache_group in kv_cache_config.kv_cache_groups
|
||||
]
|
||||
|
||||
self.block_tables = BlockTables(
|
||||
block_sizes=block_sizes,
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
max_model_len=self.max_model_len,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
|
||||
self.attn_backends, self.attn_metadata_builders = init_attn_backend(
|
||||
self.kv_cache_config,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
)
|
||||
|
||||
self.kv_caches: list[torch.Tensor] = []
|
||||
init_kv_cache(
|
||||
self.kv_caches,
|
||||
self.compilation_config.static_forward_context,
|
||||
self.kv_cache_config,
|
||||
self.attn_backends,
|
||||
self.device,
|
||||
)
|
||||
# Attention groups are not supported.
|
||||
self.attn_groups = [] # type: ignore
|
||||
|
||||
def prepare_dummy_attn_metadata(self, input_batch: InputBatch) -> None:
|
||||
block_tables = self.block_tables.get_dummy_block_tables(input_batch.num_reqs)
|
||||
slot_mappings = self.block_tables.get_dummy_slot_mappings(
|
||||
input_batch.num_tokens
|
||||
)
|
||||
num_computed_tokens_cpu = torch.zeros(
|
||||
input_batch.num_reqs, dtype=torch.int32, device="cpu"
|
||||
)
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_metadata_builders=self.attn_metadata_builders,
|
||||
num_reqs=input_batch.num_reqs,
|
||||
num_tokens=input_batch.num_tokens,
|
||||
query_start_loc=self.input_buffers.query_start_loc,
|
||||
seq_lens=self.input_buffers.seq_lens,
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
)
|
||||
input_batch.attn_metadata = attn_metadata
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_run(
|
||||
self,
|
||||
num_tokens: int,
|
||||
*args,
|
||||
skip_attn: bool = True,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
input_batch = InputBatch.make_dummy(
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens,
|
||||
input_buffers=self.input_buffers,
|
||||
device=self.device,
|
||||
)
|
||||
if not skip_attn:
|
||||
self.prepare_dummy_attn_metadata(input_batch)
|
||||
|
||||
if self.dp_size == 1:
|
||||
num_tokens_across_dp: torch.Tensor | None = None
|
||||
else:
|
||||
num_tokens_across_dp = torch.full(
|
||||
(self.dp_size,), num_tokens, dtype=torch.int32, device="cpu"
|
||||
)
|
||||
num_sampled_tokens = np.ones(input_batch.num_reqs, dtype=np.int32)
|
||||
with (
|
||||
self.maybe_dummy_run_with_lora(
|
||||
self.lora_config,
|
||||
input_batch.num_scheduled_tokens,
|
||||
num_sampled_tokens,
|
||||
),
|
||||
set_forward_context(
|
||||
input_batch.attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
),
|
||||
):
|
||||
hidden_states = self.model(
|
||||
input_ids=input_batch.input_ids,
|
||||
positions=input_batch.positions,
|
||||
)
|
||||
sample_hidden_states = hidden_states[input_batch.logits_indices]
|
||||
return hidden_states, sample_hidden_states
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_sampler_run(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs = hidden_states.shape[0]
|
||||
sampling_metadata = SamplingMetadata.make_dummy(
|
||||
num_reqs=num_reqs,
|
||||
device=self.device,
|
||||
)
|
||||
logits = self.model.compute_logits(hidden_states)
|
||||
self.sampler(logits, sampling_metadata)
|
||||
|
||||
@torch.inference_mode()
|
||||
def profile_run(self) -> None:
|
||||
hidden_states, sample_hidden_states = self._dummy_run(
|
||||
self.max_num_tokens,
|
||||
skip_attn=True,
|
||||
)
|
||||
self._dummy_sampler_run(sample_hidden_states)
|
||||
torch.cuda.synchronize()
|
||||
del hidden_states, sample_hidden_states
|
||||
gc.collect()
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
pass
|
||||
|
||||
def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int:
|
||||
# SP is not supported yet.
|
||||
return num_scheduled_tokens
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_model(self) -> int:
|
||||
if not self.cudagraph_manager.needs_capture():
|
||||
logger.warning(
|
||||
"Skipping CUDA graph capture. To turn on CUDA graph capture, "
|
||||
"ensure `cudagraph_mode` was not manually set to `NONE`"
|
||||
)
|
||||
return 0
|
||||
|
||||
start_time = time.perf_counter()
|
||||
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
|
||||
with self.maybe_setup_dummy_loras(self.lora_config):
|
||||
self.cudagraph_manager.capture(
|
||||
model=self.model,
|
||||
input_buffers=self.input_buffers,
|
||||
block_tables=self.block_tables,
|
||||
attn_metadata_builders=self.attn_metadata_builders,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
elapsed_time = end_time - start_time
|
||||
cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
|
||||
# This usually takes 5~20 seconds.
|
||||
logger.info(
|
||||
"Graph capturing finished in %.0f secs, took %.2f GiB",
|
||||
elapsed_time,
|
||||
cuda_graph_size / (1 << 30),
|
||||
)
|
||||
return cuda_graph_size
|
||||
|
||||
def warmup_for_prefill(self) -> None:
|
||||
# For FlashInfer, we would like to execute a dummy prefill run
|
||||
# to trigger JIT compilation.
|
||||
if all("FLASHINFER" in b.get_name() for b in self.attn_backends.values()):
|
||||
self._dummy_run(self.max_num_tokens, skip_attn=False)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def update_states(self, scheduler_output: SchedulerOutput) -> None:
|
||||
for req_id in scheduler_output.preempted_req_ids:
|
||||
self.req_states.remove_request(req_id)
|
||||
for req_id in scheduler_output.finished_req_ids:
|
||||
self.req_states.remove_request(req_id)
|
||||
|
||||
# TODO(woosuk): Change SchedulerOutput.
|
||||
req_indices: list[int] = []
|
||||
cu_num_new_blocks = tuple(
|
||||
[0] for _ in range(self.block_tables.num_kv_cache_groups)
|
||||
)
|
||||
new_block_ids: tuple[list[int], ...] = tuple(
|
||||
[] for _ in range(self.block_tables.num_kv_cache_groups)
|
||||
)
|
||||
overwrite: list[bool] = []
|
||||
|
||||
# Add new requests.
|
||||
for new_req_data in scheduler_output.scheduled_new_reqs:
|
||||
req_id = new_req_data.req_id
|
||||
self.req_states.add_request(
|
||||
req_id=req_id,
|
||||
prompt_len=len(new_req_data.prompt_token_ids),
|
||||
prefill_token_ids=new_req_data.prefill_token_ids,
|
||||
num_computed_tokens=new_req_data.num_computed_tokens,
|
||||
sampling_params=new_req_data.sampling_params,
|
||||
lora_request=new_req_data.lora_request,
|
||||
)
|
||||
|
||||
req_index = self.req_states.req_id_to_index[req_id]
|
||||
req_indices.append(req_index)
|
||||
for i, block_ids in enumerate(new_req_data.block_ids):
|
||||
x = cu_num_new_blocks[i][-1]
|
||||
cu_num_new_blocks[i].append(x + len(block_ids))
|
||||
new_block_ids[i].extend(block_ids)
|
||||
overwrite.append(True)
|
||||
|
||||
# Add new blocks for the existing requests.
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for i, req_id in enumerate(cached_reqs.req_ids):
|
||||
req_index = self.req_states.req_id_to_index[req_id]
|
||||
|
||||
req_new_block_ids = cached_reqs.new_block_ids[i]
|
||||
if req_new_block_ids is not None:
|
||||
req_indices.append(req_index)
|
||||
for group_id, block_ids in enumerate(req_new_block_ids):
|
||||
x = cu_num_new_blocks[group_id][-1]
|
||||
cu_num_new_blocks[group_id].append(x + len(block_ids))
|
||||
new_block_ids[group_id].extend(block_ids)
|
||||
overwrite.append(False)
|
||||
|
||||
if req_indices:
|
||||
self.block_tables.append_block_ids(
|
||||
req_indices=req_indices,
|
||||
cu_num_new_blocks=cu_num_new_blocks,
|
||||
new_block_ids=new_block_ids,
|
||||
overwrite=overwrite,
|
||||
)
|
||||
|
||||
def prepare_inputs(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
num_tokens_after_padding: int,
|
||||
) -> InputBatch:
|
||||
num_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert num_tokens > 0
|
||||
num_reqs = len(scheduler_output.num_scheduled_tokens)
|
||||
|
||||
# Decode first, then prefill.
|
||||
# batch_idx -> req_id
|
||||
req_ids = sorted(
|
||||
scheduler_output.num_scheduled_tokens,
|
||||
key=scheduler_output.num_scheduled_tokens.get,
|
||||
)
|
||||
num_scheduled_tokens = np.array(
|
||||
[scheduler_output.num_scheduled_tokens[i] for i in req_ids], dtype=np.int32
|
||||
)
|
||||
|
||||
idx_mapping_list = [
|
||||
self.req_states.req_id_to_index[req_id] for req_id in req_ids
|
||||
]
|
||||
idx_mapping = self.input_buffers.idx_mapping
|
||||
idx_mapping.np[:num_reqs] = idx_mapping_list
|
||||
idx_mapping_np = idx_mapping.np[:num_reqs]
|
||||
idx_mapping = idx_mapping.copy_to_gpu(num_reqs)
|
||||
|
||||
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
|
||||
block_tables = self.block_tables.gather_block_tables(idx_mapping)
|
||||
|
||||
prepare_inputs(
|
||||
idx_mapping_np,
|
||||
self.req_states.prefill_token_ids,
|
||||
self.req_states.num_computed_tokens,
|
||||
num_scheduled_tokens,
|
||||
self.input_buffers.input_ids,
|
||||
self.input_buffers.positions,
|
||||
self.input_buffers.query_start_loc,
|
||||
self.input_buffers.seq_lens,
|
||||
num_tokens,
|
||||
)
|
||||
|
||||
query_start_loc = self.input_buffers.query_start_loc
|
||||
query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
|
||||
query_start_loc_np = query_start_loc.np[: num_reqs + 1]
|
||||
seq_lens_gpu = self.input_buffers.seq_lens.gpu[:num_reqs]
|
||||
seq_lens_np = self.input_buffers.seq_lens.np[:num_reqs]
|
||||
|
||||
# Some input token ids are directly read from the last sampled tokens.
|
||||
combine_last_token_ids(
|
||||
self.input_buffers.input_ids.gpu,
|
||||
idx_mapping,
|
||||
self.req_states.last_sampled_tokens,
|
||||
query_start_loc_gpu,
|
||||
seq_lens_gpu,
|
||||
self.req_states.prefill_len.copy_to_gpu(),
|
||||
)
|
||||
|
||||
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
query_start_loc_gpu, self.input_buffers.positions.gpu[:num_tokens]
|
||||
)
|
||||
|
||||
num_computed_tokens_cpu = torch.from_numpy(
|
||||
self.req_states.num_computed_tokens[idx_mapping_np]
|
||||
)
|
||||
|
||||
# Logits indices to sample next token from.
|
||||
logits_indices = query_start_loc_gpu[1:] - 1
|
||||
|
||||
# Layer name -> attention metadata.
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_metadata_builders=self.attn_metadata_builders,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens,
|
||||
query_start_loc=self.input_buffers.query_start_loc,
|
||||
seq_lens=self.input_buffers.seq_lens,
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
)
|
||||
|
||||
input_ids = self.input_buffers.input_ids.gpu[:num_tokens_after_padding]
|
||||
positions = self.input_buffers.positions.gpu[:num_tokens_after_padding]
|
||||
return InputBatch(
|
||||
req_ids=req_ids,
|
||||
num_reqs=num_reqs,
|
||||
idx_mapping=idx_mapping,
|
||||
idx_mapping_np=idx_mapping_np,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_after_padding=num_tokens_after_padding,
|
||||
query_start_loc=query_start_loc_gpu,
|
||||
query_start_loc_np=query_start_loc_np,
|
||||
seq_lens=seq_lens_gpu,
|
||||
seq_lens_np=seq_lens_np,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
attn_metadata=attn_metadata,
|
||||
logits_indices=logits_indices,
|
||||
)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_batch: InputBatch,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
grammar_output: GrammarOutput | None,
|
||||
) -> SamplerOutput:
|
||||
sample_hidden_states = hidden_states[input_batch.logits_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
if grammar_output is not None:
|
||||
# Apply grammar bitmask to the logits in-place.
|
||||
with async_barrier(self.structured_outputs_event):
|
||||
apply_grammar_bitmask(
|
||||
logits,
|
||||
input_batch.req_ids,
|
||||
grammar_output.structured_output_request_ids,
|
||||
grammar_output.grammar_bitmask,
|
||||
self.input_buffers,
|
||||
)
|
||||
sampler_output = self.sampler(logits, sampling_metadata)
|
||||
return sampler_output
|
||||
|
||||
def compute_prompt_logprobs(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_batch: InputBatch,
|
||||
) -> dict[str, LogprobsTensors]:
|
||||
idx_mapping_np = input_batch.idx_mapping_np
|
||||
needs_prompt_logprobs = self.req_states.needs_prompt_logprobs[idx_mapping_np]
|
||||
if not np.any(needs_prompt_logprobs):
|
||||
# No request asks for prompt logprobs.
|
||||
return {}
|
||||
|
||||
num_computed_tokens = self.req_states.num_computed_tokens[idx_mapping_np]
|
||||
prompt_lens = self.req_states.prompt_len[idx_mapping_np]
|
||||
# NOTE(woosuk): -1 because the last prompt token's hidden state is not
|
||||
# needed for prompt logprobs.
|
||||
includes_prompt = num_computed_tokens < prompt_lens - 1
|
||||
# NOTE(woosuk): If the request was resumed after preemption, its prompt
|
||||
# logprobs must have been computed before preemption. Skip.
|
||||
resumed_after_prompt = (
|
||||
prompt_lens < self.req_states.prefill_len.np[idx_mapping_np]
|
||||
)
|
||||
needs_prompt_logprobs &= includes_prompt & ~resumed_after_prompt
|
||||
if not np.any(needs_prompt_logprobs):
|
||||
return {}
|
||||
|
||||
# Just to be safe, clone the input ids.
|
||||
n = input_batch.num_tokens
|
||||
# Shift the input ids by one.
|
||||
token_ids = torch.empty_like(input_batch.input_ids[:n])
|
||||
token_ids[: n - 1] = input_batch.input_ids[1:n]
|
||||
# To avoid out-of-bound access, set the last token id to 0.
|
||||
token_ids[n - 1] = 0
|
||||
|
||||
# Handle chunked prompts.
|
||||
seq_lens = self.input_buffers.seq_lens.np[: input_batch.num_reqs]
|
||||
is_prompt_chunked = seq_lens < prompt_lens
|
||||
prefill_token_ids = self.req_states.prefill_token_ids
|
||||
query_start_loc = self.input_buffers.query_start_loc.np
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
if not needs_prompt_logprobs[i]:
|
||||
continue
|
||||
if not is_prompt_chunked[i]:
|
||||
continue
|
||||
# The prompt is chunked. Get the next prompt token.
|
||||
req_idx = input_batch.idx_mapping_np[i]
|
||||
next_prompt_token = int(prefill_token_ids[req_idx, seq_lens[i]])
|
||||
idx = int(query_start_loc[i + 1] - 1)
|
||||
# Set the next prompt token.
|
||||
# NOTE(woosuk): This triggers a GPU operation.
|
||||
token_ids[idx] = next_prompt_token
|
||||
|
||||
# NOTE(woosuk): We mask out logprobs for negative tokens.
|
||||
prompt_logprobs, prompt_ranks = compute_prompt_logprobs(
|
||||
token_ids,
|
||||
hidden_states[:n],
|
||||
self.model.compute_logits,
|
||||
)
|
||||
|
||||
prompt_token_ids = token_ids.unsqueeze(-1)
|
||||
prompt_logprobs_dict: dict[str, LogprobsTensors] = {}
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
if not needs_prompt_logprobs[i]:
|
||||
continue
|
||||
|
||||
start_idx = query_start_loc[i]
|
||||
end_idx = query_start_loc[i + 1]
|
||||
assert start_idx < end_idx, (
|
||||
f"start_idx ({start_idx}) >= end_idx ({end_idx})"
|
||||
)
|
||||
logprobs = LogprobsTensors(
|
||||
logprob_token_ids=prompt_token_ids[start_idx:end_idx],
|
||||
logprobs=prompt_logprobs[start_idx:end_idx],
|
||||
selected_token_ranks=prompt_ranks[start_idx:end_idx],
|
||||
)
|
||||
|
||||
req_extra_data = self.req_states.extra_data[req_id]
|
||||
prompt_logprobs_list = req_extra_data.in_progress_prompt_logprobs
|
||||
if is_prompt_chunked[i]:
|
||||
# Prompt is chunked. Do not return the logprobs yet.
|
||||
prompt_logprobs_list.append(logprobs)
|
||||
continue
|
||||
|
||||
if prompt_logprobs_list:
|
||||
# Merge the in-progress logprobs.
|
||||
prompt_logprobs_list.append(logprobs)
|
||||
logprobs = LogprobsTensors(
|
||||
logprob_token_ids=torch.cat(
|
||||
[x.logprob_token_ids for x in prompt_logprobs_list]
|
||||
),
|
||||
logprobs=torch.cat([x.logprobs for x in prompt_logprobs_list]),
|
||||
selected_token_ranks=torch.cat(
|
||||
[x.selected_token_ranks for x in prompt_logprobs_list]
|
||||
),
|
||||
)
|
||||
prompt_logprobs_list.clear()
|
||||
|
||||
prompt_logprobs_dict[req_id] = logprobs
|
||||
return prompt_logprobs_dict
|
||||
|
||||
def postprocess(
|
||||
self,
|
||||
sampler_output: SamplerOutput,
|
||||
prompt_logprobs_dict: dict[str, LogprobsTensors],
|
||||
input_batch: InputBatch,
|
||||
) -> AsyncOutput | ModelRunnerOutput:
|
||||
# Store the last sampled token ids.
|
||||
self.req_states.last_sampled_tokens[input_batch.idx_mapping] = (
|
||||
sampler_output.sampled_token_ids
|
||||
)
|
||||
# Get the number of sampled tokens.
|
||||
# 0 if chunked-prefilling, 1 if not.
|
||||
idx_mapping_np = input_batch.idx_mapping_np
|
||||
is_chunked_prefilling = (
|
||||
input_batch.seq_lens_np < self.req_states.num_tokens[idx_mapping_np]
|
||||
)
|
||||
num_sampled_tokens = (~is_chunked_prefilling).astype(np.int32)
|
||||
# Increment the number of tokens.
|
||||
self.req_states.num_tokens[idx_mapping_np] += num_sampled_tokens
|
||||
# Increment the number of computed tokens.
|
||||
self.req_states.num_computed_tokens[idx_mapping_np] += (
|
||||
input_batch.num_scheduled_tokens
|
||||
)
|
||||
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=input_batch.req_ids,
|
||||
req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
|
||||
sampled_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
pooler_output=[],
|
||||
kv_connector_output=None,
|
||||
num_nans_in_logits=None,
|
||||
)
|
||||
async_output = AsyncOutput(
|
||||
model_runner_output=model_runner_output,
|
||||
sampler_output=sampler_output,
|
||||
num_sampled_tokens=num_sampled_tokens,
|
||||
copy_stream=self.output_copy_stream,
|
||||
copy_event=self.output_copy_event,
|
||||
)
|
||||
if self.use_async_scheduling:
|
||||
return async_output
|
||||
return async_output.get_output()
|
||||
|
||||
def get_cudagraph_and_dp_padding(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> tuple[CUDAGraphMode, int, torch.Tensor | None]:
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
if self.dp_size == 1:
|
||||
# No DP. Only consider CUDA graphs.
|
||||
if total_num_scheduled_tokens == 0:
|
||||
# Special case: no tokens to run.
|
||||
return CUDAGraphMode.NONE, 0, None
|
||||
|
||||
cudagraph_size = self.cudagraph_manager.get_cudagraph_size(
|
||||
scheduler_output, total_num_scheduled_tokens
|
||||
)
|
||||
if cudagraph_size is not None:
|
||||
# Use full CUDA graph.
|
||||
return CUDAGraphMode.FULL, cudagraph_size, None
|
||||
# Fall back to eager mode.
|
||||
# TODO(woosuk): Support piecewise CUDA graphs.
|
||||
return CUDAGraphMode.NONE, total_num_scheduled_tokens, None
|
||||
|
||||
# Consider DP padding and CUDA graph.
|
||||
if total_num_scheduled_tokens == 0:
|
||||
# Special handling is needed for 0.
|
||||
cudagraph_size_before_dp: int | None = 0
|
||||
else:
|
||||
cudagraph_size_before_dp = self.cudagraph_manager.get_cudagraph_size(
|
||||
scheduler_output, total_num_scheduled_tokens
|
||||
)
|
||||
if cudagraph_size_before_dp is None:
|
||||
cudagraph_size_before_dp = -1
|
||||
|
||||
assert cudagraph_size_before_dp is not None
|
||||
num_tokens_across_dp, cudagraph_size_across_dp = get_batch_metadata_across_dp(
|
||||
total_num_scheduled_tokens,
|
||||
cudagraph_size_before_dp,
|
||||
self.dp_size,
|
||||
self.dp_rank,
|
||||
)
|
||||
if all(cudagraph_size_across_dp >= 0):
|
||||
# If all ranks can use CUDA graph, pad to the maximum number of tokens
|
||||
# across DP and use CUDA graph.
|
||||
num_tokens_after_padding = int(cudagraph_size_across_dp.max().item())
|
||||
cudagraph_mode = CUDAGraphMode.FULL
|
||||
else:
|
||||
# If any of the ranks cannot use CUDA graph, use eager mode for all ranks.
|
||||
# No padding is needed except for ranks that have no tokens to run.
|
||||
num_tokens_across_dp = torch.clamp(num_tokens_across_dp, min=1)
|
||||
num_tokens_after_padding = num_tokens_across_dp[self.dp_rank]
|
||||
cudagraph_mode = CUDAGraphMode.NONE
|
||||
return cudagraph_mode, num_tokens_after_padding, num_tokens_across_dp
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
intermediate_tensors: Any | None = None,
|
||||
dummy_run: bool = False,
|
||||
) -> ModelRunnerOutput | None:
|
||||
assert intermediate_tensors is None
|
||||
if scheduler_output.total_num_scheduled_tokens == 0 and not dummy_run:
|
||||
# No need to run the model.
|
||||
with async_barrier(self.input_prep_event):
|
||||
self.update_states(scheduler_output)
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
# NOTE: Call this before the async barrier so CPU all-reduce and
|
||||
# GPU execution can overlap.
|
||||
cudagraph_mode, num_tokens_after_padding, num_tokens_across_dp = (
|
||||
self.get_cudagraph_and_dp_padding(scheduler_output)
|
||||
)
|
||||
with async_barrier(self.input_prep_event):
|
||||
self.update_states(scheduler_output)
|
||||
if num_tokens_after_padding == 0:
|
||||
# All DP ranks have zero tokens to run.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
if not dummy_run:
|
||||
# Common case.
|
||||
# Prepare all the inputs and copy to the input buffers.
|
||||
input_batch = self.prepare_inputs(
|
||||
scheduler_output,
|
||||
num_tokens_after_padding,
|
||||
)
|
||||
|
||||
# NOTE(woosuk): Sampling metadata should be built under the async
|
||||
# barrier to avoid race conditions.
|
||||
pos = input_batch.positions[input_batch.logits_indices]
|
||||
sampling_metadata = self.req_states.make_sampling_metadata(
|
||||
input_batch.idx_mapping_np, pos
|
||||
)
|
||||
|
||||
if self.lora_config:
|
||||
# Activate LoRA adapters.
|
||||
lora_inputs = self.req_states.make_lora_inputs(
|
||||
input_batch.req_ids,
|
||||
input_batch.idx_mapping_np,
|
||||
input_batch.num_scheduled_tokens,
|
||||
)
|
||||
self._set_active_loras(*lora_inputs)
|
||||
else:
|
||||
# No actual tokens to run. A dummy run for DP.
|
||||
num_reqs = min(num_tokens_after_padding, self.max_num_reqs)
|
||||
input_batch = InputBatch.make_dummy(
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens_after_padding,
|
||||
input_buffers=self.input_buffers,
|
||||
device=self.device,
|
||||
)
|
||||
self.prepare_dummy_attn_metadata(input_batch)
|
||||
sampling_metadata = None
|
||||
|
||||
# Run model.
|
||||
if cudagraph_mode == CUDAGraphMode.FULL:
|
||||
# Run CUDA graph.
|
||||
# NOTE(woosuk): Here, we don't need to pass the input tensors,
|
||||
# because they are already copied to the CUDA graph input buffers.
|
||||
hidden_states = self.cudagraph_manager.run(
|
||||
input_batch.num_tokens_after_padding
|
||||
)
|
||||
else:
|
||||
# Run PyTorch model in eager mode.
|
||||
with set_forward_context(
|
||||
input_batch.attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=input_batch.num_tokens_after_padding,
|
||||
cudagraph_runtime_mode=cudagraph_mode,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
):
|
||||
hidden_states = self.model(
|
||||
input_ids=input_batch.input_ids,
|
||||
positions=input_batch.positions,
|
||||
)
|
||||
|
||||
self.execute_model_state = hidden_states, input_batch, sampling_metadata
|
||||
return None
|
||||
|
||||
@torch.inference_mode()
|
||||
def sample_tokens(
|
||||
self,
|
||||
grammar_output: GrammarOutput | None,
|
||||
) -> AsyncOutput | ModelRunnerOutput:
|
||||
assert self.execute_model_state is not None
|
||||
hidden_states, input_batch, sampling_metadata = self.execute_model_state
|
||||
self.execute_model_state = None # type: ignore
|
||||
assert sampling_metadata is not None
|
||||
|
||||
sampler_output = self.sample(
|
||||
hidden_states, input_batch, sampling_metadata, grammar_output
|
||||
)
|
||||
prompt_logprobs_dict = self.compute_prompt_logprobs(hidden_states, input_batch)
|
||||
output = self.postprocess(
|
||||
sampler_output,
|
||||
prompt_logprobs_dict,
|
||||
input_batch,
|
||||
)
|
||||
return output
|
||||
327
vllm/v1/worker/gpu/sampler.py
Normal file
327
vllm/v1/worker/gpu/sampler.py
Normal file
@ -0,0 +1,327 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.config.model import LogprobsMode
|
||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
|
||||
|
||||
class Sampler:
|
||||
def __init__(
|
||||
self,
|
||||
logprobs_mode: LogprobsMode = "raw_logprobs",
|
||||
):
|
||||
if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]:
|
||||
raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}")
|
||||
self.logprobs_mode = logprobs_mode
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
if sampling_metadata.max_num_logprobs is not None:
|
||||
if self.logprobs_mode == "processed_logprobs":
|
||||
sampled, logits = self.sample(
|
||||
logits, sampling_metadata, return_logits=True
|
||||
)
|
||||
else:
|
||||
assert self.logprobs_mode == "raw_logprobs"
|
||||
sampled, _ = self.sample(logits, sampling_metadata, return_logits=False)
|
||||
|
||||
logprobs_tensors = compute_topk_logprobs(
|
||||
logits,
|
||||
sampling_metadata.max_num_logprobs,
|
||||
sampled,
|
||||
)
|
||||
else:
|
||||
sampled, _ = self.sample(logits, sampling_metadata, return_logits=False)
|
||||
logprobs_tensors = None
|
||||
|
||||
# These are GPU tensors.
|
||||
sampler_output = SamplerOutput(
|
||||
# The sampled tokens are expanded to 2D tensor with shape
|
||||
# [num_requests, 1], where each row represents one generated
|
||||
# token per request.
|
||||
sampled_token_ids=sampled.view(-1, 1),
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
)
|
||||
return sampler_output
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
return_logits: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
is_greedy = sampling_metadata.temperature == 0
|
||||
temp = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
|
||||
logits = logits / temp.view(-1, 1)
|
||||
logits = apply_top_k_top_p(
|
||||
logits, sampling_metadata.top_k, sampling_metadata.top_p
|
||||
)
|
||||
|
||||
sampled = gumbel_sample(
|
||||
logits,
|
||||
is_greedy,
|
||||
sampling_metadata.seeds,
|
||||
sampling_metadata.pos,
|
||||
)
|
||||
return sampled, logits if return_logits else None
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _gumbel_sample_kernel(
|
||||
sampled_ptr,
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
seeds_ptr,
|
||||
pos_ptr,
|
||||
is_greedy_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
is_greedy = tl.load(is_greedy_ptr + req_idx)
|
||||
|
||||
if is_greedy:
|
||||
# Greedy sampling. Don't apply gumbel noise.
|
||||
max_val = float("-inf")
|
||||
max_idx = 0
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(
|
||||
logits_ptr + req_idx * logits_stride + block,
|
||||
mask=mask,
|
||||
other=float("-inf"),
|
||||
)
|
||||
|
||||
idx = tl.argmax(logits, axis=0)
|
||||
value = tl.max(logits, axis=0)
|
||||
is_greater = value > max_val
|
||||
max_val = tl.where(is_greater, value, max_val)
|
||||
max_idx = tl.where(is_greater, i + idx, max_idx)
|
||||
tl.store(sampled_ptr + req_idx, max_idx)
|
||||
return
|
||||
|
||||
# Random sampling.
|
||||
# Calculate gumbel seed.
|
||||
seed = tl.load(seeds_ptr + req_idx)
|
||||
pos = tl.load(pos_ptr + req_idx)
|
||||
gumbel_seed = tl.randint(seed, pos)
|
||||
|
||||
max_val = float("-inf")
|
||||
max_idx = 0
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
|
||||
# Generate gumbel noise.
|
||||
r = tl.rand(gumbel_seed, block).to(tl.float64)
|
||||
gumbel_noise = -tl.log(-tl.log(r + 1e-20) + 1e-20)
|
||||
gumbel_noise = gumbel_noise.to(tl.float32)
|
||||
|
||||
# Apply gumbel noise.
|
||||
logits = tl.load(logits_ptr + req_idx * logits_stride + block, mask=mask)
|
||||
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
|
||||
|
||||
# Argmax to get the sampled token.
|
||||
idx = tl.argmax(logits, axis=0)
|
||||
value = tl.max(logits, axis=0)
|
||||
is_greater = value > max_val
|
||||
max_val = tl.where(is_greater, value, max_val)
|
||||
max_idx = tl.where(is_greater, i + idx, max_idx)
|
||||
tl.store(sampled_ptr + req_idx, max_idx)
|
||||
|
||||
|
||||
def gumbel_sample(
|
||||
logits: torch.Tensor, # [num_reqs, vocab_size]
|
||||
is_greedy: torch.Tensor, # [num_reqs]
|
||||
seed: torch.Tensor, # [num_reqs]
|
||||
pos: torch.Tensor, # [num_reqs]
|
||||
) -> torch.Tensor:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
# NOTE(woosuk): Use int64 for later indexing.
|
||||
sampled = torch.empty(
|
||||
num_reqs,
|
||||
dtype=torch.int64,
|
||||
device=logits.device,
|
||||
)
|
||||
_gumbel_sample_kernel[(num_reqs,)](
|
||||
sampled,
|
||||
logits,
|
||||
logits.stride(0),
|
||||
seed,
|
||||
pos,
|
||||
is_greedy,
|
||||
vocab_size,
|
||||
num_warps=8,
|
||||
BLOCK_SIZE=16384, # type: ignore
|
||||
)
|
||||
return sampled
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _topk_log_softmax_kernel(
|
||||
output_ptr,
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
topk_ids_ptr,
|
||||
topk,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
PADDED_TOPK: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
row_ptr = logits_ptr + req_idx * logits_stride
|
||||
|
||||
max_val = float("-inf")
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
|
||||
max_val = tl.max(tl.maximum(logits, max_val))
|
||||
max_val = max_val.to(tl.float32) # type: ignore
|
||||
|
||||
se = 0.0
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0)
|
||||
# NOTE(woosuk): Make sure that logits and all following operations use FP32.
|
||||
logits = logits.to(tl.float32)
|
||||
e = tl.exp(logits - max_val)
|
||||
e = tl.where(block < vocab_size, e, 0.0)
|
||||
se += tl.sum(e)
|
||||
lse = tl.log(se)
|
||||
|
||||
k_offset = tl.arange(0, PADDED_TOPK)
|
||||
k_mask = k_offset < topk
|
||||
topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask, other=0)
|
||||
|
||||
logits = tl.load(row_ptr + topk_ids, mask=k_mask)
|
||||
logits = logits.to(tl.float32)
|
||||
o = logits - max_val - lse
|
||||
tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _ranks_kernel(
|
||||
output_ptr,
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
token_ids_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
row_ptr = logits_ptr + req_idx * logits_stride
|
||||
|
||||
token_id = tl.load(token_ids_ptr + req_idx)
|
||||
x = tl.load(row_ptr + token_id)
|
||||
|
||||
n = 0
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
|
||||
n += tl.sum((logits > x).to(tl.int32))
|
||||
tl.store(output_ptr + req_idx, n)
|
||||
|
||||
|
||||
def compute_token_logprobs(
|
||||
logits: torch.Tensor,
|
||||
token_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
batch_size = logits.shape[0]
|
||||
vocab_size = logits.shape[1]
|
||||
token_ids = token_ids.to(torch.int64)
|
||||
num_logprobs = token_ids.shape[1]
|
||||
logprobs = torch.empty(
|
||||
batch_size,
|
||||
num_logprobs,
|
||||
dtype=torch.float32,
|
||||
device=logits.device,
|
||||
)
|
||||
_topk_log_softmax_kernel[(batch_size,)](
|
||||
logprobs,
|
||||
logits,
|
||||
logits.stride(0),
|
||||
token_ids,
|
||||
num_logprobs,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=1024, # type: ignore
|
||||
PADDED_TOPK=triton.next_power_of_2(num_logprobs),
|
||||
)
|
||||
return logprobs
|
||||
|
||||
|
||||
def compute_topk_logprobs(
|
||||
logits: torch.Tensor,
|
||||
num_logprobs: int,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
) -> LogprobsTensors:
|
||||
assert num_logprobs >= 0
|
||||
batch_size, vocab_size = logits.shape
|
||||
if num_logprobs == 0:
|
||||
logprob_token_ids = sampled_token_ids.unsqueeze(-1)
|
||||
else:
|
||||
topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices
|
||||
logprob_token_ids = torch.cat(
|
||||
(sampled_token_ids.unsqueeze(-1), topk_indices), dim=1
|
||||
)
|
||||
|
||||
# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
|
||||
# logprobs tensor. Instead, we only compute and return the logprobs of
|
||||
# the topk + 1 tokens.
|
||||
logprobs = compute_token_logprobs(logits, logprob_token_ids)
|
||||
token_ranks = torch.empty(
|
||||
batch_size,
|
||||
dtype=torch.int64,
|
||||
device=logits.device,
|
||||
)
|
||||
_ranks_kernel[(batch_size,)](
|
||||
token_ranks,
|
||||
logits,
|
||||
logits.stride(0),
|
||||
sampled_token_ids,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=8192, # type: ignore
|
||||
)
|
||||
return LogprobsTensors(
|
||||
logprob_token_ids=logprob_token_ids,
|
||||
logprobs=logprobs,
|
||||
selected_token_ranks=token_ranks,
|
||||
)
|
||||
|
||||
|
||||
def compute_prompt_logprobs(
|
||||
prompt_token_ids: torch.Tensor,
|
||||
prompt_hidden_states: torch.Tensor,
|
||||
logits_fn: Callable[[torch.Tensor], torch.Tensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Since materializing the full prompt logits can take too much memory,
|
||||
# we compute it in chunks.
|
||||
CHUNK_SIZE = 1024
|
||||
logprobs = []
|
||||
ranks = []
|
||||
prompt_token_ids = prompt_token_ids.to(torch.int64)
|
||||
for start_idx in range(0, prompt_token_ids.shape[0], CHUNK_SIZE):
|
||||
end_idx = start_idx + CHUNK_SIZE
|
||||
# NOTE(woosuk): logits_fn can be slow because it involves all-gather.
|
||||
prompt_logits = logits_fn(prompt_hidden_states[start_idx:end_idx])
|
||||
prompt_logprobs = compute_topk_logprobs(
|
||||
prompt_logits,
|
||||
0, # num_logprobs
|
||||
prompt_token_ids[start_idx:end_idx],
|
||||
)
|
||||
logprobs.append(prompt_logprobs.logprobs)
|
||||
ranks.append(prompt_logprobs.selected_token_ranks)
|
||||
|
||||
logprobs = torch.cat(logprobs, dim=0) if len(logprobs) > 1 else logprobs[0]
|
||||
ranks = torch.cat(ranks, dim=0) if len(ranks) > 1 else ranks[0]
|
||||
return logprobs, ranks
|
||||
265
vllm/v1/worker/gpu/states.py
Normal file
265
vllm/v1/worker/gpu/states.py
Normal file
@ -0,0 +1,265 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
|
||||
_NP_INT64_MIN = np.iinfo(np.int64).min
|
||||
_NP_INT64_MAX = np.iinfo(np.int64).max
|
||||
NO_LORA_ID = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplingMetadata:
|
||||
temperature: torch.Tensor
|
||||
|
||||
top_p: torch.Tensor | None
|
||||
top_k: torch.Tensor | None
|
||||
|
||||
seeds: torch.Tensor
|
||||
pos: torch.Tensor
|
||||
|
||||
# None means no logprobs, 0 means sampled token logprobs only
|
||||
max_num_logprobs: int | None
|
||||
|
||||
@classmethod
|
||||
def make_dummy(
|
||||
cls,
|
||||
num_reqs: int,
|
||||
device: torch.device,
|
||||
) -> "SamplingMetadata":
|
||||
assert num_reqs > 0
|
||||
temperature = torch.zeros(num_reqs, dtype=torch.float32, device=device)
|
||||
temperature[0] = 0.5
|
||||
# TODO(woosuk): Use top-p and top-k for dummy sampler.
|
||||
# Currently, they are disabled because of memory usage.
|
||||
# top_p = torch.full((num_reqs,), 0.95, dtype=torch.float32, device=device)
|
||||
# top_k = torch.full((num_reqs,), 20, dtype=torch.int32, device=device)
|
||||
top_p = None
|
||||
top_k = None
|
||||
seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device)
|
||||
pos = torch.zeros(num_reqs, dtype=torch.int64, device=device)
|
||||
max_num_logprobs = 20
|
||||
|
||||
return cls(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
seeds=seeds,
|
||||
pos=pos,
|
||||
max_num_logprobs=max_num_logprobs,
|
||||
)
|
||||
|
||||
|
||||
class RequestState:
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_batched_tokens: int,
|
||||
vocab_size: int,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.vocab_size = vocab_size
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
|
||||
self.req_id_to_index: dict[str, int] = {}
|
||||
self.index_to_req_id: dict[int, str] = {}
|
||||
self.free_indices = list(range(max_num_reqs))
|
||||
self.extra_data: dict[str, ExtraData] = {}
|
||||
|
||||
self.prompt_len = np.zeros(self.max_num_reqs, dtype=np.int32)
|
||||
self.prefill_token_ids = np.zeros(
|
||||
(self.max_num_reqs, self.max_model_len),
|
||||
dtype=np.int32,
|
||||
)
|
||||
self.prefill_len = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
|
||||
self.num_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
|
||||
self.num_computed_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
|
||||
|
||||
# Last sampled tokens.
|
||||
self.last_sampled_tokens = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
1,
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# LoRA.
|
||||
self.lora_ids = np.zeros(self.max_num_reqs, dtype=np.int32)
|
||||
self.lora_ids.fill(NO_LORA_ID)
|
||||
|
||||
# Sampling parameters.
|
||||
self.temperature = self._make_param(self.max_num_reqs, torch.float32)
|
||||
self.top_p = self._make_param(self.max_num_reqs, torch.float32)
|
||||
self.top_k = self._make_param(self.max_num_reqs, torch.int32)
|
||||
self.seeds = self._make_param(self.max_num_reqs, torch.int64)
|
||||
|
||||
self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32)
|
||||
# -1 means no logprobs are requested.
|
||||
self.num_logprobs.fill(-1)
|
||||
self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
|
||||
|
||||
def _make_param(self, size: int, dtype: torch.dtype) -> "Param":
|
||||
return Param(size, dtype=dtype, device=self.device, pin_memory=self.pin_memory)
|
||||
|
||||
def _make_buffer(self, size: int, dtype: torch.dtype) -> CpuGpuBuffer:
|
||||
return CpuGpuBuffer(
|
||||
size, dtype=dtype, device=self.device, pin_memory=self.pin_memory
|
||||
)
|
||||
|
||||
@property
|
||||
def num_reqs(self) -> int:
|
||||
return len(self.req_id_to_index)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
req_id: str,
|
||||
prompt_len: int,
|
||||
prefill_token_ids: list[int],
|
||||
num_computed_tokens: int,
|
||||
sampling_params: SamplingParams,
|
||||
lora_request: LoRARequest | None,
|
||||
) -> None:
|
||||
assert len(self.free_indices) > 0, "No free indices"
|
||||
req_idx = self.free_indices.pop()
|
||||
self.req_id_to_index[req_id] = req_idx
|
||||
self.index_to_req_id[req_idx] = req_id
|
||||
self.extra_data[req_id] = ExtraData(lora_request)
|
||||
|
||||
self.prompt_len[req_idx] = prompt_len
|
||||
prefill_len = len(prefill_token_ids)
|
||||
assert prefill_len >= prompt_len, (
|
||||
f"prefill_len {prefill_len} < prompt_len {prompt_len}"
|
||||
)
|
||||
self.prefill_len.np[req_idx] = prefill_len
|
||||
self.prefill_token_ids[req_idx, :prefill_len] = prefill_token_ids
|
||||
self.num_tokens[req_idx] = prefill_len
|
||||
self.num_computed_tokens[req_idx] = num_computed_tokens
|
||||
|
||||
if lora_request is not None:
|
||||
self.lora_ids[req_idx] = lora_request.lora_int_id
|
||||
else:
|
||||
self.lora_ids[req_idx] = NO_LORA_ID
|
||||
|
||||
self.temperature.np[req_idx] = sampling_params.temperature
|
||||
self.top_p.np[req_idx] = sampling_params.top_p
|
||||
if 0 < sampling_params.top_k < self.vocab_size:
|
||||
top_k = sampling_params.top_k
|
||||
else:
|
||||
top_k = self.vocab_size
|
||||
self.top_k.np[req_idx] = top_k
|
||||
|
||||
if sampling_params.seed is not None:
|
||||
seed = sampling_params.seed
|
||||
else:
|
||||
seed = np.random.randint(_NP_INT64_MIN, _NP_INT64_MAX)
|
||||
self.seeds.np[req_idx] = seed
|
||||
|
||||
if sampling_params.logprobs is not None:
|
||||
num_logprobs = sampling_params.logprobs
|
||||
else:
|
||||
num_logprobs = -1
|
||||
self.num_logprobs[req_idx] = num_logprobs
|
||||
|
||||
# For now, only support prompt logprobs for the prompt tokens.
|
||||
needs_prompt_logprobs = sampling_params.prompt_logprobs is not None
|
||||
self.needs_prompt_logprobs[req_idx] = needs_prompt_logprobs
|
||||
|
||||
def remove_request(self, req_id: str) -> None:
|
||||
self.extra_data.pop(req_id, None)
|
||||
req_idx = self.req_id_to_index.pop(req_id, None)
|
||||
if req_idx is None:
|
||||
# Request not found.
|
||||
return
|
||||
self.index_to_req_id.pop(req_idx, None)
|
||||
self.free_indices.append(req_idx)
|
||||
|
||||
def make_sampling_metadata(
|
||||
self,
|
||||
idx_mapping: np.ndarray,
|
||||
pos: torch.Tensor,
|
||||
) -> SamplingMetadata:
|
||||
temperature = self.temperature.np[idx_mapping]
|
||||
temperature = self.temperature.copy_np_to_gpu(temperature)
|
||||
|
||||
top_p = self.top_p.np[idx_mapping]
|
||||
no_top_p = np.all(top_p == 1.0)
|
||||
top_p = self.top_p.copy_np_to_gpu(top_p) if not no_top_p else None
|
||||
|
||||
top_k = self.top_k.np[idx_mapping]
|
||||
no_top_k = np.all(top_k == self.vocab_size)
|
||||
top_k = self.top_k.copy_np_to_gpu(top_k) if not no_top_k else None
|
||||
|
||||
seeds = self.seeds.np[idx_mapping]
|
||||
seeds = self.seeds.copy_np_to_gpu(seeds)
|
||||
|
||||
num_logprobs = self.num_logprobs[idx_mapping]
|
||||
max_num_logprobs: int | None = int(np.max(num_logprobs))
|
||||
if max_num_logprobs == -1:
|
||||
max_num_logprobs = None
|
||||
|
||||
return SamplingMetadata(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
seeds=seeds,
|
||||
pos=pos,
|
||||
max_num_logprobs=max_num_logprobs,
|
||||
)
|
||||
|
||||
def make_lora_inputs(
|
||||
self,
|
||||
req_ids: list[str],
|
||||
idx_mapping: np.ndarray,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
|
||||
lora_ids = self.lora_ids[idx_mapping]
|
||||
prompt_lora_mapping = tuple(lora_ids)
|
||||
token_lora_mapping = tuple(lora_ids.repeat(num_scheduled_tokens))
|
||||
|
||||
active_lora_requests: set[LoRARequest] = set()
|
||||
for req_id in req_ids:
|
||||
lora_request = self.extra_data[req_id].lora_request
|
||||
if lora_request is not None:
|
||||
active_lora_requests.add(lora_request)
|
||||
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
|
||||
|
||||
|
||||
class Param:
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
):
|
||||
self.buffer = CpuGpuBuffer(
|
||||
size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
self.np = np.zeros_like(self.buffer.np)
|
||||
|
||||
def copy_np_to_gpu(self, x: np.ndarray) -> torch.Tensor:
|
||||
n = x.shape[0]
|
||||
self.buffer.np[:n] = x
|
||||
return self.buffer.copy_to_gpu(n)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtraData:
|
||||
lora_request: LoRARequest | None
|
||||
in_progress_prompt_logprobs: list[LogprobsTensors] = field(default_factory=list)
|
||||
76
vllm/v1/worker/gpu/structured_outputs.py
Normal file
76
vllm/v1/worker/gpu/structured_outputs.py
Normal file
@ -0,0 +1,76 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.worker.gpu.input_batch import InputBuffers
|
||||
|
||||
|
||||
def apply_grammar_bitmask(
|
||||
logits: torch.Tensor,
|
||||
req_ids: list[str],
|
||||
grammar_req_ids: list[str],
|
||||
grammar_bitmask: np.ndarray,
|
||||
input_buffers: InputBuffers,
|
||||
) -> None:
|
||||
input_buffers.grammar_bitmask.np[: grammar_bitmask.shape[0]] = grammar_bitmask
|
||||
input_buffers.grammar_bitmask.copy_to_gpu(grammar_bitmask.shape[0])
|
||||
|
||||
batch_size = logits.shape[0]
|
||||
grammar_req_id_to_idx = {req_id: i for i, req_id in enumerate(grammar_req_ids)}
|
||||
# logits -> bitmask mapping
|
||||
mapping = [grammar_req_id_to_idx.get(req_id, -1) for req_id in req_ids]
|
||||
input_buffers.bitmask_indices.np[:batch_size] = mapping
|
||||
input_buffers.bitmask_indices.copy_to_gpu(batch_size)
|
||||
|
||||
vocab_size = logits.shape[-1]
|
||||
BLOCK_SIZE = 8192
|
||||
grid = (batch_size, triton.cdiv(vocab_size, BLOCK_SIZE))
|
||||
_apply_grammar_bitmask_kernel[grid](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
input_buffers.grammar_bitmask.gpu,
|
||||
input_buffers.grammar_bitmask.gpu.stride(0),
|
||||
input_buffers.bitmask_indices.gpu,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/mlc-ai/xgrammar/blob/main/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py
|
||||
@triton.jit
|
||||
def _apply_grammar_bitmask_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
bitmask_ptr,
|
||||
bitmask_stride,
|
||||
bitmask_indices_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
logits_idx = tl.program_id(0)
|
||||
bitmask_idx = tl.load(bitmask_indices_ptr + logits_idx)
|
||||
if bitmask_idx == -1:
|
||||
# No bitmask to apply.
|
||||
return
|
||||
|
||||
# Load the bitmask.
|
||||
block_id = tl.program_id(1)
|
||||
bitmask_offset = (block_id * BLOCK_SIZE) // 32 + tl.arange(0, BLOCK_SIZE // 32)
|
||||
packed_bitmask = tl.load(
|
||||
bitmask_ptr + bitmask_idx * bitmask_stride + bitmask_offset,
|
||||
mask=bitmask_offset < bitmask_stride,
|
||||
)
|
||||
# Unpack the bitmask.
|
||||
bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0
|
||||
bitmask = bitmask.reshape(BLOCK_SIZE)
|
||||
|
||||
# Apply the bitmask to the logits.
|
||||
block_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
tl.store(
|
||||
logits_ptr + logits_idx * logits_stride + block_offset,
|
||||
-float("inf"),
|
||||
mask=bitmask & (block_offset < vocab_size),
|
||||
)
|
||||
@ -41,7 +41,7 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils.mem_constants import GiB_bytes
|
||||
from vllm.utils.mem_utils import MemorySnapshot, memory_profiling
|
||||
from vllm.v1.core.sched.output import GrammarOutput
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import (
|
||||
@ -58,7 +58,6 @@ logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
|
||||
class Worker(WorkerBase):
|
||||
@ -101,6 +100,8 @@ class Worker(WorkerBase):
|
||||
else:
|
||||
self.profiler = None
|
||||
|
||||
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
|
||||
|
||||
def sleep(self, level: int = 1) -> None:
|
||||
from vllm.device_allocator.cumem import CuMemAllocator
|
||||
|
||||
@ -237,9 +238,17 @@ class Worker(WorkerBase):
|
||||
raise RuntimeError(f"Not support device type: {self.device_config.device}")
|
||||
|
||||
# Construct the model runner
|
||||
self.model_runner: GPUModelRunner = GPUModelRunner(
|
||||
self.vllm_config, self.device
|
||||
)
|
||||
if self.use_v2_model_runner:
|
||||
from vllm.v1.worker.gpu.model_runner import (
|
||||
GPUModelRunner as GPUModelRunnerV2,
|
||||
)
|
||||
|
||||
# HACK(woosuk): This is a temporary fix to avoid type errors.
|
||||
self.model_runner: GPUModelRunner = GPUModelRunnerV2( # type: ignore
|
||||
self.vllm_config, self.device
|
||||
)
|
||||
else:
|
||||
self.model_runner = GPUModelRunner(self.vllm_config, self.device)
|
||||
|
||||
if self.rank == 0:
|
||||
# If usage stat is enabled, collect relevant info.
|
||||
@ -573,7 +582,12 @@ class Worker(WorkerBase):
|
||||
self.profiler.stop()
|
||||
|
||||
def execute_dummy_batch(self) -> None:
|
||||
self.model_runner._dummy_run(1, uniform_decode=True)
|
||||
if self.use_v2_model_runner:
|
||||
self.model_runner.execute_model(
|
||||
SchedulerOutput.make_empty(), dummy_run=True
|
||||
)
|
||||
else:
|
||||
self.model_runner._dummy_run(1, uniform_decode=True)
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.model_runner.add_lora(lora_request)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user