mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 23:55:49 +08:00
[Hardware][Intel-Gaudi] Support Automatic Prefix Caching on HPU (#17648)
Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai>
This commit is contained in:
parent
e515668edf
commit
843b222723
@ -57,16 +57,16 @@ class HPUAttentionBackend(AttentionBackend):
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: Dict[int, int],
|
||||
src_to_dsts: torch.Tensor,
|
||||
) -> None:
|
||||
HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dsts)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: Dict[int, List[int]],
|
||||
src_to_dsts: torch.Tensor,
|
||||
) -> None:
|
||||
HPUPagedAttention.copy_blocks(kv_caches, src_to_dists)
|
||||
HPUPagedAttention.copy_blocks(kv_caches, src_to_dsts)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -77,6 +77,7 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
|
||||
is_prompt: bool
|
||||
attn_bias: Optional[torch.Tensor]
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
context_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
|
||||
class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
||||
@ -198,8 +199,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
||||
key_cache = None
|
||||
value_cache = None
|
||||
if attn_metadata.is_prompt and self.attn_type \
|
||||
is not AttentionType.ENCODER_ONLY \
|
||||
and attn_metadata.block_list is None:
|
||||
is not AttentionType.ENCODER_ONLY:
|
||||
key = key.unflatten(0, (block_indices.size(0), -1))
|
||||
value = value.unflatten(0, (block_indices.size(0), -1))
|
||||
if kv_cache is not None and isinstance(kv_cache, tuple):
|
||||
@ -229,6 +229,9 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
||||
attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1))
|
||||
attn_bias.add_(position_bias)
|
||||
|
||||
block_list = attn_metadata.block_list if attn_metadata \
|
||||
and attn_metadata.block_list is not None else None
|
||||
|
||||
out = ops.prompt_attention(
|
||||
impl=self.prefill_impl,
|
||||
query=query.view(query_shape),
|
||||
@ -237,23 +240,25 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
||||
is_causal=True,
|
||||
attn_bias=attn_bias,
|
||||
valid_seq_lengths=attn_metadata.seq_lens_tensor,
|
||||
**self.common_attention_args())
|
||||
**self.common_attention_args(block_list, key_cache,
|
||||
value_cache))
|
||||
output = out.reshape(batch_size, seq_len, hidden_size)
|
||||
else:
|
||||
# Decoding run.
|
||||
output = HPUPagedAttention.forward_decode(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
block_list=attn_metadata.block_list,
|
||||
block_mapping=attn_metadata.block_mapping,
|
||||
block_bias=attn_metadata.attn_bias,
|
||||
block_groups=attn_metadata.block_groups,
|
||||
**self.common_attention_args())
|
||||
**self.common_attention_args(attn_metadata.block_list,
|
||||
key_cache, value_cache))
|
||||
# Reshape the output tensor.
|
||||
return output.view(batch_size, seq_len, hidden_size)
|
||||
|
||||
def common_attention_args(self):
|
||||
def common_attention_args(self,
|
||||
block_list=None,
|
||||
key_cache=None,
|
||||
value_cache=None):
|
||||
fsdpa_op = self.fused_scaled_dot_product_attention.apply \
|
||||
if self.fused_scaled_dot_product_attention is not None else None
|
||||
return {
|
||||
@ -266,6 +271,9 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
||||
'keys_fetch_func': self.k_cache.fetch_from_cache,
|
||||
'values_fetch_func': self.v_cache.fetch_from_cache,
|
||||
'softmax_op': self.softmax,
|
||||
'block_list': block_list,
|
||||
'key_cache': key_cache,
|
||||
'value_cache': value_cache,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
###############################################################################
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from vllm_hpu_extension import cache_ops, ops
|
||||
@ -63,43 +63,25 @@ class HPUPagedAttention:
|
||||
def forward_decode(**kwargs) -> torch.Tensor:
|
||||
return ops.flat_pa(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def forward_prefix(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
subquery_start_loc: torch.Tensor,
|
||||
seq_lens_tensor: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
max_query_len: int,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
sliding_window: Optional[int],
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError(
|
||||
"forward_prefix is not implemented for HPUPagedAttention")
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: Dict[int, int],
|
||||
src_kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||
dst_kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||
src_to_dsts: torch.Tensor,
|
||||
) -> None:
|
||||
src_key_cache = src_kv_cache[0]
|
||||
dst_key_cache = dst_kv_cache[0]
|
||||
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
|
||||
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dsts)
|
||||
|
||||
src_value_cache = src_kv_cache[1]
|
||||
dst_value_cache = dst_kv_cache[1]
|
||||
cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
|
||||
cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dsts)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: Dict[int, List[int]],
|
||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
src_to_dsts: torch.Tensor,
|
||||
) -> None:
|
||||
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
||||
cache_ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||
cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts)
|
||||
|
||||
@ -14,7 +14,7 @@ import math
|
||||
import os
|
||||
import time
|
||||
from array import array
|
||||
from enum import IntEnum
|
||||
from enum import Enum, IntEnum
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple,
|
||||
Optional, Set, Tuple, Type, TypeVar, Union)
|
||||
|
||||
@ -75,6 +75,12 @@ LORA_WARMUP_RANK = 8
|
||||
DUMMY_TOKEN_ID = -1
|
||||
|
||||
|
||||
class PhaseType(Enum):
|
||||
PREFILL = 'prefill'
|
||||
PREFIX_PREFILL = 'prefix_prefill'
|
||||
DECODE = 'decode'
|
||||
|
||||
|
||||
def subtuple(obj: object,
|
||||
typename: str,
|
||||
to_copy: List[str],
|
||||
@ -213,20 +219,40 @@ class HpuModelAdapter:
|
||||
|
||||
def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device,
|
||||
dtype):
|
||||
prefill_metadata = attn_metadata
|
||||
if prefill_metadata is None or self.prefill_use_fusedsdpa:
|
||||
if (attn_metadata is None
|
||||
or (self.prefill_use_fusedsdpa \
|
||||
and attn_metadata.block_list is None)
|
||||
or not attn_metadata.is_prompt):
|
||||
return attn_metadata
|
||||
|
||||
prefill_metadata = attn_metadata
|
||||
|
||||
seq_lens_t = prefill_metadata.seq_lens_tensor
|
||||
context_lens_t = prefill_metadata.context_lens_tensor
|
||||
query_lens_t = seq_lens_t - context_lens_t
|
||||
|
||||
block_list = attn_metadata.block_list
|
||||
max_context_len = (block_list.size(-1) //
|
||||
batch_size if block_list is not None else 0)
|
||||
max_context_len = max_context_len * self.block_size
|
||||
past_mask = torch.arange(0,
|
||||
max_context_len,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
past_mask = (past_mask.view(1, -1).expand(batch_size, -1).ge(
|
||||
context_lens_t.view(-1, 1)).view(batch_size, 1, -1).expand(
|
||||
batch_size, seq_len, -1).view(batch_size, 1, seq_len, -1))
|
||||
|
||||
len_mask = (torch.arange(0, seq_len, device=device,
|
||||
dtype=torch.int32).view(1, seq_len).ge(
|
||||
seq_lens_t.unsqueeze(-1)).view(
|
||||
query_lens_t.unsqueeze(-1)).view(
|
||||
batch_size, 1, 1, seq_len))
|
||||
causal_mask = torch.triu(torch.ones((batch_size, 1, seq_len, seq_len),
|
||||
device=device,
|
||||
dtype=torch.bool),
|
||||
diagonal=1)
|
||||
mask = causal_mask.logical_or(len_mask)
|
||||
mask = torch.concat((past_mask, mask), dim=-1)
|
||||
attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(
|
||||
mask, -math.inf))
|
||||
attn_metadata = prefill_metadata._replace(attn_bias=attn_bias)
|
||||
@ -517,6 +543,11 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
False, self.max_model_len)
|
||||
self.graphed_buckets: Set[Any] = set()
|
||||
self._set_gc_threshold()
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
os.environ.setdefault("VLLM_CONTIGUOUS_PA", "False")
|
||||
assert os.environ.get(
|
||||
"VLLM_CONTIGUOUS_PA",
|
||||
"").lower() != "true", "Contiguous PA doesn't support APC"
|
||||
self.use_contiguous_pa = envs.VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH
|
||||
|
||||
# For multi-step scheduling
|
||||
@ -702,6 +733,10 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
computed_block_nums) > 0 and self.sliding_window is None:
|
||||
# Prefix is not supported with sliding_window
|
||||
context_len = len(computed_block_nums) * self.block_size
|
||||
if context_len == seq_len \
|
||||
and self.vllm_config.cache_config.enable_prefix_caching:
|
||||
# Fully cached prompt - compute only last token
|
||||
context_len = context_len - 1
|
||||
prompt_tokens = prompt_tokens[context_len:]
|
||||
prefix_block_tables.append(computed_block_nums)
|
||||
elif self.scheduler_config.chunked_prefill_enabled:
|
||||
@ -779,12 +814,33 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
if lora_id > 0:
|
||||
lora_requests.add(seq_group_metadata.lora_request)
|
||||
|
||||
lora_index_mapping += [lora_id] * (max_prompt_len - context_len)
|
||||
lora_index_mapping += [lora_id] * max_prompt_len
|
||||
lora_prompt_mapping.extend(
|
||||
[lora_id] *
|
||||
(max_prompt_len - context_len
|
||||
(max_prompt_len
|
||||
if seq_group_metadata.sampling_params.prompt_logprobs else 1))
|
||||
|
||||
if any(context_lens):
|
||||
assert not self.scheduler_config.chunked_prefill_enabled
|
||||
# prefix caching
|
||||
|
||||
max_num_block = max(len(bt) for bt in prefix_block_tables)
|
||||
prefix_block_list = list(
|
||||
itertools.chain.from_iterable(
|
||||
bt if len(bt) == max_num_block else bt +
|
||||
([_PAD_BLOCK_ID] * (max_num_block - len(bt)))
|
||||
for bt in prefix_block_tables))
|
||||
|
||||
pad_len = len(prefix_block_list)
|
||||
prefix_block_list = pad_list(prefix_block_list, pad_len,
|
||||
_PAD_BLOCK_ID)
|
||||
|
||||
prefix_block_list_tensor = torch.tensor(prefix_block_list,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
else:
|
||||
prefix_block_list_tensor = None
|
||||
|
||||
input_tokens = make_tensor_with_pad(input_tokens,
|
||||
max_len=max_prompt_len,
|
||||
pad=0,
|
||||
@ -807,11 +863,15 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
|
||||
context_lens_tensor = torch.tensor(context_lens,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
|
||||
block_indices, block_offsets = precompute_indices_and_offsets(
|
||||
self.block_size, slot_mapping, True)
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=True,
|
||||
block_list=None,
|
||||
block_list=prefix_block_list_tensor,
|
||||
block_mapping=None,
|
||||
block_usage=None,
|
||||
block_indices=block_indices,
|
||||
@ -819,6 +879,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
block_groups=None,
|
||||
attn_bias=None,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
num_prefills=real_num_seqs,
|
||||
num_prefill_tokens=sum_query_len,
|
||||
num_decode_tokens=0,
|
||||
@ -987,6 +1048,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
block_groups=block_groups,
|
||||
attn_bias=None,
|
||||
seq_lens_tensor=None,
|
||||
context_lens_tensor=None,
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
@ -1091,7 +1153,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
# FIXME: We need to adjust selected_token_indices to accommodate
|
||||
# for padding
|
||||
max_len = input_tokens.size(1)
|
||||
paddings = [max_len - s for s in seq_lens]
|
||||
paddings = [max_len - q for q in query_lens]
|
||||
paddings = [0] + paddings[:-1]
|
||||
paddings = list(itertools.accumulate(paddings))
|
||||
paddings_prompt_logprobs = []
|
||||
@ -1187,9 +1249,17 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
# input_hash(123) != input_hash(321)
|
||||
# input_hash("abc") != input_hash("cba")
|
||||
attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [
|
||||
'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping',
|
||||
'block_usage', 'slot_mapping', 'is_prompt', 'block_indices',
|
||||
'block_offsets', 'block_groups'
|
||||
'attn_bias',
|
||||
'seq_lens_tensor',
|
||||
'context_lens_tensor',
|
||||
'block_list',
|
||||
'block_mapping',
|
||||
'block_usage',
|
||||
'slot_mapping',
|
||||
'is_prompt',
|
||||
'block_indices',
|
||||
'block_offsets',
|
||||
'block_groups',
|
||||
])
|
||||
return attention_metadata
|
||||
|
||||
@ -1733,14 +1803,44 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
|
||||
from neural_compressor.torch.quantization import finalize_calibration
|
||||
finalize_calibration(self.model.model)
|
||||
|
||||
def _check_config(self, batch_size, seq_len, is_prompt, warmup_mode):
|
||||
cfg = (batch_size, seq_len, is_prompt)
|
||||
def _num_blocks(self, attn_metadata):
|
||||
if attn_metadata.block_list is None:
|
||||
return 0
|
||||
return attn_metadata.block_list.numel()
|
||||
|
||||
def _phase(self, attn_metadata):
|
||||
phase_type: PhaseType
|
||||
is_prompt = attn_metadata.is_prompt
|
||||
is_prefix_prefill = is_prompt and attn_metadata.block_list is not None
|
||||
if is_prompt and is_prefix_prefill:
|
||||
phase_type = PhaseType.PREFIX_PREFILL
|
||||
elif is_prompt and not is_prefix_prefill:
|
||||
phase_type = PhaseType.PREFILL
|
||||
elif not is_prompt:
|
||||
phase_type = PhaseType.DECODE
|
||||
else:
|
||||
raise ValueError("Unrecognized pass type, likely due to malformed "
|
||||
"attention metadata")
|
||||
return phase_type
|
||||
|
||||
def _check_config(self, batch_size, seq_len, attn_metadata, warmup_mode):
|
||||
is_prefix_caching = self.vllm_config.cache_config.enable_prefix_caching
|
||||
cfg: Optional[tuple] = None
|
||||
assert cfg is None, "Configs changed between 2D and 3D"
|
||||
if is_prefix_caching:
|
||||
phase = self._phase(attn_metadata)
|
||||
num_blocks = self._num_blocks(attn_metadata)
|
||||
cfg = (batch_size, seq_len, num_blocks, phase)
|
||||
else:
|
||||
phase = 'prompt' if attn_metadata.is_prompt else 'decode'
|
||||
cfg = (batch_size, seq_len, phase)
|
||||
seen = cfg in self.seen_configs
|
||||
self.seen_configs.add(cfg)
|
||||
if not seen and not warmup_mode:
|
||||
phase = 'prompt' if is_prompt else 'decode'
|
||||
logger.warning("Configuration: (%s, %s, %s) was not warmed-up!",
|
||||
phase, batch_size, seq_len)
|
||||
logger.warning("Configuration: %s was not warmed-up!",
|
||||
(phase.value, batch_size, seq_len,
|
||||
num_blocks) if is_prefix_caching else
|
||||
(phase, batch_size, seq_len))
|
||||
|
||||
def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int],
|
||||
is_prompt: bool):
|
||||
@ -1912,7 +2012,7 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
|
||||
batch_size = input_tokens.size(0)
|
||||
seq_len = self._seq_len(attn_metadata)
|
||||
use_graphs = self._use_graphs(batch_size, seq_len, is_prompt)
|
||||
self._check_config(batch_size, seq_len, is_prompt, warmup_mode)
|
||||
self._check_config(batch_size, seq_len, attn_metadata, warmup_mode)
|
||||
|
||||
lora_mask: torch.Tensor = None
|
||||
lora_logits_mask: torch.Tensor = None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user