diff --git a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py index 6dc1f317b87d5..f85679400d1c6 100644 --- a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py +++ b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py @@ -52,9 +52,15 @@ class P2PAFDConnector(AFDConnectorBase): self._need_recv_metadata: bool = True self._tensor_metadata_list: dict[int, TensorMetadata] = {} self._current_afd_connector_metadata: AFDConnectorMetadata | None = None - self.num_hidden_layers: int = ( - self.config.model_config.hf_config.num_hidden_layers - ) + if getattr(self.config.model_config.hf_config, "text_config", None) is not None: + self.num_hidden_layers: int = ( + self.config.model_config.hf_config.text_config.num_hidden_layers + ) + else: + self.num_hidden_layers: int = ( + self.config.model_config.hf_config.num_hidden_layers + ) + self.recv_attn_output_counter: int = 0 self.recv_ffn_output_counter: int = 0 self.dp_metadata_list: dict[int, DPMetadata] = {} @@ -175,16 +181,28 @@ class P2PAFDConnector(AFDConnectorBase): tensor_metadata, self._current_afd_connector_metadata ) if self.config.parallel_config.data_parallel_size > 1: - logger.info("jcz recv_metadata num_of_stages:{}".format(self._current_afd_connector_metadata.num_of_stages)) + logger.info( + "jcz recv_metadata num_of_stages:{}".format( + self._current_afd_connector_metadata.num_of_stages + ) + ) for stage_idx in range(self._current_afd_connector_metadata.num_of_stages): num_tokens_per_ubatch = self._tensor_metadata_list[stage_idx].size[0] self.dp_metadata_list[stage_idx] = DPMetadata.make( self.config.parallel_config, num_tokens_per_ubatch, - torch.tensor([num_tokens_per_ubatch] * self.config.parallel_config.data_parallel_size, - device="cpu", dtype=torch.int32), + torch.tensor( + [num_tokens_per_ubatch] + * self.config.parallel_config.data_parallel_size, + device="cpu", + dtype=torch.int32, + ), ) - logger.info("jcz recv_metadata self.dp_metadata_list:{}".format(self.dp_metadata_list)) + logger.info( + "jcz recv_metadata self.dp_metadata_list:{}".format( + self.dp_metadata_list + ) + ) def _send_hidden_states( self, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 7938cff98c354..6b7842b00f54a 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -1670,7 +1670,7 @@ class DeepseekV2ForCausalLM( return hidden_states def compute_ffn_output( - self, current_layer_idx, hidden_states + self, hidden_states, current_layer_idx ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model.compute_ffn_output(hidden_states, current_layer_idx) return hidden_states diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py index 7077f1a22e8d7..eb723582a0ccb 100644 --- a/vllm/model_executor/models/step3_text.py +++ b/vllm/model_executor/models/step3_text.py @@ -11,12 +11,16 @@ from torch import nn from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.config import AFDConfig, CacheConfig, ModelConfig, VllmConfig from vllm.distributed import ( get_pp_group, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) +from vllm.distributed.afd_transfer.afd_connector.metadata import ( + AFDConnectorMetadata, +) +from vllm.forward_context import AFDMetadata, get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE @@ -37,6 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.step3_vl import Step3TextConfig +from vllm.v1.worker.ubatching import dbo_current_ubatch_id, dbo_enabled, dbo_yield from .interfaces import SupportsPP from .utils import ( @@ -228,54 +233,59 @@ class Step3TextDecoderLayer(nn.Module): config: Step3TextConfig, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, + afd_config: AFDConfig | None = None, prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size + self.afd_role = afd_config.afd_role if afd_config is not None else None - self.self_attn = Step3TextAttention( - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - num_kv_heads=1, - cache_config=cache_config, - quant_config=quant_config, - norm_eps=config.rms_norm_eps, - max_position_embedding=config.max_position_embedding, - head_dim=config.head_dim, - share_q_dim=config.share_q_dim, - rope_parameters=config.rope_parameters, - prefix=f"{prefix}.self_attn", - ) - - layer_idx = int(prefix.split("layers.")[1].split(".")[0]) - moe_layers_enum = getattr(config, "moe_layers_enum", None) - if moe_layers_enum is not None: - moe_layers_idx = [int(i) for i in moe_layers_enum.strip().split(",")] - else: - # Default to 1dense. - moe_layers_idx = [i for i in range(1, config.num_hidden_layers)] - - if layer_idx in moe_layers_idx: - self.moe = FusedMoEBlock( - config=config, quant_config=quant_config, prefix=f"{prefix}.moe" - ) - self.share_expert = Step3TextMLP( + if self.afd_role is None or self.afd_role == "attention": + self.self_attn = Step3TextAttention( hidden_size=self.hidden_size, - intermediate_size=config.share_expert_dim, - hidden_act="silu", + num_heads=config.num_attention_heads, + num_kv_heads=1, + cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix}.share_expert", + norm_eps=config.rms_norm_eps, + max_position_embedding=config.max_position_embedding, + head_dim=config.head_dim, + share_q_dim=config.share_q_dim, + rope_parameters=config.rope_parameters, + prefix=f"{prefix}.self_attn", ) - self.use_moe = True - else: - self.mlp = Step3TextMLP( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act="silu", - quant_config=quant_config, - prefix=f"{prefix}.mlp", - ) - self.use_moe = False + + self.layer_idx = int(prefix.split("layers.")[1].split(".")[0]) + + if self.afd_role is None or self.afd_role == "ffn": + moe_layers_enum = getattr(config, "moe_layers_enum", None) + if moe_layers_enum is not None: + moe_layers_idx = [int(i) for i in moe_layers_enum.strip().split(",")] + else: + # Default to 1dense. + moe_layers_idx = [i for i in range(1, config.num_hidden_layers)] + + if self.layer_idx in moe_layers_idx: + self.moe = FusedMoEBlock( + config=config, quant_config=quant_config, prefix=f"{prefix}.moe" + ) + self.share_expert = Step3TextMLP( + hidden_size=self.hidden_size, + intermediate_size=config.share_expert_dim, + hidden_act="silu", + quant_config=quant_config, + prefix=f"{prefix}.share_expert", + ) + self.use_moe = True + else: + self.mlp = Step3TextMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act="silu", + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.use_moe = False self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps @@ -300,6 +310,9 @@ class Step3TextDecoderLayer(nn.Module): hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + if self.afd_role == "attention": + return hidden_states, residual + if self.use_moe: share_output = self.share_expert(hidden_states) moe_output = self.moe(hidden_states) @@ -309,6 +322,16 @@ class Step3TextDecoderLayer(nn.Module): return hidden_states, residual + def compute_ffn_output(self, hidden_states): + assert self.afd_role == "ffn" + if self.use_moe: + share_output = self.share_expert(hidden_states) + moe_output = self.moe(hidden_states) + hidden_states = share_output + moe_output + else: + hidden_states = self.mlp(hidden_states) + return hidden_states + @support_torch_compile class Step3TextModel(nn.Module): @@ -317,6 +340,7 @@ class Step3TextModel(nn.Module): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config + afd_config = vllm_config.afd_config self.vocab_size = config.vocab_size self.config = config @@ -336,6 +360,7 @@ class Step3TextModel(nn.Module): config=config, cache_config=cache_config, quant_config=quant_config, + afd_config=afd_config, prefix=prefix, ), prefix=f"{prefix}.layers", @@ -352,6 +377,81 @@ class Step3TextModel(nn.Module): def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) + def forward_with_afd( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + positions: torch.Tensor, + afd_metadata: AFDMetadata, + ) -> tuple[torch.Tensor, torch.Tensor]: + forward_conext = get_forward_context() + recv_handle = None + + ubatch_hidden_states = [] + ubatch_residual = [] + + start_idx = 0 + for pos in afd_metadata.positions_list: + num_tokens = pos.shape[1] if pos.ndim == 2 else pos.shape[0] + end_idx = start_idx + num_tokens + ubatch_hidden_states.append(hidden_states[start_idx:end_idx]) + ubatch_residual.append( + residual[start_idx:end_idx] if residual is not None else None + ) + start_idx = end_idx + + for layer in islice(self.layers, self.start_layer, self.end_layer): + for stage_i in range(forward_conext.afd_metadata.num_of_stages): + afd_connector = afd_metadata.afd_connector + forward_conext.attn_metadata = afd_metadata.attn_metadata_list[stage_i] + forward_conext.dp_metadata = afd_metadata.dp_metadata_list[stage_i] + + residual = ubatch_residual[stage_i] + + if layer.layer_idx > 0: + hidden_states, recv_metadata = afd_connector.recv_ffn_output() + if recv_metadata.recv_handle_list is not None: + recv_handle = recv_metadata.recv_handle_list + else: + hidden_states = ubatch_hidden_states[stage_i] + + if recv_handle is not None: + for work in recv_handle: + work.wait() + + current_positions = afd_metadata.positions_list[stage_i] + hidden_states, residual = layer( + current_positions, hidden_states, residual + ) + + ubatch_hidden_states[stage_i] = hidden_states + ubatch_residual[stage_i] = residual + metadata = AFDConnectorMetadata.create_attention_metadata( + layer_idx=layer.layer_idx, + stage_idx=stage_i, + seq_len=hidden_states.shape[0], + dtype=hidden_states.dtype, + device=hidden_states.device, + num_of_stages=afd_metadata.num_of_stages, + afd_tokens_lens=afd_metadata.afd_tokens_lens, + ) + afd_connector.send_attn_output(hidden_states, metadata) + + # Recv last layer FFN output. + for stage_i in range(afd_metadata.num_of_stages): + ubatch_hidden_states[stage_i], recv_metadata = ( + afd_connector.recv_ffn_output() + ) + + # Re-assemble the batch + hidden_states = torch.cat(ubatch_hidden_states, dim=0) + if any(r is not None for r in ubatch_residual): + residual = torch.cat(ubatch_residual, dim=0) + else: + residual = None + + return hidden_states, residual + def forward( self, input_ids: torch.Tensor, @@ -370,8 +470,19 @@ class Step3TextModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in islice(self.layers, self.start_layer, self.end_layer): - hidden_states, residual = layer(positions, hidden_states, residual) + forward_ctx = get_forward_context() + afd_metadata = forward_ctx.afd_metadata if forward_ctx is not None else None + + if afd_metadata is not None: + hidden_states, residual = self.forward_with_afd( + hidden_states, + residual, + positions, + afd_metadata, + ) + else: + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors( @@ -384,6 +495,14 @@ class Step3TextModel(nn.Module): hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + def compute_ffn_output( + self, + hidden_states, + layer_idx, + ) -> torch.Tensor | IntermediateTensors: + hidden_states = self.layers[layer_idx].compute_ffn_output(hidden_states) + return hidden_states + class Step3TextForCausalLM(nn.Module, SupportsPP): def __init__( @@ -398,6 +517,11 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): self.config = config self.vllm_config = vllm_config + self.afd_config = vllm_config.afd_config + self.afd_role = ( + self.afd_config.afd_role if self.afd_config is not None else None + ) + self.model = Step3TextModel(vllm_config=vllm_config, prefix=prefix) if get_pp_group().is_last_rank: @@ -429,6 +553,14 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): ) return hidden_states + def compute_ffn_output( + self, + hidden_states, + current_layer_idx, + ) -> torch.Tensor | IntermediateTensors: + hidden_states = self.model.compute_ffn_output(hidden_states, current_layer_idx) + return hidden_states + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states) return logits @@ -477,9 +609,13 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): disable_moe_stacked_params = [data[1] for data in expert_params_mapping] for name, loaded_weight in weights: + if self.afd_role == "attention" and self.is_moe_weight(name): + continue + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue + if any( disable_moe_stacked_param in name for disable_moe_stacked_param in disable_moe_stacked_params @@ -498,6 +634,10 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): param_name, weight_name, shard_id = mapping if weight_name not in name: continue + + if self.afd_role is not None and self.afd_role == "attention": + continue + name = name.replace(weight_name, param_name) # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -521,6 +661,12 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): loaded_params.add(name) break else: + if ( + self.afd_role == "ffn" + and not self.is_moe_weight(name) + and not self.is_common_weight(name) + ): + continue for ( param_name, weight_name, @@ -552,3 +698,25 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params + + def is_moe_weight(self, name): + if ( + "shared_expert" in name + or "experts" in name + or "gate" in name + or "up" in name + or "down" in name + ): + return True + return False + + def is_common_weight(self, name): + if ( + "lm_head" in name + or "model.norm.weight" in name + or "embed" in name + or "input_layernorm" in name + or "post_attention_layernorm" in name + ): + return True + return False diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index e5038e56a2708..e16cb53e9f194 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -1126,6 +1126,16 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) return hidden_states + def compute_ffn_output( + self, + hidden_states, + current_layer_idx, + ) -> torch.Tensor | IntermediateTensors: + hidden_states = self.language_model.compute_ffn_output( + hidden_states, current_layer_idx + ) + return hidden_states + def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/v1/worker/gpu_ffn_model_runner.py b/vllm/v1/worker/gpu_ffn_model_runner.py index cb08c9c05ae58..79ba7e6eb5c2e 100644 --- a/vllm/v1/worker/gpu_ffn_model_runner.py +++ b/vllm/v1/worker/gpu_ffn_model_runner.py @@ -130,15 +130,17 @@ class GPUFFNModelRunner(LoRAModelRunnerMixin): try: hidden_states, recv_metadata = self.connector.recv_attn_output() - if hasattr(self.connector, 'dp_metadata_list'): - dp_metadata = self.connector.dp_metadata_list.get(recv_metadata.stage_idx, None) + if hasattr(self.connector, "dp_metadata_list"): + dp_metadata = self.connector.dp_metadata_list.get( + recv_metadata.stage_idx, None + ) else: dp_metadata = None current_layer_idx = recv_metadata.layer_idx - logger.info( - f"layer {current_layer_idx} moe recv hidden states type:{type(hidden_states)}, shape:{hidden_states.shape}" - f" dp_metadata: {dp_metadata}" - ) + # logger.info( + # f"layer {current_layer_idx} moe recv hidden states type:{type(hidden_states)}, shape:{hidden_states.shape}" + # f" dp_metadata: {dp_metadata}" + # ) num_tokens = hidden_states.shape[0] if recv_metadata is not None and recv_metadata.recv_handle_list is not None: for work in recv_metadata.recv_handle_list: @@ -220,7 +222,7 @@ class GPUFFNModelRunner(LoRAModelRunnerMixin): hidden_states, dim=0 ) ffn_output = self.model.compute_ffn_output( - current_layer_idx, gathered_hidden_states + gathered_hidden_states, current_layer_idx ) # Extract the output corresponding to current rank start_idx = hidden_states.shape[0] * get_tensor_model_parallel_rank() @@ -229,7 +231,7 @@ class GPUFFNModelRunner(LoRAModelRunnerMixin): else: # Single TP case rank_ffn_output = self.model.compute_ffn_output( - current_layer_idx, hidden_states + hidden_states, current_layer_idx ) return rank_ffn_output @@ -349,7 +351,7 @@ class GPUFFNModelRunner(LoRAModelRunnerMixin): hidden_states, dim=0 ) ffn_output = self.model.compute_ffn_output( - current_layer_idx, gathered_hidden_states + gathered_hidden_states, current_layer_idx ) # Extract the output corresponding to current rank @@ -359,7 +361,7 @@ class GPUFFNModelRunner(LoRAModelRunnerMixin): else: # Single TP case rank_ffn_output = self.model.compute_ffn_output( - current_layer_idx, hidden_states + hidden_states, current_layer_idx ) return rank_ffn_output diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 82dd80d267182..2409c9071f94b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3141,7 +3141,9 @@ class GPUModelRunner( # Mark KV scales as calculated after the first forward pass self.calculate_kv_scales = False - afd_metadata = self._build_afd_metadata(ubatch_slices_padded, num_tokens_unpadded) + afd_metadata = self._build_afd_metadata( + ubatch_slices_padded, num_tokens_unpadded + ) self.profiler.step() # Run the model. @@ -3160,8 +3162,9 @@ class GPUModelRunner( record_function_or_nullcontext("gpu_model_runner: forward"), self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, ): - logger.info(f"input_ids: {input_ids.shape}") - if inputs_embeds: + if input_ids is not None: + logger.info(f"input_ids: {input_ids.shape}") + if inputs_embeds is not None: logger.info(f"inputs_embeds: {inputs_embeds.shape}") model_output = self._model_forward( input_ids=input_ids, @@ -4274,7 +4277,9 @@ class GPUModelRunner( if num_tokens_across_dp is not None: num_tokens_across_dp[:] = num_tokens_padded - afd_metadata = self._build_afd_metadata(ubatch_slices_padded, num_tokens_unpadded) + afd_metadata = self._build_afd_metadata( + ubatch_slices_padded, num_tokens_unpadded + ) with ( self.maybe_randomize_inputs(input_ids, inputs_embeds), diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py index 9e17c718c5513..7a44b6cbd42b9 100644 --- a/vllm/v1/worker/gpu_ubatch_wrapper.py +++ b/vllm/v1/worker/gpu_ubatch_wrapper.py @@ -405,9 +405,12 @@ class UBatchWrapper: afd_metadata.input_ids_list.append(sliced_input_ids) afd_metadata.positions_list.append(sliced_positions) afd_metadata.inputs_embeds_list.append(sliced_inputs_embeds) - afd_metadata.intermediate_tensors_list.append(sliced_intermediate_tensors) + afd_metadata.intermediate_tensors_list.append( + sliced_intermediate_tensors + ) afd_metadata.attn_metadata_list.append( - attn_metadata[i] if attn_metadata is not None else None) + attn_metadata[i] if attn_metadata is not None else None + ) afd_metadata.dp_metadata_list.append(ubatch_dp_metadata) return afd_metadata @@ -481,7 +484,7 @@ class UBatchWrapper: # num_tokens, we don't have a non-ubatched one. Without this # check, the cudagraph wrapper will try to capture a cudagraph # for this shape during a normal run. - + if cudagraph_runtime_mode is CUDAGraphMode.FULL: assert batch_descriptor is not None if batch_descriptor.num_tokens in self.cudagraphs: