mirror of
https://git.datalinker.icu/deepseek-ai/DeepSeek-V3.git
synced 2026-01-23 08:44:23 +08:00
Update model.py
This commit is contained in:
parent
9b4e9788e4
commit
dd56d2cab8
@ -1,4 +1,5 @@
|
||||
import math
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, Optional, Literal
|
||||
|
||||
@ -9,13 +10,78 @@ import torch.distributed as dist
|
||||
|
||||
from kernel import act_quant, weight_dequant, fp8_gemm
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global configuration
|
||||
world_size = 1
|
||||
rank = 0
|
||||
block_size = 128
|
||||
gemm_impl: Literal["bf16", "fp8"] = "bf16"
|
||||
attn_impl: Literal["naive", "absorb"] = "absorb"
|
||||
|
||||
# Distributed operation timeout (seconds)
|
||||
DIST_TIMEOUT = 60.0
|
||||
|
||||
# Custom exceptions for distributed operations
|
||||
class DistributedOperationError(Exception):
|
||||
"""Base exception for distributed operation errors"""
|
||||
pass
|
||||
|
||||
class TensorShapeMismatchError(DistributedOperationError):
|
||||
"""Raised when tensor shapes don't match in distributed operations"""
|
||||
pass
|
||||
|
||||
class WorldSizeMismatchError(DistributedOperationError):
|
||||
"""Raised when world size doesn't match expectations"""
|
||||
pass
|
||||
|
||||
class RankMismatchError(DistributedOperationError):
|
||||
"""Raised when rank doesn't match expectations"""
|
||||
pass
|
||||
|
||||
class DistributedInitializationError(DistributedOperationError):
|
||||
"""Raised when distributed initialization fails"""
|
||||
pass
|
||||
|
||||
|
||||
def safe_dist_operation(func):
|
||||
"""
|
||||
Decorator for safe distributed operations with error handling.
|
||||
|
||||
Args:
|
||||
func: Function to wrap with error handling
|
||||
|
||||
Returns:
|
||||
Wrapped function with error handling
|
||||
"""
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
# Check if distributed environment is properly initialized
|
||||
if world_size > 1:
|
||||
if not dist.is_initialized():
|
||||
raise DistributedInitializationError(
|
||||
"Distributed environment not initialized but world_size > 1. "
|
||||
"Call dist.init_process_group() before using distributed operations."
|
||||
)
|
||||
if rank >= world_size:
|
||||
raise RankMismatchError(
|
||||
f"Rank {rank} is out of bounds for world_size {world_size}"
|
||||
)
|
||||
|
||||
return func(*args, **kwargs)
|
||||
except dist.DistBackendError as e:
|
||||
logger.error(f"Distributed backend error in {func.__name__}: {e}")
|
||||
raise DistributedOperationError(f"Backend error in {func.__name__}: {e}") from e
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Runtime error in distributed operation {func.__name__}: {e}")
|
||||
raise DistributedOperationError(f"Runtime error in {func.__name__}: {e}") from e
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in {func.__name__}: {e}")
|
||||
raise
|
||||
return wrapper
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
"""
|
||||
@ -86,6 +152,25 @@ class ModelArgs:
|
||||
mscale: float = 1.
|
||||
|
||||
|
||||
def validate_distributed_params(param_name: str, param_value: int, divisor: int):
|
||||
"""
|
||||
Validate that a parameter is divisible by world_size.
|
||||
|
||||
Args:
|
||||
param_name: Name of the parameter for error messages
|
||||
param_value: Value of the parameter to validate
|
||||
divisor: Expected divisor (usually world_size)
|
||||
|
||||
Raises:
|
||||
WorldSizeMismatchError: If parameter is not divisible by divisor
|
||||
"""
|
||||
if divisor > 1 and param_value % divisor != 0:
|
||||
raise WorldSizeMismatchError(
|
||||
f"{param_name} ({param_value}) must be divisible by world_size ({divisor}) "
|
||||
f"for distributed training. Remainder: {param_value % divisor}"
|
||||
)
|
||||
|
||||
|
||||
class ParallelEmbedding(nn.Module):
|
||||
"""
|
||||
Embedding layer with parallelism support across distributed processes.
|
||||
@ -98,12 +183,25 @@ class ParallelEmbedding(nn.Module):
|
||||
super().__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.dim = dim
|
||||
assert vocab_size % world_size == 0, f"Vocabulary size must be divisible by world size (world_size={world_size})"
|
||||
|
||||
# Validate distributed parameters
|
||||
validate_distributed_params("vocab_size", vocab_size, world_size)
|
||||
|
||||
self.part_vocab_size = (vocab_size // world_size)
|
||||
self.vocab_start_idx = rank * self.part_vocab_size
|
||||
self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
|
||||
|
||||
# Validate indices
|
||||
if self.vocab_start_idx >= vocab_size or self.vocab_end_idx > vocab_size:
|
||||
raise RankMismatchError(
|
||||
f"Calculated vocabulary indices out of bounds for rank {rank}. "
|
||||
f"Start: {self.vocab_start_idx}, End: {self.vocab_end_idx}, "
|
||||
f"Vocab Size: {vocab_size}"
|
||||
)
|
||||
|
||||
self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
|
||||
|
||||
@safe_dist_operation
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for parallel embedding layer.
|
||||
@ -115,20 +213,48 @@ class ParallelEmbedding(nn.Module):
|
||||
torch.Tensor: Embedded representations.
|
||||
|
||||
Raises:
|
||||
ValueError: If `world_size` is not defined.
|
||||
ValueError: If tensor contains indices outside valid range
|
||||
DistributedOperationError: If distributed operation fails
|
||||
"""
|
||||
if world_size > 1:
|
||||
# Check input tensor validity
|
||||
if x.numel() == 0:
|
||||
raise ValueError("Input tensor is empty")
|
||||
|
||||
# Create mask for indices outside our partition
|
||||
mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
|
||||
|
||||
# Shift indices for local embedding lookup
|
||||
x = x - self.vocab_start_idx
|
||||
x[mask] = 0
|
||||
y = F.embedding(x, self.weight)
|
||||
if world_size > 1:
|
||||
|
||||
# Handle out-of-bounds indices
|
||||
out_of_bounds = (x < 0) | (x >= self.part_vocab_size)
|
||||
if out_of_bounds.any():
|
||||
# Zero out invalid indices
|
||||
x[out_of_bounds] = 0
|
||||
mask = mask | out_of_bounds
|
||||
|
||||
# Perform embedding
|
||||
y = F.embedding(x, self.weight)
|
||||
|
||||
# Zero out masked positions
|
||||
y[mask] = 0
|
||||
dist.all_reduce(y)
|
||||
|
||||
# All-reduce across all processes
|
||||
try:
|
||||
dist.all_reduce(y, timeout=DIST_TIMEOUT)
|
||||
except Exception as e:
|
||||
logger.error(f"All-reduce failed in ParallelEmbedding: {e}")
|
||||
raise DistributedOperationError(f"All-reduce failed: {e}") from e
|
||||
else:
|
||||
# Single GPU case - use standard embedding
|
||||
y = F.embedding(x, self.weight)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, scale_fmt: Optional[str] = None) -> torch.Tensor:
|
||||
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None,
|
||||
scale_fmt: Optional[str] = None) -> torch.Tensor:
|
||||
"""
|
||||
Applies a linear transformation to the incoming data: y = xA^T + b.
|
||||
This function supports specialized implementations based on quantization
|
||||
@ -139,23 +265,34 @@ def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] =
|
||||
weight (torch.Tensor): The weight tensor. It may be quantized and
|
||||
requires dequantization for certain cases.
|
||||
bias (Optional[torch.Tensor]): The bias tensor to be added. Default is None.
|
||||
scale_fmt (Optional[str]): Scale format for quantization.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The result of the linear transformation, which may involve
|
||||
quantization-aware computations depending on the input parameters.
|
||||
torch.Tensor: The result of the linear transformation.
|
||||
|
||||
Notes:
|
||||
- If `weight` is quantized (e.g., `element_size() == 1`), a dequantized version
|
||||
is used for computation.
|
||||
- If `gemm_impl == "bf16"`, dequantization and a `bf16` GEMM operation are applied.
|
||||
- For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation.
|
||||
Raises:
|
||||
ValueError: If tensor shapes are incompatible
|
||||
DistributedOperationError: If distributed operation fails
|
||||
"""
|
||||
# Validate input dimensions
|
||||
if x.dim() != 2:
|
||||
raise ValueError(f"Input tensor must be 2D, got shape {x.shape}")
|
||||
|
||||
if x.size(-1) != weight.size(-1):
|
||||
raise ValueError(
|
||||
f"Input feature dimension ({x.size(-1)}) must match "
|
||||
f"weight feature dimension ({weight.size(-1)})"
|
||||
)
|
||||
|
||||
if weight.element_size() > 1:
|
||||
# Full precision weights
|
||||
return F.linear(x, weight, bias)
|
||||
elif gemm_impl == "bf16":
|
||||
# BF16 implementation with dequantization
|
||||
weight = weight_dequant(weight, weight.scale)
|
||||
return F.linear(x, weight, bias)
|
||||
else:
|
||||
# FP8 implementation
|
||||
x, scale = act_quant(x, block_size, scale_fmt)
|
||||
y = fp8_gemm(x, scale, weight, weight.scale)
|
||||
if bias is not None:
|
||||
@ -180,13 +317,23 @@ class Linear(nn.Module):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
|
||||
# Validate parameters
|
||||
if in_features <= 0 or out_features <= 0:
|
||||
raise ValueError(f"Invalid dimensions: in_features={in_features}, out_features={out_features}")
|
||||
|
||||
self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype or Linear.dtype))
|
||||
|
||||
if self.weight.element_size() == 1:
|
||||
# Quantized weights - create scale parameters
|
||||
scale_out_features = (out_features + block_size - 1) // block_size
|
||||
scale_in_features = (in_features + block_size - 1) // block_size
|
||||
self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32))
|
||||
self.weight.scale = self.scale = nn.Parameter(
|
||||
torch.empty(scale_out_features, scale_in_features, dtype=torch.float32)
|
||||
)
|
||||
else:
|
||||
self.register_parameter("scale", None)
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.empty(out_features))
|
||||
else:
|
||||
@ -216,10 +363,16 @@ class ColumnParallelLinear(Linear):
|
||||
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
||||
"""
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
||||
assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})"
|
||||
# Validate distributed parameters
|
||||
validate_distributed_params("out_features", out_features, world_size)
|
||||
|
||||
self.part_out_features = out_features // world_size
|
||||
super().__init__(in_features, self.part_out_features, bias, dtype)
|
||||
|
||||
# Store full dimensions for validation
|
||||
self.full_out_features = out_features
|
||||
|
||||
@safe_dist_operation
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for column parallel linear layer.
|
||||
@ -230,6 +383,13 @@ class ColumnParallelLinear(Linear):
|
||||
Returns:
|
||||
torch.Tensor: Transformed tensor with column-parallel computation.
|
||||
"""
|
||||
# Validate input
|
||||
if x.size(-1) != self.in_features:
|
||||
raise TensorShapeMismatchError(
|
||||
f"Input feature dimension ({x.size(-1)}) does not match "
|
||||
f"layer input dimension ({self.in_features})"
|
||||
)
|
||||
|
||||
y = linear(x, self.weight, self.bias)
|
||||
return y
|
||||
|
||||
@ -245,10 +405,16 @@ class RowParallelLinear(Linear):
|
||||
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
||||
"""
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
||||
assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})"
|
||||
# Validate distributed parameters
|
||||
validate_distributed_params("in_features", in_features, world_size)
|
||||
|
||||
self.part_in_features = in_features // world_size
|
||||
super().__init__(self.part_in_features, out_features, bias, dtype)
|
||||
|
||||
# Store full dimensions for validation
|
||||
self.full_in_features = in_features
|
||||
|
||||
@safe_dist_operation
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for row parallel linear layer.
|
||||
@ -259,11 +425,40 @@ class RowParallelLinear(Linear):
|
||||
Returns:
|
||||
torch.Tensor: Transformed tensor with row-parallel computation.
|
||||
"""
|
||||
y = linear(x, self.weight)
|
||||
# Validate input
|
||||
if x.size(-1) != self.full_in_features:
|
||||
raise TensorShapeMismatchError(
|
||||
f"Input feature dimension ({x.size(-1)}) does not match "
|
||||
f"full layer input dimension ({self.full_in_features})"
|
||||
)
|
||||
|
||||
# Split input along feature dimension for this partition
|
||||
start_idx = rank * self.part_in_features
|
||||
end_idx = start_idx + self.part_in_features
|
||||
x_part = x[..., start_idx:end_idx]
|
||||
|
||||
# Apply linear transformation to local partition
|
||||
y = linear(x_part, self.weight)
|
||||
|
||||
# All-reduce across all processes
|
||||
if world_size > 1:
|
||||
dist.all_reduce(y)
|
||||
try:
|
||||
# Verify tensor shape consistency before all-reduce
|
||||
expected_shape = torch.Size([*x.shape[:-1], self.out_features])
|
||||
if y.shape != expected_shape:
|
||||
raise TensorShapeMismatchError(
|
||||
f"Local output shape {y.shape} doesn't match expected shape {expected_shape}"
|
||||
)
|
||||
|
||||
dist.all_reduce(y, timeout=DIST_TIMEOUT)
|
||||
except Exception as e:
|
||||
logger.error(f"All-reduce failed in RowParallelLinear: {e}")
|
||||
raise DistributedOperationError(f"All-reduce failed: {e}") from e
|
||||
|
||||
# Add bias if present
|
||||
if self.bias is not None:
|
||||
y += self.bias
|
||||
|
||||
return y
|
||||
|
||||
|
||||
@ -291,6 +486,11 @@ class RMSNorm(nn.Module):
|
||||
Returns:
|
||||
torch.Tensor: Normalized tensor with the same shape as input.
|
||||
"""
|
||||
if x.size(-1) != self.dim:
|
||||
raise ValueError(
|
||||
f"Input feature dimension ({x.size(-1)}) must match "
|
||||
f"RMSNorm dimension ({self.dim})"
|
||||
)
|
||||
return F.rms_norm(x, (self.dim,), self.weight, self.eps)
|
||||
|
||||
|
||||
@ -414,6 +614,10 @@ class MLA(nn.Module):
|
||||
self.dim = args.dim
|
||||
self.n_heads = args.n_heads
|
||||
self.n_local_heads = args.n_heads // world_size
|
||||
|
||||
# Validate distributed parameters
|
||||
validate_distributed_params("n_heads", args.n_heads, world_size)
|
||||
|
||||
self.q_lora_rank = args.q_lora_rank
|
||||
self.kv_lora_rank = args.kv_lora_rank
|
||||
self.qk_nope_head_dim = args.qk_nope_head_dim
|
||||
@ -427,10 +631,12 @@ class MLA(nn.Module):
|
||||
self.wq_a = Linear(self.dim, self.q_lora_rank)
|
||||
self.q_norm = RMSNorm(self.q_lora_rank)
|
||||
self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim)
|
||||
|
||||
self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
self.kv_norm = RMSNorm(self.kv_lora_rank)
|
||||
self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
|
||||
self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
|
||||
|
||||
self.softmax_scale = self.qk_head_dim ** -0.5
|
||||
if args.max_seq_len > args.original_seq_len:
|
||||
mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
|
||||
@ -443,6 +649,7 @@ class MLA(nn.Module):
|
||||
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
|
||||
self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)
|
||||
|
||||
@safe_dist_operation
|
||||
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
|
||||
"""
|
||||
Forward pass for the Multi-Head Latent Attention (MLA) Layer.
|
||||
@ -458,41 +665,80 @@ class MLA(nn.Module):
|
||||
"""
|
||||
bsz, seqlen, _ = x.size()
|
||||
end_pos = start_pos + seqlen
|
||||
|
||||
# Validate input dimensions
|
||||
if x.size(-1) != self.dim:
|
||||
raise TensorShapeMismatchError(
|
||||
f"Input feature dimension ({x.size(-1)}) does not match "
|
||||
f"MLA dimension ({self.dim})"
|
||||
)
|
||||
|
||||
# Validate cache indices
|
||||
if start_pos < 0 or end_pos > self.kv_cache.size(1) if attn_impl != "naive" else self.k_cache.size(1):
|
||||
raise ValueError(
|
||||
f"Cache indices out of bounds: start_pos={start_pos}, end_pos={end_pos}, "
|
||||
f"cache_size={self.kv_cache.size(1) if attn_impl != 'naive' else self.k_cache.size(1)}"
|
||||
)
|
||||
|
||||
# Compute queries
|
||||
if self.q_lora_rank == 0:
|
||||
q = self.wq(x)
|
||||
else:
|
||||
q = self.wq_b(self.q_norm(self.wq_a(x)))
|
||||
|
||||
q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
|
||||
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
q_pe = apply_rotary_emb(q_pe, freqs_cis)
|
||||
|
||||
# Compute key-value projections
|
||||
kv = self.wkv_a(x)
|
||||
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
|
||||
|
||||
if attn_impl == "naive":
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
kv = self.wkv_b(self.kv_norm(kv))
|
||||
kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
|
||||
self.k_cache[:bsz, start_pos:end_pos] = k
|
||||
self.v_cache[:bsz, start_pos:end_pos] = v
|
||||
|
||||
# Update caches
|
||||
self.k_cache[:bsz, start_pos:end_pos].copy_(k)
|
||||
self.v_cache[:bsz, start_pos:end_pos].copy_(v)
|
||||
|
||||
# Compute attention scores
|
||||
scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale
|
||||
else:
|
||||
wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size)
|
||||
wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
|
||||
q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
|
||||
self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
|
||||
self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
|
||||
|
||||
# Update caches
|
||||
self.kv_cache[:bsz, start_pos:end_pos].copy_(self.kv_norm(kv))
|
||||
self.pe_cache[:bsz, start_pos:end_pos].copy_(k_pe.squeeze(2))
|
||||
|
||||
# Compute attention scores
|
||||
scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
|
||||
torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale
|
||||
|
||||
# Apply mask if provided
|
||||
if mask is not None:
|
||||
if mask.shape != (seqlen, seqlen):
|
||||
raise TensorShapeMismatchError(
|
||||
f"Mask shape {mask.shape} does not match expected shape {(seqlen, seqlen)}"
|
||||
)
|
||||
scores += mask.unsqueeze(1)
|
||||
|
||||
# Apply softmax
|
||||
scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
|
||||
|
||||
# Compute output
|
||||
if attn_impl == "naive":
|
||||
x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
|
||||
else:
|
||||
x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
|
||||
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
|
||||
|
||||
x = self.wo(x.flatten(2))
|
||||
return x
|
||||
|
||||
@ -515,6 +761,9 @@ class MLP(nn.Module):
|
||||
inter_dim (int): Hidden layer dimensionality.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.inter_dim = inter_dim
|
||||
|
||||
self.w1 = ColumnParallelLinear(dim, inter_dim)
|
||||
self.w2 = RowParallelLinear(inter_dim, dim)
|
||||
self.w3 = ColumnParallelLinear(dim, inter_dim)
|
||||
@ -529,6 +778,13 @@ class MLP(nn.Module):
|
||||
Returns:
|
||||
torch.Tensor: Output tensor after MLP computation.
|
||||
"""
|
||||
# Validate input
|
||||
if x.size(-1) != self.dim:
|
||||
raise TensorShapeMismatchError(
|
||||
f"Input feature dimension ({x.size(-1)}) does not match "
|
||||
f"MLP dimension ({self.dim})"
|
||||
)
|
||||
|
||||
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
@ -573,27 +829,42 @@ class Gate(nn.Module):
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Routing weights and selected expert indices.
|
||||
"""
|
||||
# Validate input
|
||||
if x.size(-1) != self.dim:
|
||||
raise TensorShapeMismatchError(
|
||||
f"Input feature dimension ({x.size(-1)}) does not match "
|
||||
f"gate dimension ({self.dim})"
|
||||
)
|
||||
|
||||
scores = linear(x, self.weight)
|
||||
|
||||
if self.score_func == "softmax":
|
||||
scores = scores.softmax(dim=-1, dtype=torch.float32)
|
||||
else:
|
||||
scores = scores.sigmoid()
|
||||
|
||||
original_scores = scores
|
||||
|
||||
if self.bias is not None:
|
||||
scores = scores + self.bias
|
||||
|
||||
if self.n_groups > 1:
|
||||
scores = scores.view(x.size(0), self.n_groups, -1)
|
||||
if self.bias is None:
|
||||
group_scores = scores.amax(dim=-1)
|
||||
else:
|
||||
group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)
|
||||
|
||||
indices = group_scores.topk(self.topk_groups, dim=-1)[1]
|
||||
mask = scores.new_ones(x.size(0), self.n_groups, dtype=bool).scatter_(1, indices, False)
|
||||
scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf")).flatten(1)
|
||||
|
||||
indices = torch.topk(scores, self.topk, dim=-1)[1]
|
||||
weights = original_scores.gather(1, indices)
|
||||
|
||||
if self.score_func == "sigmoid":
|
||||
weights /= weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
weights *= self.route_scale
|
||||
return weights.type_as(x), indices
|
||||
|
||||
@ -616,6 +887,9 @@ class Expert(nn.Module):
|
||||
inter_dim (int): Hidden layer dimensionality.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.inter_dim = inter_dim
|
||||
|
||||
self.w1 = Linear(dim, inter_dim)
|
||||
self.w2 = Linear(inter_dim, dim)
|
||||
self.w3 = Linear(dim, inter_dim)
|
||||
@ -630,6 +904,13 @@ class Expert(nn.Module):
|
||||
Returns:
|
||||
torch.Tensor: Output tensor after expert computation.
|
||||
"""
|
||||
# Validate input
|
||||
if x.size(-1) != self.dim:
|
||||
raise TensorShapeMismatchError(
|
||||
f"Input feature dimension ({x.size(-1)}) does not match "
|
||||
f"expert dimension ({self.dim})"
|
||||
)
|
||||
|
||||
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
@ -655,17 +936,35 @@ class MoE(nn.Module):
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = args.dim
|
||||
assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})"
|
||||
|
||||
# Validate distributed parameters
|
||||
validate_distributed_params("n_routed_experts", args.n_routed_experts, world_size)
|
||||
|
||||
self.n_routed_experts = args.n_routed_experts
|
||||
self.n_local_experts = args.n_routed_experts // world_size
|
||||
self.n_activated_experts = args.n_activated_experts
|
||||
self.experts_start_idx = rank * self.n_local_experts
|
||||
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
|
||||
|
||||
# Validate expert indices
|
||||
if self.experts_start_idx >= args.n_routed_experts or self.experts_end_idx > args.n_routed_experts:
|
||||
raise RankMismatchError(
|
||||
f"Calculated expert indices out of bounds for rank {rank}. "
|
||||
f"Start: {self.experts_start_idx}, End: {self.experts_end_idx}, "
|
||||
f"Total Experts: {args.n_routed_experts}"
|
||||
)
|
||||
|
||||
self.gate = Gate(args)
|
||||
self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None
|
||||
for i in range(self.n_routed_experts)])
|
||||
|
||||
# Create expert modules only for local experts
|
||||
self.experts = nn.ModuleList([
|
||||
Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None
|
||||
for i in range(self.n_routed_experts)
|
||||
])
|
||||
|
||||
self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)
|
||||
|
||||
@safe_dist_operation
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for the MoE module.
|
||||
@ -676,20 +975,57 @@ class MoE(nn.Module):
|
||||
Returns:
|
||||
torch.Tensor: Output tensor after expert routing and computation.
|
||||
"""
|
||||
# Validate input
|
||||
if x.size(-1) != self.dim:
|
||||
raise TensorShapeMismatchError(
|
||||
f"Input feature dimension ({x.size(-1)}) does not match "
|
||||
f"MoE dimension ({self.dim})"
|
||||
)
|
||||
|
||||
shape = x.size()
|
||||
x = x.view(-1, self.dim)
|
||||
|
||||
weights, indices = self.gate(x)
|
||||
y = torch.zeros_like(x)
|
||||
|
||||
# Count how many tokens are routed to each expert
|
||||
counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
|
||||
|
||||
# Process local experts
|
||||
for i in range(self.experts_start_idx, self.experts_end_idx):
|
||||
if counts[i] == 0:
|
||||
continue
|
||||
|
||||
expert = self.experts[i]
|
||||
if expert is None:
|
||||
# This shouldn't happen if indices are correct
|
||||
continue
|
||||
|
||||
# Get tokens assigned to this expert
|
||||
idx, top = torch.where(indices == i)
|
||||
y[idx] += expert(x[idx]) * weights[idx, top, None]
|
||||
|
||||
# Apply expert to assigned tokens
|
||||
if len(idx) > 0:
|
||||
y[idx] += expert(x[idx]) * weights[idx, top, None]
|
||||
|
||||
# Apply shared experts
|
||||
z = self.shared_experts(x)
|
||||
|
||||
# All-reduce across all processes to combine expert outputs
|
||||
if world_size > 1:
|
||||
dist.all_reduce(y)
|
||||
try:
|
||||
# Verify tensor shape before all-reduce
|
||||
expected_shape = y.shape
|
||||
if y.shape != expected_shape:
|
||||
raise TensorShapeMismatchError(
|
||||
f"Expert output shape {y.shape} doesn't match expected shape {expected_shape}"
|
||||
)
|
||||
|
||||
dist.all_reduce(y, timeout=DIST_TIMEOUT)
|
||||
except Exception as e:
|
||||
logger.error(f"All-reduce failed in MoE: {e}")
|
||||
raise DistributedOperationError(f"All-reduce failed: {e}") from e
|
||||
|
||||
return (y + z).view(shape)
|
||||
|
||||
|
||||
@ -712,6 +1048,7 @@ class Block(nn.Module):
|
||||
args (ModelArgs): Model arguments containing block parameters.
|
||||
"""
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
self.attn = MLA(args)
|
||||
self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args)
|
||||
self.attn_norm = RMSNorm(args.dim)
|
||||
@ -730,8 +1067,12 @@ class Block(nn.Module):
|
||||
Returns:
|
||||
torch.Tensor: Output tensor after block computation.
|
||||
"""
|
||||
# Attention layer with residual connection
|
||||
x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask)
|
||||
|
||||
# Feed-forward layer with residual connection
|
||||
x = x + self.ffn(self.ffn_norm(x))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@ -755,21 +1096,50 @@ class Transformer(nn.Module):
|
||||
args (ModelArgs): Model arguments containing transformer parameters.
|
||||
"""
|
||||
global world_size, rank
|
||||
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
rank = dist.get_rank() if dist.is_initialized() else 0
|
||||
|
||||
try:
|
||||
# Initialize distributed environment
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
logger.info(f"Distributed environment initialized: world_size={world_size}, rank={rank}")
|
||||
else:
|
||||
world_size = 1
|
||||
rank = 0
|
||||
logger.info("Running in single-GPU mode")
|
||||
|
||||
# Validate distributed parameters
|
||||
if world_size > 1:
|
||||
validate_distributed_params("vocab_size", args.vocab_size, world_size)
|
||||
validate_distributed_params("n_heads", args.n_heads, world_size)
|
||||
validate_distributed_params("n_routed_experts", args.n_routed_experts, world_size)
|
||||
validate_distributed_params("args.dim for parallel linear", args.dim, world_size)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize distributed environment: {e}")
|
||||
raise DistributedInitializationError(f"Distributed initialization failed: {e}") from e
|
||||
|
||||
# Set data types
|
||||
Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
|
||||
Linear.scale_fmt = args.scale_fmt
|
||||
|
||||
super().__init__()
|
||||
self.max_seq_len = args.max_seq_len
|
||||
self.embed = ParallelEmbedding(args.vocab_size, args.dim)
|
||||
self.layers = torch.nn.ModuleList()
|
||||
for layer_id in range(args.n_layers):
|
||||
self.layers.append(Block(layer_id, args))
|
||||
|
||||
# Create transformer layers
|
||||
self.layers = torch.nn.ModuleList([
|
||||
Block(layer_id, args) for layer_id in range(args.n_layers)
|
||||
])
|
||||
|
||||
self.norm = RMSNorm(args.dim)
|
||||
self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.get_default_dtype())
|
||||
|
||||
# Precompute rotary embeddings
|
||||
self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)
|
||||
|
||||
@torch.inference_mode()
|
||||
@safe_dist_operation
|
||||
def forward(self, tokens: torch.Tensor, start_pos: int = 0):
|
||||
"""
|
||||
Forward pass for the Transformer model.
|
||||
@ -781,28 +1151,120 @@ class Transformer(nn.Module):
|
||||
Returns:
|
||||
torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
|
||||
"""
|
||||
seqlen = tokens.size(1)
|
||||
# Validate input
|
||||
if tokens.dim() != 2:
|
||||
raise ValueError(f"Tokens must be 2D tensor, got shape {tokens.shape}")
|
||||
|
||||
if tokens.numel() == 0:
|
||||
raise ValueError("Tokens tensor is empty")
|
||||
|
||||
bsz, seqlen = tokens.size()
|
||||
|
||||
if bsz > self.freqs_cis.size(0):
|
||||
raise ValueError(
|
||||
f"Batch size {bsz} exceeds maximum batch size {self.freqs_cis.size(0)}"
|
||||
)
|
||||
|
||||
if start_pos + seqlen > self.max_seq_len:
|
||||
raise ValueError(
|
||||
f"Sequence length {seqlen} at position {start_pos} exceeds "
|
||||
f"maximum sequence length {self.max_seq_len}"
|
||||
)
|
||||
|
||||
# Get embeddings
|
||||
h = self.embed(tokens)
|
||||
|
||||
# Get rotary embeddings for this position
|
||||
freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
|
||||
|
||||
# Create attention mask if needed
|
||||
mask = None
|
||||
if seqlen > 1:
|
||||
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1)
|
||||
for layer in self.layers:
|
||||
h = layer(h, start_pos, freqs_cis, mask)
|
||||
|
||||
# Apply transformer layers
|
||||
for layer_idx, layer in enumerate(self.layers):
|
||||
try:
|
||||
h = layer(h, start_pos, freqs_cis, mask)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in transformer layer {layer_idx}: {e}")
|
||||
raise DistributedOperationError(f"Layer {layer_idx} failed: {e}") from e
|
||||
|
||||
# Final normalization and projection
|
||||
h = self.norm(h)[:, -1]
|
||||
logits = self.head(h)
|
||||
|
||||
# Gather logits from all processes in distributed mode
|
||||
if world_size > 1:
|
||||
all_logits = [torch.empty_like(logits) for _ in range(world_size)]
|
||||
dist.all_gather(all_logits, logits)
|
||||
logits = torch.cat(all_logits, dim=-1)
|
||||
try:
|
||||
# Verify tensor shapes before all-gather
|
||||
all_logits = [torch.empty_like(logits) for _ in range(world_size)]
|
||||
|
||||
# Check that all tensors have same shape
|
||||
for i in range(world_size):
|
||||
if i == rank:
|
||||
continue
|
||||
all_logits[i] = torch.empty_like(logits)
|
||||
|
||||
dist.all_gather(all_logits, logits, timeout=DIST_TIMEOUT)
|
||||
|
||||
# Concatenate along vocabulary dimension
|
||||
logits = torch.cat(all_logits, dim=-1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"All-gather failed in Transformer forward: {e}")
|
||||
raise DistributedOperationError(f"All-gather failed: {e}") from e
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def setup_distributed():
|
||||
"""
|
||||
Setup distributed environment if needed.
|
||||
|
||||
Returns:
|
||||
tuple: (world_size, rank) if distributed is initialized
|
||||
"""
|
||||
global world_size, rank
|
||||
|
||||
try:
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
return world_size, rank
|
||||
else:
|
||||
world_size = 1
|
||||
rank = 0
|
||||
return world_size, rank
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to setup distributed environment: {e}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# Setup distributed environment
|
||||
setup_distributed()
|
||||
|
||||
# Run test
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
torch.manual_seed(0)
|
||||
|
||||
args = ModelArgs()
|
||||
x = torch.randint(0, args.vocab_size, (2, 128))
|
||||
model = Transformer(args)
|
||||
print(model(x).size())
|
||||
|
||||
# Test with small batch to catch errors early
|
||||
test_batch_size = min(2, args.max_batch_size)
|
||||
test_seq_len = min(32, args.max_seq_len)
|
||||
|
||||
x = torch.randint(0, args.vocab_size, (test_batch_size, test_seq_len))
|
||||
|
||||
try:
|
||||
model = Transformer(args)
|
||||
output = model(x)
|
||||
print(f"Success! Output size: {output.size()}")
|
||||
except Exception as e:
|
||||
logger.error(f"Test failed: {e}")
|
||||
raise
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user