Merge dd56d2cab82de2cc59a2ac6ad02dcf20c09cc3fe into 9b4e9788e4a3a731f7567338ed15d3ec549ce03b

This commit is contained in:
Avi Sinha 2025-12-15 04:52:09 +00:00 committed by GitHub
commit 2f9215febe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,4 +1,5 @@
import math
import logging
from dataclasses import dataclass
from typing import Tuple, Optional, Literal
@ -9,13 +10,78 @@ import torch.distributed as dist
from kernel import act_quant, weight_dequant, fp8_gemm
# Configure logging
logger = logging.getLogger(__name__)
# Global configuration
world_size = 1
rank = 0
block_size = 128
gemm_impl: Literal["bf16", "fp8"] = "bf16"
attn_impl: Literal["naive", "absorb"] = "absorb"
# Distributed operation timeout (seconds)
DIST_TIMEOUT = 60.0
# Custom exceptions for distributed operations
class DistributedOperationError(Exception):
"""Base exception for distributed operation errors"""
pass
class TensorShapeMismatchError(DistributedOperationError):
"""Raised when tensor shapes don't match in distributed operations"""
pass
class WorldSizeMismatchError(DistributedOperationError):
"""Raised when world size doesn't match expectations"""
pass
class RankMismatchError(DistributedOperationError):
"""Raised when rank doesn't match expectations"""
pass
class DistributedInitializationError(DistributedOperationError):
"""Raised when distributed initialization fails"""
pass
def safe_dist_operation(func):
"""
Decorator for safe distributed operations with error handling.
Args:
func: Function to wrap with error handling
Returns:
Wrapped function with error handling
"""
def wrapper(*args, **kwargs):
try:
# Check if distributed environment is properly initialized
if world_size > 1:
if not dist.is_initialized():
raise DistributedInitializationError(
"Distributed environment not initialized but world_size > 1. "
"Call dist.init_process_group() before using distributed operations."
)
if rank >= world_size:
raise RankMismatchError(
f"Rank {rank} is out of bounds for world_size {world_size}"
)
return func(*args, **kwargs)
except dist.DistBackendError as e:
logger.error(f"Distributed backend error in {func.__name__}: {e}")
raise DistributedOperationError(f"Backend error in {func.__name__}: {e}") from e
except RuntimeError as e:
logger.error(f"Runtime error in distributed operation {func.__name__}: {e}")
raise DistributedOperationError(f"Runtime error in {func.__name__}: {e}") from e
except Exception as e:
logger.error(f"Unexpected error in {func.__name__}: {e}")
raise
return wrapper
@dataclass
class ModelArgs:
"""
@ -86,6 +152,25 @@ class ModelArgs:
mscale: float = 1.
def validate_distributed_params(param_name: str, param_value: int, divisor: int):
"""
Validate that a parameter is divisible by world_size.
Args:
param_name: Name of the parameter for error messages
param_value: Value of the parameter to validate
divisor: Expected divisor (usually world_size)
Raises:
WorldSizeMismatchError: If parameter is not divisible by divisor
"""
if divisor > 1 and param_value % divisor != 0:
raise WorldSizeMismatchError(
f"{param_name} ({param_value}) must be divisible by world_size ({divisor}) "
f"for distributed training. Remainder: {param_value % divisor}"
)
class ParallelEmbedding(nn.Module):
"""
Embedding layer with parallelism support across distributed processes.
@ -98,12 +183,25 @@ class ParallelEmbedding(nn.Module):
super().__init__()
self.vocab_size = vocab_size
self.dim = dim
assert vocab_size % world_size == 0, f"Vocabulary size must be divisible by world size (world_size={world_size})"
# Validate distributed parameters
validate_distributed_params("vocab_size", vocab_size, world_size)
self.part_vocab_size = (vocab_size // world_size)
self.vocab_start_idx = rank * self.part_vocab_size
self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
# Validate indices
if self.vocab_start_idx >= vocab_size or self.vocab_end_idx > vocab_size:
raise RankMismatchError(
f"Calculated vocabulary indices out of bounds for rank {rank}. "
f"Start: {self.vocab_start_idx}, End: {self.vocab_end_idx}, "
f"Vocab Size: {vocab_size}"
)
self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
@safe_dist_operation
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for parallel embedding layer.
@ -115,20 +213,48 @@ class ParallelEmbedding(nn.Module):
torch.Tensor: Embedded representations.
Raises:
ValueError: If `world_size` is not defined.
ValueError: If tensor contains indices outside valid range
DistributedOperationError: If distributed operation fails
"""
if world_size > 1:
# Check input tensor validity
if x.numel() == 0:
raise ValueError("Input tensor is empty")
# Create mask for indices outside our partition
mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
# Shift indices for local embedding lookup
x = x - self.vocab_start_idx
x[mask] = 0
y = F.embedding(x, self.weight)
if world_size > 1:
# Handle out-of-bounds indices
out_of_bounds = (x < 0) | (x >= self.part_vocab_size)
if out_of_bounds.any():
# Zero out invalid indices
x[out_of_bounds] = 0
mask = mask | out_of_bounds
# Perform embedding
y = F.embedding(x, self.weight)
# Zero out masked positions
y[mask] = 0
dist.all_reduce(y)
# All-reduce across all processes
try:
dist.all_reduce(y, timeout=DIST_TIMEOUT)
except Exception as e:
logger.error(f"All-reduce failed in ParallelEmbedding: {e}")
raise DistributedOperationError(f"All-reduce failed: {e}") from e
else:
# Single GPU case - use standard embedding
y = F.embedding(x, self.weight)
return y
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, scale_fmt: Optional[str] = None) -> torch.Tensor:
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None,
scale_fmt: Optional[str] = None) -> torch.Tensor:
"""
Applies a linear transformation to the incoming data: y = xA^T + b.
This function supports specialized implementations based on quantization
@ -139,23 +265,34 @@ def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] =
weight (torch.Tensor): The weight tensor. It may be quantized and
requires dequantization for certain cases.
bias (Optional[torch.Tensor]): The bias tensor to be added. Default is None.
scale_fmt (Optional[str]): Scale format for quantization.
Returns:
torch.Tensor: The result of the linear transformation, which may involve
quantization-aware computations depending on the input parameters.
torch.Tensor: The result of the linear transformation.
Notes:
- If `weight` is quantized (e.g., `element_size() == 1`), a dequantized version
is used for computation.
- If `gemm_impl == "bf16"`, dequantization and a `bf16` GEMM operation are applied.
- For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation.
Raises:
ValueError: If tensor shapes are incompatible
DistributedOperationError: If distributed operation fails
"""
# Validate input dimensions
if x.dim() != 2:
raise ValueError(f"Input tensor must be 2D, got shape {x.shape}")
if x.size(-1) != weight.size(-1):
raise ValueError(
f"Input feature dimension ({x.size(-1)}) must match "
f"weight feature dimension ({weight.size(-1)})"
)
if weight.element_size() > 1:
# Full precision weights
return F.linear(x, weight, bias)
elif gemm_impl == "bf16":
# BF16 implementation with dequantization
weight = weight_dequant(weight, weight.scale)
return F.linear(x, weight, bias)
else:
# FP8 implementation
x, scale = act_quant(x, block_size, scale_fmt)
y = fp8_gemm(x, scale, weight, weight.scale)
if bias is not None:
@ -180,13 +317,23 @@ class Linear(nn.Module):
super().__init__()
self.in_features = in_features
self.out_features = out_features
# Validate parameters
if in_features <= 0 or out_features <= 0:
raise ValueError(f"Invalid dimensions: in_features={in_features}, out_features={out_features}")
self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype or Linear.dtype))
if self.weight.element_size() == 1:
# Quantized weights - create scale parameters
scale_out_features = (out_features + block_size - 1) // block_size
scale_in_features = (in_features + block_size - 1) // block_size
self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32))
self.weight.scale = self.scale = nn.Parameter(
torch.empty(scale_out_features, scale_in_features, dtype=torch.float32)
)
else:
self.register_parameter("scale", None)
if bias:
self.bias = nn.Parameter(torch.empty(out_features))
else:
@ -216,10 +363,16 @@ class ColumnParallelLinear(Linear):
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
"""
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})"
# Validate distributed parameters
validate_distributed_params("out_features", out_features, world_size)
self.part_out_features = out_features // world_size
super().__init__(in_features, self.part_out_features, bias, dtype)
# Store full dimensions for validation
self.full_out_features = out_features
@safe_dist_operation
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for column parallel linear layer.
@ -230,6 +383,13 @@ class ColumnParallelLinear(Linear):
Returns:
torch.Tensor: Transformed tensor with column-parallel computation.
"""
# Validate input
if x.size(-1) != self.in_features:
raise TensorShapeMismatchError(
f"Input feature dimension ({x.size(-1)}) does not match "
f"layer input dimension ({self.in_features})"
)
y = linear(x, self.weight, self.bias)
return y
@ -245,10 +405,16 @@ class RowParallelLinear(Linear):
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
"""
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})"
# Validate distributed parameters
validate_distributed_params("in_features", in_features, world_size)
self.part_in_features = in_features // world_size
super().__init__(self.part_in_features, out_features, bias, dtype)
# Store full dimensions for validation
self.full_in_features = in_features
@safe_dist_operation
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for row parallel linear layer.
@ -259,11 +425,40 @@ class RowParallelLinear(Linear):
Returns:
torch.Tensor: Transformed tensor with row-parallel computation.
"""
y = linear(x, self.weight)
# Validate input
if x.size(-1) != self.full_in_features:
raise TensorShapeMismatchError(
f"Input feature dimension ({x.size(-1)}) does not match "
f"full layer input dimension ({self.full_in_features})"
)
# Split input along feature dimension for this partition
start_idx = rank * self.part_in_features
end_idx = start_idx + self.part_in_features
x_part = x[..., start_idx:end_idx]
# Apply linear transformation to local partition
y = linear(x_part, self.weight)
# All-reduce across all processes
if world_size > 1:
dist.all_reduce(y)
try:
# Verify tensor shape consistency before all-reduce
expected_shape = torch.Size([*x.shape[:-1], self.out_features])
if y.shape != expected_shape:
raise TensorShapeMismatchError(
f"Local output shape {y.shape} doesn't match expected shape {expected_shape}"
)
dist.all_reduce(y, timeout=DIST_TIMEOUT)
except Exception as e:
logger.error(f"All-reduce failed in RowParallelLinear: {e}")
raise DistributedOperationError(f"All-reduce failed: {e}") from e
# Add bias if present
if self.bias is not None:
y += self.bias
return y
@ -291,6 +486,11 @@ class RMSNorm(nn.Module):
Returns:
torch.Tensor: Normalized tensor with the same shape as input.
"""
if x.size(-1) != self.dim:
raise ValueError(
f"Input feature dimension ({x.size(-1)}) must match "
f"RMSNorm dimension ({self.dim})"
)
return F.rms_norm(x, (self.dim,), self.weight, self.eps)
@ -414,6 +614,10 @@ class MLA(nn.Module):
self.dim = args.dim
self.n_heads = args.n_heads
self.n_local_heads = args.n_heads // world_size
# Validate distributed parameters
validate_distributed_params("n_heads", args.n_heads, world_size)
self.q_lora_rank = args.q_lora_rank
self.kv_lora_rank = args.kv_lora_rank
self.qk_nope_head_dim = args.qk_nope_head_dim
@ -427,10 +631,12 @@ class MLA(nn.Module):
self.wq_a = Linear(self.dim, self.q_lora_rank)
self.q_norm = RMSNorm(self.q_lora_rank)
self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim)
self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
self.kv_norm = RMSNorm(self.kv_lora_rank)
self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
self.softmax_scale = self.qk_head_dim ** -0.5
if args.max_seq_len > args.original_seq_len:
mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
@ -443,6 +649,7 @@ class MLA(nn.Module):
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)
@safe_dist_operation
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
"""
Forward pass for the Multi-Head Latent Attention (MLA) Layer.
@ -458,41 +665,80 @@ class MLA(nn.Module):
"""
bsz, seqlen, _ = x.size()
end_pos = start_pos + seqlen
# Validate input dimensions
if x.size(-1) != self.dim:
raise TensorShapeMismatchError(
f"Input feature dimension ({x.size(-1)}) does not match "
f"MLA dimension ({self.dim})"
)
# Validate cache indices
if start_pos < 0 or end_pos > self.kv_cache.size(1) if attn_impl != "naive" else self.k_cache.size(1):
raise ValueError(
f"Cache indices out of bounds: start_pos={start_pos}, end_pos={end_pos}, "
f"cache_size={self.kv_cache.size(1) if attn_impl != 'naive' else self.k_cache.size(1)}"
)
# Compute queries
if self.q_lora_rank == 0:
q = self.wq(x)
else:
q = self.wq_b(self.q_norm(self.wq_a(x)))
q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
q_pe = apply_rotary_emb(q_pe, freqs_cis)
# Compute key-value projections
kv = self.wkv_a(x)
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
if attn_impl == "naive":
q = torch.cat([q_nope, q_pe], dim=-1)
kv = self.wkv_b(self.kv_norm(kv))
kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
self.k_cache[:bsz, start_pos:end_pos] = k
self.v_cache[:bsz, start_pos:end_pos] = v
# Update caches
self.k_cache[:bsz, start_pos:end_pos].copy_(k)
self.v_cache[:bsz, start_pos:end_pos].copy_(v)
# Compute attention scores
scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale
else:
wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size)
wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
# Update caches
self.kv_cache[:bsz, start_pos:end_pos].copy_(self.kv_norm(kv))
self.pe_cache[:bsz, start_pos:end_pos].copy_(k_pe.squeeze(2))
# Compute attention scores
scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale
# Apply mask if provided
if mask is not None:
if mask.shape != (seqlen, seqlen):
raise TensorShapeMismatchError(
f"Mask shape {mask.shape} does not match expected shape {(seqlen, seqlen)}"
)
scores += mask.unsqueeze(1)
# Apply softmax
scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
# Compute output
if attn_impl == "naive":
x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
else:
x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
x = self.wo(x.flatten(2))
return x
@ -515,6 +761,9 @@ class MLP(nn.Module):
inter_dim (int): Hidden layer dimensionality.
"""
super().__init__()
self.dim = dim
self.inter_dim = inter_dim
self.w1 = ColumnParallelLinear(dim, inter_dim)
self.w2 = RowParallelLinear(inter_dim, dim)
self.w3 = ColumnParallelLinear(dim, inter_dim)
@ -529,6 +778,13 @@ class MLP(nn.Module):
Returns:
torch.Tensor: Output tensor after MLP computation.
"""
# Validate input
if x.size(-1) != self.dim:
raise TensorShapeMismatchError(
f"Input feature dimension ({x.size(-1)}) does not match "
f"MLP dimension ({self.dim})"
)
return self.w2(F.silu(self.w1(x)) * self.w3(x))
@ -573,27 +829,42 @@ class Gate(nn.Module):
Returns:
Tuple[torch.Tensor, torch.Tensor]: Routing weights and selected expert indices.
"""
# Validate input
if x.size(-1) != self.dim:
raise TensorShapeMismatchError(
f"Input feature dimension ({x.size(-1)}) does not match "
f"gate dimension ({self.dim})"
)
scores = linear(x, self.weight)
if self.score_func == "softmax":
scores = scores.softmax(dim=-1, dtype=torch.float32)
else:
scores = scores.sigmoid()
original_scores = scores
if self.bias is not None:
scores = scores + self.bias
if self.n_groups > 1:
scores = scores.view(x.size(0), self.n_groups, -1)
if self.bias is None:
group_scores = scores.amax(dim=-1)
else:
group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)
indices = group_scores.topk(self.topk_groups, dim=-1)[1]
mask = scores.new_ones(x.size(0), self.n_groups, dtype=bool).scatter_(1, indices, False)
scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf")).flatten(1)
indices = torch.topk(scores, self.topk, dim=-1)[1]
weights = original_scores.gather(1, indices)
if self.score_func == "sigmoid":
weights /= weights.sum(dim=-1, keepdim=True)
weights *= self.route_scale
return weights.type_as(x), indices
@ -616,6 +887,9 @@ class Expert(nn.Module):
inter_dim (int): Hidden layer dimensionality.
"""
super().__init__()
self.dim = dim
self.inter_dim = inter_dim
self.w1 = Linear(dim, inter_dim)
self.w2 = Linear(inter_dim, dim)
self.w3 = Linear(dim, inter_dim)
@ -630,6 +904,13 @@ class Expert(nn.Module):
Returns:
torch.Tensor: Output tensor after expert computation.
"""
# Validate input
if x.size(-1) != self.dim:
raise TensorShapeMismatchError(
f"Input feature dimension ({x.size(-1)}) does not match "
f"expert dimension ({self.dim})"
)
return self.w2(F.silu(self.w1(x)) * self.w3(x))
@ -655,17 +936,35 @@ class MoE(nn.Module):
"""
super().__init__()
self.dim = args.dim
assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})"
# Validate distributed parameters
validate_distributed_params("n_routed_experts", args.n_routed_experts, world_size)
self.n_routed_experts = args.n_routed_experts
self.n_local_experts = args.n_routed_experts // world_size
self.n_activated_experts = args.n_activated_experts
self.experts_start_idx = rank * self.n_local_experts
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
# Validate expert indices
if self.experts_start_idx >= args.n_routed_experts or self.experts_end_idx > args.n_routed_experts:
raise RankMismatchError(
f"Calculated expert indices out of bounds for rank {rank}. "
f"Start: {self.experts_start_idx}, End: {self.experts_end_idx}, "
f"Total Experts: {args.n_routed_experts}"
)
self.gate = Gate(args)
self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None
for i in range(self.n_routed_experts)])
# Create expert modules only for local experts
self.experts = nn.ModuleList([
Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None
for i in range(self.n_routed_experts)
])
self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)
@safe_dist_operation
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the MoE module.
@ -676,20 +975,57 @@ class MoE(nn.Module):
Returns:
torch.Tensor: Output tensor after expert routing and computation.
"""
# Validate input
if x.size(-1) != self.dim:
raise TensorShapeMismatchError(
f"Input feature dimension ({x.size(-1)}) does not match "
f"MoE dimension ({self.dim})"
)
shape = x.size()
x = x.view(-1, self.dim)
weights, indices = self.gate(x)
y = torch.zeros_like(x)
# Count how many tokens are routed to each expert
counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
# Process local experts
for i in range(self.experts_start_idx, self.experts_end_idx):
if counts[i] == 0:
continue
expert = self.experts[i]
if expert is None:
# This shouldn't happen if indices are correct
continue
# Get tokens assigned to this expert
idx, top = torch.where(indices == i)
y[idx] += expert(x[idx]) * weights[idx, top, None]
# Apply expert to assigned tokens
if len(idx) > 0:
y[idx] += expert(x[idx]) * weights[idx, top, None]
# Apply shared experts
z = self.shared_experts(x)
# All-reduce across all processes to combine expert outputs
if world_size > 1:
dist.all_reduce(y)
try:
# Verify tensor shape before all-reduce
expected_shape = y.shape
if y.shape != expected_shape:
raise TensorShapeMismatchError(
f"Expert output shape {y.shape} doesn't match expected shape {expected_shape}"
)
dist.all_reduce(y, timeout=DIST_TIMEOUT)
except Exception as e:
logger.error(f"All-reduce failed in MoE: {e}")
raise DistributedOperationError(f"All-reduce failed: {e}") from e
return (y + z).view(shape)
@ -712,6 +1048,7 @@ class Block(nn.Module):
args (ModelArgs): Model arguments containing block parameters.
"""
super().__init__()
self.layer_id = layer_id
self.attn = MLA(args)
self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args)
self.attn_norm = RMSNorm(args.dim)
@ -730,8 +1067,12 @@ class Block(nn.Module):
Returns:
torch.Tensor: Output tensor after block computation.
"""
# Attention layer with residual connection
x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask)
# Feed-forward layer with residual connection
x = x + self.ffn(self.ffn_norm(x))
return x
@ -755,21 +1096,50 @@ class Transformer(nn.Module):
args (ModelArgs): Model arguments containing transformer parameters.
"""
global world_size, rank
world_size = dist.get_world_size() if dist.is_initialized() else 1
rank = dist.get_rank() if dist.is_initialized() else 0
try:
# Initialize distributed environment
if dist.is_available() and dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()
logger.info(f"Distributed environment initialized: world_size={world_size}, rank={rank}")
else:
world_size = 1
rank = 0
logger.info("Running in single-GPU mode")
# Validate distributed parameters
if world_size > 1:
validate_distributed_params("vocab_size", args.vocab_size, world_size)
validate_distributed_params("n_heads", args.n_heads, world_size)
validate_distributed_params("n_routed_experts", args.n_routed_experts, world_size)
validate_distributed_params("args.dim for parallel linear", args.dim, world_size)
except Exception as e:
logger.error(f"Failed to initialize distributed environment: {e}")
raise DistributedInitializationError(f"Distributed initialization failed: {e}") from e
# Set data types
Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
Linear.scale_fmt = args.scale_fmt
super().__init__()
self.max_seq_len = args.max_seq_len
self.embed = ParallelEmbedding(args.vocab_size, args.dim)
self.layers = torch.nn.ModuleList()
for layer_id in range(args.n_layers):
self.layers.append(Block(layer_id, args))
# Create transformer layers
self.layers = torch.nn.ModuleList([
Block(layer_id, args) for layer_id in range(args.n_layers)
])
self.norm = RMSNorm(args.dim)
self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.get_default_dtype())
# Precompute rotary embeddings
self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)
@torch.inference_mode()
@safe_dist_operation
def forward(self, tokens: torch.Tensor, start_pos: int = 0):
"""
Forward pass for the Transformer model.
@ -781,28 +1151,120 @@ class Transformer(nn.Module):
Returns:
torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
"""
seqlen = tokens.size(1)
# Validate input
if tokens.dim() != 2:
raise ValueError(f"Tokens must be 2D tensor, got shape {tokens.shape}")
if tokens.numel() == 0:
raise ValueError("Tokens tensor is empty")
bsz, seqlen = tokens.size()
if bsz > self.freqs_cis.size(0):
raise ValueError(
f"Batch size {bsz} exceeds maximum batch size {self.freqs_cis.size(0)}"
)
if start_pos + seqlen > self.max_seq_len:
raise ValueError(
f"Sequence length {seqlen} at position {start_pos} exceeds "
f"maximum sequence length {self.max_seq_len}"
)
# Get embeddings
h = self.embed(tokens)
# Get rotary embeddings for this position
freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
# Create attention mask if needed
mask = None
if seqlen > 1:
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1)
for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
# Apply transformer layers
for layer_idx, layer in enumerate(self.layers):
try:
h = layer(h, start_pos, freqs_cis, mask)
except Exception as e:
logger.error(f"Error in transformer layer {layer_idx}: {e}")
raise DistributedOperationError(f"Layer {layer_idx} failed: {e}") from e
# Final normalization and projection
h = self.norm(h)[:, -1]
logits = self.head(h)
# Gather logits from all processes in distributed mode
if world_size > 1:
all_logits = [torch.empty_like(logits) for _ in range(world_size)]
dist.all_gather(all_logits, logits)
logits = torch.cat(all_logits, dim=-1)
try:
# Verify tensor shapes before all-gather
all_logits = [torch.empty_like(logits) for _ in range(world_size)]
# Check that all tensors have same shape
for i in range(world_size):
if i == rank:
continue
all_logits[i] = torch.empty_like(logits)
dist.all_gather(all_logits, logits, timeout=DIST_TIMEOUT)
# Concatenate along vocabulary dimension
logits = torch.cat(all_logits, dim=-1)
except Exception as e:
logger.error(f"All-gather failed in Transformer forward: {e}")
raise DistributedOperationError(f"All-gather failed: {e}") from e
return logits
def setup_distributed():
"""
Setup distributed environment if needed.
Returns:
tuple: (world_size, rank) if distributed is initialized
"""
global world_size, rank
try:
if dist.is_available() and dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()
return world_size, rank
else:
world_size = 1
rank = 0
return world_size, rank
except Exception as e:
logger.error(f"Failed to setup distributed environment: {e}")
raise
if __name__ == "__main__":
# Setup logging
logging.basicConfig(level=logging.INFO)
# Setup distributed environment
setup_distributed()
# Run test
torch.set_default_dtype(torch.bfloat16)
torch.set_default_device("cuda")
torch.set_default_device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)
args = ModelArgs()
x = torch.randint(0, args.vocab_size, (2, 128))
model = Transformer(args)
print(model(x).size())
# Test with small batch to catch errors early
test_batch_size = min(2, args.max_batch_size)
test_seq_len = min(32, args.max_seq_len)
x = torch.randint(0, args.vocab_size, (test_batch_size, test_seq_len))
try:
model = Transformer(args)
output = model(x)
print(f"Success! Output size: {output.size()}")
except Exception as e:
logger.error(f"Test failed: {e}")
raise