diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 1ad5e6e8e4e17..9eb533685dbd2 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -118,12 +118,8 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): self.matmul_av = Matmul() self.batch2block_matmul = Matmul() self.block2batch_matmul = Matmul() - # NOTE(kzawora): Contiguous PA is off until model runner supports it self.k_cache = VLLMKVCache() - self.k_cache.use_contiguous_pa = False self.v_cache = VLLMKVCache() - self.v_cache.use_contiguous_pa = False - # NOTE(kzawora): Pipelined PA is off until model runner supports it ops.pa_impl = ops.pa self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads @@ -249,7 +245,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): block_mapping=attn_metadata.block_mapping, block_bias=attn_metadata.attn_bias, block_scales=attn_metadata.block_scales, - block_groups=None, + block_groups=attn_metadata.block_groups, scale=self.scale, matmul_qk_op=self.matmul_qk, matmul_av_op=self.matmul_av, diff --git a/vllm/attention/ops/hpu_paged_attn.py b/vllm/attention/ops/hpu_paged_attn.py index 8bb536343ed8c..49ea420d092cc 100644 --- a/vllm/attention/ops/hpu_paged_attn.py +++ b/vllm/attention/ops/hpu_paged_attn.py @@ -23,6 +23,7 @@ class HPUPagedAttentionMetadata: block_indices: Optional[torch.Tensor] block_offsets: Optional[torch.Tensor] block_scales: Optional[torch.Tensor] + block_groups: Optional[torch.Tensor] class HPUPagedAttention: diff --git a/vllm/envs.py b/vllm/envs.py index f8a18cc662ab0..45547416314fb 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -89,6 +89,7 @@ if TYPE_CHECKING: VLLM_RAY_PER_WORKER_GPUS: float = 1.0 VLLM_RAY_BUNDLE_INDICES: str = "" VLLM_CUDART_SO_PATH: Optional[str] = None + VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH: bool = True def get_default_cache_root(): @@ -585,6 +586,13 @@ environment_variables: Dict[str, Callable[[], Any]] = { # specify the path through environment variable VLLM_CUDART_SO_PATH. "VLLM_CUDART_SO_PATH": lambda: os.getenv("VLLM_CUDART_SO_PATH", None), + + # Contiguous cache fetching to avoid using costly gather operation on + # Gaudi3. This is only applicable to HPU contiguous cache. If set to true, + # contiguous cache fetch will be used. + "VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH": + lambda: os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() in + ("1", "true"), } # end-env-vars-definition diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 774049a5281ee..fe7c776d0a238 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -25,9 +25,11 @@ import habana_frameworks.torch.internal.bridge_config as bc import torch import torch.nn as nn from vllm_hpu_extension.ops import LoraMask as LoraMask +from vllm_hpu_extension.ops import batch2block, block2batch from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler, HabanaMemoryProfiler, format_bytes) +import vllm.envs as envs from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import DeviceConfig, VllmConfig from vllm.distributed.parallel_state import get_world_group @@ -260,10 +262,19 @@ def setup_profiler(): return profiler -def pad_list(list, k, v): - target_len = round_up(len(list), k) - padding = target_len - len(list) - return list + [v] * padding +def pad_list(input, k, v): + input_len = len(input) + target_len = round_up(input_len, k) + padding = target_len - input_len + return input + [v] * padding + + +def gather_list(input, indices, v): + return [input[i] if i is not None else v for i in indices] + + +def flatten(in_list): + return list(itertools.chain(*in_list)) def precompute_indices_and_offsets(block_size, slot_mapping, is_prompt): @@ -334,13 +345,23 @@ class HpuModelAdapter: mask = mask >= metadata.block_usage.unsqueeze(-1) attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( mask, -math.inf)) - block_mapping = torch.nn.functional.one_hot(metadata.block_mapping, + block_mapping = torch.nn.functional.one_hot(metadata.block_groups, num_classes=batch_size) block_mapping = block_mapping.to(dtype) metadata = metadata._replace(block_mapping=block_mapping, attn_bias=attn_bias) return metadata + def _set_block_scales(self, metadata, device): + block_mapping = metadata.block_mapping + ones = torch.ones((block_mapping.size(0), ), + device=device, + dtype=block_mapping.dtype) + sums = batch2block(block2batch(ones, block_mapping), block_mapping) + block_scales = torch.reciprocal(torch.maximum(ones, sums)) + metadata = metadata._replace(block_scales=block_scales) + return metadata + def _update_metadata(self, attn_metadata, batch_size, seq_len, device, dtype): if attn_metadata.is_prompt: @@ -351,6 +372,7 @@ class HpuModelAdapter: meta = attn_metadata attn_metadata = self._set_block_mapping(meta, batch_size, device, dtype) + attn_metadata = self._set_block_scales(attn_metadata, device) return attn_metadata def forward(self, *args, **kwargs): @@ -586,6 +608,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): self.bucketing_global_state = HPUBucketingGlobalState() self._setup_buckets() self._set_gc_threshold() + self.use_contiguous_pa = envs.VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH def _set_gc_threshold(self) -> None: # Read https://docs.python.org/3/library/gc.html#gc.set_threshold @@ -911,6 +934,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): block_indices=block_indices, block_offsets=block_offsets, block_scales=None, + block_groups=None, attn_bias=None, seq_lens_tensor=seq_lens_tensor, num_prefills=real_num_seqs, @@ -1008,65 +1032,69 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): num_decode_tokens = sum(seq_lens) - blocks_used = [len(bt) for bt in block_tables if bt] - block_list = [] - block_scales = [] - for i, bt in enumerate(block_tables): - block_list.extend(bt) - blocks_in_group = len(bt) - if blocks_in_group > 0: - scale = 1.0 / blocks_in_group - block_scales.extend([scale] * blocks_in_group) - - block_mapping_nested: List[List[int]] = [ - [i] * b_u for i, b_u in enumerate(blocks_used) + last_block_usage = [ + slot[0] % self.block_size + 1 for slot in slot_mapping ] - block_mapping: List[int] = list( - itertools.chain.from_iterable(block_mapping_nested)) + block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)] + block_usage = [[self.block_size] * (len(bt) - 1) + [lbu] + for bt, lbu in zip(block_tables, last_block_usage) + if bt] - last_block = [ - sl % self.block_size + 1 for sl in itertools.chain(*slot_mapping) - ] - block_usage = [[self.block_size] * (b_u - 1) + [lb] - for b_u, lb in zip(blocks_used, last_block)] - block_usage = list(itertools.chain(*block_usage)) + block_list = flatten(block_tables) + block_groups = flatten(block_groups) + block_usage = flatten(block_usage) - block_bucket_size = find_bucket( - len(block_list), - self.bucketing_global_state.decode_block_bucket_cfg) - block_list = pad_list(block_list, block_bucket_size, _PAD_BLOCK_ID) - block_mapping = pad_list(block_mapping, block_bucket_size, -1) - block_usage = pad_list(block_usage, block_bucket_size, 1) - block_scales = pad_list(block_scales, block_bucket_size, 0.0) + assert len(block_list) == len(block_groups) + assert len(block_list) == len(block_usage) + + padding_fn = None + if self.use_contiguous_pa: + block_bucket_size = max(max(block_list) + 1, len(block_list)) + block_bucket_size = find_bucket( + block_bucket_size, + self.bucketing_global_state.decode_block_bucket_cfg) + indices: List[Any] + indices = [None] * block_bucket_size + for i, bid in enumerate(block_list): + indices[bid] = i + padding_fn = lambda tensor, pad_value: gather_list( + tensor, indices, pad_value) + else: + block_bucket_size = find_bucket( + len(block_list), + self.bucketing_global_state.decode_block_bucket_cfg) + padding_fn = lambda tensor, pad_value: pad_list( + tensor, block_bucket_size, pad_value) + + block_list = padding_fn(block_list, _PAD_BLOCK_ID) + block_groups = padding_fn(block_groups, -1) + block_usage = padding_fn(block_usage, 1) block_list = torch.tensor(block_list, dtype=torch.int, device=self.device) - block_mapping = torch.tensor(block_mapping, - dtype=torch.long, - device=self.device) + block_groups = torch.tensor(block_groups, + dtype=torch.int, + device=self.device) block_usage = torch.tensor(block_usage, dtype=self.model_config.dtype, device=self.device) - slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) block_indices, block_offsets = precompute_indices_and_offsets( self.block_size, slot_mapping, False) - block_scales = torch.tensor(block_scales, - dtype=self.model_config.dtype, - device=self.device) attn_metadata = self.attn_backend.make_metadata( is_prompt=False, block_list=block_list, - block_mapping=block_mapping, + block_mapping=None, block_usage=block_usage, block_indices=block_indices, block_offsets=block_offsets, - block_scales=block_scales, + block_scales=None, + block_groups=block_groups, attn_bias=None, seq_lens_tensor=None, num_prefills=0, @@ -1280,7 +1308,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [ 'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping', 'block_usage', 'slot_mapping', 'is_prompt', 'block_indices', - 'block_offsets', 'block_scales' + 'block_offsets', 'block_scales', 'block_groups' ]) return attention_metadata