mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-02 17:24:36 +08:00
[TPU][Core]Make load weight exceed hbm error more instructive for customers (#20644)
Signed-off-by: Chenyaaang <chenyangli@google.com>
This commit is contained in:
parent
ffbcc9e757
commit
fdfd409f8f
@ -1128,26 +1128,33 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
"vllm.model_executor.layers.vocab_parallel_embedding."
|
"vllm.model_executor.layers.vocab_parallel_embedding."
|
||||||
"get_tensor_model_parallel_rank",
|
"get_tensor_model_parallel_rank",
|
||||||
return_value=xm_tp_rank):
|
return_value=xm_tp_rank):
|
||||||
if self.use_spmd:
|
try:
|
||||||
tpu_loader = TPUModelLoader(
|
if self.use_spmd:
|
||||||
load_config=self.vllm_config.load_config)
|
tpu_loader = TPUModelLoader(
|
||||||
model = tpu_loader.load_model(
|
load_config=self.vllm_config.load_config)
|
||||||
vllm_config=self.vllm_config,
|
model = tpu_loader.load_model(
|
||||||
model_config=self.vllm_config.model_config,
|
|
||||||
mesh=self.mesh)
|
|
||||||
else:
|
|
||||||
# model = get_model(vllm_config=self.vllm_config)
|
|
||||||
model_loader = get_model_loader(self.load_config)
|
|
||||||
if not hasattr(self, "model"):
|
|
||||||
logger.info("Loading model from scratch...")
|
|
||||||
model = model_loader.load_model(
|
|
||||||
vllm_config=self.vllm_config,
|
vllm_config=self.vllm_config,
|
||||||
model_config=self.model_config)
|
model_config=self.vllm_config.model_config,
|
||||||
|
mesh=self.mesh)
|
||||||
else:
|
else:
|
||||||
logger.info("Model was already initialized. \
|
model_loader = get_model_loader(self.load_config)
|
||||||
Loading weights inplace...")
|
if not hasattr(self, "model"):
|
||||||
model_loader.load_weights(self.model,
|
logger.info("Loading model from scratch...")
|
||||||
model_config=self.model_config)
|
model = model_loader.load_model(
|
||||||
|
vllm_config=self.vllm_config,
|
||||||
|
model_config=self.model_config)
|
||||||
|
else:
|
||||||
|
logger.info("Model was already initialized. \
|
||||||
|
Loading weights inplace...")
|
||||||
|
model_loader.load_weights(
|
||||||
|
self.model, model_config=self.model_config)
|
||||||
|
except RuntimeError as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Unable to load model, a likely reason is the model is "
|
||||||
|
"too large for the current device's HBM memory. "
|
||||||
|
"Consider switching to a smaller model "
|
||||||
|
"or sharding the weights on more chips. "
|
||||||
|
f"See the detailed error: {e}") from e
|
||||||
if self.lora_config is not None:
|
if self.lora_config is not None:
|
||||||
model = self.load_lora_model(model, self.model_config,
|
model = self.load_lora_model(model, self.model_config,
|
||||||
self.scheduler_config,
|
self.scheduler_config,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user