[XPU]Fix xpu spec decoding UTs, avoid using cuda graph (#25847)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Kunshang Ji 2025-09-29 13:15:10 +08:00 committed by GitHub
parent 219cfbe7f6
commit 143844fa43
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 7 additions and 4 deletions

View File

@ -42,7 +42,7 @@ docker run \
pytest -v -s v1/sample --ignore=v1/sample/test_logprobs.py --ignore=v1/sample/test_logprobs_e2e.py
pytest -v -s v1/worker --ignore=v1/worker/test_gpu_model_runner.py
pytest -v -s v1/structured_output
pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_eagle.py --ignore=v1/spec_decode/test_tree_attention.py
pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_tree_attention.py
pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_shared_storage_connector.py
pytest -v -s v1/test_metrics
pytest -v -s v1/test_serial_utils.py

View File

@ -1143,6 +1143,8 @@ def get_attn_backend_list_based_on_platform() -> list[str]:
print("Skip FLASH_ATTN on ROCm as aiter is not installed")
return attn_backend_list
elif current_platform.is_xpu():
return ["FLASH_ATTN", "TRITON_ATTN"]
else:
raise ValueError("Unsupported platform")

View File

@ -72,12 +72,13 @@ class EagleProposer:
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
self.use_cuda_graph = (self.vllm_config.compilation_config.level
self.use_cuda_graph = (not current_platform.is_xpu()
and self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE and
not self.vllm_config.model_config.enforce_eager)
self.cudagraph_batch_sizes = list(
reversed(
self.vllm_config.compilation_config.cudagraph_capture_sizes))
reversed(self.vllm_config.compilation_config.
cudagraph_capture_sizes)) if self.use_cuda_graph else []
# persistent buffers for cuda graph
self.input_ids = torch.zeros(self.max_num_tokens,