mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-08 17:37:02 +08:00
fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
82e591f7eb
commit
8407fa02ed
@ -206,12 +206,11 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
)
|
||||
|
||||
if H < MAX_HEADS:
|
||||
# Extract the subsets of the outputs
|
||||
returned_lse = lse[:, :H].contiguous(
|
||||
) if self.need_to_return_lse_for_decode else lse
|
||||
out = out[:, :H]
|
||||
if self.need_to_return_lse_for_decode:
|
||||
lse = lse[:, :H].contiguous()
|
||||
|
||||
return out, returned_lse
|
||||
return out, lse
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
|
||||
@ -4,21 +4,26 @@
|
||||
|
||||
import ast
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -179,6 +184,13 @@ class TreeAttentionMetadataBuilder(
|
||||
device=device,
|
||||
)
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
return reorder_batch_to_split_decodes_and_prefills(
|
||||
input_batch,
|
||||
scheduler_output,
|
||||
decode_threshold=self.tree_attn_bias.shape[0])
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
|
||||
@ -703,6 +703,69 @@ def split_decodes_and_prefills(
|
||||
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
|
||||
|
||||
|
||||
def reorder_batch_to_split_decodes_and_prefills(
|
||||
input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput",
|
||||
decode_threshold: int = 1,
|
||||
) -> bool:
|
||||
"""
|
||||
Reorders the batch to split into prefill and decode requests; places all
|
||||
requests with <= decode_threshold tokens at the front of the batch.
|
||||
|
||||
Returns:
|
||||
True if the batch was modified, False otherwise.
|
||||
"""
|
||||
# We now want to reorder the batch so that the "decode" requests are at
|
||||
# the front and the "prefill" requests are at the back using the least
|
||||
# amount of swaps possible. (NOTE for now we loosely use "decode" to mean
|
||||
# requests where attention is likely memory-bound and "prefill" to mean
|
||||
# requests where attention is likely compute-bound, TODO(lucas): figure out
|
||||
# a better naming here)
|
||||
decodes = []
|
||||
prefills = []
|
||||
num_decode_tokens = 0
|
||||
num_prefill_tokens = 0
|
||||
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
# for now treat 1 scheduled token as "decode" even if it's not,
|
||||
# we should update this to something like < 8 in the future but
|
||||
# currently the TritonMLA._forward_decode only supports
|
||||
# num_tokens = 1
|
||||
if num_tokens <= decode_threshold:
|
||||
decodes.append(i)
|
||||
num_decode_tokens += num_tokens
|
||||
else:
|
||||
prefills.append(i)
|
||||
num_prefill_tokens += num_tokens
|
||||
|
||||
# We hope that this is fairly minimal since decodes
|
||||
# should be around for a number of iterations so hopefully they are
|
||||
# relatively stationary (and new request are generally appended to the
|
||||
# persistent batch so already should be at the back)
|
||||
# To achieve this we loop over the decodes in descending order and
|
||||
# the prefills in ascending order. We swap decodes from the "back"
|
||||
# i.e. past where the last decode should be in the reodorered with
|
||||
# prefills from the front of the batch.
|
||||
# `decodes` and `prefills` are already in ascending order just based on
|
||||
# the above loop
|
||||
num_decodes = len(decodes)
|
||||
num_prefills = len(prefills)
|
||||
modified_batch = False
|
||||
|
||||
for i in range(1, min(num_decodes, num_prefills) + 1):
|
||||
# If the decode is at the "back" of the batch, i, we can swap it
|
||||
# with the prefill closest to the front of the batch
|
||||
decode_idx = decodes[num_decodes - i]
|
||||
if decode_idx < num_decodes:
|
||||
break
|
||||
|
||||
input_batch.swap_states(prefills[i - 1], decode_idx)
|
||||
modified_batch = True
|
||||
|
||||
return modified_batch
|
||||
|
||||
|
||||
KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [
|
||||
('logits_indices_padded', Optional[torch.Tensor], None),
|
||||
('num_logits_indices', int, 0),
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
"""Attention layer with XFormersAttention."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Optional
|
||||
from typing import TYPE_CHECKING, ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -12,9 +12,9 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
try:
|
||||
@ -26,6 +26,10 @@ try:
|
||||
except ImportError:
|
||||
XFORMERS_AVAILABLE = False
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -209,6 +213,13 @@ class XFormersAttentionMetadataBuilder(
|
||||
self._num_decodes = 0
|
||||
self._num_decode_tokens = 0
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
return reorder_batch_to_split_decodes_and_prefills(
|
||||
input_batch,
|
||||
scheduler_output,
|
||||
decode_threshold=self.reorder_batch_threshold)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
|
||||
@ -6,14 +6,38 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplingMetadata:
|
||||
|
||||
temperature: torch.Tensor
|
||||
temperature: Optional[torch.Tensor]
|
||||
all_greedy: bool
|
||||
all_random: bool
|
||||
|
||||
top_p: Optional[torch.Tensor]
|
||||
top_k: Optional[torch.Tensor]
|
||||
|
||||
generators: dict[int, torch.Generator]
|
||||
|
||||
# None means no logprobs, 0 means sampled token logprobs only
|
||||
max_num_logprobs: Optional[int]
|
||||
|
||||
no_penalties: bool
|
||||
prompt_token_ids: Optional[torch.Tensor]
|
||||
frequency_penalties: torch.Tensor
|
||||
presence_penalties: torch.Tensor
|
||||
repetition_penalties: torch.Tensor
|
||||
|
||||
output_token_ids: list[list[int]]
|
||||
|
||||
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
|
||||
# vocab size).
|
||||
allowed_token_ids_mask: Optional[torch.Tensor]
|
||||
|
||||
# req_index -> bad_words_token_ids
|
||||
bad_words_token_ids: dict[int, list[list[int]]]
|
||||
|
||||
# Loaded logits processors
|
||||
logitsprocs: LogitsProcessors
|
||||
|
||||
@ -90,9 +90,9 @@ class Sampler(nn.Module):
|
||||
# Apply bad words exclusion.
|
||||
logits = self.apply_bad_words(logits, sampling_metadata)
|
||||
|
||||
# # Apply logits processors which can impact greedy sampling
|
||||
# for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
|
||||
# logits = processor.apply(logits)
|
||||
# Apply logits processors which can impact greedy sampling
|
||||
for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
|
||||
logits = processor.apply(logits)
|
||||
|
||||
# Apply penalties (e.g., min_tokens, freq_penalties).
|
||||
logits = self.apply_penalties(logits, sampling_metadata)
|
||||
@ -167,10 +167,10 @@ class Sampler(nn.Module):
|
||||
# Apply temperature.
|
||||
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
||||
|
||||
# # Apply logits processors that only apply to random sampling
|
||||
# # (argmax invariant)
|
||||
# for processor in sampling_metadata.logitsprocs.argmax_invariant:
|
||||
# logits = processor.apply(logits)
|
||||
# Apply logits processors that only apply to random sampling
|
||||
# (argmax invariant)
|
||||
for processor in sampling_metadata.logitsprocs.argmax_invariant:
|
||||
logits = processor.apply(logits)
|
||||
|
||||
# Apply top_k and/or top_p.
|
||||
random_sampled, processed_logprobs = self.topk_topp_sampler(
|
||||
|
||||
@ -79,6 +79,9 @@ class GPUModelRunner:
|
||||
)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def get_supported_tasks(self) -> tuple[str]:
|
||||
return ("generate", )
|
||||
|
||||
def load_model(self, *args, **kwargs) -> None:
|
||||
time_before_load = time.perf_counter()
|
||||
with DeviceMemoryProfiler() as m:
|
||||
|
||||
@ -9,6 +9,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
|
||||
_NP_INT64_MIN = np.iinfo(np.int64).min
|
||||
_NP_INT64_MAX = np.iinfo(np.int64).max
|
||||
@ -69,21 +70,21 @@ class RequestState:
|
||||
)
|
||||
|
||||
# Sampling parameters.
|
||||
self.temperature = self._make_buffer(self.max_num_reqs, torch.float32)
|
||||
self.top_p = self._make_buffer(self.max_num_reqs, torch.float32)
|
||||
self.top_k = self._make_buffer(self.max_num_reqs, torch.int32)
|
||||
self.seeds = self._make_buffer(self.max_num_reqs, torch.int64)
|
||||
self.temperature = self._make_param(self.max_num_reqs, torch.float32)
|
||||
self.top_p = self._make_param(self.max_num_reqs, torch.float32)
|
||||
self.top_k = self._make_param(self.max_num_reqs, torch.int32)
|
||||
self.seeds = self._make_param(self.max_num_reqs, torch.int64)
|
||||
|
||||
self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32)
|
||||
# -1 means no logprobs are requested.
|
||||
self.num_logprobs.fill(-1)
|
||||
self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
|
||||
|
||||
def _make_buffer(self, size, dtype: torch.dtype) -> "Buffer":
|
||||
return Buffer(size,
|
||||
dtype=dtype,
|
||||
pin_memory=self.pin_memory,
|
||||
device=self.device)
|
||||
def _make_param(self, size: int, dtype: torch.dtype) -> "Param":
|
||||
return Param(size,
|
||||
dtype=dtype,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory)
|
||||
|
||||
@property
|
||||
def num_reqs(self) -> int:
|
||||
@ -217,27 +218,24 @@ def _append_token_ids(
|
||||
num_tokens[req_idx] = end_idx
|
||||
|
||||
|
||||
class Buffer:
|
||||
class Param:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
size: int,
|
||||
dtype: torch.dtype,
|
||||
pin_memory: bool,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
):
|
||||
# NOTE(woosuk): Unlike CpuGpuBuffer, the Numpy array and CPU tensor
|
||||
# in this class do not share the same storage.
|
||||
self.np = np.zeros(*args, dtype=dtype)
|
||||
self.cpu = torch.zeros(
|
||||
*args,
|
||||
self.buffer = CpuGpuBuffer(
|
||||
size,
|
||||
dtype=dtype,
|
||||
pin_memory=pin_memory,
|
||||
device=device,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
self.gpu = self.cpu.to(device)
|
||||
self.np = np.zeros_like(self.buffer.np)
|
||||
|
||||
def copy_np_to_gpu(self, x: np.ndarray) -> torch.Tensor:
|
||||
n = x.shape[0]
|
||||
self.cpu[:n] = x
|
||||
return self.gpu[:n].copy_(self.cpu[:n], non_blocking=True)
|
||||
self.buffer.np[:n] = x
|
||||
return self.buffer.copy_to_gpu(n)
|
||||
|
||||
@ -423,7 +423,6 @@ class Worker(WorkerBase):
|
||||
return self.model_runner.get_model()
|
||||
|
||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
return "generate"
|
||||
return self.model_runner.get_supported_tasks()
|
||||
|
||||
@torch.inference_mode()
|
||||
|
||||
@ -12,6 +12,7 @@ from vllm.sampling_params import SamplingType
|
||||
from vllm.utils import swap_dict_values
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.worker.block_table import MultiGroupBlockTable
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user