mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-11 22:13:42 +08:00
[V0 Deprecation] Remove unused classes in attention (#25541)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
This commit is contained in:
parent
8c853050e7
commit
e6750d0b18
@ -2,9 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||||
AttentionMetadata,
|
AttentionMetadata, AttentionType)
|
||||||
AttentionMetadataBuilder,
|
|
||||||
AttentionState, AttentionType)
|
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.attention.selector import get_attn_backend
|
from vllm.attention.selector import get_attn_backend
|
||||||
|
|
||||||
@ -13,7 +11,5 @@ __all__ = [
|
|||||||
"AttentionBackend",
|
"AttentionBackend",
|
||||||
"AttentionMetadata",
|
"AttentionMetadata",
|
||||||
"AttentionType",
|
"AttentionType",
|
||||||
"AttentionMetadataBuilder",
|
|
||||||
"AttentionState",
|
|
||||||
"get_attn_backend",
|
"get_attn_backend",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -2,10 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from contextlib import contextmanager
|
from typing import Generic, List, Optional, Protocol, Tuple, Type, TypeVar
|
||||||
from dataclasses import dataclass, fields
|
|
||||||
from typing import (Any, Dict, Generic, List, Optional, Protocol, Set, Tuple,
|
|
||||||
Type, TypeVar)
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -49,18 +46,13 @@ class AttentionBackend(ABC):
|
|||||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@abstractmethod
|
|
||||||
def get_state_cls() -> Type["AttentionState"]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
|
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
|
||||||
return cls.get_metadata_cls()(*args, **kwargs)
|
return cls.get_metadata_cls()(*args, **kwargs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
|
def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -77,149 +69,18 @@ class AttentionBackend(ABC):
|
|||||||
def get_kv_cache_stride_order() -> Tuple[int, ...]:
|
def get_kv_cache_stride_order() -> Tuple[int, ...]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@abstractmethod
|
|
||||||
def swap_blocks(
|
|
||||||
src_kv_cache: torch.Tensor,
|
|
||||||
dst_kv_cache: torch.Tensor,
|
|
||||||
src_to_dst: torch.Tensor,
|
|
||||||
) -> None:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@abstractmethod
|
|
||||||
def copy_blocks(
|
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
src_to_dists: torch.Tensor,
|
|
||||||
) -> None:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def full_cls_name(cls) -> tuple[str, str]:
|
def full_cls_name(cls) -> tuple[str, str]:
|
||||||
return (cls.__module__, cls.__qualname__)
|
return (cls.__module__, cls.__qualname__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AttentionMetadata:
|
class AttentionMetadata:
|
||||||
"""Attention metadata for prefill and decode batched together."""
|
pass
|
||||||
# Total number of prefill requests.
|
|
||||||
num_prefills: int
|
|
||||||
# Number of prefill tokens.
|
|
||||||
num_prefill_tokens: int
|
|
||||||
# Number of decode tokens. Note that it is equivalent to the number of
|
|
||||||
# decode requests.
|
|
||||||
num_decode_tokens: int
|
|
||||||
# (num_tokens,). The indices of the token slots that input tokens will be
|
|
||||||
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
|
|
||||||
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
|
|
||||||
# in block 0, and 1st slot in block 1, respectively.
|
|
||||||
slot_mapping: torch.Tensor
|
|
||||||
|
|
||||||
# Enable/disable KV scales calculation. This is so that we can disable the
|
|
||||||
# calculation until after prefill and cuda graph capture.
|
|
||||||
enable_kv_scales_calculation: bool
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def prefill_metadata(self) -> Optional["AttentionMetadata"]:
|
|
||||||
"""Return the attention metadata that's required to run prefill
|
|
||||||
attention."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def decode_metadata(self) -> Optional["AttentionMetadata"]:
|
|
||||||
"""Return the attention metadata that's required to run decode
|
|
||||||
attention."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def asdict_zerocopy(self,
|
|
||||||
skip_fields: Optional[Set[str]] = None
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Similar to dataclasses.asdict, but avoids deepcopying."""
|
|
||||||
if skip_fields is None:
|
|
||||||
skip_fields = set()
|
|
||||||
# Note that if we add dataclasses as fields, they will need
|
|
||||||
# similar handling.
|
|
||||||
return {
|
|
||||||
field.name: getattr(self, field.name)
|
|
||||||
for field in fields(self) if field.name not in skip_fields
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T", bound=AttentionMetadata)
|
T = TypeVar("T", bound=AttentionMetadata)
|
||||||
|
|
||||||
|
|
||||||
class AttentionState(ABC, Generic[T]):
|
|
||||||
"""Holds attention backend-specific objects reused during the
|
|
||||||
lifetime of the model runner."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def __init__(self, runner: Any):
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
@contextmanager
|
|
||||||
def graph_capture(self, max_batch_size: int):
|
|
||||||
"""Context manager used when capturing CUDA graphs."""
|
|
||||||
yield
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def graph_clone(self, batch_size: int) -> "AttentionState[T]":
|
|
||||||
"""Clone attention state to save in CUDA graph metadata."""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def graph_capture_get_metadata_for_batch(
|
|
||||||
self,
|
|
||||||
batch_size: int,
|
|
||||||
is_encoder_decoder_model: bool = False) -> T:
|
|
||||||
"""Get attention metadata for CUDA graph capture of batch_size."""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_graph_input_buffers(
|
|
||||||
self,
|
|
||||||
attn_metadata: T,
|
|
||||||
is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
|
|
||||||
"""Get attention-specific input buffers for CUDA graph capture."""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def prepare_graph_input_buffers(
|
|
||||||
self,
|
|
||||||
input_buffers: Dict[str, Any],
|
|
||||||
attn_metadata: T,
|
|
||||||
is_encoder_decoder_model: bool = False) -> None:
|
|
||||||
"""In-place modify input buffers dict for CUDA graph replay."""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def begin_forward(self, model_input) -> None:
|
|
||||||
"""Prepare state for forward pass."""
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class AttentionMetadataBuilder(ABC, Generic[T]):
|
|
||||||
"""Abstract class for attention metadata builders."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def __init__(self, input_builder) -> None:
|
|
||||||
"""Create the builder, remember some configuration and parameters."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def prepare(self) -> None:
|
|
||||||
"""Prepare for one batch."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
|
||||||
cuda_graph_pad_size: int, batch_size: int) -> T:
|
|
||||||
"""Build attention metadata with on-device tensors."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class AttentionLayer(Protocol):
|
class AttentionLayer(Protocol):
|
||||||
|
|
||||||
_q_scale: torch.Tensor
|
_q_scale: torch.Tensor
|
||||||
|
|||||||
@ -1,559 +1,16 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""Attention backend utils"""
|
"""Attention backend utils"""
|
||||||
from contextlib import contextmanager
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from itertools import accumulate
|
from typing import Optional
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
|
|
||||||
AttentionState)
|
|
||||||
from vllm.attention.backends.abstract import AttentionType
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
PAD_SLOT_ID = -1
|
PAD_SLOT_ID = -1
|
||||||
|
|
||||||
# Switch to numpy implementation of compute_slot_mapping
|
|
||||||
# if we have at least this many elements. Could be tuned further.
|
|
||||||
_COMPUTE_SLOT_MAPPING_NUMPY_NUMEL = 256
|
|
||||||
|
|
||||||
|
|
||||||
def is_block_tables_empty(block_tables: Union[None, Dict]):
|
|
||||||
"""
|
|
||||||
Check if block_tables is None or a dictionary with all None values.
|
|
||||||
"""
|
|
||||||
if block_tables is None:
|
|
||||||
return True
|
|
||||||
return (isinstance(block_tables, dict)
|
|
||||||
and all(value is None for value in block_tables.values()))
|
|
||||||
|
|
||||||
|
|
||||||
def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int,
|
|
||||||
context_len: int, sliding_window: int):
|
|
||||||
"""
|
|
||||||
Compute the start index of slot mapping.
|
|
||||||
"""
|
|
||||||
start_idx = 0
|
|
||||||
if is_prompt and sliding_window is not None:
|
|
||||||
start_idx = max(0, query_len - sliding_window)
|
|
||||||
return start_idx
|
|
||||||
|
|
||||||
|
|
||||||
def _compute_slot_mapping_python(slot_mapping: List[int],
|
|
||||||
block_table: List[int], range_start: int,
|
|
||||||
range_end: int, block_size: int):
|
|
||||||
for i in range(range_start, range_end):
|
|
||||||
block_number = block_table[i // block_size]
|
|
||||||
block_offset = i % block_size
|
|
||||||
slot = block_number * block_size + block_offset
|
|
||||||
slot_mapping.append(slot)
|
|
||||||
|
|
||||||
|
|
||||||
def _compute_slot_mapping_numpy(slot_mapping: List[int],
|
|
||||||
block_table: List[int], range_start: int,
|
|
||||||
range_end: int, block_size: int):
|
|
||||||
block_table_array = np.array(block_table)
|
|
||||||
idx = np.arange(range_start, range_end)
|
|
||||||
block_offset = idx % block_size
|
|
||||||
idx //= block_size
|
|
||||||
seq_slot_mapping_array = block_table_array[idx]
|
|
||||||
seq_slot_mapping_array *= block_size
|
|
||||||
seq_slot_mapping_array += block_offset
|
|
||||||
slot_mapping.extend(seq_slot_mapping_array)
|
|
||||||
|
|
||||||
|
|
||||||
def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int],
|
|
||||||
seq_id: int, seq_len: int, context_len: int,
|
|
||||||
start_idx: int, block_size: int,
|
|
||||||
block_tables: Dict[int, List[int]]):
|
|
||||||
"""
|
|
||||||
Compute slot mapping.
|
|
||||||
"""
|
|
||||||
if is_profile_run:
|
|
||||||
# During memory profiling, the block tables are not
|
|
||||||
# initialized yet. In this case, we just use a dummy
|
|
||||||
# slot mapping.
|
|
||||||
# In embeddings, the block tables are {seq_id: None}.
|
|
||||||
slot_mapping.extend([PAD_SLOT_ID] * seq_len)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Mask the [0, start_idx) tokens of the prompt with
|
|
||||||
# PAD_SLOT_ID, where start_idx is max(0, seq_len -
|
|
||||||
# sliding_window). For example, if the prompt len is 10,
|
|
||||||
# sliding window is 8, and block size is 4, the first two
|
|
||||||
# tokens are masked and the slot mapping will be
|
|
||||||
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
|
|
||||||
padding_mask_len = max(0, start_idx - context_len)
|
|
||||||
slot_mapping.extend([PAD_SLOT_ID] * padding_mask_len)
|
|
||||||
|
|
||||||
range_start = max(start_idx, context_len)
|
|
||||||
range_end = seq_len
|
|
||||||
numel = range_end - range_start
|
|
||||||
block_table = block_tables[seq_id]
|
|
||||||
|
|
||||||
# numpy implementation will be faster than python if we have
|
|
||||||
# many elements, otherwise it will be slower.
|
|
||||||
if numel < _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL:
|
|
||||||
_compute_slot_mapping_python(slot_mapping, block_table, range_start,
|
|
||||||
range_end, block_size)
|
|
||||||
else:
|
|
||||||
_compute_slot_mapping_numpy(slot_mapping, block_table, range_start,
|
|
||||||
range_end, block_size)
|
|
||||||
|
|
||||||
|
|
||||||
TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata')
|
|
||||||
|
|
||||||
|
|
||||||
class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
|
||||||
|
|
||||||
_metadata_cls: Type[TAttentionMetadata]
|
|
||||||
|
|
||||||
def __init__(self, input_builder):
|
|
||||||
self.input_builder = input_builder
|
|
||||||
self.runner = input_builder.runner
|
|
||||||
|
|
||||||
self.sliding_window = input_builder.sliding_window
|
|
||||||
self.block_size = input_builder.block_size
|
|
||||||
|
|
||||||
def prepare(self):
|
|
||||||
self.slot_mapping: List[int] = []
|
|
||||||
self.prefill_seq_lens: List[int] = []
|
|
||||||
self.context_lens: List[int] = []
|
|
||||||
self.block_tables: List[List[int]] = []
|
|
||||||
self.curr_seq_lens: List[int] = []
|
|
||||||
self.num_prefills = 0
|
|
||||||
self.num_prefill_tokens = 0
|
|
||||||
self.num_decode_tokens = 0
|
|
||||||
|
|
||||||
def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool):
|
|
||||||
is_prompt = inter_data.is_prompt
|
|
||||||
block_tables = inter_data.block_tables
|
|
||||||
|
|
||||||
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
|
||||||
curr_sliding_window_block) in zip(
|
|
||||||
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
|
||||||
inter_data.orig_seq_lens, inter_data.seq_lens,
|
|
||||||
inter_data.query_lens, inter_data.context_lens,
|
|
||||||
inter_data.curr_sliding_window_blocks):
|
|
||||||
self.context_lens.append(context_len)
|
|
||||||
if is_prompt:
|
|
||||||
self.num_prefills += 1
|
|
||||||
self.num_prefill_tokens += token_len
|
|
||||||
self.prefill_seq_lens.append(seq_len)
|
|
||||||
else:
|
|
||||||
assert query_len == 1, (
|
|
||||||
"seq_len: {}, context_len: {}, query_len: {}".format(
|
|
||||||
seq_len, context_len, query_len))
|
|
||||||
self.num_decode_tokens += query_len
|
|
||||||
self.curr_seq_lens.append(curr_seq_len)
|
|
||||||
|
|
||||||
# Compute block table.
|
|
||||||
# TODO(sang): Combine chunked prefill and prefix caching by
|
|
||||||
# only allowing multiple of block_size chunk size.
|
|
||||||
# NOTE: This only works for oooooooxxx style attention.
|
|
||||||
block_table = []
|
|
||||||
if inter_data.prefix_cache_hit:
|
|
||||||
block_table = block_tables[seq_id]
|
|
||||||
elif ((chunked_prefill_enabled or not is_prompt)
|
|
||||||
and block_tables is not None):
|
|
||||||
if curr_sliding_window_block == 0:
|
|
||||||
block_table = block_tables[seq_id]
|
|
||||||
else:
|
|
||||||
block_table = block_tables[seq_id][
|
|
||||||
-curr_sliding_window_block:]
|
|
||||||
self.block_tables.append(block_table)
|
|
||||||
|
|
||||||
# Compute slot mapping.
|
|
||||||
is_profile_run = is_block_tables_empty(block_tables)
|
|
||||||
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
|
|
||||||
context_len,
|
|
||||||
self.sliding_window)
|
|
||||||
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
|
|
||||||
seq_len, context_len, start_idx,
|
|
||||||
self.block_size, inter_data.block_tables)
|
|
||||||
|
|
||||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
|
||||||
cuda_graph_pad_size: int, batch_size: int):
|
|
||||||
"""Build attention metadata with on-device tensors.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seq_lens: The maybe padded sequence lengths of the input sequences.
|
|
||||||
query_lens: The query lengths of the input sequences.
|
|
||||||
cuda_graph_pad_size: The padding size for cuda graph.
|
|
||||||
-1 if cuda graph is not used.
|
|
||||||
batch_size: The maybe padded batch size.
|
|
||||||
"""
|
|
||||||
for inter_data in self.input_builder.inter_data_list:
|
|
||||||
self._add_seq_group(inter_data,
|
|
||||||
self.input_builder.chunked_prefill_enabled)
|
|
||||||
|
|
||||||
device = self.runner.device
|
|
||||||
use_captured_graph = cuda_graph_pad_size != -1
|
|
||||||
|
|
||||||
max_query_len = max(query_lens)
|
|
||||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
|
||||||
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
|
||||||
num_decode_tokens = self.num_decode_tokens
|
|
||||||
query_start_loc = list(accumulate(query_lens, initial=0))
|
|
||||||
seq_start_loc = list(accumulate(seq_lens, initial=0))
|
|
||||||
|
|
||||||
if use_captured_graph:
|
|
||||||
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
|
|
||||||
self.block_tables.extend([] * cuda_graph_pad_size)
|
|
||||||
num_decode_tokens = batch_size
|
|
||||||
|
|
||||||
# The shape of graph_block_tables is
|
|
||||||
# [max batch size, max context len // block size].
|
|
||||||
input_block_tables = self.runner.graph_block_tables[:batch_size]
|
|
||||||
for i, block_table in enumerate(self.block_tables):
|
|
||||||
if block_table:
|
|
||||||
input_block_tables[i, :len(block_table)] = block_table
|
|
||||||
block_tables = torch.from_numpy(input_block_tables).to(
|
|
||||||
device, non_blocking=True)
|
|
||||||
else:
|
|
||||||
block_tables = make_tensor_with_pad(
|
|
||||||
self.block_tables,
|
|
||||||
pad=0,
|
|
||||||
dtype=torch.int,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
assert max_query_len > 0, "query_lens: {}".format(query_lens)
|
|
||||||
|
|
||||||
assert device is not None
|
|
||||||
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
|
|
||||||
device, self.runner.pin_memory)
|
|
||||||
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
|
||||||
self.runner.pin_memory)
|
|
||||||
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
|
|
||||||
device, self.runner.pin_memory)
|
|
||||||
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
|
|
||||||
device,
|
|
||||||
self.runner.pin_memory)
|
|
||||||
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
|
||||||
device, self.runner.pin_memory)
|
|
||||||
|
|
||||||
return self._metadata_cls( # type: ignore
|
|
||||||
num_prefills=self.num_prefills,
|
|
||||||
slot_mapping=slot_mapping_tensor,
|
|
||||||
enable_kv_scales_calculation=True,
|
|
||||||
num_prefill_tokens=self.num_prefill_tokens,
|
|
||||||
num_decode_tokens=num_decode_tokens,
|
|
||||||
seq_lens=seq_lens,
|
|
||||||
seq_lens_tensor=seq_lens_tensor,
|
|
||||||
max_query_len=max_query_len,
|
|
||||||
max_prefill_seq_len=max_prefill_seq_len,
|
|
||||||
max_decode_seq_len=max_decode_seq_len,
|
|
||||||
query_start_loc=query_start_loc_tensor,
|
|
||||||
seq_start_loc=seq_start_loc_tensor,
|
|
||||||
context_lens_tensor=context_lens_tensor,
|
|
||||||
block_tables=block_tables,
|
|
||||||
use_cuda_graph=use_captured_graph,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CommonAttentionState(AttentionState):
|
|
||||||
|
|
||||||
def __init__(self, runner):
|
|
||||||
self.runner = runner
|
|
||||||
self._is_graph_capturing = False
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def graph_capture(self, max_batch_size: int):
|
|
||||||
|
|
||||||
self._is_graph_capturing = True
|
|
||||||
|
|
||||||
self._graph_slot_mapping = torch.full((max_batch_size, ),
|
|
||||||
PAD_SLOT_ID,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=self.runner.device)
|
|
||||||
self._graph_seq_lens = torch.ones(max_batch_size,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=self.runner.device)
|
|
||||||
self._graph_block_tables = torch.from_numpy(
|
|
||||||
self.runner.graph_block_tables).to(device=self.runner.device)
|
|
||||||
|
|
||||||
yield
|
|
||||||
|
|
||||||
self._is_graph_capturing = False
|
|
||||||
del self._graph_slot_mapping
|
|
||||||
del self._graph_seq_lens
|
|
||||||
del self._graph_block_tables
|
|
||||||
|
|
||||||
def graph_clone(self, batch_size: int) -> "CommonAttentionState":
|
|
||||||
assert self._is_graph_capturing
|
|
||||||
return self.__class__(self.runner)
|
|
||||||
|
|
||||||
def graph_capture_get_metadata_for_batch(
|
|
||||||
self, batch_size: int, is_encoder_decoder_model: bool = False):
|
|
||||||
assert self._is_graph_capturing
|
|
||||||
attn_metadata = self.runner.attn_backend.make_metadata(
|
|
||||||
num_prefills=0,
|
|
||||||
num_prefill_tokens=0,
|
|
||||||
num_decode_tokens=batch_size,
|
|
||||||
slot_mapping=self._graph_slot_mapping[:batch_size],
|
|
||||||
enable_kv_scales_calculation=True,
|
|
||||||
seq_lens=None,
|
|
||||||
seq_lens_tensor=self._graph_seq_lens[:batch_size],
|
|
||||||
max_query_len=1,
|
|
||||||
max_decode_query_len=1,
|
|
||||||
max_prefill_seq_len=0,
|
|
||||||
max_decode_seq_len=self.runner.max_model_len,
|
|
||||||
query_start_loc=None,
|
|
||||||
seq_start_loc=None,
|
|
||||||
context_lens_tensor=None,
|
|
||||||
block_tables=self._graph_block_tables[:batch_size],
|
|
||||||
use_cuda_graph=True,
|
|
||||||
)
|
|
||||||
if is_encoder_decoder_model:
|
|
||||||
# The encoder decoder model works only with XFormers and
|
|
||||||
# Flash Attention backend. Assert the same.
|
|
||||||
assert self.runner.attn_backend.get_name() in \
|
|
||||||
["XFORMERS", "FLASH_ATTN"], \
|
|
||||||
f"Expected attn_backend name to be either 'XFORMERS' or " \
|
|
||||||
f"'FLASH_ATTN', but got '{self.runner.attn_backend.get_name()}'"
|
|
||||||
self._update_captured_metadata_for_enc_dec_model(
|
|
||||||
batch_size=batch_size, attn_metadata=attn_metadata)
|
|
||||||
|
|
||||||
return attn_metadata
|
|
||||||
|
|
||||||
def get_graph_input_buffers(
|
|
||||||
self,
|
|
||||||
attn_metadata,
|
|
||||||
is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
|
|
||||||
input_buffers = {
|
|
||||||
"slot_mapping": attn_metadata.slot_mapping,
|
|
||||||
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
|
|
||||||
"block_tables": attn_metadata.decode_metadata.block_tables,
|
|
||||||
}
|
|
||||||
if is_encoder_decoder_model:
|
|
||||||
# The encoder decoder model works only with XFormers and
|
|
||||||
# Flash Attention backend. Assert the same.
|
|
||||||
assert self.runner.attn_backend.get_name() in \
|
|
||||||
["XFORMERS", "FLASH_ATTN"], \
|
|
||||||
f"Expected attn_backend name to be either 'XFORMERS' or " \
|
|
||||||
f"'FLASH_ATTN', but got '{self.runner.attn_backend.get_name()}'"
|
|
||||||
self._add_additional_input_buffers_for_enc_dec_model(
|
|
||||||
attn_metadata=attn_metadata, input_buffers=input_buffers)
|
|
||||||
return input_buffers
|
|
||||||
|
|
||||||
def prepare_graph_input_buffers(
|
|
||||||
self,
|
|
||||||
input_buffers,
|
|
||||||
attn_metadata,
|
|
||||||
is_encoder_decoder_model: bool = False) -> None:
|
|
||||||
input_buffers["seq_lens_tensor"].copy_(
|
|
||||||
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
|
|
||||||
input_buffers["block_tables"].copy_(
|
|
||||||
attn_metadata.decode_metadata.block_tables, non_blocking=True)
|
|
||||||
if is_encoder_decoder_model:
|
|
||||||
# The encoder decoder model works only with XFormers and
|
|
||||||
# Flash Attention backend. Assert the same.
|
|
||||||
assert self.runner.attn_backend.get_name() in\
|
|
||||||
["XFORMERS", "FLASH_ATTN"], \
|
|
||||||
f"Expected attn_backend name to be either 'XFORMERS' or "\
|
|
||||||
f"'FLASH_ATTN', but "\
|
|
||||||
f"got '{self.runner.attn_backend.get_name()}'"
|
|
||||||
self._prepare_input_buffers_for_enc_dec_model(
|
|
||||||
attn_metadata, input_buffers)
|
|
||||||
|
|
||||||
def begin_forward(self, model_input) -> None:
|
|
||||||
return
|
|
||||||
|
|
||||||
def _update_captured_metadata_for_enc_dec_model(self, batch_size: int,
|
|
||||||
attn_metadata):
|
|
||||||
"""
|
|
||||||
Updates the attention metadata parameters for CUDA graph capture in an
|
|
||||||
encoder-decoder model.
|
|
||||||
|
|
||||||
This method modifies attention-related tensors and metadata required
|
|
||||||
for CUDA graph capture in encoder-decoder models. Specifically, it
|
|
||||||
updates the cross-attention and encoder sequence tensors in the
|
|
||||||
AttentionMetadata object.
|
|
||||||
"""
|
|
||||||
# During decode phase the cross_slot_mapping will be empty. Hence set
|
|
||||||
# an empty tensor for CUDA Graph capture.
|
|
||||||
attn_metadata.cross_slot_mapping = torch.tensor(
|
|
||||||
[], dtype=torch.int).cuda()
|
|
||||||
attn_metadata.cross_block_tables = torch.full(
|
|
||||||
(batch_size, self.runner.get_max_block_per_batch()),
|
|
||||||
1,
|
|
||||||
dtype=torch.int).cuda()
|
|
||||||
attn_metadata.encoder_seq_lens = torch.full((batch_size, ),
|
|
||||||
1,
|
|
||||||
dtype=torch.int).cuda()
|
|
||||||
attn_metadata.encoder_seq_lens_tensor = torch.full(
|
|
||||||
(batch_size, ), 1, dtype=torch.int).cuda()
|
|
||||||
attn_metadata.max_encoder_seq_len = self.runner.max_model_len
|
|
||||||
attn_metadata.num_encoder_tokens = 0
|
|
||||||
|
|
||||||
def _add_additional_input_buffers_for_enc_dec_model(
|
|
||||||
self, attn_metadata, input_buffers: Dict[str, Any]):
|
|
||||||
"""
|
|
||||||
Saves additional input buffers specific to the encoder-decoder model
|
|
||||||
from the attention metadata.
|
|
||||||
|
|
||||||
This method extracts and stores encoder-decoder related input buffers
|
|
||||||
from the `attn_metadata` into the `input_buffers` dictionary. The
|
|
||||||
buffers include encoder sequence lengths, cross-slot mappings, and
|
|
||||||
cross-block tables, which are essential for the encoder-decoder model
|
|
||||||
during CUDA graph replay.
|
|
||||||
"""
|
|
||||||
input_buffers["encoder_seq_lens_tensor"] = (
|
|
||||||
attn_metadata.decode_metadata.encoder_seq_lens_tensor)
|
|
||||||
input_buffers["cross_slot_mapping"] = (
|
|
||||||
attn_metadata.decode_metadata.cross_slot_mapping)
|
|
||||||
input_buffers["cross_block_tables"] = (
|
|
||||||
attn_metadata.decode_metadata.cross_block_tables)
|
|
||||||
|
|
||||||
def _prepare_input_buffers_for_enc_dec_model(self, attn_metadata,
|
|
||||||
input_buffers: Dict[str,
|
|
||||||
Any]):
|
|
||||||
"""
|
|
||||||
Populates input buffers with data from the encoder-decoder model's
|
|
||||||
attention metadata.
|
|
||||||
|
|
||||||
This method fills the input buffers with encoder-decoder specific
|
|
||||||
tensors. It copies data from the `attn_metadata` and keyword arguments
|
|
||||||
(`kwargs`) into corresponding buffers in the `input_buffers` dictionary.
|
|
||||||
The copied data includes attention-related metadata as well as input
|
|
||||||
IDs and positional information for the encoder.
|
|
||||||
"""
|
|
||||||
input_buffers["encoder_seq_lens_tensor"].copy_(
|
|
||||||
attn_metadata.decode_metadata.encoder_seq_lens_tensor,
|
|
||||||
non_blocking=True)
|
|
||||||
input_buffers["cross_slot_mapping"].copy_(
|
|
||||||
attn_metadata.decode_metadata.cross_slot_mapping,
|
|
||||||
non_blocking=True)
|
|
||||||
input_buffers["cross_block_tables"].copy_(
|
|
||||||
attn_metadata.decode_metadata.cross_block_tables,
|
|
||||||
non_blocking=True)
|
|
||||||
|
|
||||||
|
|
||||||
def is_all_encoder_attn_metadata_set(attn_metadata):
|
|
||||||
'''
|
|
||||||
All attention metadata required for encoder attention is set.
|
|
||||||
'''
|
|
||||||
return ((attn_metadata.encoder_seq_lens is not None)
|
|
||||||
and (attn_metadata.encoder_seq_lens_tensor is not None)
|
|
||||||
and (attn_metadata.max_encoder_seq_len is not None))
|
|
||||||
|
|
||||||
|
|
||||||
def is_all_cross_attn_metadata_set(attn_metadata):
|
|
||||||
'''
|
|
||||||
All attention metadata required for enc/dec cross-attention is set.
|
|
||||||
|
|
||||||
Superset of encoder attention required metadata.
|
|
||||||
'''
|
|
||||||
return (attn_metadata.is_all_encoder_attn_metadata_set
|
|
||||||
and (attn_metadata.cross_slot_mapping is not None)
|
|
||||||
and (attn_metadata.cross_block_tables is not None))
|
|
||||||
|
|
||||||
|
|
||||||
def get_seq_len_block_table_args(
|
|
||||||
attn_metadata,
|
|
||||||
is_prompt: bool,
|
|
||||||
attn_type: str,
|
|
||||||
) -> tuple:
|
|
||||||
'''
|
|
||||||
The particular choice of sequence-length- and block-table-related
|
|
||||||
attributes which should be extracted from attn_metadata is dependent
|
|
||||||
on the type of attention operation.
|
|
||||||
|
|
||||||
Decoder attn -> select entirely decoder self-attention-related fields
|
|
||||||
Encoder/decoder cross-attn -> select encoder sequence lengths &
|
|
||||||
cross-attn block-tables fields
|
|
||||||
Encoder attn -> select encoder sequence lengths fields & no block tables
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
|
|
||||||
* attn_metadata: Attention metadata structure associated with attention op
|
|
||||||
* is_prompt: True if prefill, False otherwise
|
|
||||||
* attn_type: encoder attention, decoder self-attention,
|
|
||||||
encoder/decoder cross-attention
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
* Appropriate sequence-lengths tensor
|
|
||||||
* Appropriate max sequence-length scalar
|
|
||||||
* Appropriate block tables (or None)
|
|
||||||
'''
|
|
||||||
|
|
||||||
if attn_type == AttentionType.DECODER:
|
|
||||||
# Decoder self-attention
|
|
||||||
# Choose max_seq_len based on whether we are in prompt_run
|
|
||||||
if is_prompt:
|
|
||||||
max_seq_len = attn_metadata.max_prefill_seq_len
|
|
||||||
else:
|
|
||||||
max_seq_len = attn_metadata.max_decode_seq_len
|
|
||||||
return (attn_metadata.seq_lens_tensor, max_seq_len,
|
|
||||||
attn_metadata.block_tables)
|
|
||||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
|
||||||
# Enc/dec cross-attention KVs match encoder sequence length;
|
|
||||||
# cross-attention utilizes special "cross" block tables
|
|
||||||
return (attn_metadata.encoder_seq_lens_tensor,
|
|
||||||
attn_metadata.max_encoder_seq_len,
|
|
||||||
attn_metadata.cross_block_tables)
|
|
||||||
elif attn_type == AttentionType.ENCODER:
|
|
||||||
# No block tables associated with encoder attention
|
|
||||||
return (attn_metadata.encoder_seq_lens_tensor,
|
|
||||||
attn_metadata.max_encoder_seq_len, None)
|
|
||||||
else:
|
|
||||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
|
||||||
|
|
||||||
|
|
||||||
def get_num_prefill_decode_query_kv_tokens(
|
|
||||||
attn_metadata,
|
|
||||||
attn_type: str,
|
|
||||||
) -> Tuple[int, int, int]:
|
|
||||||
"""
|
|
||||||
Calculate the number of prefill and decode tokens for query, key/value
|
|
||||||
based on the attention metadata and the specified attention type.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
attn_metadata (AttentionMetadata): Attention Metadata object.
|
|
||||||
attn_type (AttentionType): The type of attention being used.
|
|
||||||
Returns:
|
|
||||||
Tuple[int, int, int]: A tuple containing three integers:
|
|
||||||
- The number of prefill query tokens.
|
|
||||||
- The number of prefill key/value tokens.
|
|
||||||
- The number of decode query tokens.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
AssertionError: If the number of encoder tokens in `attn_metadata`
|
|
||||||
is `None` when required for the calculations.
|
|
||||||
"""
|
|
||||||
num_prefill_query_tokens = 0
|
|
||||||
num_decode_query_tokens = 0
|
|
||||||
num_prefill_kv_tokens = 0
|
|
||||||
if attn_type == AttentionType.ENCODER:
|
|
||||||
# Encoder attention is only invoked during prefill phase.
|
|
||||||
# The same input servers a both query and key.
|
|
||||||
assert attn_metadata.num_encoder_tokens is not None
|
|
||||||
num_prefill_query_tokens = attn_metadata.num_encoder_tokens
|
|
||||||
num_prefill_kv_tokens = attn_metadata.num_encoder_tokens
|
|
||||||
num_decode_query_tokens = 0
|
|
||||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
|
||||||
assert attn_metadata.num_encoder_tokens is not None
|
|
||||||
num_prefill_query_tokens = attn_metadata.num_prefill_tokens
|
|
||||||
# The key is the encoder/cross-attention.
|
|
||||||
num_prefill_kv_tokens = attn_metadata.num_encoder_tokens
|
|
||||||
num_decode_query_tokens = attn_metadata.num_decode_tokens
|
|
||||||
else: # attn_type == AttentionType.DECODER or
|
|
||||||
# attn_type == AttentionType.ENCODER_ONLY
|
|
||||||
num_prefill_query_tokens = attn_metadata.num_prefill_tokens
|
|
||||||
num_prefill_kv_tokens = attn_metadata.num_prefill_tokens
|
|
||||||
num_decode_query_tokens = attn_metadata.num_decode_tokens
|
|
||||||
|
|
||||||
return (num_prefill_query_tokens, num_prefill_kv_tokens,
|
|
||||||
num_decode_query_tokens)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MLADims:
|
class MLADims:
|
||||||
|
|||||||
@ -11,7 +11,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|||||||
AttentionLayer,
|
AttentionLayer,
|
||||||
AttentionMetadata, AttentionType,
|
AttentionMetadata, AttentionType,
|
||||||
is_quantized_kv_cache)
|
is_quantized_kv_cache)
|
||||||
from vllm.attention.backends.utils import CommonAttentionState
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||||
@ -65,10 +64,6 @@ class TorchSDPABackend(AttentionBackend):
|
|||||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||||
return TorchSDPAMetadata
|
return TorchSDPAMetadata
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_state_cls() -> type["CommonAttentionState"]:
|
|
||||||
return CommonAttentionState
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_builder_cls() -> type["TorchSDPAMetadataBuilderV1"]:
|
def get_builder_cls() -> type["TorchSDPAMetadataBuilderV1"]:
|
||||||
return TorchSDPAMetadataBuilderV1
|
return TorchSDPAMetadataBuilderV1
|
||||||
@ -835,16 +830,6 @@ class _PagedAttention:
|
|||||||
blocksparse_head_sliding_step,
|
blocksparse_head_sliding_step,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def copy_blocks(
|
|
||||||
kv_caches: list[torch.Tensor],
|
|
||||||
src_to_dists: torch.Tensor,
|
|
||||||
*args,
|
|
||||||
) -> None:
|
|
||||||
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
|
||||||
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
|
||||||
ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
|
||||||
|
|
||||||
|
|
||||||
class _IPEXPagedAttention(_PagedAttention):
|
class _IPEXPagedAttention(_PagedAttention):
|
||||||
|
|
||||||
|
|||||||
@ -8,7 +8,6 @@ import torch
|
|||||||
|
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionLayer, AttentionType)
|
AttentionLayer, AttentionType)
|
||||||
from vllm.attention.backends.utils import CommonAttentionState
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import cdiv, next_power_of_2
|
from vllm.utils import cdiv, next_power_of_2
|
||||||
@ -97,10 +96,6 @@ class PallasAttentionBackend(AttentionBackend):
|
|||||||
def get_metadata_cls() -> type["PallasMetadata"]:
|
def get_metadata_cls() -> type["PallasMetadata"]:
|
||||||
return PallasMetadata
|
return PallasMetadata
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_state_cls() -> type["CommonAttentionState"]:
|
|
||||||
return CommonAttentionState
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_kv_cache_shape(
|
def get_kv_cache_shape(
|
||||||
num_blocks: int,
|
num_blocks: int,
|
||||||
|
|||||||
@ -9,7 +9,6 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionMetadataBuilder
|
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.config import (CompilationLevel, VllmConfig,
|
from vllm.config import (CompilationLevel, VllmConfig,
|
||||||
get_layers_from_vllm_config)
|
get_layers_from_vllm_config)
|
||||||
@ -25,7 +24,8 @@ from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
|||||||
from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata,
|
from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata,
|
||||||
TreeAttentionMetadataBuilder)
|
TreeAttentionMetadataBuilder)
|
||||||
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
|
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
|
||||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||||
|
CommonAttentionMetadata)
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
@ -184,8 +184,9 @@ class EagleProposer:
|
|||||||
builder = (self._get_attention_metadata_builder()
|
builder = (self._get_attention_metadata_builder()
|
||||||
if self.attn_metadata_builder is None else
|
if self.attn_metadata_builder is None else
|
||||||
self.attn_metadata_builder)
|
self.attn_metadata_builder)
|
||||||
attn_metadata = builder.build_for_drafting(
|
attn_metadata = builder.build_for_drafting( # type: ignore
|
||||||
common_attn_metadata=common_attn_metadata, draft_index=0)
|
common_attn_metadata=common_attn_metadata,
|
||||||
|
draft_index=0)
|
||||||
|
|
||||||
# At this moment, we assume all eagle layers belong to the same KV
|
# At this moment, we assume all eagle layers belong to the same KV
|
||||||
# cache group, thus using the same attention metadata.
|
# cache group, thus using the same attention metadata.
|
||||||
@ -319,7 +320,7 @@ class EagleProposer:
|
|||||||
exceeds_max_model_len, PADDING_SLOT_ID)
|
exceeds_max_model_len, PADDING_SLOT_ID)
|
||||||
|
|
||||||
# Rebuild attention metadata
|
# Rebuild attention metadata
|
||||||
attn_metadata = builder.build_for_drafting(
|
attn_metadata = builder.build_for_drafting( # type: ignore
|
||||||
common_attn_metadata=common_attn_metadata,
|
common_attn_metadata=common_attn_metadata,
|
||||||
draft_index=token_index + 1)
|
draft_index=token_index + 1)
|
||||||
for layer_name in self.attn_layer_names:
|
for layer_name in self.attn_layer_names:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user