[Chore] resolve some bugs due to merge

This commit is contained in:
i-yuanyukun 2025-12-18 15:56:20 +08:00
parent d306d01dd7
commit cd16bcff1e
3 changed files with 5 additions and 8 deletions

View File

@ -220,7 +220,7 @@ class GPUFFNModelRunner(LoRAModelRunnerMixin):
hidden_states, dim=0
)
ffn_output = self.model.compute_ffn_output(
current_layer_idx, gathered_hidden_states
gathered_hidden_states, current_layer_idx
)
# Extract the output corresponding to current rank
start_idx = hidden_states.shape[0] * get_tensor_model_parallel_rank()
@ -229,7 +229,7 @@ class GPUFFNModelRunner(LoRAModelRunnerMixin):
else:
# Single TP case
rank_ffn_output = self.model.compute_ffn_output(
current_layer_idx, hidden_states
hidden_states, current_layer_idx
)
return rank_ffn_output

View File

@ -3211,8 +3211,9 @@ class GPUModelRunner(
record_function_or_nullcontext("gpu_model_runner: forward"),
self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output,
):
logger.info(f"input_ids: {input_ids.shape}")
if inputs_embeds:
if input_ids is not None:
logger.info(f"input_ids: {input_ids.shape}")
if inputs_embeds is not None:
logger.info(f"inputs_embeds: {inputs_embeds.shape}")
model_output = self._model_forward(
input_ids=input_ids,

View File

@ -127,10 +127,6 @@ class UBatchWrapper:
comm_sms: int = envs.VLLM_DBO_COMM_SMS
set_comm_sms = lambda sms: None
if (
vllm_config.parallel_config.enable_expert_parallel
and not vllm_config.afd_config
):
if (
vllm_config.parallel_config.enable_expert_parallel
and not vllm_config.afd_config