diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index c5555aba1a3e3..6a78e00a90495 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -406,12 +406,12 @@ class Platform: """Raises if this request is unsupported on this platform""" def __getattr__(self, key: str): - device = getattr(torch, self.device_name, None) + device = getattr(torch, self.device_type, None) if device is not None and hasattr(device, key): return getattr(device, key) else: logger.warning("Current platform %s does not have '%s'" \ - " attribute.", self.device_name, key) + " attribute.", self.device_type, key) return None @classmethod