diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index cb83f43a2a4e2..698c59d49fe06 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -440,6 +440,7 @@ def initialize_dummy_weights( model: torch.nn.Module, low: float = -1e-3, high: float = 1e-3, + seed: int = 1234, ) -> None: """Initialize model weights with random values. @@ -447,17 +448,25 @@ def initialize_dummy_weights( measurements. Additionally, the model weights should not cause NaNs in the forward pass. We empirically found that initializing the weights with values between -1e-3 and 1e-3 works well for most models. + + We use per-parameter random seed, so that dummy weights are consistent, + even if the model is partitioned across multiple devices. When the seed + is fixed, the random values generated by this function only depends on + the parameter's number of elements and its data type. """ for param in model.state_dict().values(): if torch.is_floating_point(param): + generator = torch.Generator(device=param.data.device) + generator.manual_seed(seed) if torch.finfo(param.data.dtype).bits < 16: # uniform_ doesn't support < 16-bit datatypes (FP8) dtype = param.data.dtype tmp_param = param.data.to(torch.float16) - tmp_param = tmp_param.uniform_(low, high).to(dtype) + tmp_param = tmp_param.uniform_(low, high, + generator=generator).to(dtype) param.data.copy_(tmp_param) else: - param.uniform_(low, high) + param.uniform_(low, high, generator=generator) def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: