[Hardware][Gaudi][Feature] Support Contiguous Cache Fetch (#12139)

Signed-off-by: yuzhou <yuzhou@habana.ai>
Signed-off-by: zhouyu5 <yu.zhou@intel.com>
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
Yu-Zhou 2025-02-19 11:40:19 +08:00 committed by GitHub
parent 00b69c2d27
commit d0a7a2769d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 80 additions and 47 deletions

View File

@ -118,12 +118,8 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
self.matmul_av = Matmul()
self.batch2block_matmul = Matmul()
self.block2batch_matmul = Matmul()
# NOTE(kzawora): Contiguous PA is off until model runner supports it
self.k_cache = VLLMKVCache()
self.k_cache.use_contiguous_pa = False
self.v_cache = VLLMKVCache()
self.v_cache.use_contiguous_pa = False
# NOTE(kzawora): Pipelined PA is off until model runner supports it
ops.pa_impl = ops.pa
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
@ -249,7 +245,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
block_mapping=attn_metadata.block_mapping,
block_bias=attn_metadata.attn_bias,
block_scales=attn_metadata.block_scales,
block_groups=None,
block_groups=attn_metadata.block_groups,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
matmul_av_op=self.matmul_av,

View File

@ -23,6 +23,7 @@ class HPUPagedAttentionMetadata:
block_indices: Optional[torch.Tensor]
block_offsets: Optional[torch.Tensor]
block_scales: Optional[torch.Tensor]
block_groups: Optional[torch.Tensor]
class HPUPagedAttention:

View File

@ -89,6 +89,7 @@ if TYPE_CHECKING:
VLLM_RAY_PER_WORKER_GPUS: float = 1.0
VLLM_RAY_BUNDLE_INDICES: str = ""
VLLM_CUDART_SO_PATH: Optional[str] = None
VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH: bool = True
def get_default_cache_root():
@ -585,6 +586,13 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# specify the path through environment variable VLLM_CUDART_SO_PATH.
"VLLM_CUDART_SO_PATH":
lambda: os.getenv("VLLM_CUDART_SO_PATH", None),
# Contiguous cache fetching to avoid using costly gather operation on
# Gaudi3. This is only applicable to HPU contiguous cache. If set to true,
# contiguous cache fetch will be used.
"VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH":
lambda: os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() in
("1", "true"),
}
# end-env-vars-definition

View File

@ -25,9 +25,11 @@ import habana_frameworks.torch.internal.bridge_config as bc
import torch
import torch.nn as nn
from vllm_hpu_extension.ops import LoraMask as LoraMask
from vllm_hpu_extension.ops import batch2block, block2batch
from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler,
HabanaMemoryProfiler, format_bytes)
import vllm.envs as envs
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import DeviceConfig, VllmConfig
from vllm.distributed.parallel_state import get_world_group
@ -260,10 +262,19 @@ def setup_profiler():
return profiler
def pad_list(list, k, v):
target_len = round_up(len(list), k)
padding = target_len - len(list)
return list + [v] * padding
def pad_list(input, k, v):
input_len = len(input)
target_len = round_up(input_len, k)
padding = target_len - input_len
return input + [v] * padding
def gather_list(input, indices, v):
return [input[i] if i is not None else v for i in indices]
def flatten(in_list):
return list(itertools.chain(*in_list))
def precompute_indices_and_offsets(block_size, slot_mapping, is_prompt):
@ -334,13 +345,23 @@ class HpuModelAdapter:
mask = mask >= metadata.block_usage.unsqueeze(-1)
attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(
mask, -math.inf))
block_mapping = torch.nn.functional.one_hot(metadata.block_mapping,
block_mapping = torch.nn.functional.one_hot(metadata.block_groups,
num_classes=batch_size)
block_mapping = block_mapping.to(dtype)
metadata = metadata._replace(block_mapping=block_mapping,
attn_bias=attn_bias)
return metadata
def _set_block_scales(self, metadata, device):
block_mapping = metadata.block_mapping
ones = torch.ones((block_mapping.size(0), ),
device=device,
dtype=block_mapping.dtype)
sums = batch2block(block2batch(ones, block_mapping), block_mapping)
block_scales = torch.reciprocal(torch.maximum(ones, sums))
metadata = metadata._replace(block_scales=block_scales)
return metadata
def _update_metadata(self, attn_metadata, batch_size, seq_len, device,
dtype):
if attn_metadata.is_prompt:
@ -351,6 +372,7 @@ class HpuModelAdapter:
meta = attn_metadata
attn_metadata = self._set_block_mapping(meta, batch_size, device,
dtype)
attn_metadata = self._set_block_scales(attn_metadata, device)
return attn_metadata
def forward(self, *args, **kwargs):
@ -586,6 +608,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
self.bucketing_global_state = HPUBucketingGlobalState()
self._setup_buckets()
self._set_gc_threshold()
self.use_contiguous_pa = envs.VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH
def _set_gc_threshold(self) -> None:
# Read https://docs.python.org/3/library/gc.html#gc.set_threshold
@ -911,6 +934,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
block_indices=block_indices,
block_offsets=block_offsets,
block_scales=None,
block_groups=None,
attn_bias=None,
seq_lens_tensor=seq_lens_tensor,
num_prefills=real_num_seqs,
@ -1008,65 +1032,69 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
num_decode_tokens = sum(seq_lens)
blocks_used = [len(bt) for bt in block_tables if bt]
block_list = []
block_scales = []
for i, bt in enumerate(block_tables):
block_list.extend(bt)
blocks_in_group = len(bt)
if blocks_in_group > 0:
scale = 1.0 / blocks_in_group
block_scales.extend([scale] * blocks_in_group)
block_mapping_nested: List[List[int]] = [
[i] * b_u for i, b_u in enumerate(blocks_used)
last_block_usage = [
slot[0] % self.block_size + 1 for slot in slot_mapping
]
block_mapping: List[int] = list(
itertools.chain.from_iterable(block_mapping_nested))
block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)]
block_usage = [[self.block_size] * (len(bt) - 1) + [lbu]
for bt, lbu in zip(block_tables, last_block_usage)
if bt]
last_block = [
sl % self.block_size + 1 for sl in itertools.chain(*slot_mapping)
]
block_usage = [[self.block_size] * (b_u - 1) + [lb]
for b_u, lb in zip(blocks_used, last_block)]
block_usage = list(itertools.chain(*block_usage))
block_list = flatten(block_tables)
block_groups = flatten(block_groups)
block_usage = flatten(block_usage)
block_bucket_size = find_bucket(
len(block_list),
self.bucketing_global_state.decode_block_bucket_cfg)
block_list = pad_list(block_list, block_bucket_size, _PAD_BLOCK_ID)
block_mapping = pad_list(block_mapping, block_bucket_size, -1)
block_usage = pad_list(block_usage, block_bucket_size, 1)
block_scales = pad_list(block_scales, block_bucket_size, 0.0)
assert len(block_list) == len(block_groups)
assert len(block_list) == len(block_usage)
padding_fn = None
if self.use_contiguous_pa:
block_bucket_size = max(max(block_list) + 1, len(block_list))
block_bucket_size = find_bucket(
block_bucket_size,
self.bucketing_global_state.decode_block_bucket_cfg)
indices: List[Any]
indices = [None] * block_bucket_size
for i, bid in enumerate(block_list):
indices[bid] = i
padding_fn = lambda tensor, pad_value: gather_list(
tensor, indices, pad_value)
else:
block_bucket_size = find_bucket(
len(block_list),
self.bucketing_global_state.decode_block_bucket_cfg)
padding_fn = lambda tensor, pad_value: pad_list(
tensor, block_bucket_size, pad_value)
block_list = padding_fn(block_list, _PAD_BLOCK_ID)
block_groups = padding_fn(block_groups, -1)
block_usage = padding_fn(block_usage, 1)
block_list = torch.tensor(block_list,
dtype=torch.int,
device=self.device)
block_mapping = torch.tensor(block_mapping,
dtype=torch.long,
device=self.device)
block_groups = torch.tensor(block_groups,
dtype=torch.int,
device=self.device)
block_usage = torch.tensor(block_usage,
dtype=self.model_config.dtype,
device=self.device)
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
block_indices, block_offsets = precompute_indices_and_offsets(
self.block_size, slot_mapping, False)
block_scales = torch.tensor(block_scales,
dtype=self.model_config.dtype,
device=self.device)
attn_metadata = self.attn_backend.make_metadata(
is_prompt=False,
block_list=block_list,
block_mapping=block_mapping,
block_mapping=None,
block_usage=block_usage,
block_indices=block_indices,
block_offsets=block_offsets,
block_scales=block_scales,
block_scales=None,
block_groups=block_groups,
attn_bias=None,
seq_lens_tensor=None,
num_prefills=0,
@ -1280,7 +1308,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [
'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping',
'block_usage', 'slot_mapping', 'is_prompt', 'block_indices',
'block_offsets', 'block_scales'
'block_offsets', 'block_scales', 'block_groups'
])
return attention_metadata