[Hardware][Intel-Gaudi] Support Automatic Prefix Caching on HPU (#17648)

Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai>
This commit is contained in:
Agata Dobrzyniewicz 2025-05-08 07:37:03 +02:00 committed by GitHub
parent e515668edf
commit 843b222723
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 146 additions and 56 deletions

View File

@ -57,16 +57,16 @@ class HPUAttentionBackend(AttentionBackend):
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
src_to_dsts: torch.Tensor,
) -> None:
HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dsts)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
src_to_dsts: torch.Tensor,
) -> None:
HPUPagedAttention.copy_blocks(kv_caches, src_to_dists)
HPUPagedAttention.copy_blocks(kv_caches, src_to_dsts)
@dataclass
@ -77,6 +77,7 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
is_prompt: bool
attn_bias: Optional[torch.Tensor]
seq_lens_tensor: Optional[torch.Tensor]
context_lens_tensor: Optional[torch.Tensor]
class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
@ -198,8 +199,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
key_cache = None
value_cache = None
if attn_metadata.is_prompt and self.attn_type \
is not AttentionType.ENCODER_ONLY \
and attn_metadata.block_list is None:
is not AttentionType.ENCODER_ONLY:
key = key.unflatten(0, (block_indices.size(0), -1))
value = value.unflatten(0, (block_indices.size(0), -1))
if kv_cache is not None and isinstance(kv_cache, tuple):
@ -229,6 +229,9 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1))
attn_bias.add_(position_bias)
block_list = attn_metadata.block_list if attn_metadata \
and attn_metadata.block_list is not None else None
out = ops.prompt_attention(
impl=self.prefill_impl,
query=query.view(query_shape),
@ -237,23 +240,25 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
is_causal=True,
attn_bias=attn_bias,
valid_seq_lengths=attn_metadata.seq_lens_tensor,
**self.common_attention_args())
**self.common_attention_args(block_list, key_cache,
value_cache))
output = out.reshape(batch_size, seq_len, hidden_size)
else:
# Decoding run.
output = HPUPagedAttention.forward_decode(
query=query,
key_cache=key_cache,
value_cache=value_cache,
block_list=attn_metadata.block_list,
block_mapping=attn_metadata.block_mapping,
block_bias=attn_metadata.attn_bias,
block_groups=attn_metadata.block_groups,
**self.common_attention_args())
**self.common_attention_args(attn_metadata.block_list,
key_cache, value_cache))
# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)
def common_attention_args(self):
def common_attention_args(self,
block_list=None,
key_cache=None,
value_cache=None):
fsdpa_op = self.fused_scaled_dot_product_attention.apply \
if self.fused_scaled_dot_product_attention is not None else None
return {
@ -266,6 +271,9 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
'keys_fetch_func': self.k_cache.fetch_from_cache,
'values_fetch_func': self.v_cache.fetch_from_cache,
'softmax_op': self.softmax,
'block_list': block_list,
'key_cache': key_cache,
'value_cache': value_cache,
}

View File

@ -5,7 +5,7 @@
###############################################################################
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from typing import List, Optional, Tuple
import torch
from vllm_hpu_extension import cache_ops, ops
@ -63,43 +63,25 @@ class HPUPagedAttention:
def forward_decode(**kwargs) -> torch.Tensor:
return ops.flat_pa(**kwargs)
@staticmethod
def forward_prefix(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
subquery_start_loc: torch.Tensor,
seq_lens_tensor: torch.Tensor,
context_lens: torch.Tensor,
max_query_len: int,
alibi_slopes: Optional[torch.Tensor],
sliding_window: Optional[int],
) -> torch.Tensor:
raise NotImplementedError(
"forward_prefix is not implemented for HPUPagedAttention")
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
src_kv_cache: Tuple[torch.Tensor, torch.Tensor],
dst_kv_cache: Tuple[torch.Tensor, torch.Tensor],
src_to_dsts: torch.Tensor,
) -> None:
src_key_cache = src_kv_cache[0]
dst_key_cache = dst_kv_cache[0]
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dsts)
src_value_cache = src_kv_cache[1]
dst_value_cache = dst_kv_cache[1]
cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dsts)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
src_to_dsts: torch.Tensor,
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
cache_ops.copy_blocks(key_caches, value_caches, src_to_dists)
cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts)

View File

@ -14,7 +14,7 @@ import math
import os
import time
from array import array
from enum import IntEnum
from enum import Enum, IntEnum
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple,
Optional, Set, Tuple, Type, TypeVar, Union)
@ -75,6 +75,12 @@ LORA_WARMUP_RANK = 8
DUMMY_TOKEN_ID = -1
class PhaseType(Enum):
PREFILL = 'prefill'
PREFIX_PREFILL = 'prefix_prefill'
DECODE = 'decode'
def subtuple(obj: object,
typename: str,
to_copy: List[str],
@ -213,20 +219,40 @@ class HpuModelAdapter:
def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device,
dtype):
prefill_metadata = attn_metadata
if prefill_metadata is None or self.prefill_use_fusedsdpa:
if (attn_metadata is None
or (self.prefill_use_fusedsdpa \
and attn_metadata.block_list is None)
or not attn_metadata.is_prompt):
return attn_metadata
prefill_metadata = attn_metadata
seq_lens_t = prefill_metadata.seq_lens_tensor
context_lens_t = prefill_metadata.context_lens_tensor
query_lens_t = seq_lens_t - context_lens_t
block_list = attn_metadata.block_list
max_context_len = (block_list.size(-1) //
batch_size if block_list is not None else 0)
max_context_len = max_context_len * self.block_size
past_mask = torch.arange(0,
max_context_len,
dtype=torch.int32,
device=device)
past_mask = (past_mask.view(1, -1).expand(batch_size, -1).ge(
context_lens_t.view(-1, 1)).view(batch_size, 1, -1).expand(
batch_size, seq_len, -1).view(batch_size, 1, seq_len, -1))
len_mask = (torch.arange(0, seq_len, device=device,
dtype=torch.int32).view(1, seq_len).ge(
seq_lens_t.unsqueeze(-1)).view(
query_lens_t.unsqueeze(-1)).view(
batch_size, 1, 1, seq_len))
causal_mask = torch.triu(torch.ones((batch_size, 1, seq_len, seq_len),
device=device,
dtype=torch.bool),
diagonal=1)
mask = causal_mask.logical_or(len_mask)
mask = torch.concat((past_mask, mask), dim=-1)
attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(
mask, -math.inf))
attn_metadata = prefill_metadata._replace(attn_bias=attn_bias)
@ -517,6 +543,11 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
False, self.max_model_len)
self.graphed_buckets: Set[Any] = set()
self._set_gc_threshold()
if self.vllm_config.cache_config.enable_prefix_caching:
os.environ.setdefault("VLLM_CONTIGUOUS_PA", "False")
assert os.environ.get(
"VLLM_CONTIGUOUS_PA",
"").lower() != "true", "Contiguous PA doesn't support APC"
self.use_contiguous_pa = envs.VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH
# For multi-step scheduling
@ -702,6 +733,10 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
computed_block_nums) > 0 and self.sliding_window is None:
# Prefix is not supported with sliding_window
context_len = len(computed_block_nums) * self.block_size
if context_len == seq_len \
and self.vllm_config.cache_config.enable_prefix_caching:
# Fully cached prompt - compute only last token
context_len = context_len - 1
prompt_tokens = prompt_tokens[context_len:]
prefix_block_tables.append(computed_block_nums)
elif self.scheduler_config.chunked_prefill_enabled:
@ -779,12 +814,33 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
if lora_id > 0:
lora_requests.add(seq_group_metadata.lora_request)
lora_index_mapping += [lora_id] * (max_prompt_len - context_len)
lora_index_mapping += [lora_id] * max_prompt_len
lora_prompt_mapping.extend(
[lora_id] *
(max_prompt_len - context_len
(max_prompt_len
if seq_group_metadata.sampling_params.prompt_logprobs else 1))
if any(context_lens):
assert not self.scheduler_config.chunked_prefill_enabled
# prefix caching
max_num_block = max(len(bt) for bt in prefix_block_tables)
prefix_block_list = list(
itertools.chain.from_iterable(
bt if len(bt) == max_num_block else bt +
([_PAD_BLOCK_ID] * (max_num_block - len(bt)))
for bt in prefix_block_tables))
pad_len = len(prefix_block_list)
prefix_block_list = pad_list(prefix_block_list, pad_len,
_PAD_BLOCK_ID)
prefix_block_list_tensor = torch.tensor(prefix_block_list,
dtype=torch.long,
device=self.device)
else:
prefix_block_list_tensor = None
input_tokens = make_tensor_with_pad(input_tokens,
max_len=max_prompt_len,
pad=0,
@ -807,11 +863,15 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
dtype=torch.long,
device=self.device)
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.long,
device=self.device)
block_indices, block_offsets = precompute_indices_and_offsets(
self.block_size, slot_mapping, True)
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
block_list=None,
block_list=prefix_block_list_tensor,
block_mapping=None,
block_usage=None,
block_indices=block_indices,
@ -819,6 +879,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
block_groups=None,
attn_bias=None,
seq_lens_tensor=seq_lens_tensor,
context_lens_tensor=context_lens_tensor,
num_prefills=real_num_seqs,
num_prefill_tokens=sum_query_len,
num_decode_tokens=0,
@ -987,6 +1048,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
block_groups=block_groups,
attn_bias=None,
seq_lens_tensor=None,
context_lens_tensor=None,
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=num_decode_tokens,
@ -1091,7 +1153,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
# FIXME: We need to adjust selected_token_indices to accommodate
# for padding
max_len = input_tokens.size(1)
paddings = [max_len - s for s in seq_lens]
paddings = [max_len - q for q in query_lens]
paddings = [0] + paddings[:-1]
paddings = list(itertools.accumulate(paddings))
paddings_prompt_logprobs = []
@ -1187,9 +1249,17 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
# input_hash(123) != input_hash(321)
# input_hash("abc") != input_hash("cba")
attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [
'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping',
'block_usage', 'slot_mapping', 'is_prompt', 'block_indices',
'block_offsets', 'block_groups'
'attn_bias',
'seq_lens_tensor',
'context_lens_tensor',
'block_list',
'block_mapping',
'block_usage',
'slot_mapping',
'is_prompt',
'block_indices',
'block_offsets',
'block_groups',
])
return attention_metadata
@ -1733,14 +1803,44 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
from neural_compressor.torch.quantization import finalize_calibration
finalize_calibration(self.model.model)
def _check_config(self, batch_size, seq_len, is_prompt, warmup_mode):
cfg = (batch_size, seq_len, is_prompt)
def _num_blocks(self, attn_metadata):
if attn_metadata.block_list is None:
return 0
return attn_metadata.block_list.numel()
def _phase(self, attn_metadata):
phase_type: PhaseType
is_prompt = attn_metadata.is_prompt
is_prefix_prefill = is_prompt and attn_metadata.block_list is not None
if is_prompt and is_prefix_prefill:
phase_type = PhaseType.PREFIX_PREFILL
elif is_prompt and not is_prefix_prefill:
phase_type = PhaseType.PREFILL
elif not is_prompt:
phase_type = PhaseType.DECODE
else:
raise ValueError("Unrecognized pass type, likely due to malformed "
"attention metadata")
return phase_type
def _check_config(self, batch_size, seq_len, attn_metadata, warmup_mode):
is_prefix_caching = self.vllm_config.cache_config.enable_prefix_caching
cfg: Optional[tuple] = None
assert cfg is None, "Configs changed between 2D and 3D"
if is_prefix_caching:
phase = self._phase(attn_metadata)
num_blocks = self._num_blocks(attn_metadata)
cfg = (batch_size, seq_len, num_blocks, phase)
else:
phase = 'prompt' if attn_metadata.is_prompt else 'decode'
cfg = (batch_size, seq_len, phase)
seen = cfg in self.seen_configs
self.seen_configs.add(cfg)
if not seen and not warmup_mode:
phase = 'prompt' if is_prompt else 'decode'
logger.warning("Configuration: (%s, %s, %s) was not warmed-up!",
phase, batch_size, seq_len)
logger.warning("Configuration: %s was not warmed-up!",
(phase.value, batch_size, seq_len,
num_blocks) if is_prefix_caching else
(phase, batch_size, seq_len))
def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int],
is_prompt: bool):
@ -1912,7 +2012,7 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
batch_size = input_tokens.size(0)
seq_len = self._seq_len(attn_metadata)
use_graphs = self._use_graphs(batch_size, seq_len, is_prompt)
self._check_config(batch_size, seq_len, is_prompt, warmup_mode)
self._check_config(batch_size, seq_len, attn_metadata, warmup_mode)
lora_mask: torch.Tensor = None
lora_logits_mask: torch.Tensor = None