[Bugfix] Misaligned params in TreeAttentionImpl (#22226)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-08-05 13:40:09 +08:00 committed by GitHub
parent 4b3e4474d7
commit cdfd6871a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,7 +4,7 @@
import ast import ast
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Optional
import torch import torch
@ -313,15 +313,11 @@ class TreeAttentionImpl(AttentionImpl):
alibi_slopes: Optional[list[float]], alibi_slopes: Optional[list[float]],
sliding_window: Optional[int], sliding_window: Optional[int],
kv_cache_dtype: str, kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None, logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER, attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None, kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False, use_irope: bool = False,
) -> None: ) -> None:
if blocksparse_params is not None:
raise ValueError(
"TreeAttention does not support block-sparse attention.")
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)