mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-01 07:01:50 +08:00
[Attention] Clean up iRoPE in V1 (#21188)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
parent
6ece16c4fe
commit
304dce7ec0
@ -137,6 +137,13 @@ class Attention(nn.Module):
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
# For v1 we have backend agnostic iRoPE (local chunked attention)
|
||||
# we have to store the flag on the layer so gpu model runner can
|
||||
# set KVSpec appropriately (and pop it so it doesnt get passed to
|
||||
# the backends)
|
||||
if envs.VLLM_USE_V1:
|
||||
self.use_irope = extra_impl_args.pop("use_irope", False)
|
||||
|
||||
quant_method = quant_config.get_quant_method(
|
||||
self, prefix=prefix) if quant_config else None
|
||||
if quant_method is not None and not isinstance(
|
||||
|
||||
@ -446,17 +446,12 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
raise NotImplementedError("KV sharing is not supported in V0.")
|
||||
if logits_soft_cap is not None:
|
||||
logger.warning_once("Torch SPDA does not support logits soft cap. "
|
||||
"Outputs may be slightly off.")
|
||||
if use_irope:
|
||||
logger.warning_once(
|
||||
"Using irope in Torch SPDA is not supported yet, it will fall"
|
||||
" back to global attention for long context.")
|
||||
self.paged_attn_impl = _get_paged_attn_impl()
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
|
||||
@ -352,7 +352,6 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
@ -381,7 +380,6 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashAttentionImpl")
|
||||
self.use_irope = use_irope
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype) \
|
||||
and not flash_attn_supports_fp8():
|
||||
|
||||
@ -493,7 +493,6 @@ class FlashInferImpl(AttentionImpl):
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[int] = None,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
@ -509,7 +508,6 @@ class FlashInferImpl(AttentionImpl):
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
self.use_irope = use_irope
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
|
||||
@ -148,12 +148,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[int] = None,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if use_irope:
|
||||
logger.warning_once(
|
||||
"Using irope in Pallas is not supported yet, it will fall back "
|
||||
"to global attention for long context.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
|
||||
@ -337,7 +337,6 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[int] = None,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
@ -367,7 +366,6 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashAttentionImpl")
|
||||
self.use_irope = use_irope
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"AiterFlashAttention does not support fp8 kv-cache on this "
|
||||
|
||||
@ -72,9 +72,6 @@ class TritonAttentionMetadataBuilder(
|
||||
vllm_config.parallel_config)
|
||||
self.headdim = model_config.get_head_size()
|
||||
|
||||
self.attention_chunk_size = getattr(vllm_config.scheduler_config,
|
||||
'attention_chunk_size', None)
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
) -> TritonAttentionMetadata:
|
||||
@ -208,7 +205,6 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[int] = None,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
@ -228,8 +224,6 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
self.use_irope = use_irope
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
TritonAttentionBackend.validate_head_size(head_size)
|
||||
|
||||
@ -2702,8 +2702,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# TODO: Support other attention modules, e.g., cross-attention
|
||||
if attn_module.attn_type == AttentionType.DECODER:
|
||||
use_local_attention = (self.attention_chunk_size is not None
|
||||
and getattr(attn_module.impl,
|
||||
"use_irope", False))
|
||||
and attn_module.use_irope)
|
||||
if attn_module.sliding_window is not None:
|
||||
kv_cache_spec[layer_name] = SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
@ -2716,13 +2715,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
"attention module can not be with ",
|
||||
"both local attention and sliding window")
|
||||
elif use_local_attention:
|
||||
kv_cache_spec[layer_name] = (ChunkedLocalAttentionSpec(
|
||||
kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
attention_chunk_size=self.attention_chunk_size,
|
||||
use_mla=use_mla))
|
||||
use_mla=use_mla)
|
||||
else:
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
|
||||
@ -519,6 +519,10 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
continue
|
||||
|
||||
if attn_module.attn_type == AttentionType.DECODER:
|
||||
if attn_module.use_irope:
|
||||
logger.warning_once(
|
||||
"Using irope in Pallas is not supported yet, it "
|
||||
"will fall back to global attention for long context.")
|
||||
if attn_module.sliding_window is not None:
|
||||
kv_cache_spec[layer_name] = SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user