[Core] [Bugfix]: tensor parallel with prompt embeds (#18171)

Signed-off-by: Nan2018 <nan@protopia.ai>
Co-authored-by: Andrew Sansom <andrew@protopia.ai>
This commit is contained in:
Nan Qin 2025-05-19 22:21:27 -05:00 committed by GitHub
parent f07a673eb2
commit 9609327fa4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 138 additions and 64 deletions

View File

@ -8,12 +8,13 @@ import weakref
from unittest.mock import Mock from unittest.mock import Mock
import pytest import pytest
import torch
from vllm import LLM from vllm import LLM, envs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1 from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1
from ..conftest import VllmRunner from ..conftest import HfRunner, VllmRunner
from ..models.utils import check_outputs_equal from ..models.utils import check_outputs_equal
from ..utils import multi_gpu_test from ..utils import multi_gpu_test
@ -43,11 +44,26 @@ def test_vllm_gc_ed():
assert weak_llm() is None assert weak_llm() is None
def _fix_prompt_embed_outputs(
vllm_outputs: list[tuple[list[int], str]], hf_model: HfRunner,
example_prompts: list[str]) -> list[tuple[list[int], str]]:
fixed_vllm_outputs = []
for vllm_output, hf_input, prompt in zip(
vllm_outputs, hf_model.get_inputs(example_prompts),
example_prompts):
hf_input_ids = hf_input["input_ids"].tolist()[0]
fixed_vllm_outputs.append(
(hf_input_ids + vllm_output[0][len(hf_input_ids):],
prompt + vllm_output[1]))
return fixed_vllm_outputs
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("backend", ["FLASH_ATTN"]) @pytest.mark.parametrize("backend", ["FLASH_ATTN"])
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [False]) @pytest.mark.parametrize("enforce_eager", [False])
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
def test_models( def test_models(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
hf_runner, hf_runner,
@ -56,8 +72,13 @@ def test_models(
dtype: str, dtype: str,
max_tokens: int, max_tokens: int,
enforce_eager: bool, enforce_eager: bool,
enable_prompt_embeds: bool,
) -> None: ) -> None:
if enable_prompt_embeds and envs.is_set(
"VLLM_USE_V1") and envs.VLLM_USE_V1:
pytest.skip("enable_prompt_embeds is not supported in v1.")
if backend == "FLASHINFER" and current_platform.is_rocm(): if backend == "FLASHINFER" and current_platform.is_rocm():
pytest.skip("Flashinfer does not support ROCm/HIP.") pytest.skip("Flashinfer does not support ROCm/HIP.")
@ -78,14 +99,25 @@ def test_models(
with hf_runner(model, dtype=dtype) as hf_model: with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
if enable_prompt_embeds:
with torch.no_grad():
prompt_embeds = hf_model.get_prompt_embeddings(
example_prompts)
with VllmRunner(model, with VllmRunner(model,
max_model_len=8192, max_model_len=8192,
dtype=dtype, dtype=dtype,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
enable_prompt_embeds=enable_prompt_embeds,
gpu_memory_utilization=0.7) as vllm_model: gpu_memory_utilization=0.7) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, if enable_prompt_embeds:
max_tokens) vllm_outputs = vllm_model.generate_greedy(
prompt_embeds, max_tokens)
vllm_outputs = _fix_prompt_embed_outputs(
vllm_outputs, hf_model, example_prompts)
else:
vllm_outputs = vllm_model.generate_greedy(
example_prompts, max_tokens)
check_outputs_equal( check_outputs_equal(
outputs_0_lst=hf_outputs, outputs_0_lst=hf_outputs,
@ -108,6 +140,7 @@ def test_models(
("distilbert/distilgpt2", "mp", "FLASHINFER", "A100"), ("distilbert/distilgpt2", "mp", "FLASHINFER", "A100"),
("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100"), ("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100"),
]) ])
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
def test_models_distributed( def test_models_distributed(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
hf_runner, hf_runner,
@ -117,14 +150,22 @@ def test_models_distributed(
distributed_executor_backend: str, distributed_executor_backend: str,
attention_backend: str, attention_backend: str,
test_suite: str, test_suite: str,
enable_prompt_embeds: bool,
) -> None: ) -> None:
if enable_prompt_embeds and envs.is_set(
"VLLM_USE_V1") and envs.VLLM_USE_V1:
pytest.skip("enable_prompt_embeds is not supported in v1.")
if test_suite != TARGET_TEST_SUITE: if test_suite != TARGET_TEST_SUITE:
pytest.skip(f"Skip test for {test_suite}") pytest.skip(f"Skip test for {test_suite}")
with monkeypatch.context() as monkeypatch_context: with monkeypatch.context() as monkeypatch_context:
if model == "meta-llama/Llama-3.2-1B-Instruct" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa if model == "meta-llama/Llama-3.2-1B-Instruct" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa
# test Ray Compiled Graph if enable_prompt_embeds:
pytest.skip(
"enable_prompt_embeds does not work with ray compiled dag."
)
monkeypatch_context.setenv("VLLM_USE_RAY_SPMD_WORKER", "1") monkeypatch_context.setenv("VLLM_USE_RAY_SPMD_WORKER", "1")
monkeypatch_context.setenv("VLLM_USE_RAY_COMPILED_DAG", "1") monkeypatch_context.setenv("VLLM_USE_RAY_COMPILED_DAG", "1")
@ -147,12 +188,26 @@ def test_models_distributed(
dtype=dtype, dtype=dtype,
tensor_parallel_size=2, tensor_parallel_size=2,
distributed_executor_backend=distributed_executor_backend, distributed_executor_backend=distributed_executor_backend,
enable_prompt_embeds=enable_prompt_embeds,
gpu_memory_utilization=0.7,
) as vllm_model: ) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, if enable_prompt_embeds:
max_tokens)
with hf_runner(model, dtype=dtype) as hf_model: with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) with torch.no_grad():
prompt_embeds = hf_model.get_prompt_embeddings(
example_prompts)
vllm_outputs = vllm_model.generate_greedy(
prompt_embeds, max_tokens)
vllm_outputs = _fix_prompt_embed_outputs(
vllm_outputs, hf_model, example_prompts)
hf_outputs = hf_model.generate_greedy(
example_prompts, max_tokens)
else:
vllm_outputs = vllm_model.generate_greedy(
example_prompts, max_tokens)
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(
example_prompts, max_tokens)
check_outputs_equal( check_outputs_equal(
outputs_0_lst=hf_outputs, outputs_0_lst=hf_outputs,

View File

@ -430,6 +430,15 @@ class HfRunner:
return all_inputs return all_inputs
def get_prompt_embeddings(self, prompts: list[str]) -> list[torch.Tensor]:
all_inputs = self.get_inputs(prompts)
embeddings = []
for inputs in all_inputs:
input_ids = self.wrap_device(inputs)["input_ids"]
embedding = self.model.get_input_embeddings()(input_ids).squeeze(0)
embeddings.append(embedding)
return embeddings
def classify(self, prompts: list[str]) -> list[str]: def classify(self, prompts: list[str]) -> list[str]:
# output is final logits # output is final logits
all_inputs = self.get_inputs(prompts) all_inputs = self.get_inputs(prompts)

View File

@ -1123,7 +1123,7 @@ class SequenceOutput(
self.output_embed.shape if self.output_embed is not None else None self.output_embed.shape if self.output_embed is not None else None
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
f"output_token={self.output_token}, " f"output_token={self.output_token}, "
f"output_embed.shape={output_embed_shape}" f"output_embed.shape={output_embed_shape}, "
f"logprobs={self.logprobs})") f"logprobs={self.logprobs})")
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:

View File

@ -23,7 +23,7 @@ from vllm.attention.backends.abstract import AttentionState
from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import CompilationLevel, VllmConfig from vllm.config import CompilationLevel, VllmConfig
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_pp_group from vllm.distributed import broadcast_tensor_dict, get_pp_group
from vllm.distributed.kv_transfer import get_kv_transfer_group from vllm.distributed.kv_transfer import get_kv_transfer_group
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
graph_capture) graph_capture)
@ -872,7 +872,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
""" """
# Combine and flatten intermediate data. # Combine and flatten intermediate data.
input_tokens = list[int]() input_tokens = list[int]()
inputs_embeds_lst = list[torch.Tensor]() inputs_embeds_list = list[torch.Tensor]()
token_types = list[int]() token_types = list[int]()
for inter_data in self.inter_data_list: for inter_data in self.inter_data_list:
for cur_input_tokens in inter_data.input_tokens: for cur_input_tokens in inter_data.input_tokens:
@ -880,15 +880,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
for cur_token_types in inter_data.token_types: for cur_token_types in inter_data.token_types:
token_types.extend(cur_token_types) token_types.extend(cur_token_types)
if inter_data.inputs_embeds is not None: if inter_data.inputs_embeds is not None:
inputs_embeds_lst.append( inputs_embeds_list.append(
inter_data.inputs_embeds.to( inter_data.inputs_embeds.to(
dtype=self.runner.model_config.dtype, dtype=self.runner.model_config.dtype,
device=self.runner.device)) device=self.runner.device))
inputs_embeds: Optional[torch.Tensor] inputs_embeds: Optional[torch.Tensor]
if len(inputs_embeds_lst) == 0: if len(inputs_embeds_list) == 0:
inputs_embeds = None inputs_embeds = None
else: else:
inputs_embeds = torch.cat(inputs_embeds_lst, dim=0).to( inputs_embeds = torch.cat(inputs_embeds_list, dim=0).to(
dtype=self.runner.model_config.dtype, dtype=self.runner.model_config.dtype,
device=self.runner.device) device=self.runner.device)
assert len(inputs_embeds) == len(input_tokens) assert len(inputs_embeds) == len(input_tokens)
@ -1893,15 +1893,13 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
logits = self.model.compute_logits(hidden_or_intermediate_states, logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata) model_input.sampling_metadata)
if not self.is_driver_worker: if self.is_driver_worker:
return []
if model_input.async_callback is not None: if model_input.async_callback is not None:
model_input.async_callback() model_input.async_callback()
# Sample the next token. # Sample the next token.
assert isinstance(self.sampler, Sampler) assert isinstance(self.sampler, Sampler)
orig_include_gpu_probs_tensor = self.sampler.include_gpu_probs_tensor orig_include_gpu_probs = self.sampler.include_gpu_probs_tensor
if model_input.inputs_embeds is not None: if model_input.inputs_embeds is not None:
self.sampler.include_gpu_probs_tensor = True self.sampler.include_gpu_probs_tensor = True
@ -1919,24 +1917,36 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
if intermediate_tensors is not None: if intermediate_tensors is not None:
orig_model_forward_time = intermediate_tensors.tensors.get( orig_model_forward_time = intermediate_tensors.tensors.get(
"model_forward_time", torch.tensor(0.0)).item() "model_forward_time", torch.tensor(0.0)).item()
# If there are multiple workers, we are still tracking the latency # If there are multiple workers, we are still tracking the
# from the start time of the driver worker to the end time of the # latency from the start time of the driver worker to the end
# driver worker. The model forward time will then end up covering # time of the driver worker. The model forward time will then
# the communication time as well. # end up covering the communication time as well.
output.model_forward_time = (orig_model_forward_time + output.model_forward_time = (orig_model_forward_time +
model_forward_time) model_forward_time)
if model_input.inputs_embeds is not None: if model_input.inputs_embeds is not None:
if self.is_driver_worker:
sampled = broadcast_tensor_dict(
{"token_ids": output.sampled_token_ids})
else:
sampled = broadcast_tensor_dict()
if sampled["token_ids"] is not None:
sampled_token_embeds = self.model.get_input_embeddings(
sampled["token_ids"].squeeze(1))
if self.is_driver_worker:
self.sampler.include_gpu_probs_tensor = \ self.sampler.include_gpu_probs_tensor = \
orig_include_gpu_probs_tensor orig_include_gpu_probs
if output.sampled_token_ids is not None:
output.sampled_token_embeds = self.model.get_input_embeddings( output.sampled_token_embeds = sampled_token_embeds
output.sampled_token_ids.squeeze(1))
for token_embed, sequence_group_output in zip( for token_embed, sequence_group_output in zip(
output.sampled_token_embeds, output.outputs): output.sampled_token_embeds, output.outputs):
assert len(sequence_group_output.samples) == 1 assert len(sequence_group_output.samples) == 1
sequence_group_output.samples[0].output_embed = token_embed sequence_group_output.samples[
0].output_embed = token_embed
if not self.is_driver_worker:
return []
if self.return_hidden_states: if self.return_hidden_states:
# we only need to pass hidden states of most recent token # we only need to pass hidden states of most recent token