diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 2f5d9ddd9054f..cd93f0ef1e310 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -27,7 +27,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.utils import cdiv -from .utils import extract_layer_index, maybe_prefix +from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index, + maybe_prefix) class OAIAttention(nn.Module): @@ -203,6 +204,7 @@ class GptOssModel(nn.Module): super().__init__() self.config = vllm_config.model_config.hf_config self.quant_config = vllm_config.quant_config + self.parallel_config = vllm_config.parallel_config self.config.hidden_size = self.config.hidden_size self.embedding = VocabParallelEmbedding( self.config.vocab_size, @@ -225,8 +227,364 @@ class GptOssModel(nn.Module): x = self.norm(x) return x + def _load_weights_mxfp4( + self, + ep_rank_end: int, + ep_rank_start: int, + heads_per_rank: int, + head_start: int, + weights: Iterable[tuple[str, torch.Tensor]], + stacked_params_mapping: list[tuple[str, ...]], + ) -> set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + mxfp4_block = 32 + use_ep = self.parallel_config.enable_expert_parallel + num_experts = self.config.num_local_experts + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + + intermediate_size = self.config.intermediate_size + intermediate_size_block = intermediate_size // mxfp4_block + per_rank_intermediate_size_block = cdiv(intermediate_size_block, + tp_size) + per_rank_intermediate_size = (per_rank_intermediate_size_block * + mxfp4_block) + + # Calculate common slicing bounds for current rank + tp_rank_start = tp_rank * per_rank_intermediate_size + tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, + intermediate_size) + + for name, weight in weights: + # FIXME(woosuk): Remove this after testing. + weight = weight.cuda() + + if ".w13_weight_scale" in name: + # Handle MLP gate and up projection weights scale + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, + 2 * tp_rank_start:2 * tp_rank_end, + ...] + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=None) + loaded_params.add(name) + continue + elif ".w2_weight_scale" in name: + # Handle MLP down projection weights + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[..., tp_rank_start // + mxfp4_block:tp_rank_end // + mxfp4_block] + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=None) + loaded_params.add(name) + continue + elif ".w13_weight" in name: + # Handle MLP gate and up projection weights + # flat weight from (E, 2 * N, block_size, entry_per_block) + # to (E, 2 * N, -1), shouldn't trigger copy for contiguous + weight = weight.view(num_experts, 2 * intermediate_size, + -1).contiguous() + + # Extract gate and up projection parts + # since the weight is shuffled, we can slice directly + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, + 2 * tp_rank_start:2 * tp_rank_end, + ...] + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=None) + loaded_params.add(name) + continue + elif ".w2_weight" in name: + # Handle MLP down projection weights + # same flatten here, but since 2 mx4 value are packed in 1 + # uint8, divide by 2 + weight = weight.view(num_experts, -1, + intermediate_size // 2).contiguous() + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[..., + tp_rank_start // 2:tp_rank_end // 2] + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=None) + loaded_params.add(name) + continue + elif ".w13_bias" in name: + # Handle MLP gate and up projection biases + # Extract gate and up projection bias parts + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, + 2 * tp_rank_start:2 * tp_rank_end] + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=None) + loaded_params.add(name) + continue + elif ".w2_bias" in name: + # Handle MLP down projection bias + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if use_ep: + weight = weight[ep_rank_start:ep_rank_end, ...] + else: + # (only load on rank 0 to avoid duplication) + if tp_rank != 0: + weight.zero_() + weight_loader(param, + weight, + weight_name=name, + shard_id=None, + expert_id=None) + loaded_params.add(name) + continue + elif "sinks" in name: + # Handle attention sinks (distributed across ranks) + param = params_dict[name] + narrow_weight = weight.narrow(0, head_start, heads_per_rank) + param.data.copy_(narrow_weight) + loaded_params.add(name) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if weight_loader == default_weight_loader: + weight_loader(param, weight) + else: + weight_loader(param, weight, shard_id) + break + else: + # Handle all other weights with potential renaming + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, weight) + loaded_params.add(name) + return loaded_params + + def _load_weights_other( + self, + ep_rank_start: int, + ep_rank_end: int, + heads_per_rank: int, + head_start: int, + weights: Iterable[tuple[str, torch.Tensor]], + stacked_params_mapping: list[tuple[str, ...]], + ) -> set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + use_ep = self.parallel_config.enable_expert_parallel + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + + intermediate_size = self.config.intermediate_size + per_rank_intermediate_size = cdiv(intermediate_size, tp_size) + # Calculate common slicing bounds for current rank + tp_rank_start = tp_rank * per_rank_intermediate_size + tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, + intermediate_size) + + for name, weight in weights: + if ".w13_weight" in name: + # Handle MLP gate and up projection weights + # Extract gate and up projection parts + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, :, + 2 * tp_rank_start:2 * tp_rank_end] + + narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() + param = params_dict[name] + + param.copy_(narrow_weight) + loaded_params.add(name) + continue + elif ".w2_weight" in name: + # Handle MLP down projection weights + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, tp_rank_start:tp_rank_end, :] + narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() + param = params_dict[name] + + param.copy_(narrow_weight) + loaded_params.add(name) + continue + elif ".w13_bias" in name: + # Handle MLP gate and up projection biases + # Extract gate and up projection bias parts + if use_ep: + narrow_weight = weight[ep_rank_start:ep_rank_end, ...] + else: + narrow_weight = weight[:, + 2 * tp_rank_start:2 * tp_rank_end] + + param = params_dict[name] + param.copy_(narrow_weight) + loaded_params.add(name) + continue + elif ".w2_bias" in name: + # Handle MLP down projection bias + if use_ep: + weight = weight[ep_rank_start:ep_rank_end, ...] + else: + # (only load on rank 0 to avoid duplication) + if tp_rank != 0: + weight.zero_() + param = params_dict[name] + param.copy_(weight) + loaded_params.add(name) + continue + elif "sinks" in name: + # Handle attention sinks (distributed across ranks) + param = params_dict[name] + narrow_weight = weight.narrow(0, head_start, heads_per_rank) + param.data.copy_(narrow_weight) + loaded_params.add(name) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if weight_loader == default_weight_loader: + weight_loader(param, weight) + else: + weight_loader(param, weight, shard_id) + break + else: + # Handle all other weights with potential renaming + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, weight) + loaded_params.add(name) + return loaded_params + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv", ".q_proj", "q"), + (".qkv", ".k_proj", "k"), + (".qkv", ".v_proj", "v"), + ] + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + + # Attention heads per rank + heads_per_rank = self.config.num_attention_heads // tp_size + head_start = tp_rank * heads_per_rank + + ep_size = get_ep_group().world_size + ep_rank = get_ep_group().rank + num_experts = self.config.num_local_experts + experts_per_rank = num_experts // ep_size + ep_rank_start = ep_rank * experts_per_rank + ep_rank_end = (ep_rank + 1) * experts_per_rank + + quant_method = (self.config.quantization_config['quant_method'] if + hasattr(self.config, "quantization_config") else None) + if quant_method == "mxfp4": + return self._load_weights_mxfp4(ep_rank_end, ep_rank_start, + heads_per_rank, head_start, + weights, stacked_params_mapping) + else: + return self._load_weights_other(ep_rank_end, ep_rank_start, + heads_per_rank, head_start, + weights, stacked_params_mapping) + class GptOssForCausalLM(nn.Module): + packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]} + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".self_attn.": ".attn.", + ".post_attention_layernorm.": ".mlp.norm.", + }, + orig_to_new_suffix={ + ".embed_tokens.weight": ".embedding.weight", + ".input_layernorm.weight": ".attn.norm.weight", + ".post_attention_layernorm.weight": ".mlp.norm.weight", + + # MoE MXFP4 weights + ".gate_up_proj_blocks": ".w13_weight", + ".down_proj_blocks": ".w2_weight", + ".gate_up_proj_scales": ".w13_weight_scale", + ".down_proj_scales": ".w2_weight_scale", + + # MoE other weights + ".gate_up_proj": ".w13_weight", + ".down_proj": ".w2_weight", + + # MoE Bias + ".gate_up_proj_bias": ".w13_bias", + ".down_proj_bias": ".w2_bias", + }, + ) def __init__( self, @@ -235,16 +593,17 @@ class GptOssForCausalLM(nn.Module): ): super().__init__() self.vllm_config = vllm_config - self.model_config = vllm_config.model_config.hf_config + self.config = vllm_config.model_config.hf_config + self.model = GptOssModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model"), ) self.lm_head = ParallelLMHead( - self.model_config.vocab_size, - self.model_config.hidden_size, + self.config.vocab_size, + self.config.hidden_size, ) - self.logits_processor = LogitsProcessor(self.model_config.vocab_size) + self.logits_processor = LogitsProcessor(self.config.vocab_size) def forward(self, input_ids: torch.Tensor, @@ -261,354 +620,11 @@ class GptOssForCausalLM(nn.Module): sampling_metadata) return logits - def _load_weights_mxfp4( - self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - rename_mapping = { - "self_attn": "attn", - "input_layernorm.weight": "attn.norm.weight", - "post_attention_layernorm.weight": "mlp.norm.weight", - "embed_tokens": "embedding", - } - - def maybe_rename(name: str) -> str: - for remap_name, new_name in rename_mapping.items(): - if remap_name in name: - return name.replace(remap_name, new_name) - return name - - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - mxfp4_block = 32 - - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() - intermediate_size = self.model_config.intermediate_size - intermediate_size_block = intermediate_size // mxfp4_block - per_rank_intermediate_size_block = cdiv(intermediate_size_block, - tp_size) - per_rank_intermediate_size = (per_rank_intermediate_size_block * - mxfp4_block) - - # Calculate common slicing bounds for current rank - tp_rank_start = tp_rank * per_rank_intermediate_size - tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, - intermediate_size) - - # Attention heads per rank - heads_per_rank = self.model_config.num_attention_heads // tp_size - head_start = tp_rank * heads_per_rank - - use_ep = self.vllm_config.parallel_config.enable_expert_parallel - ep_size = get_ep_group().world_size - ep_rank = get_ep_group().rank - num_experts = self.model_config.num_local_experts - experts_per_rank = num_experts // ep_size - ep_rank_start = ep_rank * experts_per_rank - ep_rank_end = (ep_rank + 1) * experts_per_rank - - for name, weight in weights: - # FIXME(woosuk): Remove this after testing. - weight = weight.cuda() - - if "gate_up_proj_blocks" in name: - # Handle MLP gate and up projection weights - new_name = name.replace("gate_up_proj_blocks", "w13_weight") - - # flat weight from (E, 2 * N, block_size, entry_per_block) - # to (E, 2 * N, -1), shouldn't trigger copy for contiguous - weight = weight.view(num_experts, 2 * intermediate_size, - -1).contiguous() - - # Extract gate and up projection parts - # since the weight is shuffled, we can slice directly - if use_ep: - narrow_weight = weight[ep_rank_start:ep_rank_end, ...] - else: - narrow_weight = weight[:, - 2 * tp_rank_start:2 * tp_rank_end, - ...] - - param = params_dict[new_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, - narrow_weight, - weight_name=new_name, - shard_id=None, - expert_id=None) - loaded_params.add(new_name) - - elif "down_proj_blocks" in name: - # Handle MLP down projection weights - new_name = name.replace("down_proj_blocks", "w2_weight") - # same flatten here, but since 2 mx4 value are packed in 1 - # uint8, divide by 2 - weight = weight.view(num_experts, -1, - intermediate_size // 2).contiguous() - if use_ep: - narrow_weight = weight[ep_rank_start:ep_rank_end, ...] - else: - narrow_weight = weight[..., - tp_rank_start // 2:tp_rank_end // 2] - - param = params_dict[new_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, - narrow_weight, - weight_name=new_name, - shard_id=None, - expert_id=None) - loaded_params.add(new_name) - - elif "gate_up_proj_scales" in name: - # Handle MLP gate and up projection weights scale - new_name = name.replace("gate_up_proj_scales", - "w13_weight_scale") - if use_ep: - narrow_weight = weight[ep_rank_start:ep_rank_end, ...] - else: - narrow_weight = weight[:, - 2 * tp_rank_start:2 * tp_rank_end, - ...] - - param = params_dict[new_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, - narrow_weight, - weight_name=new_name, - shard_id=None, - expert_id=None) - loaded_params.add(new_name) - - elif "down_proj_scales" in name: - # Handle MLP down projection weights - new_name = name.replace("down_proj_scales", "w2_weight_scale") - if use_ep: - narrow_weight = weight[ep_rank_start:ep_rank_end, ...] - else: - narrow_weight = weight[..., tp_rank_start // - mxfp4_block:tp_rank_end // - mxfp4_block] - - param = params_dict[new_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, - narrow_weight, - weight_name=new_name, - shard_id=None, - expert_id=None) - loaded_params.add(new_name) - elif "gate_up_proj_bias" in name: - # Handle MLP gate and up projection biases - new_name = name.replace("gate_up_proj_bias", "w13_bias") - - # Extract gate and up projection bias parts - if use_ep: - narrow_weight = weight[ep_rank_start:ep_rank_end, ...] - else: - narrow_weight = weight[:, - 2 * tp_rank_start:2 * tp_rank_end] - - param = params_dict[new_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, - narrow_weight, - weight_name=new_name, - shard_id=None, - expert_id=None) - loaded_params.add(new_name) - - elif "down_proj_bias" in name: - # Handle MLP down projection bias - new_name = name.replace("down_proj_bias", "w2_bias") - param = params_dict[new_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - if use_ep: - weight = weight[ep_rank_start:ep_rank_end, ...] - else: - # (only load on rank 0 to avoid duplication) - if tp_rank != 0: - weight.zero_() - weight_loader(param, - weight, - weight_name=new_name, - shard_id=None, - expert_id=None) - loaded_params.add(new_name) - elif "sinks" in name: - # Handle attention sinks (distributed across ranks) - name = name.replace("self_attn", "attn") - param = params_dict[name] - narrow_weight = weight.narrow(0, head_start, heads_per_rank) - param.data.copy_(narrow_weight) - loaded_params.add(name) - elif "q_proj" in name or "k_proj" in name or "v_proj" in name: - shard_id = ("q" if "q_proj" in name else - "k" if "k_proj" in name else "v") - name = name.replace("self_attn", "attn") - param_name = name.replace(f"{shard_id}_proj", "qkv") - param = params_dict[param_name] - weight_loader = param.weight_loader - weight_loader(param, weight, loaded_shard_id=shard_id) - loaded_params.add(param_name) - else: - # Handle all other weights with potential renaming - renamed_name = maybe_rename(name) - if renamed_name not in params_dict: - continue - param = params_dict[renamed_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, weight) - loaded_params.add(renamed_name) - - return loaded_params - - def _load_weights_other( - self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - rename_mapping = { - "self_attn": "attn", - "input_layernorm.weight": "attn.norm.weight", - "post_attention_layernorm.weight": "mlp.norm.weight", - "embed_tokens": "embedding", - } - - def maybe_rename(name: str) -> str: - for remap_name, new_name in rename_mapping.items(): - if remap_name in name: - return name.replace(remap_name, new_name) - return name - - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() - intermediate_size = self.model_config.intermediate_size - - per_rank_intermediate_size = cdiv(intermediate_size, tp_size) - # Calculate common slicing bounds for current rank - tp_rank_start = tp_rank * per_rank_intermediate_size - tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, - intermediate_size) - - # Attention heads per rank - heads_per_rank = self.model_config.num_attention_heads // tp_size - head_start = tp_rank * heads_per_rank - - use_ep = self.vllm_config.parallel_config.enable_expert_parallel - ep_size = get_ep_group().world_size - ep_rank = get_ep_group().rank - num_experts = self.model_config.num_local_experts - experts_per_rank = num_experts // ep_size - ep_rank_start = ep_rank * experts_per_rank - ep_rank_end = (ep_rank + 1) * experts_per_rank - - for name, weight in weights: - if ".experts.gate_up_proj" in name and "bias" not in name: - # Handle MLP gate and up projection weights - new_name = name.replace(".experts.gate_up_proj", - ".experts.w13_weight") - - # Extract gate and up projection parts - # since the weight is shuffled, we can slice directly - if use_ep: - narrow_weight = weight[ep_rank_start:ep_rank_end, ...] - else: - narrow_weight = weight[:, :, - 2 * tp_rank_start:2 * tp_rank_end] - - narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() - param = params_dict[new_name] - - param.copy_(narrow_weight) - loaded_params.add(new_name) - - elif ".experts.down_proj" in name and "bias" not in name: - # Handle MLP down projection weights - new_name = name.replace(".experts.down_proj", - ".experts.w2_weight") - - if use_ep: - narrow_weight = weight[ep_rank_start:ep_rank_end, ...] - else: - narrow_weight = weight[:, tp_rank_start:tp_rank_end, :] - narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() - param = params_dict[new_name] - - param.copy_(narrow_weight) - loaded_params.add(new_name) - - elif "gate_up_proj_bias" in name: - # Handle MLP gate and up projection biases - new_name = name.replace("gate_up_proj_bias", "w13_bias") - - # Extract gate and up projection bias parts - if use_ep: - narrow_weight = weight[ep_rank_start:ep_rank_end, ...] - else: - narrow_weight = weight[:, - 2 * tp_rank_start:2 * tp_rank_end] - - param = params_dict[new_name] - - param.copy_(narrow_weight) - loaded_params.add(new_name) - - elif "down_proj_bias" in name: - # Handle MLP down projection bias - new_name = name.replace("down_proj_bias", "w2_bias") - - if use_ep: - weight = weight[ep_rank_start:ep_rank_end, ...] - else: - # (only load on rank 0 to avoid duplication) - if tp_rank != 0: - weight.zero_() - param = params_dict[new_name] - param.copy_(weight) - loaded_params.add(new_name) - elif "sinks" in name: - # Handle attention sinks (distributed across ranks) - name = name.replace("self_attn", "attn") - param = params_dict[name] - narrow_weight = weight.narrow(0, head_start, heads_per_rank) - param.data.copy_(narrow_weight) - loaded_params.add(name) - elif "q_proj" in name or "k_proj" in name or "v_proj" in name: - shard_id = ("q" if "q_proj" in name else - "k" if "k_proj" in name else "v") - name = name.replace("self_attn", "attn") - param_name = name.replace(f"{shard_id}_proj", "qkv") - param = params_dict[param_name] - weight_loader = param.weight_loader - weight_loader(param, weight, loaded_shard_id=shard_id) - loaded_params.add(param_name) - else: - # Handle all other weights with potential renaming - - renamed_name = maybe_rename(name) - if renamed_name not in params_dict: - continue - param = params_dict[renamed_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, weight) - loaded_params.add(renamed_name) - - return loaded_params - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - quant_method = (self.model_config.quantization_config['quant_method'] - if hasattr(self.model_config, "quantization_config") - else None) - if quant_method == "mxfp4": - return self._load_weights_mxfp4(weights) - else: - return self._load_weights_other(weights) + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)