From dd56d2cab82de2cc59a2ac6ad02dcf20c09cc3fe Mon Sep 17 00:00:00 2001 From: Avi Sinha Date: Mon, 15 Dec 2025 10:20:59 +0530 Subject: [PATCH] Update model.py --- inference/model.py | 548 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 505 insertions(+), 43 deletions(-) diff --git a/inference/model.py b/inference/model.py index 8868499..326a113 100644 --- a/inference/model.py +++ b/inference/model.py @@ -1,4 +1,5 @@ import math +import logging from dataclasses import dataclass from typing import Tuple, Optional, Literal @@ -9,13 +10,78 @@ import torch.distributed as dist from kernel import act_quant, weight_dequant, fp8_gemm +# Configure logging +logger = logging.getLogger(__name__) +# Global configuration world_size = 1 rank = 0 block_size = 128 gemm_impl: Literal["bf16", "fp8"] = "bf16" attn_impl: Literal["naive", "absorb"] = "absorb" +# Distributed operation timeout (seconds) +DIST_TIMEOUT = 60.0 + +# Custom exceptions for distributed operations +class DistributedOperationError(Exception): + """Base exception for distributed operation errors""" + pass + +class TensorShapeMismatchError(DistributedOperationError): + """Raised when tensor shapes don't match in distributed operations""" + pass + +class WorldSizeMismatchError(DistributedOperationError): + """Raised when world size doesn't match expectations""" + pass + +class RankMismatchError(DistributedOperationError): + """Raised when rank doesn't match expectations""" + pass + +class DistributedInitializationError(DistributedOperationError): + """Raised when distributed initialization fails""" + pass + + +def safe_dist_operation(func): + """ + Decorator for safe distributed operations with error handling. + + Args: + func: Function to wrap with error handling + + Returns: + Wrapped function with error handling + """ + def wrapper(*args, **kwargs): + try: + # Check if distributed environment is properly initialized + if world_size > 1: + if not dist.is_initialized(): + raise DistributedInitializationError( + "Distributed environment not initialized but world_size > 1. " + "Call dist.init_process_group() before using distributed operations." + ) + if rank >= world_size: + raise RankMismatchError( + f"Rank {rank} is out of bounds for world_size {world_size}" + ) + + return func(*args, **kwargs) + except dist.DistBackendError as e: + logger.error(f"Distributed backend error in {func.__name__}: {e}") + raise DistributedOperationError(f"Backend error in {func.__name__}: {e}") from e + except RuntimeError as e: + logger.error(f"Runtime error in distributed operation {func.__name__}: {e}") + raise DistributedOperationError(f"Runtime error in {func.__name__}: {e}") from e + except Exception as e: + logger.error(f"Unexpected error in {func.__name__}: {e}") + raise + return wrapper + + @dataclass class ModelArgs: """ @@ -86,6 +152,25 @@ class ModelArgs: mscale: float = 1. +def validate_distributed_params(param_name: str, param_value: int, divisor: int): + """ + Validate that a parameter is divisible by world_size. + + Args: + param_name: Name of the parameter for error messages + param_value: Value of the parameter to validate + divisor: Expected divisor (usually world_size) + + Raises: + WorldSizeMismatchError: If parameter is not divisible by divisor + """ + if divisor > 1 and param_value % divisor != 0: + raise WorldSizeMismatchError( + f"{param_name} ({param_value}) must be divisible by world_size ({divisor}) " + f"for distributed training. Remainder: {param_value % divisor}" + ) + + class ParallelEmbedding(nn.Module): """ Embedding layer with parallelism support across distributed processes. @@ -98,12 +183,25 @@ class ParallelEmbedding(nn.Module): super().__init__() self.vocab_size = vocab_size self.dim = dim - assert vocab_size % world_size == 0, f"Vocabulary size must be divisible by world size (world_size={world_size})" + + # Validate distributed parameters + validate_distributed_params("vocab_size", vocab_size, world_size) + self.part_vocab_size = (vocab_size // world_size) self.vocab_start_idx = rank * self.part_vocab_size self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size + + # Validate indices + if self.vocab_start_idx >= vocab_size or self.vocab_end_idx > vocab_size: + raise RankMismatchError( + f"Calculated vocabulary indices out of bounds for rank {rank}. " + f"Start: {self.vocab_start_idx}, End: {self.vocab_end_idx}, " + f"Vocab Size: {vocab_size}" + ) + self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim)) + @safe_dist_operation def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass for parallel embedding layer. @@ -115,20 +213,48 @@ class ParallelEmbedding(nn.Module): torch.Tensor: Embedded representations. Raises: - ValueError: If `world_size` is not defined. + ValueError: If tensor contains indices outside valid range + DistributedOperationError: If distributed operation fails """ if world_size > 1: + # Check input tensor validity + if x.numel() == 0: + raise ValueError("Input tensor is empty") + + # Create mask for indices outside our partition mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx) + + # Shift indices for local embedding lookup x = x - self.vocab_start_idx - x[mask] = 0 - y = F.embedding(x, self.weight) - if world_size > 1: + + # Handle out-of-bounds indices + out_of_bounds = (x < 0) | (x >= self.part_vocab_size) + if out_of_bounds.any(): + # Zero out invalid indices + x[out_of_bounds] = 0 + mask = mask | out_of_bounds + + # Perform embedding + y = F.embedding(x, self.weight) + + # Zero out masked positions y[mask] = 0 - dist.all_reduce(y) + + # All-reduce across all processes + try: + dist.all_reduce(y, timeout=DIST_TIMEOUT) + except Exception as e: + logger.error(f"All-reduce failed in ParallelEmbedding: {e}") + raise DistributedOperationError(f"All-reduce failed: {e}") from e + else: + # Single GPU case - use standard embedding + y = F.embedding(x, self.weight) + return y -def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, scale_fmt: Optional[str] = None) -> torch.Tensor: +def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, + scale_fmt: Optional[str] = None) -> torch.Tensor: """ Applies a linear transformation to the incoming data: y = xA^T + b. This function supports specialized implementations based on quantization @@ -139,23 +265,34 @@ def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = weight (torch.Tensor): The weight tensor. It may be quantized and requires dequantization for certain cases. bias (Optional[torch.Tensor]): The bias tensor to be added. Default is None. + scale_fmt (Optional[str]): Scale format for quantization. Returns: - torch.Tensor: The result of the linear transformation, which may involve - quantization-aware computations depending on the input parameters. + torch.Tensor: The result of the linear transformation. - Notes: - - If `weight` is quantized (e.g., `element_size() == 1`), a dequantized version - is used for computation. - - If `gemm_impl == "bf16"`, dequantization and a `bf16` GEMM operation are applied. - - For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation. + Raises: + ValueError: If tensor shapes are incompatible + DistributedOperationError: If distributed operation fails """ + # Validate input dimensions + if x.dim() != 2: + raise ValueError(f"Input tensor must be 2D, got shape {x.shape}") + + if x.size(-1) != weight.size(-1): + raise ValueError( + f"Input feature dimension ({x.size(-1)}) must match " + f"weight feature dimension ({weight.size(-1)})" + ) + if weight.element_size() > 1: + # Full precision weights return F.linear(x, weight, bias) elif gemm_impl == "bf16": + # BF16 implementation with dequantization weight = weight_dequant(weight, weight.scale) return F.linear(x, weight, bias) else: + # FP8 implementation x, scale = act_quant(x, block_size, scale_fmt) y = fp8_gemm(x, scale, weight, weight.scale) if bias is not None: @@ -180,13 +317,23 @@ class Linear(nn.Module): super().__init__() self.in_features = in_features self.out_features = out_features + + # Validate parameters + if in_features <= 0 or out_features <= 0: + raise ValueError(f"Invalid dimensions: in_features={in_features}, out_features={out_features}") + self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype or Linear.dtype)) + if self.weight.element_size() == 1: + # Quantized weights - create scale parameters scale_out_features = (out_features + block_size - 1) // block_size scale_in_features = (in_features + block_size - 1) // block_size - self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32)) + self.weight.scale = self.scale = nn.Parameter( + torch.empty(scale_out_features, scale_in_features, dtype=torch.float32) + ) else: self.register_parameter("scale", None) + if bias: self.bias = nn.Parameter(torch.empty(out_features)) else: @@ -216,10 +363,16 @@ class ColumnParallelLinear(Linear): dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. """ def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): - assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})" + # Validate distributed parameters + validate_distributed_params("out_features", out_features, world_size) + self.part_out_features = out_features // world_size super().__init__(in_features, self.part_out_features, bias, dtype) + + # Store full dimensions for validation + self.full_out_features = out_features + @safe_dist_operation def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass for column parallel linear layer. @@ -230,6 +383,13 @@ class ColumnParallelLinear(Linear): Returns: torch.Tensor: Transformed tensor with column-parallel computation. """ + # Validate input + if x.size(-1) != self.in_features: + raise TensorShapeMismatchError( + f"Input feature dimension ({x.size(-1)}) does not match " + f"layer input dimension ({self.in_features})" + ) + y = linear(x, self.weight, self.bias) return y @@ -245,10 +405,16 @@ class RowParallelLinear(Linear): dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. """ def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): - assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})" + # Validate distributed parameters + validate_distributed_params("in_features", in_features, world_size) + self.part_in_features = in_features // world_size super().__init__(self.part_in_features, out_features, bias, dtype) + + # Store full dimensions for validation + self.full_in_features = in_features + @safe_dist_operation def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass for row parallel linear layer. @@ -259,11 +425,40 @@ class RowParallelLinear(Linear): Returns: torch.Tensor: Transformed tensor with row-parallel computation. """ - y = linear(x, self.weight) + # Validate input + if x.size(-1) != self.full_in_features: + raise TensorShapeMismatchError( + f"Input feature dimension ({x.size(-1)}) does not match " + f"full layer input dimension ({self.full_in_features})" + ) + + # Split input along feature dimension for this partition + start_idx = rank * self.part_in_features + end_idx = start_idx + self.part_in_features + x_part = x[..., start_idx:end_idx] + + # Apply linear transformation to local partition + y = linear(x_part, self.weight) + + # All-reduce across all processes if world_size > 1: - dist.all_reduce(y) + try: + # Verify tensor shape consistency before all-reduce + expected_shape = torch.Size([*x.shape[:-1], self.out_features]) + if y.shape != expected_shape: + raise TensorShapeMismatchError( + f"Local output shape {y.shape} doesn't match expected shape {expected_shape}" + ) + + dist.all_reduce(y, timeout=DIST_TIMEOUT) + except Exception as e: + logger.error(f"All-reduce failed in RowParallelLinear: {e}") + raise DistributedOperationError(f"All-reduce failed: {e}") from e + + # Add bias if present if self.bias is not None: y += self.bias + return y @@ -291,6 +486,11 @@ class RMSNorm(nn.Module): Returns: torch.Tensor: Normalized tensor with the same shape as input. """ + if x.size(-1) != self.dim: + raise ValueError( + f"Input feature dimension ({x.size(-1)}) must match " + f"RMSNorm dimension ({self.dim})" + ) return F.rms_norm(x, (self.dim,), self.weight, self.eps) @@ -414,6 +614,10 @@ class MLA(nn.Module): self.dim = args.dim self.n_heads = args.n_heads self.n_local_heads = args.n_heads // world_size + + # Validate distributed parameters + validate_distributed_params("n_heads", args.n_heads, world_size) + self.q_lora_rank = args.q_lora_rank self.kv_lora_rank = args.kv_lora_rank self.qk_nope_head_dim = args.qk_nope_head_dim @@ -427,10 +631,12 @@ class MLA(nn.Module): self.wq_a = Linear(self.dim, self.q_lora_rank) self.q_norm = RMSNorm(self.q_lora_rank) self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim) + self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim) self.kv_norm = RMSNorm(self.kv_lora_rank) self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim)) self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim) + self.softmax_scale = self.qk_head_dim ** -0.5 if args.max_seq_len > args.original_seq_len: mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0 @@ -443,6 +649,7 @@ class MLA(nn.Module): self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False) self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False) + @safe_dist_operation def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): """ Forward pass for the Multi-Head Latent Attention (MLA) Layer. @@ -458,41 +665,80 @@ class MLA(nn.Module): """ bsz, seqlen, _ = x.size() end_pos = start_pos + seqlen + + # Validate input dimensions + if x.size(-1) != self.dim: + raise TensorShapeMismatchError( + f"Input feature dimension ({x.size(-1)}) does not match " + f"MLA dimension ({self.dim})" + ) + + # Validate cache indices + if start_pos < 0 or end_pos > self.kv_cache.size(1) if attn_impl != "naive" else self.k_cache.size(1): + raise ValueError( + f"Cache indices out of bounds: start_pos={start_pos}, end_pos={end_pos}, " + f"cache_size={self.kv_cache.size(1) if attn_impl != 'naive' else self.k_cache.size(1)}" + ) + + # Compute queries if self.q_lora_rank == 0: q = self.wq(x) else: q = self.wq_b(self.q_norm(self.wq_a(x))) + q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim) q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) q_pe = apply_rotary_emb(q_pe, freqs_cis) + + # Compute key-value projections kv = self.wkv_a(x) kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis) + if attn_impl == "naive": q = torch.cat([q_nope, q_pe], dim=-1) kv = self.wkv_b(self.kv_norm(kv)) kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1) - self.k_cache[:bsz, start_pos:end_pos] = k - self.v_cache[:bsz, start_pos:end_pos] = v + + # Update caches + self.k_cache[:bsz, start_pos:end_pos].copy_(k) + self.v_cache[:bsz, start_pos:end_pos].copy_(v) + + # Compute attention scores scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale else: wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size) wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank) q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim]) - self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv) - self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2) + + # Update caches + self.kv_cache[:bsz, start_pos:end_pos].copy_(self.kv_norm(kv)) + self.pe_cache[:bsz, start_pos:end_pos].copy_(k_pe.squeeze(2)) + + # Compute attention scores scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) + torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale + + # Apply mask if provided if mask is not None: + if mask.shape != (seqlen, seqlen): + raise TensorShapeMismatchError( + f"Mask shape {mask.shape} does not match expected shape {(seqlen, seqlen)}" + ) scores += mask.unsqueeze(1) + + # Apply softmax scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x) + + # Compute output if attn_impl == "naive": x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos]) else: x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos]) x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:]) + x = self.wo(x.flatten(2)) return x @@ -515,6 +761,9 @@ class MLP(nn.Module): inter_dim (int): Hidden layer dimensionality. """ super().__init__() + self.dim = dim + self.inter_dim = inter_dim + self.w1 = ColumnParallelLinear(dim, inter_dim) self.w2 = RowParallelLinear(inter_dim, dim) self.w3 = ColumnParallelLinear(dim, inter_dim) @@ -529,6 +778,13 @@ class MLP(nn.Module): Returns: torch.Tensor: Output tensor after MLP computation. """ + # Validate input + if x.size(-1) != self.dim: + raise TensorShapeMismatchError( + f"Input feature dimension ({x.size(-1)}) does not match " + f"MLP dimension ({self.dim})" + ) + return self.w2(F.silu(self.w1(x)) * self.w3(x)) @@ -573,27 +829,42 @@ class Gate(nn.Module): Returns: Tuple[torch.Tensor, torch.Tensor]: Routing weights and selected expert indices. """ + # Validate input + if x.size(-1) != self.dim: + raise TensorShapeMismatchError( + f"Input feature dimension ({x.size(-1)}) does not match " + f"gate dimension ({self.dim})" + ) + scores = linear(x, self.weight) + if self.score_func == "softmax": scores = scores.softmax(dim=-1, dtype=torch.float32) else: scores = scores.sigmoid() + original_scores = scores + if self.bias is not None: scores = scores + self.bias + if self.n_groups > 1: scores = scores.view(x.size(0), self.n_groups, -1) if self.bias is None: group_scores = scores.amax(dim=-1) else: group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1) + indices = group_scores.topk(self.topk_groups, dim=-1)[1] mask = scores.new_ones(x.size(0), self.n_groups, dtype=bool).scatter_(1, indices, False) scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf")).flatten(1) + indices = torch.topk(scores, self.topk, dim=-1)[1] weights = original_scores.gather(1, indices) + if self.score_func == "sigmoid": weights /= weights.sum(dim=-1, keepdim=True) + weights *= self.route_scale return weights.type_as(x), indices @@ -616,6 +887,9 @@ class Expert(nn.Module): inter_dim (int): Hidden layer dimensionality. """ super().__init__() + self.dim = dim + self.inter_dim = inter_dim + self.w1 = Linear(dim, inter_dim) self.w2 = Linear(inter_dim, dim) self.w3 = Linear(dim, inter_dim) @@ -630,6 +904,13 @@ class Expert(nn.Module): Returns: torch.Tensor: Output tensor after expert computation. """ + # Validate input + if x.size(-1) != self.dim: + raise TensorShapeMismatchError( + f"Input feature dimension ({x.size(-1)}) does not match " + f"expert dimension ({self.dim})" + ) + return self.w2(F.silu(self.w1(x)) * self.w3(x)) @@ -655,17 +936,35 @@ class MoE(nn.Module): """ super().__init__() self.dim = args.dim - assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})" + + # Validate distributed parameters + validate_distributed_params("n_routed_experts", args.n_routed_experts, world_size) + self.n_routed_experts = args.n_routed_experts self.n_local_experts = args.n_routed_experts // world_size self.n_activated_experts = args.n_activated_experts self.experts_start_idx = rank * self.n_local_experts self.experts_end_idx = self.experts_start_idx + self.n_local_experts + + # Validate expert indices + if self.experts_start_idx >= args.n_routed_experts or self.experts_end_idx > args.n_routed_experts: + raise RankMismatchError( + f"Calculated expert indices out of bounds for rank {rank}. " + f"Start: {self.experts_start_idx}, End: {self.experts_end_idx}, " + f"Total Experts: {args.n_routed_experts}" + ) + self.gate = Gate(args) - self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None - for i in range(self.n_routed_experts)]) + + # Create expert modules only for local experts + self.experts = nn.ModuleList([ + Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None + for i in range(self.n_routed_experts) + ]) + self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim) + @safe_dist_operation def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass for the MoE module. @@ -676,20 +975,57 @@ class MoE(nn.Module): Returns: torch.Tensor: Output tensor after expert routing and computation. """ + # Validate input + if x.size(-1) != self.dim: + raise TensorShapeMismatchError( + f"Input feature dimension ({x.size(-1)}) does not match " + f"MoE dimension ({self.dim})" + ) + shape = x.size() x = x.view(-1, self.dim) + weights, indices = self.gate(x) y = torch.zeros_like(x) + + # Count how many tokens are routed to each expert counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist() + + # Process local experts for i in range(self.experts_start_idx, self.experts_end_idx): if counts[i] == 0: continue + expert = self.experts[i] + if expert is None: + # This shouldn't happen if indices are correct + continue + + # Get tokens assigned to this expert idx, top = torch.where(indices == i) - y[idx] += expert(x[idx]) * weights[idx, top, None] + + # Apply expert to assigned tokens + if len(idx) > 0: + y[idx] += expert(x[idx]) * weights[idx, top, None] + + # Apply shared experts z = self.shared_experts(x) + + # All-reduce across all processes to combine expert outputs if world_size > 1: - dist.all_reduce(y) + try: + # Verify tensor shape before all-reduce + expected_shape = y.shape + if y.shape != expected_shape: + raise TensorShapeMismatchError( + f"Expert output shape {y.shape} doesn't match expected shape {expected_shape}" + ) + + dist.all_reduce(y, timeout=DIST_TIMEOUT) + except Exception as e: + logger.error(f"All-reduce failed in MoE: {e}") + raise DistributedOperationError(f"All-reduce failed: {e}") from e + return (y + z).view(shape) @@ -712,6 +1048,7 @@ class Block(nn.Module): args (ModelArgs): Model arguments containing block parameters. """ super().__init__() + self.layer_id = layer_id self.attn = MLA(args) self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args) self.attn_norm = RMSNorm(args.dim) @@ -730,8 +1067,12 @@ class Block(nn.Module): Returns: torch.Tensor: Output tensor after block computation. """ + # Attention layer with residual connection x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask) + + # Feed-forward layer with residual connection x = x + self.ffn(self.ffn_norm(x)) + return x @@ -755,21 +1096,50 @@ class Transformer(nn.Module): args (ModelArgs): Model arguments containing transformer parameters. """ global world_size, rank - world_size = dist.get_world_size() if dist.is_initialized() else 1 - rank = dist.get_rank() if dist.is_initialized() else 0 + + try: + # Initialize distributed environment + if dist.is_available() and dist.is_initialized(): + world_size = dist.get_world_size() + rank = dist.get_rank() + logger.info(f"Distributed environment initialized: world_size={world_size}, rank={rank}") + else: + world_size = 1 + rank = 0 + logger.info("Running in single-GPU mode") + + # Validate distributed parameters + if world_size > 1: + validate_distributed_params("vocab_size", args.vocab_size, world_size) + validate_distributed_params("n_heads", args.n_heads, world_size) + validate_distributed_params("n_routed_experts", args.n_routed_experts, world_size) + validate_distributed_params("args.dim for parallel linear", args.dim, world_size) + + except Exception as e: + logger.error(f"Failed to initialize distributed environment: {e}") + raise DistributedInitializationError(f"Distributed initialization failed: {e}") from e + + # Set data types Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16 Linear.scale_fmt = args.scale_fmt + super().__init__() self.max_seq_len = args.max_seq_len self.embed = ParallelEmbedding(args.vocab_size, args.dim) - self.layers = torch.nn.ModuleList() - for layer_id in range(args.n_layers): - self.layers.append(Block(layer_id, args)) + + # Create transformer layers + self.layers = torch.nn.ModuleList([ + Block(layer_id, args) for layer_id in range(args.n_layers) + ]) + self.norm = RMSNorm(args.dim) self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.get_default_dtype()) + + # Precompute rotary embeddings self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False) @torch.inference_mode() + @safe_dist_operation def forward(self, tokens: torch.Tensor, start_pos: int = 0): """ Forward pass for the Transformer model. @@ -781,28 +1151,120 @@ class Transformer(nn.Module): Returns: torch.Tensor: Logits tensor of shape (batch_size, vocab_size). """ - seqlen = tokens.size(1) + # Validate input + if tokens.dim() != 2: + raise ValueError(f"Tokens must be 2D tensor, got shape {tokens.shape}") + + if tokens.numel() == 0: + raise ValueError("Tokens tensor is empty") + + bsz, seqlen = tokens.size() + + if bsz > self.freqs_cis.size(0): + raise ValueError( + f"Batch size {bsz} exceeds maximum batch size {self.freqs_cis.size(0)}" + ) + + if start_pos + seqlen > self.max_seq_len: + raise ValueError( + f"Sequence length {seqlen} at position {start_pos} exceeds " + f"maximum sequence length {self.max_seq_len}" + ) + + # Get embeddings h = self.embed(tokens) + + # Get rotary embeddings for this position freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen] + + # Create attention mask if needed mask = None if seqlen > 1: mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1) - for layer in self.layers: - h = layer(h, start_pos, freqs_cis, mask) + + # Apply transformer layers + for layer_idx, layer in enumerate(self.layers): + try: + h = layer(h, start_pos, freqs_cis, mask) + except Exception as e: + logger.error(f"Error in transformer layer {layer_idx}: {e}") + raise DistributedOperationError(f"Layer {layer_idx} failed: {e}") from e + + # Final normalization and projection h = self.norm(h)[:, -1] logits = self.head(h) + + # Gather logits from all processes in distributed mode if world_size > 1: - all_logits = [torch.empty_like(logits) for _ in range(world_size)] - dist.all_gather(all_logits, logits) - logits = torch.cat(all_logits, dim=-1) + try: + # Verify tensor shapes before all-gather + all_logits = [torch.empty_like(logits) for _ in range(world_size)] + + # Check that all tensors have same shape + for i in range(world_size): + if i == rank: + continue + all_logits[i] = torch.empty_like(logits) + + dist.all_gather(all_logits, logits, timeout=DIST_TIMEOUT) + + # Concatenate along vocabulary dimension + logits = torch.cat(all_logits, dim=-1) + + except Exception as e: + logger.error(f"All-gather failed in Transformer forward: {e}") + raise DistributedOperationError(f"All-gather failed: {e}") from e + return logits +def setup_distributed(): + """ + Setup distributed environment if needed. + + Returns: + tuple: (world_size, rank) if distributed is initialized + """ + global world_size, rank + + try: + if dist.is_available() and dist.is_initialized(): + world_size = dist.get_world_size() + rank = dist.get_rank() + return world_size, rank + else: + world_size = 1 + rank = 0 + return world_size, rank + except Exception as e: + logger.error(f"Failed to setup distributed environment: {e}") + raise + + if __name__ == "__main__": + # Setup logging + logging.basicConfig(level=logging.INFO) + + # Setup distributed environment + setup_distributed() + + # Run test torch.set_default_dtype(torch.bfloat16) - torch.set_default_device("cuda") + torch.set_default_device("cuda" if torch.cuda.is_available() else "cpu") torch.manual_seed(0) + args = ModelArgs() - x = torch.randint(0, args.vocab_size, (2, 128)) - model = Transformer(args) - print(model(x).size()) + + # Test with small batch to catch errors early + test_batch_size = min(2, args.max_batch_size) + test_seq_len = min(32, args.max_seq_len) + + x = torch.randint(0, args.vocab_size, (test_batch_size, test_seq_len)) + + try: + model = Transformer(args) + output = model(x) + print(f"Success! Output size: {output.size()}") + except Exception as e: + logger.error(f"Test failed: {e}") + raise