Merge abaadd9b3e746fc91fcd19dbb79ff4ed615580e3 into 9b4e9788e4a3a731f7567338ed15d3ec549ce03b

This commit is contained in:
Saro 2025-09-01 14:01:26 +00:00 committed by GitHub
commit 982c69ea41
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -117,16 +117,23 @@ class ParallelEmbedding(nn.Module):
Raises:
ValueError: If `world_size` is not defined.
"""
if world_size > 1:
mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
x = x - self.vocab_start_idx
x[mask] = 0
y = F.embedding(x, self.weight)
if world_size > 1:
y[mask] = 0
dist.all_reduce(y)
return y
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.world_size < 1:
raise ValueError("world_size must be >= 1")
if self.world_size == 1:
return F.embedding(x, self.weight)
# For world_size > 1
mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
x = x - self.vocab_start_idx
x[mask] = 0
y = F.embedding(x, self.weight)
y[mask] = 0
dist.all_reduce(y)
return y
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, scale_fmt: Optional[str] = None) -> torch.Tensor:
"""