mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 05:25:02 +08:00
[bug fix] Fix llama4 spec decoding (#22691)
Signed-off-by: qizixi <qizixi@meta.com> Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com>
This commit is contained in:
parent
31fd3265c8
commit
5bfe0dea7a
@ -195,7 +195,9 @@ class Llama4Attention(nn.Module):
|
|||||||
is_neox_style=is_neox_style,
|
is_neox_style=is_neox_style,
|
||||||
) if not self.nope else None
|
) if not self.nope else None
|
||||||
|
|
||||||
attn_cls = Attention if self.nope else ChunkedLocalAttention
|
use_chunked_local_attn = not self.nope and config.attention_chunk_size
|
||||||
|
attn_cls = (ChunkedLocalAttention
|
||||||
|
if use_chunked_local_attn else Attention)
|
||||||
self.attn = attn_cls(
|
self.attn = attn_cls(
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
@ -206,7 +208,7 @@ class Llama4Attention(nn.Module):
|
|||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
**({
|
**({
|
||||||
"attention_chunk_size": config.attention_chunk_size
|
"attention_chunk_size": config.attention_chunk_size
|
||||||
} if not self.nope else {}))
|
} if use_chunked_local_attn else {}))
|
||||||
|
|
||||||
def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
|
def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
|
||||||
floor = torch.floor((positions + 1.0) / self.floor_scale)
|
floor = torch.floor((positions + 1.0) / self.floor_scale)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user