mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:44:57 +08:00
[Hardware][Gaudi][Feature] Support Contiguous Cache Fetch (#12139)
Signed-off-by: yuzhou <yuzhou@habana.ai> Signed-off-by: zhouyu5 <yu.zhou@intel.com> Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
parent
00b69c2d27
commit
d0a7a2769d
@ -118,12 +118,8 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
||||
self.matmul_av = Matmul()
|
||||
self.batch2block_matmul = Matmul()
|
||||
self.block2batch_matmul = Matmul()
|
||||
# NOTE(kzawora): Contiguous PA is off until model runner supports it
|
||||
self.k_cache = VLLMKVCache()
|
||||
self.k_cache.use_contiguous_pa = False
|
||||
self.v_cache = VLLMKVCache()
|
||||
self.v_cache.use_contiguous_pa = False
|
||||
# NOTE(kzawora): Pipelined PA is off until model runner supports it
|
||||
ops.pa_impl = ops.pa
|
||||
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
@ -249,7 +245,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
||||
block_mapping=attn_metadata.block_mapping,
|
||||
block_bias=attn_metadata.attn_bias,
|
||||
block_scales=attn_metadata.block_scales,
|
||||
block_groups=None,
|
||||
block_groups=attn_metadata.block_groups,
|
||||
scale=self.scale,
|
||||
matmul_qk_op=self.matmul_qk,
|
||||
matmul_av_op=self.matmul_av,
|
||||
|
||||
@ -23,6 +23,7 @@ class HPUPagedAttentionMetadata:
|
||||
block_indices: Optional[torch.Tensor]
|
||||
block_offsets: Optional[torch.Tensor]
|
||||
block_scales: Optional[torch.Tensor]
|
||||
block_groups: Optional[torch.Tensor]
|
||||
|
||||
|
||||
class HPUPagedAttention:
|
||||
|
||||
@ -89,6 +89,7 @@ if TYPE_CHECKING:
|
||||
VLLM_RAY_PER_WORKER_GPUS: float = 1.0
|
||||
VLLM_RAY_BUNDLE_INDICES: str = ""
|
||||
VLLM_CUDART_SO_PATH: Optional[str] = None
|
||||
VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH: bool = True
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -585,6 +586,13 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
# specify the path through environment variable VLLM_CUDART_SO_PATH.
|
||||
"VLLM_CUDART_SO_PATH":
|
||||
lambda: os.getenv("VLLM_CUDART_SO_PATH", None),
|
||||
|
||||
# Contiguous cache fetching to avoid using costly gather operation on
|
||||
# Gaudi3. This is only applicable to HPU contiguous cache. If set to true,
|
||||
# contiguous cache fetch will be used.
|
||||
"VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH":
|
||||
lambda: os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() in
|
||||
("1", "true"),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
||||
@ -25,9 +25,11 @@ import habana_frameworks.torch.internal.bridge_config as bc
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from vllm_hpu_extension.ops import LoraMask as LoraMask
|
||||
from vllm_hpu_extension.ops import batch2block, block2batch
|
||||
from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler,
|
||||
HabanaMemoryProfiler, format_bytes)
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.config import DeviceConfig, VllmConfig
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
@ -260,10 +262,19 @@ def setup_profiler():
|
||||
return profiler
|
||||
|
||||
|
||||
def pad_list(list, k, v):
|
||||
target_len = round_up(len(list), k)
|
||||
padding = target_len - len(list)
|
||||
return list + [v] * padding
|
||||
def pad_list(input, k, v):
|
||||
input_len = len(input)
|
||||
target_len = round_up(input_len, k)
|
||||
padding = target_len - input_len
|
||||
return input + [v] * padding
|
||||
|
||||
|
||||
def gather_list(input, indices, v):
|
||||
return [input[i] if i is not None else v for i in indices]
|
||||
|
||||
|
||||
def flatten(in_list):
|
||||
return list(itertools.chain(*in_list))
|
||||
|
||||
|
||||
def precompute_indices_and_offsets(block_size, slot_mapping, is_prompt):
|
||||
@ -334,13 +345,23 @@ class HpuModelAdapter:
|
||||
mask = mask >= metadata.block_usage.unsqueeze(-1)
|
||||
attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(
|
||||
mask, -math.inf))
|
||||
block_mapping = torch.nn.functional.one_hot(metadata.block_mapping,
|
||||
block_mapping = torch.nn.functional.one_hot(metadata.block_groups,
|
||||
num_classes=batch_size)
|
||||
block_mapping = block_mapping.to(dtype)
|
||||
metadata = metadata._replace(block_mapping=block_mapping,
|
||||
attn_bias=attn_bias)
|
||||
return metadata
|
||||
|
||||
def _set_block_scales(self, metadata, device):
|
||||
block_mapping = metadata.block_mapping
|
||||
ones = torch.ones((block_mapping.size(0), ),
|
||||
device=device,
|
||||
dtype=block_mapping.dtype)
|
||||
sums = batch2block(block2batch(ones, block_mapping), block_mapping)
|
||||
block_scales = torch.reciprocal(torch.maximum(ones, sums))
|
||||
metadata = metadata._replace(block_scales=block_scales)
|
||||
return metadata
|
||||
|
||||
def _update_metadata(self, attn_metadata, batch_size, seq_len, device,
|
||||
dtype):
|
||||
if attn_metadata.is_prompt:
|
||||
@ -351,6 +372,7 @@ class HpuModelAdapter:
|
||||
meta = attn_metadata
|
||||
attn_metadata = self._set_block_mapping(meta, batch_size, device,
|
||||
dtype)
|
||||
attn_metadata = self._set_block_scales(attn_metadata, device)
|
||||
return attn_metadata
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
@ -586,6 +608,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
self.bucketing_global_state = HPUBucketingGlobalState()
|
||||
self._setup_buckets()
|
||||
self._set_gc_threshold()
|
||||
self.use_contiguous_pa = envs.VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH
|
||||
|
||||
def _set_gc_threshold(self) -> None:
|
||||
# Read https://docs.python.org/3/library/gc.html#gc.set_threshold
|
||||
@ -911,6 +934,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
block_indices=block_indices,
|
||||
block_offsets=block_offsets,
|
||||
block_scales=None,
|
||||
block_groups=None,
|
||||
attn_bias=None,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
num_prefills=real_num_seqs,
|
||||
@ -1008,65 +1032,69 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
|
||||
num_decode_tokens = sum(seq_lens)
|
||||
|
||||
blocks_used = [len(bt) for bt in block_tables if bt]
|
||||
block_list = []
|
||||
block_scales = []
|
||||
for i, bt in enumerate(block_tables):
|
||||
block_list.extend(bt)
|
||||
blocks_in_group = len(bt)
|
||||
if blocks_in_group > 0:
|
||||
scale = 1.0 / blocks_in_group
|
||||
block_scales.extend([scale] * blocks_in_group)
|
||||
|
||||
block_mapping_nested: List[List[int]] = [
|
||||
[i] * b_u for i, b_u in enumerate(blocks_used)
|
||||
last_block_usage = [
|
||||
slot[0] % self.block_size + 1 for slot in slot_mapping
|
||||
]
|
||||
block_mapping: List[int] = list(
|
||||
itertools.chain.from_iterable(block_mapping_nested))
|
||||
block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)]
|
||||
block_usage = [[self.block_size] * (len(bt) - 1) + [lbu]
|
||||
for bt, lbu in zip(block_tables, last_block_usage)
|
||||
if bt]
|
||||
|
||||
last_block = [
|
||||
sl % self.block_size + 1 for sl in itertools.chain(*slot_mapping)
|
||||
]
|
||||
block_usage = [[self.block_size] * (b_u - 1) + [lb]
|
||||
for b_u, lb in zip(blocks_used, last_block)]
|
||||
block_usage = list(itertools.chain(*block_usage))
|
||||
block_list = flatten(block_tables)
|
||||
block_groups = flatten(block_groups)
|
||||
block_usage = flatten(block_usage)
|
||||
|
||||
block_bucket_size = find_bucket(
|
||||
len(block_list),
|
||||
self.bucketing_global_state.decode_block_bucket_cfg)
|
||||
block_list = pad_list(block_list, block_bucket_size, _PAD_BLOCK_ID)
|
||||
block_mapping = pad_list(block_mapping, block_bucket_size, -1)
|
||||
block_usage = pad_list(block_usage, block_bucket_size, 1)
|
||||
block_scales = pad_list(block_scales, block_bucket_size, 0.0)
|
||||
assert len(block_list) == len(block_groups)
|
||||
assert len(block_list) == len(block_usage)
|
||||
|
||||
padding_fn = None
|
||||
if self.use_contiguous_pa:
|
||||
block_bucket_size = max(max(block_list) + 1, len(block_list))
|
||||
block_bucket_size = find_bucket(
|
||||
block_bucket_size,
|
||||
self.bucketing_global_state.decode_block_bucket_cfg)
|
||||
indices: List[Any]
|
||||
indices = [None] * block_bucket_size
|
||||
for i, bid in enumerate(block_list):
|
||||
indices[bid] = i
|
||||
padding_fn = lambda tensor, pad_value: gather_list(
|
||||
tensor, indices, pad_value)
|
||||
else:
|
||||
block_bucket_size = find_bucket(
|
||||
len(block_list),
|
||||
self.bucketing_global_state.decode_block_bucket_cfg)
|
||||
padding_fn = lambda tensor, pad_value: pad_list(
|
||||
tensor, block_bucket_size, pad_value)
|
||||
|
||||
block_list = padding_fn(block_list, _PAD_BLOCK_ID)
|
||||
block_groups = padding_fn(block_groups, -1)
|
||||
block_usage = padding_fn(block_usage, 1)
|
||||
|
||||
block_list = torch.tensor(block_list,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
block_mapping = torch.tensor(block_mapping,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
block_groups = torch.tensor(block_groups,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
block_usage = torch.tensor(block_usage,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device)
|
||||
|
||||
slot_mapping = torch.tensor(slot_mapping,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
|
||||
block_indices, block_offsets = precompute_indices_and_offsets(
|
||||
self.block_size, slot_mapping, False)
|
||||
block_scales = torch.tensor(block_scales,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device)
|
||||
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=False,
|
||||
block_list=block_list,
|
||||
block_mapping=block_mapping,
|
||||
block_mapping=None,
|
||||
block_usage=block_usage,
|
||||
block_indices=block_indices,
|
||||
block_offsets=block_offsets,
|
||||
block_scales=block_scales,
|
||||
block_scales=None,
|
||||
block_groups=block_groups,
|
||||
attn_bias=None,
|
||||
seq_lens_tensor=None,
|
||||
num_prefills=0,
|
||||
@ -1280,7 +1308,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [
|
||||
'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping',
|
||||
'block_usage', 'slot_mapping', 'is_prompt', 'block_indices',
|
||||
'block_offsets', 'block_scales'
|
||||
'block_offsets', 'block_scales', 'block_groups'
|
||||
])
|
||||
return attention_metadata
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user