diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 254cd2e10b8fb..e16fc13c945cf 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -874,6 +874,14 @@ class FusedMoE(torch.nn.Module): elif shard_id == "w2": param_data[expert_id] = loaded_weight + def _load_w13_weight_scale(self, shard_dim: int, + loaded_weight: torch.Tensor, + param: torch.Tensor, tp_rank: int): + shard_size = param.shape[shard_dim] + loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, + shard_size) + param.copy_(loaded_weight) + def _load_model_weight_or_group_weight_scale(self, shard_dim: int, expert_data: torch.Tensor, @@ -1123,7 +1131,12 @@ class FusedMoE(torch.nn.Module): "weight_scale_2" in weight_name if uses_weight_scale_2 else "weight_scale" in weight_name) or "input_scale" in weight_name - if per_tensor_conditions: + if "w13_weight_scale" in weight_name: + self._load_w13_weight_scale(shard_dim=shard_dim, + loaded_weight=loaded_weight, + param=param, + tp_rank=self.tp_rank) + elif per_tensor_conditions: self._load_per_tensor_weight_scale( shard_id=shard_id, param=param, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 38866586ae29e..8fbc3231d86c3 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -778,8 +778,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): # Swizzle the weight blockscale. # contracting dimension is input dimension # block_size = 16; - assert (layer.weight_scale.shape[1] % 16 == 0), ( - "Expected weight_scale.dim(1) to be divisible by 16") assert (layer.weight_scale.dtype == torch.float8_e4m3fn), ( "Weight Block scale must be represented as FP8-E4M3") swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index fab1c163ac288..470e701d98013 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -342,34 +342,94 @@ class Llama4Model(LlamaModel): expert_params_mapping: list[tuple[str, str, int, str]], fused: bool = True, ) -> bool: + """ + Load MoE expert weights. + + Args: + name: The name of the weight to load. + loaded_weight: The weight to load. + params_dict: The dictionary of module parameters. + loaded_params: The set of already loaded parameters. + expert_params_mapping: The mapping of expert parameters. Must be + generated by FusedMoE.make_expert_params_mapping(). + fused: Whether the expert weights are fused into a single weight + tensor or are separate weight tensors for each expert. + When fused is True, loaded_weight should have shape of: + [num_experts, hidden_in, hidden_out] for gate/up/down proj and + [hidden_out, hidden_in] for the others like router. + When fused is False, loaded_weight should have shape of: + [hidden_out, hidden_in]. + + Returns: + True if loaded_weight is one of MoE weights and the MoE expert + weights are loaded successfully, False otherwise. + """ + + # Whether the MoE expert weights are loaded successfully. expert_param_loaded = False - if "experts.gate_up_proj" in name: - loaded_weight = loaded_weight.chunk(2, dim=-1) + + # If fused is True, the loaded weight is in the layout of: + # [num_experts, hidden_in, hidden_out], so we must transpose the last + # two dimensions to match the expected layout of the parameters. + if fused and loaded_weight.ndim == 3: + loaded_weight = loaded_weight.transpose(-1, -2) + + # If the gate_proj and up_proj weights are fused into a single + # weight tensor, we need to split the weight tensor into a tuple + # of two weight tensors along the hidden_out dimension. + if "experts.gate_up_proj" in name: + loaded_weight = loaded_weight.chunk(2, dim=-2) + + # Iterate over all the expert parameters and load the weights if we find + # a match in weight name. for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping: + + # Get a view of the loaded_weight to avoid modifying the original + # one across iterations. new_loaded_weight = loaded_weight + + # If expert weights are fused into a single weight tensor, remove + # the expert index from the expected weight name. if fused: + # The string between e_str and proj_str is the expert index. e_str, _, proj_str, _ = weight_name.split('.') weight_name = f"{e_str}.{proj_str}" param_name = f"{param_name}weight" + + # Skip if the current weight is not one of the MoE weights. if weight_name not in name: continue + + # Replace the weight name with the parameter name. full_param_name = name.replace(weight_name, param_name) - # Skip layers on other devices. + + # Skip if the current weight corresponds to a parameter that + # does not exist on the current PP (pipeline parallel) rank. if is_pp_missing_parameter(name, self): continue + + # Skip if the current weight is for the bias. if ((name.endswith(".bias") or name.endswith("_bias")) and name not in params_dict): continue + param = params_dict[full_param_name] weight_loader = param.weight_loader + if fused: + # If the parameter is for w13 together, the corresponding weight + # will be a tuple, so we must select the correct weight + # depending on the shard id, which is either "w1" or "w3". if "w13" in full_param_name: + assert shard_id in ["w1", "w3"] shard_idx = 0 if shard_id == "w1" else 1 new_loaded_weight = new_loaded_weight[shard_idx] - new_loaded_weight = new_loaded_weight.transpose(-1, -2) + + # If EP (expert parallel) is enabled, update expert_id to the + # starting expert index for the current EP rank and extract the + # corresponding expert weights. layer_idx = extract_layer_index(name) - # EP mapping expert_map = self.layers[ layer_idx].feed_forward.experts.expert_map if expert_map is not None: @@ -382,6 +442,9 @@ class Llama4Model(LlamaModel): else: # TODO: add EP support for non fused weights pass + + # Load the weight into the module parameter with corresponding + # shard id and expert id. weight_loader(param, new_loaded_weight, full_param_name, @@ -390,10 +453,13 @@ class Llama4Model(LlamaModel): loaded_params.add(full_param_name) expert_param_loaded = True + return expert_param_loaded def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # Name mapping from the parameter name to the shard name and + # corresponding shard id. stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -402,26 +468,43 @@ class Llama4Model(LlamaModel): (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), ] + # Indicate whether the expert weights are fused into a single weight + # tensor. fused_experts_params = False + # Expert parameter mapping for the case where the expert weights are + # not fused into a single weight tensor. expert_params_mapping = FusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.num_experts) + # Expert parameter mapping for the case where the expert weights are + # fused into a single weight tensor. expert_params_mapping_fused = FusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_up_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="gate_up_proj", num_experts=1) + # All the module parameters. params_dict = dict(self.named_parameters()) + # The module parameters that have been loaded. loaded_params: set[str] = set() + + # Iterate over all the weights and load them into module parameters. for name, loaded_weight in weights: + + # If the name contains "experts.gate_up_proj" or "experts.down_proj" + # without the expert indices, it means the expert weights are fused + # into a single weight tensor across all experts. if "experts.gate_up_proj" in name or "experts.down_proj" in name: fused_experts_params = True expert_params_mapping = expert_params_mapping_fused + + # If kv cache quantization scales exist and the weight name + # corresponds to one of the kv cache quantization scales, load + # them. if (self.quant_config is not None and (scale_name := self.quant_config.get_cache_scale(name))): - # Loading kv cache quantization scales param = params_dict[scale_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) @@ -430,84 +513,119 @@ class Llama4Model(LlamaModel): weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue + + # Iterate over stacked_params_mapping to check if the current weight + # is one of the stacked parameters. If so, load the weight with the + # corresponding shard id. Note that MoE weights are handled + # separately in the else block. for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip if the current weight is not one of the stacked + # parameters or if the current weight is a MoE weight. if weight_name not in name or "experts" in name: continue - # This check is for ModelOpt ckpts with kv cache quant enabled + + # For ModelOpt checkpoints, we need to rename the self_attn + # weight/weight_scale names except for kv cache scales. if not (name.endswith( (".k_scale", ".v_scale")) and "self_attn" in name): name = name.replace(weight_name, param_name) + + # Skip if the current weight corresponds to a parameter that + # does not exist on the current PP (pipeline parallel) rank. if is_pp_missing_parameter(name, self): continue - if name.endswith("scale") and "expert" not in name: - # Remapping the name of FP8 kv-scale. + + # Remap kv cache scale names for ModelOpt checkpoints. + # TODO: ModelOpt should implement get_cache_scale() such that + # kv cache scale name remapping can be done there. + if name.endswith("scale"): name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue + + # Load the weight into the module parameter with corresponding + # shard id and exit the for loop and the else block. param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) + if weight_loader == default_weight_loader: weight_loader(param, loaded_weight) else: weight_loader(param, loaded_weight, shard_id) + loaded_params.add(name) break + + # Handle normal (non-stacked) weights and MoE weights. else: - moe_loaded = self.load_moe_expert_weights( - name, - loaded_weight, - params_dict, - loaded_params, - expert_params_mapping, - fused=fused_experts_params) + # First, try to load MoE weights using load_moe_expert_weights. + # If successful, move on to next loaded weight. + if self.load_moe_expert_weights(name, + loaded_weight, + params_dict, + loaded_params, + expert_params_mapping, + fused=fused_experts_params): + continue - if not moe_loaded: - if is_pp_missing_parameter(name, self): - continue + # Skip if the current weight corresponds to a parameter that + # does not exist on the current PP (pipeline parallel) rank. + if is_pp_missing_parameter(name, self): + continue - # Handle flat expert scale parameters that - # don't match per-expert patterns - if ("experts." in name and ("w13_input_scale" in name - or "w13_weight_scale" in name - or "w2_input_scale" in name - or "w2_weight_scale" in name)): - # These are flat expert scales that apply to all experts - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - - # Check for MoE-specific loading support via - # attribute instead of expensive runtime reflection - supports_moe = getattr(weight_loader, - 'supports_moe_loading', False) - - if supports_moe: - # This is a MoE weight loader - if "w13_" in name: - shard_id = "w1" - elif "w2_" in name: - shard_id = "w2" - else: - shard_id = "w1" - - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=0) - else: - # Regular weight loader (handles both - # param.weight_loader and default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - continue + # Handle flat expert scale parameters that don't match + # per-expert patterns, i.e. one weight scale tensor for all + # experts. + scale_names = [ + "w13_input_scale", "w13_weight_scale", "w2_input_scale", + "w2_weight_scale" + ] + if ("experts." in name and any(scale_name in name + for scale_name in scale_names)): param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) + + # If weight loader supports special moe loading, use it to + # avoid expensive runtime reflection + if getattr(weight_loader, 'supports_moe_loading', False): + # Map the weight name to the corresponding shard id. + shard_id = "w2" if "w2_" in name else "w1" + + # Transpose if weight scales are FP8 block scales with + # three dimensions: + # [num_experts, hidden_in, hidden_out]. + if name.endswith("weight_scale") \ + and loaded_weight.dtype == torch.float8_e4m3fn \ + and loaded_weight.ndim == 3: + loaded_weight = loaded_weight.transpose(-1, -2) + + # Load the weight into the module parameter with + # corresponding shard id and expert id. + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=0) + + else: + # Regular weight loader (handles both + # param.weight_loader and default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + continue + + # Handle normal (non-stacked, non-MoE) weights. + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + # Finally, return the set of loaded parameters. return loaded_params @@ -560,23 +678,43 @@ class Llama4ForCausalLM(LlamaForCausalLM): loaded_weight: torch.Tensor, ) -> tuple[str, torch.Tensor]: - def permute(w: torch.Tensor, n_heads: int): + # Helper function to permute the weight's channels + def permute(w: torch.Tensor, n_heads: int, is_weight_scale: bool): + + # Calculate the expected shape of the weight. + # Do not rely on w's shape, as it may be in another layout. attn_in = self.config.head_dim * n_heads attn_out = self.config.hidden_size + # If the weight is FP4 packed as uint8, we need to divide attn_out + # by 2. + if w.dtype == torch.uint8 and w.shape[1] * 2 == attn_out: + attn_out = attn_out // 2 + + # If the weight is a weight scale, we need to divide attn_out by + # block size, which is currently 16. + elif w.dtype == torch.float8_e4m3fn and is_weight_scale \ + and w.shape[1] * 16 == attn_out: + attn_out = attn_out // 16 + return w.view(n_heads, attn_in // n_heads // 2, 2, attn_out).transpose(1, 2).reshape(attn_in, attn_out) modules = name.split(".") - # rotary embeds should be sliced - if ("wk" in modules or "k_proj" in modules) \ - and modules[-1] == "weight": - loaded_weight = permute(loaded_weight, - self.config.num_key_value_heads) - elif ("wq" in modules or "q_proj" in modules) \ - and modules[-1] == "weight": - loaded_weight = permute(loaded_weight, - self.config.num_attention_heads) + # Permute Q/K weights and weight block scales for rotary embedding + is_weight = modules[-1] == "weight" + is_nvfp4_weight_scale = (modules[-1] == "weight_scale" and + loaded_weight.dtype == torch.float8_e4m3fn) + + if is_weight or is_nvfp4_weight_scale: + if ("wk" in modules or "k_proj" in modules): + loaded_weight = permute(loaded_weight, + self.config.num_key_value_heads, + is_nvfp4_weight_scale) + elif ("wq" in modules or "q_proj" in modules): + loaded_weight = permute(loaded_weight, + self.config.num_attention_heads, + is_nvfp4_weight_scale) return name, loaded_weight