diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 308cb3e85e27b..ba08e6f81f7fe 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -195,7 +195,9 @@ class Llama4Attention(nn.Module): is_neox_style=is_neox_style, ) if not self.nope else None - attn_cls = Attention if self.nope else ChunkedLocalAttention + use_chunked_local_attn = not self.nope and config.attention_chunk_size + attn_cls = (ChunkedLocalAttention + if use_chunked_local_attn else Attention) self.attn = attn_cls( self.num_heads, self.head_dim, @@ -206,7 +208,7 @@ class Llama4Attention(nn.Module): prefix=f"{prefix}.attn", **({ "attention_chunk_size": config.attention_chunk_size - } if not self.nope else {})) + } if use_chunked_local_attn else {})) def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor: floor = torch.floor((positions + 1.0) / self.floor_scale)