[XPU] Fix xpu model runner call torch.cuda APIs (#25011)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Kunshang Ji 2025-09-17 14:45:25 +08:00 committed by GitHub
parent 43a62c51be
commit dd39baf717
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -45,8 +45,12 @@ def _torch_cuda_wrapper():
self.synchronize = lambda: None
try:
# replace cuda Event with xpu Event, this should work by default
# replace cuda APIs with xpu APIs, this should work by default
torch.cuda.Event = torch.xpu.Event
torch.cuda.Stream = torch.xpu.Stream
torch.cuda.default_stream = torch.xpu.current_stream
torch.cuda.current_stream = torch.xpu.current_stream
torch.cuda.stream = torch.xpu.stream
yield
finally:
# if anything goes wrong, just patch it with a placeholder