mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-17 04:17:04 +08:00
fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
ba1a58f51b
commit
62d23b3006
@ -4,26 +4,21 @@
|
||||
|
||||
import ast
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -183,13 +178,6 @@ class TreeAttentionMetadataBuilder(
|
||||
device=device,
|
||||
)
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
return reorder_batch_to_split_decodes_and_prefills(
|
||||
input_batch,
|
||||
scheduler_output,
|
||||
decode_threshold=self.tree_attn_bias.shape[0])
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
|
||||
@ -684,69 +684,6 @@ def split_decodes_and_prefills(
|
||||
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
|
||||
|
||||
|
||||
def reorder_batch_to_split_decodes_and_prefills(
|
||||
input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput",
|
||||
decode_threshold: int = 1,
|
||||
) -> bool:
|
||||
"""
|
||||
Reorders the batch to split into prefill and decode requests; places all
|
||||
requests with <= decode_threshold tokens at the front of the batch.
|
||||
|
||||
Returns:
|
||||
True if the batch was modified, False otherwise.
|
||||
"""
|
||||
# We now want to reorder the batch so that the "decode" requests are at
|
||||
# the front and the "prefill" requests are at the back using the least
|
||||
# amount of swaps possible. (NOTE for now we loosely use "decode" to mean
|
||||
# requests where attention is likely memory-bound and "prefill" to mean
|
||||
# requests where attention is likely compute-bound, TODO(lucas): figure out
|
||||
# a better naming here)
|
||||
decodes = []
|
||||
prefills = []
|
||||
num_decode_tokens = 0
|
||||
num_prefill_tokens = 0
|
||||
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
# for now treat 1 scheduled token as "decode" even if its not,
|
||||
# we should update this to something like < 8 in the future but
|
||||
# currently the TritonMLA._forward_decode only supports
|
||||
# num_tokens = 1
|
||||
if num_tokens <= decode_threshold:
|
||||
decodes.append(i)
|
||||
num_decode_tokens += num_tokens
|
||||
else:
|
||||
prefills.append(i)
|
||||
num_prefill_tokens += num_tokens
|
||||
|
||||
# We hope that this is fairly minimal since decodes
|
||||
# should be around for a number of iterations so hopefully they are
|
||||
# relatively stationary (and new request are generally appended to the
|
||||
# persistent batch so already should be at the back)
|
||||
# To achieve this we loop over the decodes in descending order and
|
||||
# the prefills in ascending order. We swap decodes from the "back"
|
||||
# i.e. past where the last decode should be in the reodorered with
|
||||
# prefills from the front of the batch.
|
||||
# `decodes` and `prefills` are already in ascending order just based on
|
||||
# the above loop
|
||||
num_decodes = len(decodes)
|
||||
num_prefills = len(prefills)
|
||||
modified_batch = False
|
||||
|
||||
for i in range(1, min(num_decodes, num_prefills) + 1):
|
||||
# If the decode is at the "back" of the batch, i, we can swap it
|
||||
# with the prefill closest to the front of the batch
|
||||
decode_idx = decodes[num_decodes - i]
|
||||
if decode_idx < num_decodes:
|
||||
break
|
||||
|
||||
input_batch.swap_states(prefills[i - 1], decode_idx)
|
||||
modified_batch = True
|
||||
|
||||
return modified_batch
|
||||
|
||||
|
||||
KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [
|
||||
('logits_indices_padded', Optional[torch.Tensor], None),
|
||||
('num_logits_indices', int, 0),
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
"""Attention layer with XFormersAttention."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -12,9 +12,9 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
try:
|
||||
@ -26,10 +26,6 @@ try:
|
||||
except ImportError:
|
||||
XFORMERS_AVAILABLE = False
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -210,12 +206,6 @@ class XFormersAttentionMetadataBuilder(
|
||||
self._num_decodes = 0
|
||||
self._num_decode_tokens = 0
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
return reorder_batch_to_split_decodes_and_prefills(input_batch,
|
||||
scheduler_output,
|
||||
decode_threshold=1)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
|
||||
@ -462,6 +462,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
def _init_mrope_positions(self, req_id: str) -> None:
|
||||
req_idx = self.requests.req_id_to_index[req_id]
|
||||
req_data = self.requests.req_data[req_idx]
|
||||
prompt_len = self.requests.num_prompt_tokens.np[req_idx]
|
||||
prompt_token_ids = self.requests.token_ids.np[req_idx, :prompt_len]
|
||||
|
||||
image_grid_thw = []
|
||||
video_grid_thw = []
|
||||
@ -483,7 +485,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
req_data.mrope_positions, req_data.mrope_position_delta = \
|
||||
MRotaryEmbedding.get_input_positions_tensor(
|
||||
req_data.prompt_token_ids,
|
||||
prompt_token_ids,
|
||||
hf_config=self.model_config.hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
@ -905,7 +907,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# list of tuple (mm_hash, position_info)
|
||||
mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
|
||||
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
||||
req_data = self.requests.req_data[req_id]
|
||||
req_idx = self.requests.req_id_to_index[req_id]
|
||||
req_data = self.requests.req_data[req_idx]
|
||||
|
||||
for mm_input_id in encoder_input_ids:
|
||||
mm_hash = req_data.mm_hashes[mm_input_id]
|
||||
@ -1259,11 +1262,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
return self.kv_connector_no_forward(scheduler_output,
|
||||
self.vllm_config)
|
||||
|
||||
if self.cache_config.kv_sharing_fast_prefill:
|
||||
assert not self.input_batch.num_prompt_logprobs, (
|
||||
"--kv-sharing-fast-prefill produces incorrect logprobs for "
|
||||
"prompt tokens, tokens, please disable it when the requests "
|
||||
"need prompt logprobs")
|
||||
# if self.cache_config.kv_sharing_fast_prefill:
|
||||
# assert not self.input_batch.num_prompt_logprobs, (
|
||||
# "--kv-sharing-fast-prefill produces incorrect logprobs for "
|
||||
# "prompt tokens, tokens, please disable it when the requests "
|
||||
# "need prompt logprobs")
|
||||
|
||||
# Prepare the decoder inputs.
|
||||
input_batch = self._prepare_inputs(scheduler_output)
|
||||
@ -1296,7 +1299,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if self.supports_mm_inputs:
|
||||
# Run the multimodal encoder if any.
|
||||
self._execute_mm_encoder(scheduler_output)
|
||||
mm_embeds = self._gather_mm_embeddings(scheduler_output)
|
||||
mm_embeds = self._gather_mm_embeddings(input_batch)
|
||||
else:
|
||||
mm_embeds = []
|
||||
|
||||
@ -1328,7 +1331,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
inputs_embeds = None
|
||||
model_kwargs = self._init_model_kwargs(num_input_tokens)
|
||||
if self.uses_mrope:
|
||||
positions = self.mrope_positions[:, :num_input_tokens]
|
||||
positions = self.mrope_positions.gpu[:, :num_input_tokens]
|
||||
else:
|
||||
positions = self.positions.gpu[:num_input_tokens]
|
||||
|
||||
@ -1448,7 +1451,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
num_nans_in_logits = {}
|
||||
if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
|
||||
num_nans_in_logits = self._get_nans_in_logits(logits)
|
||||
num_nans_in_logits = self._get_nans_in_logits(
|
||||
input_batch.req_ids, logits)
|
||||
|
||||
# NOTE: GPU -> CPU Sync happens here.
|
||||
# Move as many CPU operations as possible before this sync point.
|
||||
@ -1488,14 +1492,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if self.speculative_config:
|
||||
assert input_batch.spec_decode_common_attn_metadata is not None
|
||||
self._draft_token_ids = self.propose_draft_token_ids(
|
||||
scheduler_output,
|
||||
input_batch,
|
||||
valid_sampled_token_ids,
|
||||
sampling_metadata,
|
||||
hidden_states,
|
||||
sample_hidden_states,
|
||||
aux_hidden_states,
|
||||
input_batch.spec_decode_metadata,
|
||||
input_batch.spec_decode_common_attn_metadata,
|
||||
)
|
||||
self._draft_req_ids = input_batch.req_ids
|
||||
|
||||
@ -1889,16 +1891,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
def _get_nans_in_logits(
|
||||
self,
|
||||
input_batch: InputBatch,
|
||||
req_ids: list[str],
|
||||
logits: Optional[torch.Tensor],
|
||||
) -> dict[str, int]:
|
||||
try:
|
||||
if logits is None:
|
||||
return {req_id: 0 for req_id in input_batch.req_ids}
|
||||
return {req_id: 0 for req_id in req_ids}
|
||||
|
||||
num_nans_in_logits = {}
|
||||
num_nans_for_index = logits.isnan().sum(dim=-1).cpu().numpy()
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
for i, req_id in enumerate(req_ids):
|
||||
num_nans_in_logits[req_id] = (int(num_nans_for_index[i])
|
||||
if num_nans_for_index is not None
|
||||
and i < logits.shape[0] else 0)
|
||||
@ -2092,7 +2094,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
model_kwargs = self._init_model_kwargs(num_tokens)
|
||||
|
||||
if self.uses_mrope:
|
||||
positions = self.mrope_positions[:, :num_tokens]
|
||||
positions = self.mrope_positions.gpu[:, :num_tokens]
|
||||
else:
|
||||
positions = self.positions.gpu[:num_tokens]
|
||||
|
||||
|
||||
@ -110,7 +110,7 @@ class RequestState:
|
||||
self.is_spec_decode = is_spec_decode
|
||||
self.pooling_params = None
|
||||
self.block_sizes = block_sizes
|
||||
self.num_prompt_logprobs = {}
|
||||
self.num_prompt_logprobs: dict[int, int] = {}
|
||||
|
||||
self.req_id_to_index: dict[str, int] = {}
|
||||
self.index_to_req_id: dict[int, str] = {}
|
||||
@ -378,6 +378,7 @@ def _make_sampling_metadata_kernel(
|
||||
tl.store(dst_repetition_penalties + batch_idx, repetition_penalties)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _prepare_spec_decode_kernel(
|
||||
query_start_loc, # [B + 1]
|
||||
cu_num_draft_tokens, # [B]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user