Fix typos

Signed-off-by: yuantao <2422264527@qq.com>
This commit is contained in:
yuantao 2025-12-13 17:59:28 +08:00
parent a0563e7368
commit a7430ab479
3 changed files with 3 additions and 8 deletions

View File

@ -66,9 +66,6 @@ def create_static_sink_attention_backend(
common_attn_metadata.seq_lens[:] = (
common_attn_metadata.seq_lens + self.sink_len
)
common_attn_metadata.seq_lens_cpu = (
common_attn_metadata.seq_lens_cpu + self.sink_len
)
common_attn_metadata.max_seq_len = (
common_attn_metadata.max_seq_len + self.sink_len
)
@ -152,7 +149,7 @@ class StaticSinkAttention(Attention):
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output_shape: torch.size | None = None,
output_shape: torch.Size | None = None,
) -> torch.Tensor:
assert self.sink_key is not None and self.sink_value is not None, (
"sink_key and sink_value have not been prepared"

View File

@ -238,7 +238,7 @@ def reshape_and_cache_kernel_flash_diffkv(
# [TILE_SIZE]
value_load = tl.load(
value_ptr + src_value_idx + tile_offs, mask=tile_offs * head_size_v
value_ptr + src_value_idx + tile_offs, mask=tile_offs < head_size_v
)
if FP8_KV_CACHE:
if value_load.dtype.is_fp8():
@ -322,8 +322,6 @@ def triton_reshape_and_cache_flash_diffkv(
else: # cuda
num_stages = 10
num_warps = 16
if torch.cuda.get_device_capability(key.device)[0] < 9:
TILE_SIZE = min(512, TILE_SIZE)
# TODO(ngl): maybe replace with static launch grid to avoid overhead if
# using cudagraphs

View File

@ -778,10 +778,10 @@ class OpenPanguSinkAttention(nn.Module):
quant_config: QuantizationConfig | None,
) -> None:
is_neox_style = False
rope_parameters = {"partial_rotary_factor": self.qk_rope_dim / self.head_dim}
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.qk_rope_dim,
max_position=self.max_position_embeddings,
rope_parameters=rope_parameters,
is_neox_style=is_neox_style,