mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 10:55:45 +08:00
[Chore] Minor simplification for non-PP path (#24810)
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
This commit is contained in:
parent
973c9d01da
commit
3e903b6cb4
@ -86,7 +86,7 @@ from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
|||||||
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
|
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
|
||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||||
KVConnectorModelRunnerMixin, KVConnectorOutput)
|
KVConnectorModelRunnerMixin)
|
||||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||||
|
|
||||||
from .utils import (AttentionGroup, MultiModalBudget,
|
from .utils import (AttentionGroup, MultiModalBudget,
|
||||||
@ -196,6 +196,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
||||||
self.max_num_reqs = scheduler_config.max_num_seqs
|
self.max_num_reqs = scheduler_config.max_num_seqs
|
||||||
|
|
||||||
|
# Broadcast PP output for external_launcher (torchrun)
|
||||||
|
# to make sure we are synced across pp ranks
|
||||||
|
# TODO: Support overlapping mirco-batches
|
||||||
|
# https://github.com/vllm-project/vllm/issues/18019
|
||||||
|
self.broadcast_pp_output = (
|
||||||
|
self.parallel_config.distributed_executor_backend
|
||||||
|
== "external_launcher" and len(get_pp_group().ranks) > 0)
|
||||||
|
|
||||||
# Model-related.
|
# Model-related.
|
||||||
self.num_query_heads = model_config.get_num_attention_heads(
|
self.num_query_heads = model_config.get_num_attention_heads(
|
||||||
parallel_config)
|
parallel_config)
|
||||||
@ -1701,7 +1709,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
num_scheduled_tokens: int,
|
num_scheduled_tokens: int,
|
||||||
num_scheduled_tokens_np: np.ndarray,
|
num_scheduled_tokens_np: np.ndarray,
|
||||||
kv_connector_output: Optional[KVConnectorOutput],
|
|
||||||
) -> ModelRunnerOutput:
|
) -> ModelRunnerOutput:
|
||||||
assert self.input_batch.num_reqs ==\
|
assert self.input_batch.num_reqs ==\
|
||||||
len(self.input_batch.pooling_params), \
|
len(self.input_batch.pooling_params), \
|
||||||
@ -1732,7 +1739,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
pooler_output=pooler_output,
|
pooler_output=pooler_output,
|
||||||
kv_connector_output=kv_connector_output,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _preprocess(
|
def _preprocess(
|
||||||
@ -2073,39 +2079,47 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
with record_function_or_nullcontext("Postprocess"):
|
with record_function_or_nullcontext("Postprocess"):
|
||||||
if self.use_aux_hidden_state_outputs:
|
if self.use_aux_hidden_state_outputs:
|
||||||
|
# True when EAGLE 3 is used.
|
||||||
hidden_states, aux_hidden_states = model_output
|
hidden_states, aux_hidden_states = model_output
|
||||||
else:
|
else:
|
||||||
|
# Common case.
|
||||||
hidden_states = model_output
|
hidden_states = model_output
|
||||||
aux_hidden_states = None
|
aux_hidden_states = None
|
||||||
|
|
||||||
# Broadcast PP output for external_launcher (torchrun)
|
if not self.broadcast_pp_output:
|
||||||
# to make sure we are synced across pp ranks
|
# Common case.
|
||||||
# TODO: Support overlapping mirco-batches
|
if not get_pp_group().is_last_rank:
|
||||||
# https://github.com/vllm-project/vllm/issues/18019
|
# Return the intermediate tensors.
|
||||||
broadcast_pp_output = \
|
assert isinstance(hidden_states, IntermediateTensors)
|
||||||
self.parallel_config.distributed_executor_backend \
|
|
||||||
== "external_launcher" and len(get_pp_group().ranks) > 0
|
|
||||||
if not get_pp_group().is_last_rank:
|
|
||||||
# For mid-pipeline stages, return the hidden states.
|
|
||||||
assert isinstance(hidden_states, IntermediateTensors)
|
|
||||||
if not broadcast_pp_output:
|
|
||||||
hidden_states.kv_connector_output = kv_connector_output
|
hidden_states.kv_connector_output = kv_connector_output
|
||||||
return hidden_states
|
return hidden_states
|
||||||
get_pp_group().send_tensor_dict(
|
|
||||||
hidden_states.tensors, all_gather_group=get_tp_group())
|
|
||||||
logits = None
|
|
||||||
else:
|
|
||||||
if self.is_pooling_model:
|
if self.is_pooling_model:
|
||||||
return self._pool(hidden_states, num_scheduled_tokens,
|
# Return the pooling output.
|
||||||
num_scheduled_tokens_np,
|
output = self._pool(hidden_states, num_scheduled_tokens,
|
||||||
kv_connector_output)
|
num_scheduled_tokens_np)
|
||||||
|
output.kv_connector_output = kv_connector_output
|
||||||
|
return output
|
||||||
|
|
||||||
sample_hidden_states = hidden_states[logits_indices]
|
sample_hidden_states = hidden_states[logits_indices]
|
||||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||||
if broadcast_pp_output:
|
else:
|
||||||
model_output_broadcast_data = {
|
# Rare case.
|
||||||
"logits": logits.contiguous(),
|
assert not self.is_pooling_model
|
||||||
} if logits is not None else {}
|
|
||||||
|
if not get_pp_group().is_last_rank:
|
||||||
|
get_pp_group().send_tensor_dict(
|
||||||
|
hidden_states.tensors, all_gather_group=get_tp_group())
|
||||||
|
logits = None
|
||||||
|
else:
|
||||||
|
sample_hidden_states = hidden_states[logits_indices]
|
||||||
|
logits = self.model.compute_logits(sample_hidden_states,
|
||||||
|
None)
|
||||||
|
|
||||||
|
model_output_broadcast_data = {}
|
||||||
|
if logits is not None:
|
||||||
|
model_output_broadcast_data["logits"] = logits.contiguous()
|
||||||
|
|
||||||
model_output_broadcast_data = get_pp_group(
|
model_output_broadcast_data = get_pp_group(
|
||||||
).broadcast_tensor_dict(model_output_broadcast_data,
|
).broadcast_tensor_dict(model_output_broadcast_data,
|
||||||
src=len(get_pp_group().ranks) - 1)
|
src=len(get_pp_group().ranks) - 1)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user