mirror of
https://git.datalinker.icu/deepseek-ai/DeepSeek-V3.git
synced 2025-12-10 05:15:02 +08:00
Merge abaadd9b3e746fc91fcd19dbb79ff4ed615580e3 into 9b4e9788e4a3a731f7567338ed15d3ec549ce03b
This commit is contained in:
commit
982c69ea41
@ -117,17 +117,24 @@ class ParallelEmbedding(nn.Module):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If `world_size` is not defined.
|
ValueError: If `world_size` is not defined.
|
||||||
"""
|
"""
|
||||||
if world_size > 1:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.world_size < 1:
|
||||||
|
raise ValueError("world_size must be >= 1")
|
||||||
|
|
||||||
|
if self.world_size == 1:
|
||||||
|
return F.embedding(x, self.weight)
|
||||||
|
|
||||||
|
# For world_size > 1
|
||||||
mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
|
mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
|
||||||
x = x - self.vocab_start_idx
|
x = x - self.vocab_start_idx
|
||||||
x[mask] = 0
|
x[mask] = 0
|
||||||
|
|
||||||
y = F.embedding(x, self.weight)
|
y = F.embedding(x, self.weight)
|
||||||
if world_size > 1:
|
|
||||||
y[mask] = 0
|
y[mask] = 0
|
||||||
|
|
||||||
dist.all_reduce(y)
|
dist.all_reduce(y)
|
||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, scale_fmt: Optional[str] = None) -> torch.Tensor:
|
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, scale_fmt: Optional[str] = None) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Applies a linear transformation to the incoming data: y = xA^T + b.
|
Applies a linear transformation to the incoming data: y = xA^T + b.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user