mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-27 02:41:19 +08:00
[XPU]fix cuda event used in XPU model runner (#23708)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
parent
69244e67e6
commit
6446677839
@ -1,5 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from contextlib import contextmanager
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -22,7 +23,8 @@ class XPUModelRunner(GPUModelRunner):
|
|||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
):
|
):
|
||||||
super().__init__(vllm_config, device)
|
with _torch_cuda_wrapper():
|
||||||
|
super().__init__(vllm_config, device)
|
||||||
# FIXME: To be verified.
|
# FIXME: To be verified.
|
||||||
self.cascade_attn_enabled = False
|
self.cascade_attn_enabled = False
|
||||||
|
|
||||||
@ -31,3 +33,21 @@ class XPUModelRunner(GPUModelRunner):
|
|||||||
|
|
||||||
def _sync_device(self) -> None:
|
def _sync_device(self) -> None:
|
||||||
torch.xpu.synchronize()
|
torch.xpu.synchronize()
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _torch_cuda_wrapper():
|
||||||
|
|
||||||
|
class _EventPlaceholder:
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
|
self.record = lambda: None
|
||||||
|
self.synchronize = lambda: None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# replace cuda Event with xpu Event, this should work by default
|
||||||
|
torch.cuda.Event = torch.xpu.Event
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
# if anything goes wrong, just patch it with a placeholder
|
||||||
|
torch.cuda.Event = _EventPlaceholder
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user