Fix gpt oss weight loading with EP + bf16 (#28765)

Signed-off-by: ashors1 <ashors@nvidia.com>
This commit is contained in:
Anna Shors 2025-11-16 05:12:45 -08:00 committed by GitHub
parent 3bc1175798
commit 8d259fad6c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -494,8 +494,8 @@ class GptOssModel(nn.Module):
def _load_weights_other(
self,
ep_rank_start: int,
ep_rank_end: int,
ep_rank_start: int,
heads_per_rank: int,
head_start: int,
weights: Iterable[tuple[str, torch.Tensor]],