[v1][CP] Improve DCP/PCP/MTP error messages with actionable guidance

Replace cryptic AssertionErrors with informative RuntimeErrors that:
- Explain what DCP (Decode Context Parallel) and PCP (Prefill Context
  Parallel) are
- List compatible attention backends
- Provide environment variable instructions (VLLM_ATTENTION_BACKEND)
- Include documentation links

Fixes #28407

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Signed-off-by: yurekami <249254018+yurekami@users.noreply.github.com>
This commit is contained in:
yurekami 2025-12-18 15:58:13 +09:00
parent a85724bd6e
commit 3c8358c328

View File

@ -21,22 +21,56 @@ def check_attention_cp_compatibility(vllm_config: VllmConfig) -> None:
layer_impl = getattr(layer, "impl", None)
if layer_impl is None:
continue
if vllm_config.speculative_config is not None and interleave_size > 1:
assert layer_impl.supports_mtp_with_cp_non_trivial_interleave_size, (
"MTP with cp_kv_cache_interleave_size > 1 is not "
f"supported in {layer_impl.__class__.__name__}."
)
if dcp_size > 1:
assert layer_impl.need_to_return_lse_for_decode, (
"DCP requires attention impls to return"
" the softmax lse for decode, but the impl "
f"{layer_impl.__class__.__name__} "
"does not return the softmax lse for decode."
supports_mtp = layer_impl.supports_mtp_with_cp_non_trivial_interleave_size
if (
vllm_config.speculative_config is not None
and interleave_size > 1
and not supports_mtp
):
raise RuntimeError(
f"Multi-Token Prediction (MTP) with "
f"cp_kv_cache_interleave_size > 1 is not supported by the "
f"current attention backend "
f"'{layer_impl.__class__.__name__}'.\n\n"
f"To resolve this issue, try one of the following:\n"
f" 1. Use a different attention backend by setting:\n"
f" export VLLM_ATTENTION_BACKEND=<backend>\n"
f" 2. Set cp_kv_cache_interleave_size to 1\n"
f" 3. Disable speculative decoding"
)
if pcp_size > 1:
assert layer_impl.supports_pcp, (
"PCP requires attention impls' support, "
f"but the impl {layer_impl.__class__.__name__} "
"does not support PCP."
if dcp_size > 1 and not layer_impl.need_to_return_lse_for_decode:
raise RuntimeError(
f"Decode Context Parallel (DCP) requires an attention "
f"backend that supports returning softmax LSE (log-sum-exp) "
f"for decode operations. The current backend "
f"'{layer_impl.__class__.__name__}' does not support this "
f"feature.\n\n"
f"To resolve this issue, try one of the following:\n"
f" 1. Use a compatible attention backend by setting:\n"
f" export VLLM_ATTENTION_BACKEND=<backend>\n"
f" Compatible backends: FLASH_ATTN, FLASHINFER, "
f"TRITON_MLA, FLASH_MLA, FLASH_ATTN_MLA, CUTLASS_MLA\n"
f" 2. Disable DCP by removing the "
f"--decode-context-parallel-size flag\n\n"
f"For more information, see:\n"
f" https://docs.vllm.ai/en/latest/serving/"
f"distributed_serving.html"
)
if pcp_size > 1 and not layer_impl.supports_pcp:
raise RuntimeError(
f"Prefill Context Parallel (PCP) requires an attention "
f"backend that supports PCP. The current backend "
f"'{layer_impl.__class__.__name__}' does not support this "
f"feature.\n\n"
f"To resolve this issue, try one of the following:\n"
f" 1. Use a compatible attention backend by setting:\n"
f" export VLLM_ATTENTION_BACKEND=<backend>\n"
f" 2. Disable PCP by removing the "
f"--prefill-context-parallel-size flag\n\n"
f"For more information, see:\n"
f" https://docs.vllm.ai/en/latest/serving/"
f"distributed_serving.html"
)