mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-04 21:15:40 +08:00
[Bugfix][ROCm] Using device_type because on ROCm the API is still torch.cuda (#17601)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
parent
c8386fa61d
commit
a92842454c
@ -406,12 +406,12 @@ class Platform:
|
|||||||
"""Raises if this request is unsupported on this platform"""
|
"""Raises if this request is unsupported on this platform"""
|
||||||
|
|
||||||
def __getattr__(self, key: str):
|
def __getattr__(self, key: str):
|
||||||
device = getattr(torch, self.device_name, None)
|
device = getattr(torch, self.device_type, None)
|
||||||
if device is not None and hasattr(device, key):
|
if device is not None and hasattr(device, key):
|
||||||
return getattr(device, key)
|
return getattr(device, key)
|
||||||
else:
|
else:
|
||||||
logger.warning("Current platform %s does not have '%s'" \
|
logger.warning("Current platform %s does not have '%s'" \
|
||||||
" attribute.", self.device_name, key)
|
" attribute.", self.device_type, key)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user