mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 12:19:22 +08:00
[ROCm] Disable chunked prefill/prefix caching when running MLA on non-cuda platforms (#13844)
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
e656f638de
commit
1d35662e6d
@ -232,6 +232,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|||||||
from vllm.model_executor.layers.rotary_embedding import (
|
from vllm.model_executor.layers.rotary_embedding import (
|
||||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
||||||
from vllm.multimodal import MultiModalPlaceholderMap
|
from vllm.multimodal import MultiModalPlaceholderMap
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down
|
from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -1371,18 +1372,35 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
|
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
|
||||||
value=0)
|
value=0)
|
||||||
|
|
||||||
output = self.flash_attn_varlen_func(
|
if has_context:
|
||||||
q=q,
|
if not current_platform.is_cuda():
|
||||||
k=k,
|
raise NotImplementedError(
|
||||||
v=v_padded,
|
"Chunked Prefill for MLA is not currently supported on"
|
||||||
cu_seqlens_q=prefill_metadata.query_start_loc,
|
"non-cuda platforms")
|
||||||
cu_seqlens_k=prefill_metadata.query_start_loc,
|
output = self.flash_attn_varlen_func(
|
||||||
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
|
q=q,
|
||||||
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
|
k=k,
|
||||||
softmax_scale=self.scale,
|
v=v_padded,
|
||||||
causal=True,
|
cu_seqlens_q=prefill_metadata.query_start_loc,
|
||||||
return_softmax_lse=has_context,
|
cu_seqlens_k=prefill_metadata.query_start_loc,
|
||||||
)
|
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
|
||||||
|
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=True,
|
||||||
|
return_softmax_lse=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output = self.flash_attn_varlen_func(
|
||||||
|
q=q,
|
||||||
|
k=k,
|
||||||
|
v=v_padded,
|
||||||
|
cu_seqlens_q=prefill_metadata.query_start_loc,
|
||||||
|
cu_seqlens_k=prefill_metadata.query_start_loc,
|
||||||
|
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
|
||||||
|
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
if has_context:
|
if has_context:
|
||||||
suffix_output, suffix_lse = output
|
suffix_output, suffix_lse = output
|
||||||
|
|||||||
@ -3422,6 +3422,20 @@ class VllmConfig:
|
|||||||
"Disabling `torch.compile`.")
|
"Disabling `torch.compile`.")
|
||||||
self.compilation_config.level = CompilationLevel.NO_COMPILATION
|
self.compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||||
|
|
||||||
|
if self.model_config and self.model_config.use_mla and \
|
||||||
|
not current_platform.is_cuda():
|
||||||
|
logger.info(
|
||||||
|
"MLA is enabled on a non-cuda platform; forcing chunked "
|
||||||
|
"prefill and prefix caching to be disabled.")
|
||||||
|
self.scheduler_config.enable_chunked_prefill = False
|
||||||
|
self.scheduler_config.chunked_prefill_enabled = False
|
||||||
|
self.scheduler_config.max_num_batched_tokens = max(
|
||||||
|
self.scheduler_config.max_model_len,
|
||||||
|
_DEFAULT_MAX_NUM_BATCHED_TOKENS)
|
||||||
|
|
||||||
|
if self.cache_config is not None:
|
||||||
|
self.cache_config.enable_prefix_caching = False
|
||||||
|
|
||||||
current_platform.check_and_update_config(self)
|
current_platform.check_and_update_config(self)
|
||||||
|
|
||||||
if not self.instance_id:
|
if not self.instance_id:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user