[Core] [Bugfix]: tensor parallel with prompt embeds (#18171)

Signed-off-by: Nan2018 <nan@protopia.ai>
Co-authored-by: Andrew Sansom <andrew@protopia.ai>
This commit is contained in:
Nan Qin 2025-05-19 22:21:27 -05:00 committed by GitHub
parent f07a673eb2
commit 9609327fa4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 138 additions and 64 deletions

View File

@ -8,12 +8,13 @@ import weakref
from unittest.mock import Mock from unittest.mock import Mock
import pytest import pytest
import torch
from vllm import LLM from vllm import LLM, envs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1 from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1
from ..conftest import VllmRunner from ..conftest import HfRunner, VllmRunner
from ..models.utils import check_outputs_equal from ..models.utils import check_outputs_equal
from ..utils import multi_gpu_test from ..utils import multi_gpu_test
@ -43,11 +44,26 @@ def test_vllm_gc_ed():
assert weak_llm() is None assert weak_llm() is None
def _fix_prompt_embed_outputs(
vllm_outputs: list[tuple[list[int], str]], hf_model: HfRunner,
example_prompts: list[str]) -> list[tuple[list[int], str]]:
fixed_vllm_outputs = []
for vllm_output, hf_input, prompt in zip(
vllm_outputs, hf_model.get_inputs(example_prompts),
example_prompts):
hf_input_ids = hf_input["input_ids"].tolist()[0]
fixed_vllm_outputs.append(
(hf_input_ids + vllm_output[0][len(hf_input_ids):],
prompt + vllm_output[1]))
return fixed_vllm_outputs
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("backend", ["FLASH_ATTN"]) @pytest.mark.parametrize("backend", ["FLASH_ATTN"])
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [False]) @pytest.mark.parametrize("enforce_eager", [False])
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
def test_models( def test_models(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
hf_runner, hf_runner,
@ -56,8 +72,13 @@ def test_models(
dtype: str, dtype: str,
max_tokens: int, max_tokens: int,
enforce_eager: bool, enforce_eager: bool,
enable_prompt_embeds: bool,
) -> None: ) -> None:
if enable_prompt_embeds and envs.is_set(
"VLLM_USE_V1") and envs.VLLM_USE_V1:
pytest.skip("enable_prompt_embeds is not supported in v1.")
if backend == "FLASHINFER" and current_platform.is_rocm(): if backend == "FLASHINFER" and current_platform.is_rocm():
pytest.skip("Flashinfer does not support ROCm/HIP.") pytest.skip("Flashinfer does not support ROCm/HIP.")
@ -78,14 +99,25 @@ def test_models(
with hf_runner(model, dtype=dtype) as hf_model: with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
if enable_prompt_embeds:
with torch.no_grad():
prompt_embeds = hf_model.get_prompt_embeddings(
example_prompts)
with VllmRunner(model, with VllmRunner(model,
max_model_len=8192, max_model_len=8192,
dtype=dtype, dtype=dtype,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
enable_prompt_embeds=enable_prompt_embeds,
gpu_memory_utilization=0.7) as vllm_model: gpu_memory_utilization=0.7) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, if enable_prompt_embeds:
max_tokens) vllm_outputs = vllm_model.generate_greedy(
prompt_embeds, max_tokens)
vllm_outputs = _fix_prompt_embed_outputs(
vllm_outputs, hf_model, example_prompts)
else:
vllm_outputs = vllm_model.generate_greedy(
example_prompts, max_tokens)
check_outputs_equal( check_outputs_equal(
outputs_0_lst=hf_outputs, outputs_0_lst=hf_outputs,
@ -108,6 +140,7 @@ def test_models(
("distilbert/distilgpt2", "mp", "FLASHINFER", "A100"), ("distilbert/distilgpt2", "mp", "FLASHINFER", "A100"),
("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100"), ("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100"),
]) ])
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
def test_models_distributed( def test_models_distributed(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
hf_runner, hf_runner,
@ -117,14 +150,22 @@ def test_models_distributed(
distributed_executor_backend: str, distributed_executor_backend: str,
attention_backend: str, attention_backend: str,
test_suite: str, test_suite: str,
enable_prompt_embeds: bool,
) -> None: ) -> None:
if enable_prompt_embeds and envs.is_set(
"VLLM_USE_V1") and envs.VLLM_USE_V1:
pytest.skip("enable_prompt_embeds is not supported in v1.")
if test_suite != TARGET_TEST_SUITE: if test_suite != TARGET_TEST_SUITE:
pytest.skip(f"Skip test for {test_suite}") pytest.skip(f"Skip test for {test_suite}")
with monkeypatch.context() as monkeypatch_context: with monkeypatch.context() as monkeypatch_context:
if model == "meta-llama/Llama-3.2-1B-Instruct" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa if model == "meta-llama/Llama-3.2-1B-Instruct" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa
# test Ray Compiled Graph if enable_prompt_embeds:
pytest.skip(
"enable_prompt_embeds does not work with ray compiled dag."
)
monkeypatch_context.setenv("VLLM_USE_RAY_SPMD_WORKER", "1") monkeypatch_context.setenv("VLLM_USE_RAY_SPMD_WORKER", "1")
monkeypatch_context.setenv("VLLM_USE_RAY_COMPILED_DAG", "1") monkeypatch_context.setenv("VLLM_USE_RAY_COMPILED_DAG", "1")
@ -147,12 +188,26 @@ def test_models_distributed(
dtype=dtype, dtype=dtype,
tensor_parallel_size=2, tensor_parallel_size=2,
distributed_executor_backend=distributed_executor_backend, distributed_executor_backend=distributed_executor_backend,
enable_prompt_embeds=enable_prompt_embeds,
gpu_memory_utilization=0.7,
) as vllm_model: ) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, if enable_prompt_embeds:
max_tokens) with hf_runner(model, dtype=dtype) as hf_model:
with torch.no_grad():
with hf_runner(model, dtype=dtype) as hf_model: prompt_embeds = hf_model.get_prompt_embeddings(
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) example_prompts)
vllm_outputs = vllm_model.generate_greedy(
prompt_embeds, max_tokens)
vllm_outputs = _fix_prompt_embed_outputs(
vllm_outputs, hf_model, example_prompts)
hf_outputs = hf_model.generate_greedy(
example_prompts, max_tokens)
else:
vllm_outputs = vllm_model.generate_greedy(
example_prompts, max_tokens)
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(
example_prompts, max_tokens)
check_outputs_equal( check_outputs_equal(
outputs_0_lst=hf_outputs, outputs_0_lst=hf_outputs,

View File

@ -430,6 +430,15 @@ class HfRunner:
return all_inputs return all_inputs
def get_prompt_embeddings(self, prompts: list[str]) -> list[torch.Tensor]:
all_inputs = self.get_inputs(prompts)
embeddings = []
for inputs in all_inputs:
input_ids = self.wrap_device(inputs)["input_ids"]
embedding = self.model.get_input_embeddings()(input_ids).squeeze(0)
embeddings.append(embedding)
return embeddings
def classify(self, prompts: list[str]) -> list[str]: def classify(self, prompts: list[str]) -> list[str]:
# output is final logits # output is final logits
all_inputs = self.get_inputs(prompts) all_inputs = self.get_inputs(prompts)

View File

@ -112,12 +112,12 @@ class RequestMetrics:
will include model forward, block/sync across will include model forward, block/sync across
workers, cpu-gpu sync time and sampling time. workers, cpu-gpu sync time and sampling time.
spec_token_acceptance_counts: number of accepted speculative tokens at spec_token_acceptance_counts: number of accepted speculative tokens at
each position; the first token is from each position; the first token is from
the target model and is always accepted; the target model and is always accepted;
e.g., when it's [10, 8, 4, 2] for a req, e.g., when it's [10, 8, 4, 2] for a req,
it means there were 10 forward passes in it means there were 10 forward passes in
total, and there were 8, 4, 2 accepted total, and there were 8, 4, 2 accepted
tokens at 1st, 2nd, 3rd speculation step. tokens at 1st, 2nd, 3rd speculation step.
""" """
arrival_time: float arrival_time: float
last_token_time: float last_token_time: float
@ -714,9 +714,9 @@ class SequenceGroup:
trace_headers: OpenTelemetry trace headers. trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request. prompt_adapter_request: Prompt Adapter request.
priority: User-defined priority of the request. priority: User-defined priority of the request.
draft_size: The number of speculative tokens plus one from the target draft_size: The number of speculative tokens plus one from the target
model; equal to max number of tokens a step can generate model; equal to max number of tokens a step can generate
for single-draft speculative decoding but larger than for single-draft speculative decoding but larger than
that for multi-draft SD (currently not supported). that for multi-draft SD (currently not supported).
""" """
@ -1123,7 +1123,7 @@ class SequenceOutput(
self.output_embed.shape if self.output_embed is not None else None self.output_embed.shape if self.output_embed is not None else None
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
f"output_token={self.output_token}, " f"output_token={self.output_token}, "
f"output_embed.shape={output_embed_shape}" f"output_embed.shape={output_embed_shape}, "
f"logprobs={self.logprobs})") f"logprobs={self.logprobs})")
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:

View File

@ -23,7 +23,7 @@ from vllm.attention.backends.abstract import AttentionState
from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import CompilationLevel, VllmConfig from vllm.config import CompilationLevel, VllmConfig
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_pp_group from vllm.distributed import broadcast_tensor_dict, get_pp_group
from vllm.distributed.kv_transfer import get_kv_transfer_group from vllm.distributed.kv_transfer import get_kv_transfer_group
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
graph_capture) graph_capture)
@ -872,7 +872,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
""" """
# Combine and flatten intermediate data. # Combine and flatten intermediate data.
input_tokens = list[int]() input_tokens = list[int]()
inputs_embeds_lst = list[torch.Tensor]() inputs_embeds_list = list[torch.Tensor]()
token_types = list[int]() token_types = list[int]()
for inter_data in self.inter_data_list: for inter_data in self.inter_data_list:
for cur_input_tokens in inter_data.input_tokens: for cur_input_tokens in inter_data.input_tokens:
@ -880,15 +880,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
for cur_token_types in inter_data.token_types: for cur_token_types in inter_data.token_types:
token_types.extend(cur_token_types) token_types.extend(cur_token_types)
if inter_data.inputs_embeds is not None: if inter_data.inputs_embeds is not None:
inputs_embeds_lst.append( inputs_embeds_list.append(
inter_data.inputs_embeds.to( inter_data.inputs_embeds.to(
dtype=self.runner.model_config.dtype, dtype=self.runner.model_config.dtype,
device=self.runner.device)) device=self.runner.device))
inputs_embeds: Optional[torch.Tensor] inputs_embeds: Optional[torch.Tensor]
if len(inputs_embeds_lst) == 0: if len(inputs_embeds_list) == 0:
inputs_embeds = None inputs_embeds = None
else: else:
inputs_embeds = torch.cat(inputs_embeds_lst, dim=0).to( inputs_embeds = torch.cat(inputs_embeds_list, dim=0).to(
dtype=self.runner.model_config.dtype, dtype=self.runner.model_config.dtype,
device=self.runner.device) device=self.runner.device)
assert len(inputs_embeds) == len(input_tokens) assert len(inputs_embeds) == len(input_tokens)
@ -1893,51 +1893,61 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
logits = self.model.compute_logits(hidden_or_intermediate_states, logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata) model_input.sampling_metadata)
if self.is_driver_worker:
if model_input.async_callback is not None:
model_input.async_callback()
# Sample the next token.
assert isinstance(self.sampler, Sampler)
orig_include_gpu_probs = self.sampler.include_gpu_probs_tensor
if model_input.inputs_embeds is not None:
self.sampler.include_gpu_probs_tensor = True
output: SamplerOutput = self.sampler(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time
and output is not None):
model_forward_end.synchronize()
model_forward_time = model_forward_start.elapsed_time(
model_forward_end)
orig_model_forward_time = 0.0
if intermediate_tensors is not None:
orig_model_forward_time = intermediate_tensors.tensors.get(
"model_forward_time", torch.tensor(0.0)).item()
# If there are multiple workers, we are still tracking the
# latency from the start time of the driver worker to the end
# time of the driver worker. The model forward time will then
# end up covering the communication time as well.
output.model_forward_time = (orig_model_forward_time +
model_forward_time)
if model_input.inputs_embeds is not None:
if self.is_driver_worker:
sampled = broadcast_tensor_dict(
{"token_ids": output.sampled_token_ids})
else:
sampled = broadcast_tensor_dict()
if sampled["token_ids"] is not None:
sampled_token_embeds = self.model.get_input_embeddings(
sampled["token_ids"].squeeze(1))
if self.is_driver_worker:
self.sampler.include_gpu_probs_tensor = \
orig_include_gpu_probs
output.sampled_token_embeds = sampled_token_embeds
for token_embed, sequence_group_output in zip(
output.sampled_token_embeds, output.outputs):
assert len(sequence_group_output.samples) == 1
sequence_group_output.samples[
0].output_embed = token_embed
if not self.is_driver_worker: if not self.is_driver_worker:
return [] return []
if model_input.async_callback is not None:
model_input.async_callback()
# Sample the next token.
assert isinstance(self.sampler, Sampler)
orig_include_gpu_probs_tensor = self.sampler.include_gpu_probs_tensor
if model_input.inputs_embeds is not None:
self.sampler.include_gpu_probs_tensor = True
output: SamplerOutput = self.sampler(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time
and output is not None):
model_forward_end.synchronize()
model_forward_time = model_forward_start.elapsed_time(
model_forward_end)
orig_model_forward_time = 0.0
if intermediate_tensors is not None:
orig_model_forward_time = intermediate_tensors.tensors.get(
"model_forward_time", torch.tensor(0.0)).item()
# If there are multiple workers, we are still tracking the latency
# from the start time of the driver worker to the end time of the
# driver worker. The model forward time will then end up covering
# the communication time as well.
output.model_forward_time = (orig_model_forward_time +
model_forward_time)
if model_input.inputs_embeds is not None:
self.sampler.include_gpu_probs_tensor = \
orig_include_gpu_probs_tensor
if output.sampled_token_ids is not None:
output.sampled_token_embeds = self.model.get_input_embeddings(
output.sampled_token_ids.squeeze(1))
for token_embed, sequence_group_output in zip(
output.sampled_token_embeds, output.outputs):
assert len(sequence_group_output.samples) == 1
sequence_group_output.samples[0].output_embed = token_embed
if self.return_hidden_states: if self.return_hidden_states:
# we only need to pass hidden states of most recent token # we only need to pass hidden states of most recent token
assert model_input.sampling_metadata is not None assert model_input.sampling_metadata is not None