mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:15:01 +08:00
[Chore] Minor simplification for non-PP path (#24810)
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
This commit is contained in:
parent
973c9d01da
commit
3e903b6cb4
@ -86,7 +86,7 @@ from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||
KVConnectorModelRunnerMixin, KVConnectorOutput)
|
||||
KVConnectorModelRunnerMixin)
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
|
||||
from .utils import (AttentionGroup, MultiModalBudget,
|
||||
@ -196,6 +196,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
||||
self.max_num_reqs = scheduler_config.max_num_seqs
|
||||
|
||||
# Broadcast PP output for external_launcher (torchrun)
|
||||
# to make sure we are synced across pp ranks
|
||||
# TODO: Support overlapping mirco-batches
|
||||
# https://github.com/vllm-project/vllm/issues/18019
|
||||
self.broadcast_pp_output = (
|
||||
self.parallel_config.distributed_executor_backend
|
||||
== "external_launcher" and len(get_pp_group().ranks) > 0)
|
||||
|
||||
# Model-related.
|
||||
self.num_query_heads = model_config.get_num_attention_heads(
|
||||
parallel_config)
|
||||
@ -1701,7 +1709,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
hidden_states: torch.Tensor,
|
||||
num_scheduled_tokens: int,
|
||||
num_scheduled_tokens_np: np.ndarray,
|
||||
kv_connector_output: Optional[KVConnectorOutput],
|
||||
) -> ModelRunnerOutput:
|
||||
assert self.input_batch.num_reqs ==\
|
||||
len(self.input_batch.pooling_params), \
|
||||
@ -1732,7 +1739,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=pooler_output,
|
||||
kv_connector_output=kv_connector_output,
|
||||
)
|
||||
|
||||
def _preprocess(
|
||||
@ -2073,39 +2079,47 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
with record_function_or_nullcontext("Postprocess"):
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
# True when EAGLE 3 is used.
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
else:
|
||||
# Common case.
|
||||
hidden_states = model_output
|
||||
aux_hidden_states = None
|
||||
|
||||
# Broadcast PP output for external_launcher (torchrun)
|
||||
# to make sure we are synced across pp ranks
|
||||
# TODO: Support overlapping mirco-batches
|
||||
# https://github.com/vllm-project/vllm/issues/18019
|
||||
broadcast_pp_output = \
|
||||
self.parallel_config.distributed_executor_backend \
|
||||
== "external_launcher" and len(get_pp_group().ranks) > 0
|
||||
if not get_pp_group().is_last_rank:
|
||||
# For mid-pipeline stages, return the hidden states.
|
||||
assert isinstance(hidden_states, IntermediateTensors)
|
||||
if not broadcast_pp_output:
|
||||
if not self.broadcast_pp_output:
|
||||
# Common case.
|
||||
if not get_pp_group().is_last_rank:
|
||||
# Return the intermediate tensors.
|
||||
assert isinstance(hidden_states, IntermediateTensors)
|
||||
hidden_states.kv_connector_output = kv_connector_output
|
||||
return hidden_states
|
||||
get_pp_group().send_tensor_dict(
|
||||
hidden_states.tensors, all_gather_group=get_tp_group())
|
||||
logits = None
|
||||
else:
|
||||
|
||||
if self.is_pooling_model:
|
||||
return self._pool(hidden_states, num_scheduled_tokens,
|
||||
num_scheduled_tokens_np,
|
||||
kv_connector_output)
|
||||
# Return the pooling output.
|
||||
output = self._pool(hidden_states, num_scheduled_tokens,
|
||||
num_scheduled_tokens_np)
|
||||
output.kv_connector_output = kv_connector_output
|
||||
return output
|
||||
|
||||
sample_hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
if broadcast_pp_output:
|
||||
model_output_broadcast_data = {
|
||||
"logits": logits.contiguous(),
|
||||
} if logits is not None else {}
|
||||
else:
|
||||
# Rare case.
|
||||
assert not self.is_pooling_model
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
get_pp_group().send_tensor_dict(
|
||||
hidden_states.tensors, all_gather_group=get_tp_group())
|
||||
logits = None
|
||||
else:
|
||||
sample_hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states,
|
||||
None)
|
||||
|
||||
model_output_broadcast_data = {}
|
||||
if logits is not None:
|
||||
model_output_broadcast_data["logits"] = logits.contiguous()
|
||||
|
||||
model_output_broadcast_data = get_pp_group(
|
||||
).broadcast_tensor_dict(model_output_broadcast_data,
|
||||
src=len(get_pp_group().ranks) - 1)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user