Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-08-17 14:38:24 -07:00
parent 6d243efeda
commit 33a3a26ca5
3 changed files with 269 additions and 753 deletions

View File

@ -12,12 +12,12 @@ from vllm.v1.sample.logits_processor import LogitsProcessors
@dataclass @dataclass
class SamplingMetadata: class SamplingMetadata:
temperature: Optional[torch.Tensor] temperature: torch.Tensor
all_greedy: bool all_greedy: bool
all_random: bool all_random: bool
top_p: Optional[torch.Tensor] top_p: torch.Tensor
top_k: Optional[torch.Tensor] top_k: torch.Tensor
generators: dict[int, torch.Generator] generators: dict[int, torch.Generator]
@ -25,12 +25,11 @@ class SamplingMetadata:
max_num_logprobs: Optional[int] max_num_logprobs: Optional[int]
no_penalties: bool no_penalties: bool
prompt_token_ids: Optional[torch.Tensor]
frequency_penalties: torch.Tensor frequency_penalties: torch.Tensor
presence_penalties: torch.Tensor presence_penalties: torch.Tensor
repetition_penalties: torch.Tensor repetition_penalties: torch.Tensor
output_token_ids: list[list[int]] token_ids: Optional[torch.Tensor]
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size, # `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
# vocab size). # vocab size).

File diff suppressed because it is too large Load Diff

View File

@ -57,8 +57,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
make_kv_sharing_fast_prefill_attention_metadata, make_kv_sharing_fast_prefill_attention_metadata)
reorder_batch_to_split_decodes_and_prefills)
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.kv_cache_interface import (AttentionSpec, from vllm.v1.kv_cache_interface import (AttentionSpec,
ChunkedLocalAttentionSpec, ChunkedLocalAttentionSpec,
@ -288,35 +287,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dtype=self.dtype, dtype=self.dtype,
device=self.device) device=self.device)
# OPTIMIZATION: Cache the tensors rather than creating them every step.
# Keep in int64 to avoid overflow with long context
self.arange_np = np.arange(max(self.max_num_reqs + 1,
self.max_model_len,
self.max_num_tokens),
dtype=np.int64)
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
# a faster version of creating a new tensor every time. Thus, we should
# not make any assumptions about the values in these tensors.
self.input_ids_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.positions_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
self.positions_np = self.positions_cpu.numpy()
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.query_start_loc_np = self.query_start_loc_cpu.numpy()
self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.seq_lens_np = self.seq_lens_cpu.numpy()
# Layer pairings for cross-layer KV sharing. # Layer pairings for cross-layer KV sharing.
# If an Attention layer `layer_name` is in the keys of this dict, it # If an Attention layer `layer_name` is in the keys of this dict, it
# means this layer will perform attention using the keys and values # means this layer will perform attention using the keys and values
@ -344,8 +314,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
) if self.supports_mm_inputs \ ) if self.supports_mm_inputs \
else None) else None)
self.reorder_batch_threshold: Optional[int] = None
def _init_model_kwargs(self, num_tokens: int): def _init_model_kwargs(self, num_tokens: int):
model_kwargs = dict[str, Any]() model_kwargs = dict[str, Any]()
num_reqs = self.input_batch.num_reqs num_reqs = self.input_batch.num_reqs
@ -381,30 +349,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
device=self.device) device=self.device)
return model_kwargs return model_kwargs
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
"""
Update the order of requests in the batch based on the attention
backend's needs. For example, some attention backends (namely MLA) may
want to separate requests based on if the attention computation will be
compute-bound or memory-bound.
Args:
scheduler_output: The scheduler output.
"""
# Attention free models have zero kv_cache_goups, however models
# like Mamba are also attention free but use the kv_cache for
# keeping its internal state. This is why we check the number
# of kv_cache groups instead of solely checking
# for self.model_config.is_attention_free.
if len(self.kv_cache_config.kv_cache_groups) == 0:
return
if self.reorder_batch_threshold is not None:
reorder_batch_to_split_decodes_and_prefills(
self.input_batch,
scheduler_output,
decode_threshold=self.reorder_batch_threshold)
# Note: used for model runner override. # Note: used for model runner override.
def _init_device_properties(self) -> None: def _init_device_properties(self) -> None:
"""Initialize attributes from torch.cuda.get_device_properties """Initialize attributes from torch.cuda.get_device_properties
@ -621,13 +565,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_state = self.requests[req_id] req_state = self.requests[req_id]
self.input_batch.add_request(req_state) self.input_batch.add_request(req_state)
# Condense the batched states if there are gaps left by removed requests
self.input_batch.condense()
# Allow attention backend to reorder the batch, potentially
self._may_reorder_batch(scheduler_output)
# Refresh batch metadata with any pending updates.
self.input_batch.refresh_metadata()
def _extract_mm_kwargs( def _extract_mm_kwargs(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",