mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-19 06:15:02 +08:00
[TPU] Fix dummy loading OOM (#16372)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
parent
a564797151
commit
1621b25288
@ -658,8 +658,21 @@ def initialize_dummy_weights(
|
|||||||
for param in model.state_dict().values():
|
for param in model.state_dict().values():
|
||||||
if torch.is_floating_point(param):
|
if torch.is_floating_point(param):
|
||||||
if current_platform.is_tpu():
|
if current_platform.is_tpu():
|
||||||
# XLA device does not support torch.Generator()
|
generator = torch.Generator(device="cpu")
|
||||||
param.uniform_(low, high)
|
generator.manual_seed(seed)
|
||||||
|
# Note: The param.uniform_ function cannot be used in this
|
||||||
|
# context because it demands more TPU HBM than directly copying
|
||||||
|
# from a CPU tensor.
|
||||||
|
# Note: We avoid using torch.rank_like as it doesn't currently
|
||||||
|
# support the generator argument.
|
||||||
|
param.copy_((high - low) *
|
||||||
|
torch.rand(*param.shape,
|
||||||
|
generator=generator,
|
||||||
|
dtype=param.dtype,
|
||||||
|
layout=param.layout,
|
||||||
|
requires_grad=param.requires_grad,
|
||||||
|
device="cpu") + low)
|
||||||
|
torch._sync(param)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
generator = torch.Generator(device=param.data.device)
|
generator = torch.Generator(device=param.data.device)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user