diff --git a/vllm/attention/layers/static_sink_attention.py b/vllm/attention/layers/static_sink_attention.py index 7687651ee682b..beb9add10024b 100644 --- a/vllm/attention/layers/static_sink_attention.py +++ b/vllm/attention/layers/static_sink_attention.py @@ -66,9 +66,6 @@ def create_static_sink_attention_backend( common_attn_metadata.seq_lens[:] = ( common_attn_metadata.seq_lens + self.sink_len ) - common_attn_metadata.seq_lens_cpu = ( - common_attn_metadata.seq_lens_cpu + self.sink_len - ) common_attn_metadata.max_seq_len = ( common_attn_metadata.max_seq_len + self.sink_len ) @@ -152,7 +149,7 @@ class StaticSinkAttention(Attention): query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - output_shape: torch.size | None = None, + output_shape: torch.Size | None = None, ) -> torch.Tensor: assert self.sink_key is not None and self.sink_value is not None, ( "sink_key and sink_value have not been prepared" diff --git a/vllm/attention/ops/triton_reshape_and_cache_flash.py b/vllm/attention/ops/triton_reshape_and_cache_flash.py index c119033896ec6..a383de0ac76cc 100644 --- a/vllm/attention/ops/triton_reshape_and_cache_flash.py +++ b/vllm/attention/ops/triton_reshape_and_cache_flash.py @@ -238,7 +238,7 @@ def reshape_and_cache_kernel_flash_diffkv( # [TILE_SIZE] value_load = tl.load( - value_ptr + src_value_idx + tile_offs, mask=tile_offs * head_size_v + value_ptr + src_value_idx + tile_offs, mask=tile_offs < head_size_v ) if FP8_KV_CACHE: if value_load.dtype.is_fp8(): @@ -322,8 +322,6 @@ def triton_reshape_and_cache_flash_diffkv( else: # cuda num_stages = 10 num_warps = 16 - if torch.cuda.get_device_capability(key.device)[0] < 9: - TILE_SIZE = min(512, TILE_SIZE) # TODO(ngl): maybe replace with static launch grid to avoid overhead if # using cudagraphs diff --git a/vllm/model_executor/models/openpangu.py b/vllm/model_executor/models/openpangu.py index 8e4bb62e137a8..43bfa4f8324cc 100644 --- a/vllm/model_executor/models/openpangu.py +++ b/vllm/model_executor/models/openpangu.py @@ -778,10 +778,10 @@ class OpenPanguSinkAttention(nn.Module): quant_config: QuantizationConfig | None, ) -> None: is_neox_style = False + rope_parameters = {"partial_rotary_factor": self.qk_rope_dim / self.head_dim} self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.qk_rope_dim, max_position=self.max_position_embeddings, rope_parameters=rope_parameters, is_neox_style=is_neox_style,