mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:04:58 +08:00
[XPU]Fix xpu spec decoding UTs, avoid using cuda graph (#25847)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
parent
219cfbe7f6
commit
143844fa43
@ -42,7 +42,7 @@ docker run \
|
||||
pytest -v -s v1/sample --ignore=v1/sample/test_logprobs.py --ignore=v1/sample/test_logprobs_e2e.py
|
||||
pytest -v -s v1/worker --ignore=v1/worker/test_gpu_model_runner.py
|
||||
pytest -v -s v1/structured_output
|
||||
pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_eagle.py --ignore=v1/spec_decode/test_tree_attention.py
|
||||
pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_tree_attention.py
|
||||
pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_shared_storage_connector.py
|
||||
pytest -v -s v1/test_metrics
|
||||
pytest -v -s v1/test_serial_utils.py
|
||||
|
||||
@ -1143,6 +1143,8 @@ def get_attn_backend_list_based_on_platform() -> list[str]:
|
||||
print("Skip FLASH_ATTN on ROCm as aiter is not installed")
|
||||
|
||||
return attn_backend_list
|
||||
elif current_platform.is_xpu():
|
||||
return ["FLASH_ATTN", "TRITON_ATTN"]
|
||||
else:
|
||||
raise ValueError("Unsupported platform")
|
||||
|
||||
|
||||
@ -72,12 +72,13 @@ class EagleProposer:
|
||||
|
||||
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
|
||||
|
||||
self.use_cuda_graph = (self.vllm_config.compilation_config.level
|
||||
self.use_cuda_graph = (not current_platform.is_xpu()
|
||||
and self.vllm_config.compilation_config.level
|
||||
== CompilationLevel.PIECEWISE and
|
||||
not self.vllm_config.model_config.enforce_eager)
|
||||
self.cudagraph_batch_sizes = list(
|
||||
reversed(
|
||||
self.vllm_config.compilation_config.cudagraph_capture_sizes))
|
||||
reversed(self.vllm_config.compilation_config.
|
||||
cudagraph_capture_sizes)) if self.use_cuda_graph else []
|
||||
|
||||
# persistent buffers for cuda graph
|
||||
self.input_ids = torch.zeros(self.max_num_tokens,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user