mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 04:45:01 +08:00
[Hardware][Gaudi][Feature] Support Contiguous Cache Fetch (#12139)
Signed-off-by: yuzhou <yuzhou@habana.ai> Signed-off-by: zhouyu5 <yu.zhou@intel.com> Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
parent
00b69c2d27
commit
d0a7a2769d
@ -118,12 +118,8 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
|||||||
self.matmul_av = Matmul()
|
self.matmul_av = Matmul()
|
||||||
self.batch2block_matmul = Matmul()
|
self.batch2block_matmul = Matmul()
|
||||||
self.block2batch_matmul = Matmul()
|
self.block2batch_matmul = Matmul()
|
||||||
# NOTE(kzawora): Contiguous PA is off until model runner supports it
|
|
||||||
self.k_cache = VLLMKVCache()
|
self.k_cache = VLLMKVCache()
|
||||||
self.k_cache.use_contiguous_pa = False
|
|
||||||
self.v_cache = VLLMKVCache()
|
self.v_cache = VLLMKVCache()
|
||||||
self.v_cache.use_contiguous_pa = False
|
|
||||||
# NOTE(kzawora): Pipelined PA is off until model runner supports it
|
|
||||||
ops.pa_impl = ops.pa
|
ops.pa_impl = ops.pa
|
||||||
|
|
||||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||||
@ -249,7 +245,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
|||||||
block_mapping=attn_metadata.block_mapping,
|
block_mapping=attn_metadata.block_mapping,
|
||||||
block_bias=attn_metadata.attn_bias,
|
block_bias=attn_metadata.attn_bias,
|
||||||
block_scales=attn_metadata.block_scales,
|
block_scales=attn_metadata.block_scales,
|
||||||
block_groups=None,
|
block_groups=attn_metadata.block_groups,
|
||||||
scale=self.scale,
|
scale=self.scale,
|
||||||
matmul_qk_op=self.matmul_qk,
|
matmul_qk_op=self.matmul_qk,
|
||||||
matmul_av_op=self.matmul_av,
|
matmul_av_op=self.matmul_av,
|
||||||
|
|||||||
@ -23,6 +23,7 @@ class HPUPagedAttentionMetadata:
|
|||||||
block_indices: Optional[torch.Tensor]
|
block_indices: Optional[torch.Tensor]
|
||||||
block_offsets: Optional[torch.Tensor]
|
block_offsets: Optional[torch.Tensor]
|
||||||
block_scales: Optional[torch.Tensor]
|
block_scales: Optional[torch.Tensor]
|
||||||
|
block_groups: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
class HPUPagedAttention:
|
class HPUPagedAttention:
|
||||||
|
|||||||
@ -89,6 +89,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_RAY_PER_WORKER_GPUS: float = 1.0
|
VLLM_RAY_PER_WORKER_GPUS: float = 1.0
|
||||||
VLLM_RAY_BUNDLE_INDICES: str = ""
|
VLLM_RAY_BUNDLE_INDICES: str = ""
|
||||||
VLLM_CUDART_SO_PATH: Optional[str] = None
|
VLLM_CUDART_SO_PATH: Optional[str] = None
|
||||||
|
VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH: bool = True
|
||||||
|
|
||||||
|
|
||||||
def get_default_cache_root():
|
def get_default_cache_root():
|
||||||
@ -585,6 +586,13 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
# specify the path through environment variable VLLM_CUDART_SO_PATH.
|
# specify the path through environment variable VLLM_CUDART_SO_PATH.
|
||||||
"VLLM_CUDART_SO_PATH":
|
"VLLM_CUDART_SO_PATH":
|
||||||
lambda: os.getenv("VLLM_CUDART_SO_PATH", None),
|
lambda: os.getenv("VLLM_CUDART_SO_PATH", None),
|
||||||
|
|
||||||
|
# Contiguous cache fetching to avoid using costly gather operation on
|
||||||
|
# Gaudi3. This is only applicable to HPU contiguous cache. If set to true,
|
||||||
|
# contiguous cache fetch will be used.
|
||||||
|
"VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH":
|
||||||
|
lambda: os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() in
|
||||||
|
("1", "true"),
|
||||||
}
|
}
|
||||||
|
|
||||||
# end-env-vars-definition
|
# end-env-vars-definition
|
||||||
|
|||||||
@ -25,9 +25,11 @@ import habana_frameworks.torch.internal.bridge_config as bc
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from vllm_hpu_extension.ops import LoraMask as LoraMask
|
from vllm_hpu_extension.ops import LoraMask as LoraMask
|
||||||
|
from vllm_hpu_extension.ops import batch2block, block2batch
|
||||||
from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler,
|
from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler,
|
||||||
HabanaMemoryProfiler, format_bytes)
|
HabanaMemoryProfiler, format_bytes)
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||||
from vllm.config import DeviceConfig, VllmConfig
|
from vllm.config import DeviceConfig, VllmConfig
|
||||||
from vllm.distributed.parallel_state import get_world_group
|
from vllm.distributed.parallel_state import get_world_group
|
||||||
@ -260,10 +262,19 @@ def setup_profiler():
|
|||||||
return profiler
|
return profiler
|
||||||
|
|
||||||
|
|
||||||
def pad_list(list, k, v):
|
def pad_list(input, k, v):
|
||||||
target_len = round_up(len(list), k)
|
input_len = len(input)
|
||||||
padding = target_len - len(list)
|
target_len = round_up(input_len, k)
|
||||||
return list + [v] * padding
|
padding = target_len - input_len
|
||||||
|
return input + [v] * padding
|
||||||
|
|
||||||
|
|
||||||
|
def gather_list(input, indices, v):
|
||||||
|
return [input[i] if i is not None else v for i in indices]
|
||||||
|
|
||||||
|
|
||||||
|
def flatten(in_list):
|
||||||
|
return list(itertools.chain(*in_list))
|
||||||
|
|
||||||
|
|
||||||
def precompute_indices_and_offsets(block_size, slot_mapping, is_prompt):
|
def precompute_indices_and_offsets(block_size, slot_mapping, is_prompt):
|
||||||
@ -334,13 +345,23 @@ class HpuModelAdapter:
|
|||||||
mask = mask >= metadata.block_usage.unsqueeze(-1)
|
mask = mask >= metadata.block_usage.unsqueeze(-1)
|
||||||
attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(
|
attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(
|
||||||
mask, -math.inf))
|
mask, -math.inf))
|
||||||
block_mapping = torch.nn.functional.one_hot(metadata.block_mapping,
|
block_mapping = torch.nn.functional.one_hot(metadata.block_groups,
|
||||||
num_classes=batch_size)
|
num_classes=batch_size)
|
||||||
block_mapping = block_mapping.to(dtype)
|
block_mapping = block_mapping.to(dtype)
|
||||||
metadata = metadata._replace(block_mapping=block_mapping,
|
metadata = metadata._replace(block_mapping=block_mapping,
|
||||||
attn_bias=attn_bias)
|
attn_bias=attn_bias)
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
|
def _set_block_scales(self, metadata, device):
|
||||||
|
block_mapping = metadata.block_mapping
|
||||||
|
ones = torch.ones((block_mapping.size(0), ),
|
||||||
|
device=device,
|
||||||
|
dtype=block_mapping.dtype)
|
||||||
|
sums = batch2block(block2batch(ones, block_mapping), block_mapping)
|
||||||
|
block_scales = torch.reciprocal(torch.maximum(ones, sums))
|
||||||
|
metadata = metadata._replace(block_scales=block_scales)
|
||||||
|
return metadata
|
||||||
|
|
||||||
def _update_metadata(self, attn_metadata, batch_size, seq_len, device,
|
def _update_metadata(self, attn_metadata, batch_size, seq_len, device,
|
||||||
dtype):
|
dtype):
|
||||||
if attn_metadata.is_prompt:
|
if attn_metadata.is_prompt:
|
||||||
@ -351,6 +372,7 @@ class HpuModelAdapter:
|
|||||||
meta = attn_metadata
|
meta = attn_metadata
|
||||||
attn_metadata = self._set_block_mapping(meta, batch_size, device,
|
attn_metadata = self._set_block_mapping(meta, batch_size, device,
|
||||||
dtype)
|
dtype)
|
||||||
|
attn_metadata = self._set_block_scales(attn_metadata, device)
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
@ -586,6 +608,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
|||||||
self.bucketing_global_state = HPUBucketingGlobalState()
|
self.bucketing_global_state = HPUBucketingGlobalState()
|
||||||
self._setup_buckets()
|
self._setup_buckets()
|
||||||
self._set_gc_threshold()
|
self._set_gc_threshold()
|
||||||
|
self.use_contiguous_pa = envs.VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH
|
||||||
|
|
||||||
def _set_gc_threshold(self) -> None:
|
def _set_gc_threshold(self) -> None:
|
||||||
# Read https://docs.python.org/3/library/gc.html#gc.set_threshold
|
# Read https://docs.python.org/3/library/gc.html#gc.set_threshold
|
||||||
@ -911,6 +934,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
|||||||
block_indices=block_indices,
|
block_indices=block_indices,
|
||||||
block_offsets=block_offsets,
|
block_offsets=block_offsets,
|
||||||
block_scales=None,
|
block_scales=None,
|
||||||
|
block_groups=None,
|
||||||
attn_bias=None,
|
attn_bias=None,
|
||||||
seq_lens_tensor=seq_lens_tensor,
|
seq_lens_tensor=seq_lens_tensor,
|
||||||
num_prefills=real_num_seqs,
|
num_prefills=real_num_seqs,
|
||||||
@ -1008,65 +1032,69 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
|||||||
|
|
||||||
num_decode_tokens = sum(seq_lens)
|
num_decode_tokens = sum(seq_lens)
|
||||||
|
|
||||||
blocks_used = [len(bt) for bt in block_tables if bt]
|
last_block_usage = [
|
||||||
block_list = []
|
slot[0] % self.block_size + 1 for slot in slot_mapping
|
||||||
block_scales = []
|
|
||||||
for i, bt in enumerate(block_tables):
|
|
||||||
block_list.extend(bt)
|
|
||||||
blocks_in_group = len(bt)
|
|
||||||
if blocks_in_group > 0:
|
|
||||||
scale = 1.0 / blocks_in_group
|
|
||||||
block_scales.extend([scale] * blocks_in_group)
|
|
||||||
|
|
||||||
block_mapping_nested: List[List[int]] = [
|
|
||||||
[i] * b_u for i, b_u in enumerate(blocks_used)
|
|
||||||
]
|
]
|
||||||
block_mapping: List[int] = list(
|
block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)]
|
||||||
itertools.chain.from_iterable(block_mapping_nested))
|
block_usage = [[self.block_size] * (len(bt) - 1) + [lbu]
|
||||||
|
for bt, lbu in zip(block_tables, last_block_usage)
|
||||||
|
if bt]
|
||||||
|
|
||||||
last_block = [
|
block_list = flatten(block_tables)
|
||||||
sl % self.block_size + 1 for sl in itertools.chain(*slot_mapping)
|
block_groups = flatten(block_groups)
|
||||||
]
|
block_usage = flatten(block_usage)
|
||||||
block_usage = [[self.block_size] * (b_u - 1) + [lb]
|
|
||||||
for b_u, lb in zip(blocks_used, last_block)]
|
|
||||||
block_usage = list(itertools.chain(*block_usage))
|
|
||||||
|
|
||||||
|
assert len(block_list) == len(block_groups)
|
||||||
|
assert len(block_list) == len(block_usage)
|
||||||
|
|
||||||
|
padding_fn = None
|
||||||
|
if self.use_contiguous_pa:
|
||||||
|
block_bucket_size = max(max(block_list) + 1, len(block_list))
|
||||||
|
block_bucket_size = find_bucket(
|
||||||
|
block_bucket_size,
|
||||||
|
self.bucketing_global_state.decode_block_bucket_cfg)
|
||||||
|
indices: List[Any]
|
||||||
|
indices = [None] * block_bucket_size
|
||||||
|
for i, bid in enumerate(block_list):
|
||||||
|
indices[bid] = i
|
||||||
|
padding_fn = lambda tensor, pad_value: gather_list(
|
||||||
|
tensor, indices, pad_value)
|
||||||
|
else:
|
||||||
block_bucket_size = find_bucket(
|
block_bucket_size = find_bucket(
|
||||||
len(block_list),
|
len(block_list),
|
||||||
self.bucketing_global_state.decode_block_bucket_cfg)
|
self.bucketing_global_state.decode_block_bucket_cfg)
|
||||||
block_list = pad_list(block_list, block_bucket_size, _PAD_BLOCK_ID)
|
padding_fn = lambda tensor, pad_value: pad_list(
|
||||||
block_mapping = pad_list(block_mapping, block_bucket_size, -1)
|
tensor, block_bucket_size, pad_value)
|
||||||
block_usage = pad_list(block_usage, block_bucket_size, 1)
|
|
||||||
block_scales = pad_list(block_scales, block_bucket_size, 0.0)
|
block_list = padding_fn(block_list, _PAD_BLOCK_ID)
|
||||||
|
block_groups = padding_fn(block_groups, -1)
|
||||||
|
block_usage = padding_fn(block_usage, 1)
|
||||||
|
|
||||||
block_list = torch.tensor(block_list,
|
block_list = torch.tensor(block_list,
|
||||||
dtype=torch.int,
|
dtype=torch.int,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
block_mapping = torch.tensor(block_mapping,
|
block_groups = torch.tensor(block_groups,
|
||||||
dtype=torch.long,
|
dtype=torch.int,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
block_usage = torch.tensor(block_usage,
|
block_usage = torch.tensor(block_usage,
|
||||||
dtype=self.model_config.dtype,
|
dtype=self.model_config.dtype,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
|
|
||||||
slot_mapping = torch.tensor(slot_mapping,
|
slot_mapping = torch.tensor(slot_mapping,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
|
|
||||||
block_indices, block_offsets = precompute_indices_and_offsets(
|
block_indices, block_offsets = precompute_indices_and_offsets(
|
||||||
self.block_size, slot_mapping, False)
|
self.block_size, slot_mapping, False)
|
||||||
block_scales = torch.tensor(block_scales,
|
|
||||||
dtype=self.model_config.dtype,
|
|
||||||
device=self.device)
|
|
||||||
|
|
||||||
attn_metadata = self.attn_backend.make_metadata(
|
attn_metadata = self.attn_backend.make_metadata(
|
||||||
is_prompt=False,
|
is_prompt=False,
|
||||||
block_list=block_list,
|
block_list=block_list,
|
||||||
block_mapping=block_mapping,
|
block_mapping=None,
|
||||||
block_usage=block_usage,
|
block_usage=block_usage,
|
||||||
block_indices=block_indices,
|
block_indices=block_indices,
|
||||||
block_offsets=block_offsets,
|
block_offsets=block_offsets,
|
||||||
block_scales=block_scales,
|
block_scales=None,
|
||||||
|
block_groups=block_groups,
|
||||||
attn_bias=None,
|
attn_bias=None,
|
||||||
seq_lens_tensor=None,
|
seq_lens_tensor=None,
|
||||||
num_prefills=0,
|
num_prefills=0,
|
||||||
@ -1280,7 +1308,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
|||||||
attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [
|
attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [
|
||||||
'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping',
|
'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping',
|
||||||
'block_usage', 'slot_mapping', 'is_prompt', 'block_indices',
|
'block_usage', 'slot_mapping', 'is_prompt', 'block_indices',
|
||||||
'block_offsets', 'block_scales'
|
'block_offsets', 'block_scales', 'block_groups'
|
||||||
])
|
])
|
||||||
return attention_metadata
|
return attention_metadata
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user