mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-27 11:15:15 +08:00
[TPU][Core]Make load weight exceed hbm error more instructive for customers (#20644)
Signed-off-by: Chenyaaang <chenyangli@google.com>
This commit is contained in:
parent
ffbcc9e757
commit
fdfd409f8f
@ -1128,26 +1128,33 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
"vllm.model_executor.layers.vocab_parallel_embedding."
|
||||
"get_tensor_model_parallel_rank",
|
||||
return_value=xm_tp_rank):
|
||||
if self.use_spmd:
|
||||
tpu_loader = TPUModelLoader(
|
||||
load_config=self.vllm_config.load_config)
|
||||
model = tpu_loader.load_model(
|
||||
vllm_config=self.vllm_config,
|
||||
model_config=self.vllm_config.model_config,
|
||||
mesh=self.mesh)
|
||||
else:
|
||||
# model = get_model(vllm_config=self.vllm_config)
|
||||
model_loader = get_model_loader(self.load_config)
|
||||
if not hasattr(self, "model"):
|
||||
logger.info("Loading model from scratch...")
|
||||
model = model_loader.load_model(
|
||||
try:
|
||||
if self.use_spmd:
|
||||
tpu_loader = TPUModelLoader(
|
||||
load_config=self.vllm_config.load_config)
|
||||
model = tpu_loader.load_model(
|
||||
vllm_config=self.vllm_config,
|
||||
model_config=self.model_config)
|
||||
model_config=self.vllm_config.model_config,
|
||||
mesh=self.mesh)
|
||||
else:
|
||||
logger.info("Model was already initialized. \
|
||||
Loading weights inplace...")
|
||||
model_loader.load_weights(self.model,
|
||||
model_config=self.model_config)
|
||||
model_loader = get_model_loader(self.load_config)
|
||||
if not hasattr(self, "model"):
|
||||
logger.info("Loading model from scratch...")
|
||||
model = model_loader.load_model(
|
||||
vllm_config=self.vllm_config,
|
||||
model_config=self.model_config)
|
||||
else:
|
||||
logger.info("Model was already initialized. \
|
||||
Loading weights inplace...")
|
||||
model_loader.load_weights(
|
||||
self.model, model_config=self.model_config)
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError(
|
||||
f"Unable to load model, a likely reason is the model is "
|
||||
"too large for the current device's HBM memory. "
|
||||
"Consider switching to a smaller model "
|
||||
"or sharding the weights on more chips. "
|
||||
f"See the detailed error: {e}") from e
|
||||
if self.lora_config is not None:
|
||||
model = self.load_lora_model(model, self.model_config,
|
||||
self.scheduler_config,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user