diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py index a9f56d5d8f7ac..c8bff8b7c80b6 100644 --- a/vllm/usage/usage_lib.py +++ b/vllm/usage/usage_lib.py @@ -176,6 +176,32 @@ class UsageMessage: self._report_usage_once(model_architecture, usage_context, extra_kvs) self._report_continuous_usage() + def _report_tpu_inference_usage(self) -> bool: + try: + from tpu_inference import tpu_info, utils + + self.gpu_count = tpu_info.get_num_chips() + self.gpu_type = tpu_info.get_tpu_type() + self.gpu_memory_per_device = utils.get_device_hbm_limit() + self.cuda_runtime = "tpu_inference" + return True + except Exception: + return False + + def _report_torch_xla_usage(self) -> bool: + try: + import torch_xla + + self.gpu_count = torch_xla.runtime.world_size() + self.gpu_type = torch_xla.tpu.get_tpu_type() + self.gpu_memory_per_device = torch_xla.core.xla_model.get_memory_info()[ + "bytes_limit" + ] + self.cuda_runtime = "torch_xla" + return True + except Exception: + return False + def _report_usage_once( self, model_architecture: str, @@ -192,16 +218,10 @@ class UsageMessage: ) if current_platform.is_cuda(): self.cuda_runtime = torch.version.cuda - if current_platform.is_tpu(): - try: - import torch_xla - - self.gpu_count = torch_xla.runtime.world_size() - self.gpu_type = torch_xla.tpu.get_tpu_type() - self.gpu_memory_per_device = torch_xla.core.xla_model.get_memory_info()[ - "bytes_limit" - ] - except Exception: + if current_platform.is_tpu(): # noqa: SIM102 + if (not self._report_tpu_inference_usage()) and ( + not self._report_torch_xla_usage() + ): logger.exception("Failed to collect TPU information") self.provider = _detect_cloud_provider() self.architecture = platform.machine()