mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-17 14:16:47 +08:00
fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
82e591f7eb
commit
8407fa02ed
@ -206,12 +206,11 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if H < MAX_HEADS:
|
if H < MAX_HEADS:
|
||||||
# Extract the subsets of the outputs
|
|
||||||
returned_lse = lse[:, :H].contiguous(
|
|
||||||
) if self.need_to_return_lse_for_decode else lse
|
|
||||||
out = out[:, :H]
|
out = out[:, :H]
|
||||||
|
if self.need_to_return_lse_for_decode:
|
||||||
|
lse = lse[:, :H].contiguous()
|
||||||
|
|
||||||
return out, returned_lse
|
return out, lse
|
||||||
|
|
||||||
def _forward_decode(
|
def _forward_decode(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -4,21 +4,26 @@
|
|||||||
|
|
||||||
import ast
|
import ast
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionMetadata, AttentionType)
|
AttentionMetadata, AttentionType)
|
||||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
from vllm.v1.attention.backends.utils import (
|
||||||
CommonAttentionMetadata,
|
AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||||
split_decodes_and_prefills)
|
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -179,6 +184,13 @@ class TreeAttentionMetadataBuilder(
|
|||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def reorder_batch(self, input_batch: "InputBatch",
|
||||||
|
scheduler_output: "SchedulerOutput") -> bool:
|
||||||
|
return reorder_batch_to_split_decodes_and_prefills(
|
||||||
|
input_batch,
|
||||||
|
scheduler_output,
|
||||||
|
decode_threshold=self.tree_attn_bias.shape[0])
|
||||||
|
|
||||||
def build(
|
def build(
|
||||||
self,
|
self,
|
||||||
common_prefix_len: int,
|
common_prefix_len: int,
|
||||||
|
|||||||
@ -703,6 +703,69 @@ def split_decodes_and_prefills(
|
|||||||
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
|
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def reorder_batch_to_split_decodes_and_prefills(
|
||||||
|
input_batch: "InputBatch",
|
||||||
|
scheduler_output: "SchedulerOutput",
|
||||||
|
decode_threshold: int = 1,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Reorders the batch to split into prefill and decode requests; places all
|
||||||
|
requests with <= decode_threshold tokens at the front of the batch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the batch was modified, False otherwise.
|
||||||
|
"""
|
||||||
|
# We now want to reorder the batch so that the "decode" requests are at
|
||||||
|
# the front and the "prefill" requests are at the back using the least
|
||||||
|
# amount of swaps possible. (NOTE for now we loosely use "decode" to mean
|
||||||
|
# requests where attention is likely memory-bound and "prefill" to mean
|
||||||
|
# requests where attention is likely compute-bound, TODO(lucas): figure out
|
||||||
|
# a better naming here)
|
||||||
|
decodes = []
|
||||||
|
prefills = []
|
||||||
|
num_decode_tokens = 0
|
||||||
|
num_prefill_tokens = 0
|
||||||
|
|
||||||
|
for i, req_id in enumerate(input_batch.req_ids):
|
||||||
|
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||||
|
# for now treat 1 scheduled token as "decode" even if it's not,
|
||||||
|
# we should update this to something like < 8 in the future but
|
||||||
|
# currently the TritonMLA._forward_decode only supports
|
||||||
|
# num_tokens = 1
|
||||||
|
if num_tokens <= decode_threshold:
|
||||||
|
decodes.append(i)
|
||||||
|
num_decode_tokens += num_tokens
|
||||||
|
else:
|
||||||
|
prefills.append(i)
|
||||||
|
num_prefill_tokens += num_tokens
|
||||||
|
|
||||||
|
# We hope that this is fairly minimal since decodes
|
||||||
|
# should be around for a number of iterations so hopefully they are
|
||||||
|
# relatively stationary (and new request are generally appended to the
|
||||||
|
# persistent batch so already should be at the back)
|
||||||
|
# To achieve this we loop over the decodes in descending order and
|
||||||
|
# the prefills in ascending order. We swap decodes from the "back"
|
||||||
|
# i.e. past where the last decode should be in the reodorered with
|
||||||
|
# prefills from the front of the batch.
|
||||||
|
# `decodes` and `prefills` are already in ascending order just based on
|
||||||
|
# the above loop
|
||||||
|
num_decodes = len(decodes)
|
||||||
|
num_prefills = len(prefills)
|
||||||
|
modified_batch = False
|
||||||
|
|
||||||
|
for i in range(1, min(num_decodes, num_prefills) + 1):
|
||||||
|
# If the decode is at the "back" of the batch, i, we can swap it
|
||||||
|
# with the prefill closest to the front of the batch
|
||||||
|
decode_idx = decodes[num_decodes - i]
|
||||||
|
if decode_idx < num_decodes:
|
||||||
|
break
|
||||||
|
|
||||||
|
input_batch.swap_states(prefills[i - 1], decode_idx)
|
||||||
|
modified_batch = True
|
||||||
|
|
||||||
|
return modified_batch
|
||||||
|
|
||||||
|
|
||||||
KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [
|
KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [
|
||||||
('logits_indices_padded', Optional[torch.Tensor], None),
|
('logits_indices_padded', Optional[torch.Tensor], None),
|
||||||
('num_logits_indices', int, 0),
|
('num_logits_indices', int, 0),
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
"""Attention layer with XFormersAttention."""
|
"""Attention layer with XFormersAttention."""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import ClassVar, Optional
|
from typing import TYPE_CHECKING, ClassVar, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -12,9 +12,9 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|||||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
from vllm.v1.attention.backends.utils import (
|
||||||
CommonAttentionMetadata,
|
AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||||
split_decodes_and_prefills)
|
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -26,6 +26,10 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
XFORMERS_AVAILABLE = False
|
XFORMERS_AVAILABLE = False
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -209,6 +213,13 @@ class XFormersAttentionMetadataBuilder(
|
|||||||
self._num_decodes = 0
|
self._num_decodes = 0
|
||||||
self._num_decode_tokens = 0
|
self._num_decode_tokens = 0
|
||||||
|
|
||||||
|
def reorder_batch(self, input_batch: "InputBatch",
|
||||||
|
scheduler_output: "SchedulerOutput") -> bool:
|
||||||
|
return reorder_batch_to_split_decodes_and_prefills(
|
||||||
|
input_batch,
|
||||||
|
scheduler_output,
|
||||||
|
decode_threshold=self.reorder_batch_threshold)
|
||||||
|
|
||||||
def build(
|
def build(
|
||||||
self,
|
self,
|
||||||
common_prefix_len: int,
|
common_prefix_len: int,
|
||||||
|
|||||||
@ -6,14 +6,38 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.v1.sample.logits_processor import LogitsProcessors
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SamplingMetadata:
|
class SamplingMetadata:
|
||||||
|
|
||||||
temperature: torch.Tensor
|
temperature: Optional[torch.Tensor]
|
||||||
|
all_greedy: bool
|
||||||
|
all_random: bool
|
||||||
|
|
||||||
top_p: Optional[torch.Tensor]
|
top_p: Optional[torch.Tensor]
|
||||||
top_k: Optional[torch.Tensor]
|
top_k: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
generators: dict[int, torch.Generator]
|
||||||
|
|
||||||
# None means no logprobs, 0 means sampled token logprobs only
|
# None means no logprobs, 0 means sampled token logprobs only
|
||||||
max_num_logprobs: Optional[int]
|
max_num_logprobs: Optional[int]
|
||||||
|
|
||||||
|
no_penalties: bool
|
||||||
|
prompt_token_ids: Optional[torch.Tensor]
|
||||||
|
frequency_penalties: torch.Tensor
|
||||||
|
presence_penalties: torch.Tensor
|
||||||
|
repetition_penalties: torch.Tensor
|
||||||
|
|
||||||
|
output_token_ids: list[list[int]]
|
||||||
|
|
||||||
|
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
|
||||||
|
# vocab size).
|
||||||
|
allowed_token_ids_mask: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
# req_index -> bad_words_token_ids
|
||||||
|
bad_words_token_ids: dict[int, list[list[int]]]
|
||||||
|
|
||||||
|
# Loaded logits processors
|
||||||
|
logitsprocs: LogitsProcessors
|
||||||
|
|||||||
@ -90,9 +90,9 @@ class Sampler(nn.Module):
|
|||||||
# Apply bad words exclusion.
|
# Apply bad words exclusion.
|
||||||
logits = self.apply_bad_words(logits, sampling_metadata)
|
logits = self.apply_bad_words(logits, sampling_metadata)
|
||||||
|
|
||||||
# # Apply logits processors which can impact greedy sampling
|
# Apply logits processors which can impact greedy sampling
|
||||||
# for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
|
for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
|
||||||
# logits = processor.apply(logits)
|
logits = processor.apply(logits)
|
||||||
|
|
||||||
# Apply penalties (e.g., min_tokens, freq_penalties).
|
# Apply penalties (e.g., min_tokens, freq_penalties).
|
||||||
logits = self.apply_penalties(logits, sampling_metadata)
|
logits = self.apply_penalties(logits, sampling_metadata)
|
||||||
@ -167,10 +167,10 @@ class Sampler(nn.Module):
|
|||||||
# Apply temperature.
|
# Apply temperature.
|
||||||
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
||||||
|
|
||||||
# # Apply logits processors that only apply to random sampling
|
# Apply logits processors that only apply to random sampling
|
||||||
# # (argmax invariant)
|
# (argmax invariant)
|
||||||
# for processor in sampling_metadata.logitsprocs.argmax_invariant:
|
for processor in sampling_metadata.logitsprocs.argmax_invariant:
|
||||||
# logits = processor.apply(logits)
|
logits = processor.apply(logits)
|
||||||
|
|
||||||
# Apply top_k and/or top_p.
|
# Apply top_k and/or top_p.
|
||||||
random_sampled, processed_logprobs = self.topk_topp_sampler(
|
random_sampled, processed_logprobs = self.topk_topp_sampler(
|
||||||
|
|||||||
@ -79,6 +79,9 @@ class GPUModelRunner:
|
|||||||
)
|
)
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
|
|
||||||
|
def get_supported_tasks(self) -> tuple[str]:
|
||||||
|
return ("generate", )
|
||||||
|
|
||||||
def load_model(self, *args, **kwargs) -> None:
|
def load_model(self, *args, **kwargs) -> None:
|
||||||
time_before_load = time.perf_counter()
|
time_before_load = time.perf_counter()
|
||||||
with DeviceMemoryProfiler() as m:
|
with DeviceMemoryProfiler() as m:
|
||||||
|
|||||||
@ -9,6 +9,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.v1.utils import CpuGpuBuffer
|
||||||
|
|
||||||
_NP_INT64_MIN = np.iinfo(np.int64).min
|
_NP_INT64_MIN = np.iinfo(np.int64).min
|
||||||
_NP_INT64_MAX = np.iinfo(np.int64).max
|
_NP_INT64_MAX = np.iinfo(np.int64).max
|
||||||
@ -69,21 +70,21 @@ class RequestState:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Sampling parameters.
|
# Sampling parameters.
|
||||||
self.temperature = self._make_buffer(self.max_num_reqs, torch.float32)
|
self.temperature = self._make_param(self.max_num_reqs, torch.float32)
|
||||||
self.top_p = self._make_buffer(self.max_num_reqs, torch.float32)
|
self.top_p = self._make_param(self.max_num_reqs, torch.float32)
|
||||||
self.top_k = self._make_buffer(self.max_num_reqs, torch.int32)
|
self.top_k = self._make_param(self.max_num_reqs, torch.int32)
|
||||||
self.seeds = self._make_buffer(self.max_num_reqs, torch.int64)
|
self.seeds = self._make_param(self.max_num_reqs, torch.int64)
|
||||||
|
|
||||||
self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32)
|
self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32)
|
||||||
# -1 means no logprobs are requested.
|
# -1 means no logprobs are requested.
|
||||||
self.num_logprobs.fill(-1)
|
self.num_logprobs.fill(-1)
|
||||||
self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
|
self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
|
||||||
|
|
||||||
def _make_buffer(self, size, dtype: torch.dtype) -> "Buffer":
|
def _make_param(self, size: int, dtype: torch.dtype) -> "Param":
|
||||||
return Buffer(size,
|
return Param(size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
pin_memory=self.pin_memory,
|
device=self.device,
|
||||||
device=self.device)
|
pin_memory=self.pin_memory)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_reqs(self) -> int:
|
def num_reqs(self) -> int:
|
||||||
@ -217,27 +218,24 @@ def _append_token_ids(
|
|||||||
num_tokens[req_idx] = end_idx
|
num_tokens[req_idx] = end_idx
|
||||||
|
|
||||||
|
|
||||||
class Buffer:
|
class Param:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*args,
|
size: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
pin_memory: bool,
|
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
pin_memory: bool,
|
||||||
):
|
):
|
||||||
# NOTE(woosuk): Unlike CpuGpuBuffer, the Numpy array and CPU tensor
|
self.buffer = CpuGpuBuffer(
|
||||||
# in this class do not share the same storage.
|
size,
|
||||||
self.np = np.zeros(*args, dtype=dtype)
|
|
||||||
self.cpu = torch.zeros(
|
|
||||||
*args,
|
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
pin_memory=pin_memory,
|
|
||||||
device=device,
|
device=device,
|
||||||
|
pin_memory=pin_memory,
|
||||||
)
|
)
|
||||||
self.gpu = self.cpu.to(device)
|
self.np = np.zeros_like(self.buffer.np)
|
||||||
|
|
||||||
def copy_np_to_gpu(self, x: np.ndarray) -> torch.Tensor:
|
def copy_np_to_gpu(self, x: np.ndarray) -> torch.Tensor:
|
||||||
n = x.shape[0]
|
n = x.shape[0]
|
||||||
self.cpu[:n] = x
|
self.buffer.np[:n] = x
|
||||||
return self.gpu[:n].copy_(self.cpu[:n], non_blocking=True)
|
return self.buffer.copy_to_gpu(n)
|
||||||
|
|||||||
@ -423,7 +423,6 @@ class Worker(WorkerBase):
|
|||||||
return self.model_runner.get_model()
|
return self.model_runner.get_model()
|
||||||
|
|
||||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||||
return "generate"
|
|
||||||
return self.model_runner.get_supported_tasks()
|
return self.model_runner.get_supported_tasks()
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from vllm.sampling_params import SamplingType
|
|||||||
from vllm.utils import swap_dict_values
|
from vllm.utils import swap_dict_values
|
||||||
from vllm.v1.outputs import LogprobsTensors
|
from vllm.v1.outputs import LogprobsTensors
|
||||||
from vllm.v1.worker.block_table import MultiGroupBlockTable
|
from vllm.v1.worker.block_table import MultiGroupBlockTable
|
||||||
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState
|
||||||
|
|
||||||
_SAMPLING_EPS = 1e-5
|
_SAMPLING_EPS = 1e-5
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user