diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index f5f26d8fff98a..5af052e685117 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1128,26 +1128,33 @@ class TPUModelRunner(LoRAModelRunnerMixin): "vllm.model_executor.layers.vocab_parallel_embedding." "get_tensor_model_parallel_rank", return_value=xm_tp_rank): - if self.use_spmd: - tpu_loader = TPUModelLoader( - load_config=self.vllm_config.load_config) - model = tpu_loader.load_model( - vllm_config=self.vllm_config, - model_config=self.vllm_config.model_config, - mesh=self.mesh) - else: - # model = get_model(vllm_config=self.vllm_config) - model_loader = get_model_loader(self.load_config) - if not hasattr(self, "model"): - logger.info("Loading model from scratch...") - model = model_loader.load_model( + try: + if self.use_spmd: + tpu_loader = TPUModelLoader( + load_config=self.vllm_config.load_config) + model = tpu_loader.load_model( vllm_config=self.vllm_config, - model_config=self.model_config) + model_config=self.vllm_config.model_config, + mesh=self.mesh) else: - logger.info("Model was already initialized. \ - Loading weights inplace...") - model_loader.load_weights(self.model, - model_config=self.model_config) + model_loader = get_model_loader(self.load_config) + if not hasattr(self, "model"): + logger.info("Loading model from scratch...") + model = model_loader.load_model( + vllm_config=self.vllm_config, + model_config=self.model_config) + else: + logger.info("Model was already initialized. \ + Loading weights inplace...") + model_loader.load_weights( + self.model, model_config=self.model_config) + except RuntimeError as e: + raise RuntimeError( + f"Unable to load model, a likely reason is the model is " + "too large for the current device's HBM memory. " + "Consider switching to a smaller model " + "or sharding the weights on more chips. " + f"See the detailed error: {e}") from e if self.lora_config is not None: model = self.load_lora_model(model, self.model_config, self.scheduler_config,