[misc][distributed] add seed to dummy weights (#6491)

This commit is contained in:
youkaichao 2024-07-16 19:16:34 -07:00 committed by GitHub
parent 7f62077af5
commit ce37be7ba0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -440,6 +440,7 @@ def initialize_dummy_weights(
model: torch.nn.Module,
low: float = -1e-3,
high: float = 1e-3,
seed: int = 1234,
) -> None:
"""Initialize model weights with random values.
@ -447,17 +448,25 @@ def initialize_dummy_weights(
measurements. Additionally, the model weights should not cause NaNs in the
forward pass. We empirically found that initializing the weights with
values between -1e-3 and 1e-3 works well for most models.
We use per-parameter random seed, so that dummy weights are consistent,
even if the model is partitioned across multiple devices. When the seed
is fixed, the random values generated by this function only depends on
the parameter's number of elements and its data type.
"""
for param in model.state_dict().values():
if torch.is_floating_point(param):
generator = torch.Generator(device=param.data.device)
generator.manual_seed(seed)
if torch.finfo(param.data.dtype).bits < 16:
# uniform_ doesn't support < 16-bit datatypes (FP8)
dtype = param.data.dtype
tmp_param = param.data.to(torch.float16)
tmp_param = tmp_param.uniform_(low, high).to(dtype)
tmp_param = tmp_param.uniform_(low, high,
generator=generator).to(dtype)
param.data.copy_(tmp_param)
else:
param.uniform_(low, high)
param.uniform_(low, high, generator=generator)
def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: