diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index a7475941c1278..1bb592f492ef2 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -658,8 +658,21 @@ def initialize_dummy_weights( for param in model.state_dict().values(): if torch.is_floating_point(param): if current_platform.is_tpu(): - # XLA device does not support torch.Generator() - param.uniform_(low, high) + generator = torch.Generator(device="cpu") + generator.manual_seed(seed) + # Note: The param.uniform_ function cannot be used in this + # context because it demands more TPU HBM than directly copying + # from a CPU tensor. + # Note: We avoid using torch.rank_like as it doesn't currently + # support the generator argument. + param.copy_((high - low) * + torch.rand(*param.shape, + generator=generator, + dtype=param.dtype, + layout=param.layout, + requires_grad=param.requires_grad, + device="cpu") + low) + torch._sync(param) continue generator = torch.Generator(device=param.data.device)