mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-27 00:35:19 +08:00
[misc][distributed] add seed to dummy weights (#6491)
This commit is contained in:
parent
7f62077af5
commit
ce37be7ba0
@ -440,6 +440,7 @@ def initialize_dummy_weights(
|
||||
model: torch.nn.Module,
|
||||
low: float = -1e-3,
|
||||
high: float = 1e-3,
|
||||
seed: int = 1234,
|
||||
) -> None:
|
||||
"""Initialize model weights with random values.
|
||||
|
||||
@ -447,17 +448,25 @@ def initialize_dummy_weights(
|
||||
measurements. Additionally, the model weights should not cause NaNs in the
|
||||
forward pass. We empirically found that initializing the weights with
|
||||
values between -1e-3 and 1e-3 works well for most models.
|
||||
|
||||
We use per-parameter random seed, so that dummy weights are consistent,
|
||||
even if the model is partitioned across multiple devices. When the seed
|
||||
is fixed, the random values generated by this function only depends on
|
||||
the parameter's number of elements and its data type.
|
||||
"""
|
||||
for param in model.state_dict().values():
|
||||
if torch.is_floating_point(param):
|
||||
generator = torch.Generator(device=param.data.device)
|
||||
generator.manual_seed(seed)
|
||||
if torch.finfo(param.data.dtype).bits < 16:
|
||||
# uniform_ doesn't support < 16-bit datatypes (FP8)
|
||||
dtype = param.data.dtype
|
||||
tmp_param = param.data.to(torch.float16)
|
||||
tmp_param = tmp_param.uniform_(low, high).to(dtype)
|
||||
tmp_param = tmp_param.uniform_(low, high,
|
||||
generator=generator).to(dtype)
|
||||
param.data.copy_(tmp_param)
|
||||
else:
|
||||
param.uniform_(low, high)
|
||||
param.uniform_(low, high, generator=generator)
|
||||
|
||||
|
||||
def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user