mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 20:25:01 +08:00
[v1] AttentionMetadata for each layer (#17394)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
a6fed02068
commit
cba31c47c4
@ -210,6 +210,8 @@ class Attention(nn.Module):
|
|||||||
if self.use_direct_call:
|
if self.use_direct_call:
|
||||||
forward_context: ForwardContext = get_forward_context()
|
forward_context: ForwardContext = get_forward_context()
|
||||||
attn_metadata = forward_context.attn_metadata
|
attn_metadata = forward_context.attn_metadata
|
||||||
|
if isinstance(attn_metadata, dict):
|
||||||
|
attn_metadata = attn_metadata[self.layer_name]
|
||||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||||
self.impl.forward(self,
|
self.impl.forward(self,
|
||||||
query,
|
query,
|
||||||
@ -226,6 +228,8 @@ class Attention(nn.Module):
|
|||||||
if self.use_direct_call:
|
if self.use_direct_call:
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
attn_metadata = forward_context.attn_metadata
|
attn_metadata = forward_context.attn_metadata
|
||||||
|
if isinstance(attn_metadata, dict):
|
||||||
|
attn_metadata = attn_metadata[self.layer_name]
|
||||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||||
return self.impl.forward(self, query, key, value,
|
return self.impl.forward(self, query, key, value,
|
||||||
self_kv_cache, attn_metadata)
|
self_kv_cache, attn_metadata)
|
||||||
@ -343,7 +347,7 @@ def wait_for_kv_layer_from_connector(layer_name: str):
|
|||||||
attn_metadata = forward_context.attn_metadata
|
attn_metadata = forward_context.attn_metadata
|
||||||
if attn_metadata is None:
|
if attn_metadata is None:
|
||||||
return
|
return
|
||||||
|
assert isinstance(attn_metadata, dict)
|
||||||
connector.wait_for_layer_load(layer_name)
|
connector.wait_for_layer_load(layer_name)
|
||||||
|
|
||||||
|
|
||||||
@ -360,8 +364,9 @@ def maybe_save_kv_layer_to_connector(
|
|||||||
attn_metadata = forward_context.attn_metadata
|
attn_metadata = forward_context.attn_metadata
|
||||||
if attn_metadata is None:
|
if attn_metadata is None:
|
||||||
return
|
return
|
||||||
|
assert isinstance(attn_metadata, dict)
|
||||||
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata)
|
connector.save_kv_layer(layer_name, kv_cache_layer,
|
||||||
|
attn_metadata[layer_name])
|
||||||
|
|
||||||
|
|
||||||
def unified_attention(
|
def unified_attention(
|
||||||
@ -374,6 +379,8 @@ def unified_attention(
|
|||||||
|
|
||||||
forward_context: ForwardContext = get_forward_context()
|
forward_context: ForwardContext = get_forward_context()
|
||||||
attn_metadata = forward_context.attn_metadata
|
attn_metadata = forward_context.attn_metadata
|
||||||
|
if isinstance(attn_metadata, dict):
|
||||||
|
attn_metadata = attn_metadata[layer_name]
|
||||||
self = forward_context.no_compile_layers[layer_name]
|
self = forward_context.no_compile_layers[layer_name]
|
||||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||||
output = self.impl.forward(self, query, key, value, kv_cache,
|
output = self.impl.forward(self, query, key, value, kv_cache,
|
||||||
@ -411,6 +418,8 @@ def unified_attention_with_output(
|
|||||||
wait_for_kv_layer_from_connector(layer_name)
|
wait_for_kv_layer_from_connector(layer_name)
|
||||||
forward_context: ForwardContext = get_forward_context()
|
forward_context: ForwardContext = get_forward_context()
|
||||||
attn_metadata = forward_context.attn_metadata
|
attn_metadata = forward_context.attn_metadata
|
||||||
|
if isinstance(attn_metadata, dict):
|
||||||
|
attn_metadata = attn_metadata[layer_name]
|
||||||
self = forward_context.no_compile_layers[layer_name]
|
self = forward_context.no_compile_layers[layer_name]
|
||||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||||
self.impl.forward(self,
|
self.impl.forward(self,
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import time
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -38,8 +38,13 @@ class DPMetadata:
|
|||||||
class ForwardContext:
|
class ForwardContext:
|
||||||
# copy from vllm_config.compilation_config.static_forward_context
|
# copy from vllm_config.compilation_config.static_forward_context
|
||||||
no_compile_layers: dict[str, Any]
|
no_compile_layers: dict[str, Any]
|
||||||
# TODO: extend to support per-layer dynamic forward context
|
"""
|
||||||
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
|
Type AttentionMetadata for v0,
|
||||||
|
Type Dict[str, AttentionMetadata] for v1, map from layer_name of each
|
||||||
|
attention layer to its attention metadata
|
||||||
|
set dynamically for each forward pass
|
||||||
|
"""
|
||||||
|
attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]]
|
||||||
# TODO: remove after making all virtual_engines share the same kv cache
|
# TODO: remove after making all virtual_engines share the same kv cache
|
||||||
virtual_engine: int # set dynamically for each forward pass
|
virtual_engine: int # set dynamically for each forward pass
|
||||||
# set dynamically for each forward pass
|
# set dynamically for each forward pass
|
||||||
|
|||||||
@ -18,6 +18,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
@ -309,13 +310,11 @@ class FlashAttentionMetadataBuilder:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
||||||
common_prefix_len: int):
|
common_prefix_len: int,
|
||||||
|
common_attn_metadata: CommonAttentionMetadata):
|
||||||
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
|
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
|
||||||
query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1]
|
query_start_loc = common_attn_metadata.query_start_loc
|
||||||
query_start_loc = query_start_loc_cpu.to(self.runner.device,
|
seq_lens = common_attn_metadata.seq_lens
|
||||||
non_blocking=True)
|
|
||||||
seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs]
|
|
||||||
seq_lens = seq_lens_cpu.to(self.runner.device, non_blocking=True)
|
|
||||||
block_table = (
|
block_table = (
|
||||||
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
|
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
|
||||||
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
||||||
|
|||||||
@ -18,6 +18,7 @@ from vllm.config import (VllmConfig, get_current_vllm_config,
|
|||||||
get_layers_from_vllm_config)
|
get_layers_from_vllm_config)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
||||||
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
@ -394,16 +395,15 @@ class FlashInferMetadataBuilder:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
||||||
common_prefix_len: int):
|
common_prefix_len: int,
|
||||||
|
common_attn_metadata: CommonAttentionMetadata):
|
||||||
assert self._num_decodes + self._num_prefills == num_reqs
|
assert self._num_decodes + self._num_prefills == num_reqs
|
||||||
assert (self._num_decode_tokens +
|
assert (self._num_decode_tokens +
|
||||||
self._num_prefill_tokens == num_actual_tokens)
|
self._num_prefill_tokens == num_actual_tokens)
|
||||||
page_size = self.runner.block_size
|
page_size = self.runner.block_size
|
||||||
device = self.runner.device
|
device = self.runner.device
|
||||||
qo_indptr = self.runner.query_start_loc_cpu[:num_reqs + 1].to(
|
qo_indptr = common_attn_metadata.query_start_loc
|
||||||
self.runner.device, non_blocking=True)
|
seq_lens = common_attn_metadata.seq_lens
|
||||||
seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(self.runner.device,
|
|
||||||
non_blocking=True)
|
|
||||||
block_table = (
|
block_table = (
|
||||||
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
|
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
|
||||||
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
||||||
|
|||||||
@ -207,6 +207,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import cdiv, round_down
|
from vllm.utils import cdiv, round_down
|
||||||
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||||
@ -451,7 +452,8 @@ class MLACommonMetadataBuilder(Generic[M]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
||||||
common_prefix_len: int) -> M:
|
common_prefix_len: int,
|
||||||
|
common_attn_metadata: CommonAttentionMetadata) -> M:
|
||||||
assert self._num_decodes + self._num_prefills == num_reqs
|
assert self._num_decodes + self._num_prefills == num_reqs
|
||||||
|
|
||||||
# Note(simon): be careful about the CPU <> GPU memory movement in this
|
# Note(simon): be careful about the CPU <> GPU memory movement in this
|
||||||
@ -460,15 +462,13 @@ class MLACommonMetadataBuilder(Generic[M]):
|
|||||||
device = self.runner.device
|
device = self.runner.device
|
||||||
block_table = (
|
block_table = (
|
||||||
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
|
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
|
||||||
query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to(
|
|
||||||
device, non_blocking=True)
|
|
||||||
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
||||||
device, non_blocking=True).long()
|
device, non_blocking=True).long()
|
||||||
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
|
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
|
||||||
device, non_blocking=True).long()
|
device, non_blocking=True).long()
|
||||||
|
|
||||||
seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs]
|
query_start_loc = common_attn_metadata.query_start_loc
|
||||||
seq_lens = seq_lens_cpu.to(device, non_blocking=True)
|
seq_lens = common_attn_metadata.seq_lens
|
||||||
|
|
||||||
prefill_metadata = None
|
prefill_metadata = None
|
||||||
if self._num_prefills > 0:
|
if self._num_prefills > 0:
|
||||||
|
|||||||
18
vllm/v1/attention/backends/utils.py
Normal file
18
vllm/v1/attention/backends/utils.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CommonAttentionMetadata:
|
||||||
|
"""
|
||||||
|
Attention metadata attributes that can be shared by layers in different KV
|
||||||
|
cache groups and thus having different block table.
|
||||||
|
"""
|
||||||
|
|
||||||
|
query_start_loc: torch.Tensor
|
||||||
|
"""(batch_size + 1,), the start location of each request in query Tensor"""
|
||||||
|
seq_lens: torch.Tensor
|
||||||
|
"""(batch_size,), the length of each request including both computed tokens
|
||||||
|
and newly scheduled tokens"""
|
||||||
@ -2,7 +2,9 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.config import CompilationLevel, VllmConfig, set_current_vllm_config
|
from vllm.attention.layer import Attention
|
||||||
|
from vllm.config import (CompilationLevel, VllmConfig,
|
||||||
|
get_layers_from_vllm_config, set_current_vllm_config)
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.model_loader.loader import get_model_loader
|
from vllm.model_executor.model_loader.loader import get_model_loader
|
||||||
@ -276,6 +278,8 @@ class EagleProposer:
|
|||||||
loader = get_model_loader(self.vllm_config.load_config)
|
loader = get_model_loader(self.vllm_config.load_config)
|
||||||
target_layer_num = self.vllm_config.model_config.get_num_layers(
|
target_layer_num = self.vllm_config.model_config.get_num_layers(
|
||||||
self.vllm_config.parallel_config)
|
self.vllm_config.parallel_config)
|
||||||
|
target_attn_layer_names = set(
|
||||||
|
get_layers_from_vllm_config(self.vllm_config, Attention).keys())
|
||||||
|
|
||||||
draft_model_config = \
|
draft_model_config = \
|
||||||
self.vllm_config.speculative_config.draft_model_config
|
self.vllm_config.speculative_config.draft_model_config
|
||||||
@ -292,6 +296,11 @@ class EagleProposer:
|
|||||||
vllm_config=self.vllm_config,
|
vllm_config=self.vllm_config,
|
||||||
start_layer_id=target_layer_num).to(target_device)
|
start_layer_id=target_layer_num).to(target_device)
|
||||||
|
|
||||||
|
draft_attn_layer_names = (
|
||||||
|
get_layers_from_vllm_config(self.vllm_config, Attention).keys() -
|
||||||
|
target_attn_layer_names)
|
||||||
|
assert len(draft_attn_layer_names) == 1
|
||||||
|
self.attn_layer_name = next(iter(draft_attn_layer_names))
|
||||||
loaded_weights = self.model.load_weights(
|
loaded_weights = self.model.load_weights(
|
||||||
loader.get_all_weights(draft_model_config, self.model))
|
loader.get_all_weights(draft_model_config, self.model))
|
||||||
if self.vllm_config.speculative_config.method == "eagle3":
|
if self.vllm_config.speculative_config.method == "eagle3":
|
||||||
|
|||||||
@ -30,6 +30,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
|||||||
GiB_bytes, LayerBlockType, LazyLoader, cdiv,
|
GiB_bytes, LayerBlockType, LazyLoader, cdiv,
|
||||||
check_use_alibi, is_pin_memory_available)
|
check_use_alibi, is_pin_memory_available)
|
||||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||||
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||||
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
|
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
|
||||||
KVCacheConfig, KVCacheSpec,
|
KVCacheConfig, KVCacheSpec,
|
||||||
@ -157,9 +158,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# Sampler
|
# Sampler
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
|
|
||||||
# Lazy initialization
|
# Lazy initializations
|
||||||
# self.model: nn.Module # Set after load_model
|
# self.model: nn.Module # Set after load_model
|
||||||
|
# Initialize in initialize_kv_cache
|
||||||
self.kv_caches: list[torch.Tensor] = []
|
self.kv_caches: list[torch.Tensor] = []
|
||||||
|
# self.kv_cache_config: KVCacheConfig
|
||||||
|
|
||||||
# req_id -> (input_id -> encoder_output)
|
# req_id -> (input_id -> encoder_output)
|
||||||
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
|
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
|
||||||
|
|
||||||
@ -488,7 +492,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
def _prepare_inputs(
|
def _prepare_inputs(
|
||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
) -> tuple[FlashAttentionMetadata, torch.Tensor,
|
) -> tuple[dict[str, FlashAttentionMetadata], torch.Tensor,
|
||||||
Optional[SpecDecodeMetadata]]:
|
Optional[SpecDecodeMetadata]]:
|
||||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
assert total_num_scheduled_tokens > 0
|
assert total_num_scheduled_tokens > 0
|
||||||
@ -585,20 +589,39 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.positions_cpu[:total_num_scheduled_tokens],
|
self.positions_cpu[:total_num_scheduled_tokens],
|
||||||
non_blocking=True)
|
non_blocking=True)
|
||||||
|
|
||||||
# Prepare for cascade attention if enabled & beneficial.
|
query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to(
|
||||||
common_prefix_len = 0
|
self.device, non_blocking=True)
|
||||||
if self.cascade_attn_enabled:
|
seq_lens = self.seq_lens_cpu[:num_reqs].to(self.device,
|
||||||
common_prefix_len = self._compute_cascade_attn_prefix_len(
|
non_blocking=True)
|
||||||
num_scheduled_tokens,
|
common_attn_metadata = CommonAttentionMetadata(
|
||||||
scheduler_output.num_common_prefix_blocks,
|
query_start_loc=query_start_loc, seq_lens=seq_lens)
|
||||||
)
|
|
||||||
|
|
||||||
attn_metadata = self.attn_metadata_builder.build(
|
attn_metadata: dict[str, FlashAttentionMetadata] = {}
|
||||||
num_reqs=num_reqs,
|
# Prepare the attention metadata for each KV cache group and make layers
|
||||||
num_actual_tokens=total_num_scheduled_tokens,
|
# in the same group share the same metadata.
|
||||||
max_query_len=max_num_scheduled_tokens,
|
# NOTE(Chen): there is exactly one KV cache group that contains all
|
||||||
common_prefix_len=common_prefix_len,
|
# attetnion layers in the model for now, so the current logic for
|
||||||
)
|
# getting attn_metadata is not related to kv_cache_group information.
|
||||||
|
# Will extend this part to support multiple KV cache groups later.
|
||||||
|
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||||
|
self.kv_cache_config.kv_cache_groups):
|
||||||
|
|
||||||
|
# Prepare for cascade attention if enabled & beneficial.
|
||||||
|
common_prefix_len = 0
|
||||||
|
if self.cascade_attn_enabled:
|
||||||
|
common_prefix_len = self._compute_cascade_attn_prefix_len(
|
||||||
|
num_scheduled_tokens,
|
||||||
|
scheduler_output.num_common_prefix_blocks,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_metadata_i = self.attn_metadata_builder.build(
|
||||||
|
num_reqs=num_reqs,
|
||||||
|
num_actual_tokens=total_num_scheduled_tokens,
|
||||||
|
max_query_len=max_num_scheduled_tokens,
|
||||||
|
common_prefix_len=common_prefix_len,
|
||||||
|
common_attn_metadata=common_attn_metadata)
|
||||||
|
for layer_name in kv_cache_group_spec.layer_names:
|
||||||
|
attn_metadata[layer_name] = attn_metadata_i
|
||||||
|
|
||||||
use_spec_decode = len(
|
use_spec_decode = len(
|
||||||
scheduler_output.scheduled_spec_decode_tokens) > 0
|
scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||||
@ -608,7 +631,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# from these partial requests, we do so for simplicity.
|
# from these partial requests, we do so for simplicity.
|
||||||
# We will ignore the sampled tokens from the partial requests.
|
# We will ignore the sampled tokens from the partial requests.
|
||||||
# TODO: Support prompt logprobs.
|
# TODO: Support prompt logprobs.
|
||||||
logits_indices = attn_metadata.query_start_loc[1:] - 1
|
logits_indices = query_start_loc[1:] - 1
|
||||||
spec_decode_metadata = None
|
spec_decode_metadata = None
|
||||||
else:
|
else:
|
||||||
# Get the number of draft tokens for each request.
|
# Get the number of draft tokens for each request.
|
||||||
@ -1230,6 +1253,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
next_token_ids = torch.tensor(next_token_ids,
|
next_token_ids = torch.tensor(next_token_ids,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
|
eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name]
|
||||||
|
|
||||||
if spec_decode_metadata is None:
|
if spec_decode_metadata is None:
|
||||||
# input_ids can be None for multimodal models.
|
# input_ids can be None for multimodal models.
|
||||||
@ -1241,8 +1265,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
dim=-1)
|
dim=-1)
|
||||||
else:
|
else:
|
||||||
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
||||||
target_slot_mapping = attn_metadata.slot_mapping
|
target_slot_mapping = eagle_attn_metadata.slot_mapping
|
||||||
cu_num_tokens = attn_metadata.query_start_loc
|
cu_num_tokens = eagle_attn_metadata.query_start_loc
|
||||||
else:
|
else:
|
||||||
# TODO(woosuk): Refactor this.
|
# TODO(woosuk): Refactor this.
|
||||||
num_draft_tokens = spec_decode_metadata.num_draft_tokens
|
num_draft_tokens = spec_decode_metadata.num_draft_tokens
|
||||||
@ -1256,7 +1280,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
|
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
|
||||||
attn_metadata.query_start_loc,
|
eagle_attn_metadata.query_start_loc,
|
||||||
num_rejected_tokens,
|
num_rejected_tokens,
|
||||||
)
|
)
|
||||||
target_token_ids = self.input_ids[token_indices]
|
target_token_ids = self.input_ids[token_indices]
|
||||||
@ -1266,7 +1290,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
[h[token_indices] for h in aux_hidden_states], dim=-1)
|
[h[token_indices] for h in aux_hidden_states], dim=-1)
|
||||||
else:
|
else:
|
||||||
target_hidden_states = hidden_states[token_indices]
|
target_hidden_states = hidden_states[token_indices]
|
||||||
target_slot_mapping = attn_metadata.slot_mapping[token_indices]
|
target_slot_mapping = eagle_attn_metadata.slot_mapping[
|
||||||
|
token_indices]
|
||||||
|
|
||||||
draft_token_ids = self.drafter.propose(
|
draft_token_ids = self.drafter.propose(
|
||||||
target_token_ids=target_token_ids,
|
target_token_ids=target_token_ids,
|
||||||
@ -1275,7 +1300,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
target_slot_mapping=target_slot_mapping,
|
target_slot_mapping=target_slot_mapping,
|
||||||
next_token_ids=next_token_ids,
|
next_token_ids=next_token_ids,
|
||||||
cu_num_tokens=cu_num_tokens,
|
cu_num_tokens=cu_num_tokens,
|
||||||
block_table=attn_metadata.block_table,
|
block_table=eagle_attn_metadata.block_table,
|
||||||
sampling_metadata=sampling_metadata,
|
sampling_metadata=sampling_metadata,
|
||||||
)
|
)
|
||||||
spec_token_ids = draft_token_ids.tolist()
|
spec_token_ids = draft_token_ids.tolist()
|
||||||
@ -1708,6 +1733,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Hybrid models with more than one KV cache type are not "
|
"Hybrid models with more than one KV cache type are not "
|
||||||
"supported yet.")
|
"supported yet.")
|
||||||
|
self.kv_cache_config = kv_cache_config
|
||||||
|
|
||||||
kv_caches: dict[str, torch.Tensor] = {}
|
kv_caches: dict[str, torch.Tensor] = {}
|
||||||
|
|
||||||
|
|||||||
@ -588,7 +588,14 @@ class TPUModelRunner:
|
|||||||
# Padded to avoid recompiling when `num_reqs` varies.
|
# Padded to avoid recompiling when `num_reqs` varies.
|
||||||
logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1
|
logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1
|
||||||
logits_indices = logits_indices.to(self.device)
|
logits_indices = logits_indices.to(self.device)
|
||||||
return attn_metadata, logits_indices, padded_num_reqs
|
|
||||||
|
layer_names = get_layers_from_vllm_config(self.vllm_config,
|
||||||
|
Attention).keys()
|
||||||
|
per_layer_attn_metadata = {
|
||||||
|
layer_name: attn_metadata
|
||||||
|
for layer_name in layer_names
|
||||||
|
}
|
||||||
|
return per_layer_attn_metadata, logits_indices, padded_num_reqs
|
||||||
|
|
||||||
def _scatter_placeholders(
|
def _scatter_placeholders(
|
||||||
self,
|
self,
|
||||||
@ -956,7 +963,14 @@ class TPUModelRunner:
|
|||||||
torch._dynamo.mark_dynamic(position_ids, 0)
|
torch._dynamo.mark_dynamic(position_ids, 0)
|
||||||
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
|
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
|
||||||
|
|
||||||
with set_forward_context(attn_metadata, self.vllm_config, 0):
|
layer_names = get_layers_from_vllm_config(self.vllm_config,
|
||||||
|
Attention).keys()
|
||||||
|
per_layer_attn_metadata = {
|
||||||
|
layer_name: attn_metadata
|
||||||
|
for layer_name in layer_names
|
||||||
|
}
|
||||||
|
|
||||||
|
with set_forward_context(per_layer_attn_metadata, self.vllm_config, 0):
|
||||||
out = self.model(input_ids=input_ids,
|
out = self.model(input_ids=input_ids,
|
||||||
positions=position_ids,
|
positions=position_ids,
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user