[Bugfix][ROCm] Using device_type because on ROCm the API is still torch.cuda (#17601)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
Gregory Shtrasberg 2025-05-03 01:25:47 -04:00 committed by GitHub
parent c8386fa61d
commit a92842454c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -406,12 +406,12 @@ class Platform:
"""Raises if this request is unsupported on this platform"""
def __getattr__(self, key: str):
device = getattr(torch, self.device_name, None)
device = getattr(torch, self.device_type, None)
if device is not None and hasattr(device, key):
return getattr(device, key)
else:
logger.warning("Current platform %s does not have '%s'" \
" attribute.", self.device_name, key)
" attribute.", self.device_type, key)
return None
@classmethod