mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 05:15:01 +08:00
[XPU]Fix xpu spec decoding UTs, avoid using cuda graph (#25847)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
parent
219cfbe7f6
commit
143844fa43
@ -42,7 +42,7 @@ docker run \
|
|||||||
pytest -v -s v1/sample --ignore=v1/sample/test_logprobs.py --ignore=v1/sample/test_logprobs_e2e.py
|
pytest -v -s v1/sample --ignore=v1/sample/test_logprobs.py --ignore=v1/sample/test_logprobs_e2e.py
|
||||||
pytest -v -s v1/worker --ignore=v1/worker/test_gpu_model_runner.py
|
pytest -v -s v1/worker --ignore=v1/worker/test_gpu_model_runner.py
|
||||||
pytest -v -s v1/structured_output
|
pytest -v -s v1/structured_output
|
||||||
pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_eagle.py --ignore=v1/spec_decode/test_tree_attention.py
|
pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_tree_attention.py
|
||||||
pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_shared_storage_connector.py
|
pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_shared_storage_connector.py
|
||||||
pytest -v -s v1/test_metrics
|
pytest -v -s v1/test_metrics
|
||||||
pytest -v -s v1/test_serial_utils.py
|
pytest -v -s v1/test_serial_utils.py
|
||||||
|
|||||||
@ -1143,6 +1143,8 @@ def get_attn_backend_list_based_on_platform() -> list[str]:
|
|||||||
print("Skip FLASH_ATTN on ROCm as aiter is not installed")
|
print("Skip FLASH_ATTN on ROCm as aiter is not installed")
|
||||||
|
|
||||||
return attn_backend_list
|
return attn_backend_list
|
||||||
|
elif current_platform.is_xpu():
|
||||||
|
return ["FLASH_ATTN", "TRITON_ATTN"]
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported platform")
|
raise ValueError("Unsupported platform")
|
||||||
|
|
||||||
|
|||||||
@ -72,12 +72,13 @@ class EagleProposer:
|
|||||||
|
|
||||||
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
|
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
|
||||||
|
|
||||||
self.use_cuda_graph = (self.vllm_config.compilation_config.level
|
self.use_cuda_graph = (not current_platform.is_xpu()
|
||||||
|
and self.vllm_config.compilation_config.level
|
||||||
== CompilationLevel.PIECEWISE and
|
== CompilationLevel.PIECEWISE and
|
||||||
not self.vllm_config.model_config.enforce_eager)
|
not self.vllm_config.model_config.enforce_eager)
|
||||||
self.cudagraph_batch_sizes = list(
|
self.cudagraph_batch_sizes = list(
|
||||||
reversed(
|
reversed(self.vllm_config.compilation_config.
|
||||||
self.vllm_config.compilation_config.cudagraph_capture_sizes))
|
cudagraph_capture_sizes)) if self.use_cuda_graph else []
|
||||||
|
|
||||||
# persistent buffers for cuda graph
|
# persistent buffers for cuda graph
|
||||||
self.input_ids = torch.zeros(self.max_num_tokens,
|
self.input_ids = torch.zeros(self.max_num_tokens,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user