diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index deabf7a0b0770..6962810bdd09f 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -945,36 +945,27 @@ def unified_attention_with_output( output_block_scale: torch.Tensor | None = None, ) -> None: attn_metadata, self, kv_cache = get_attention_context(layer_name) - if sink_key is None and sink_value is None: - self.impl.forward( - self, - query, - key, - value, - kv_cache, - attn_metadata, - output=output, - output_scale=output_scale, - output_block_scale=output_block_scale, - ) - else: + kwargs = {} + if sink_key is not None or sink_value is not None: assert sink_key is not None and sink_value is not None, ( "Currently, it is only supported when " "sink_key and sink_value are both not None" ) - self.impl.forward( - self, - query, - key, - value, - kv_cache, - attn_metadata, - output=output, - output_scale=output_scale, - output_block_scale=output_block_scale, - sink_key=sink_key, - sink_value=sink_value, - ) + kwargs["sink_key"] = sink_key + kwargs["sink_value"] = sink_value + + self.impl.forward( + self, + query, + key, + value, + kv_cache, + attn_metadata, + output=output, + output_scale=output_scale, + output_block_scale=output_block_scale, + **kwargs, + ) def unified_attention_with_output_fake( diff --git a/vllm/model_executor/models/openpangu.py b/vllm/model_executor/models/openpangu.py index f7249475b49de..f46fd3c7f319d 100644 --- a/vllm/model_executor/models/openpangu.py +++ b/vllm/model_executor/models/openpangu.py @@ -854,15 +854,13 @@ class OpenPanguSinkAttention(nn.Module): ) else: self.param_sink_value = torch.zeros( - torch.empty( - ( - self.param_sink_number, - self.num_kv_heads, - self.v_channels, - ), - device=torch.cuda.current_device(), - dtype=config.torch_dtype, - ) + ( + self.param_sink_number, + self.num_kv_heads, + self.v_channels, + ), + device=torch.cuda.current_device(), + dtype=config.torch_dtype, ) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): diff --git a/vllm/v1/attention/backends/flash_sink_attn.py b/vllm/v1/attention/backends/flash_sink_attn.py index 580f0f544102c..e532a69dcb4a9 100644 --- a/vllm/v1/attention/backends/flash_sink_attn.py +++ b/vllm/v1/attention/backends/flash_sink_attn.py @@ -614,7 +614,8 @@ class FlashSinkAttentionImpl(AttentionImpl): assert sink_len % block_size == 0 num_sink_blocks = sink_len // block_size sink_kv_slot_mapping = torch.arange( - sink_len, + block_size, + sink_len + block_size, device=attn_metadata.slot_mapping.device, dtype=attn_metadata.slot_mapping.dtype, ) @@ -654,7 +655,10 @@ class FlashSinkAttentionImpl(AttentionImpl): block_table = attn_metadata.block_table scheduler_metadata = attn_metadata.scheduler_metadata sink_block_table = torch.arange( - num_sink_blocks, device=block_table.device, dtype=block_table.dtype + 1, + num_sink_blocks + 1, + device=block_table.device, + dtype=block_table.dtype, ) sink_block_table = sink_block_table[None, :].expand( block_table.shape[0], -1 @@ -939,13 +943,8 @@ def cascade_attention( descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2]) num_sink_blocks = sink_len // block_size - block_table = block_table + num_sink_blocks - block_table[block_table == num_sink_blocks] = 0 - sink_block_table = ( - torch.arange( - num_sink_blocks, device=block_table.device, dtype=block_table.dtype - ) - + 1 + sink_block_table = torch.arange( + 1, num_sink_blocks + 1, device=block_table.device, dtype=block_table.dtype ) sink_block_table = sink_block_table[None, :].expand(block_table.shape[0], -1) block_table = torch.cat((sink_block_table, block_table), dim=1)