mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:05:01 +08:00
[Core] [Bugfix]: tensor parallel with prompt embeds (#18171)
Signed-off-by: Nan2018 <nan@protopia.ai> Co-authored-by: Andrew Sansom <andrew@protopia.ai>
This commit is contained in:
parent
f07a673eb2
commit
9609327fa4
@ -8,12 +8,13 @@ import weakref
|
|||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
from vllm import LLM
|
from vllm import LLM, envs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1
|
from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1
|
||||||
|
|
||||||
from ..conftest import VllmRunner
|
from ..conftest import HfRunner, VllmRunner
|
||||||
from ..models.utils import check_outputs_equal
|
from ..models.utils import check_outputs_equal
|
||||||
from ..utils import multi_gpu_test
|
from ..utils import multi_gpu_test
|
||||||
|
|
||||||
@ -43,11 +44,26 @@ def test_vllm_gc_ed():
|
|||||||
assert weak_llm() is None
|
assert weak_llm() is None
|
||||||
|
|
||||||
|
|
||||||
|
def _fix_prompt_embed_outputs(
|
||||||
|
vllm_outputs: list[tuple[list[int], str]], hf_model: HfRunner,
|
||||||
|
example_prompts: list[str]) -> list[tuple[list[int], str]]:
|
||||||
|
fixed_vllm_outputs = []
|
||||||
|
for vllm_output, hf_input, prompt in zip(
|
||||||
|
vllm_outputs, hf_model.get_inputs(example_prompts),
|
||||||
|
example_prompts):
|
||||||
|
hf_input_ids = hf_input["input_ids"].tolist()[0]
|
||||||
|
fixed_vllm_outputs.append(
|
||||||
|
(hf_input_ids + vllm_output[0][len(hf_input_ids):],
|
||||||
|
prompt + vllm_output[1]))
|
||||||
|
return fixed_vllm_outputs
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
|
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
|
||||||
@pytest.mark.parametrize("dtype", ["half"])
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
@pytest.mark.parametrize("max_tokens", [5])
|
@pytest.mark.parametrize("max_tokens", [5])
|
||||||
@pytest.mark.parametrize("enforce_eager", [False])
|
@pytest.mark.parametrize("enforce_eager", [False])
|
||||||
|
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
|
||||||
def test_models(
|
def test_models(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
hf_runner,
|
hf_runner,
|
||||||
@ -56,8 +72,13 @@ def test_models(
|
|||||||
dtype: str,
|
dtype: str,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
enforce_eager: bool,
|
enforce_eager: bool,
|
||||||
|
enable_prompt_embeds: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
|
if enable_prompt_embeds and envs.is_set(
|
||||||
|
"VLLM_USE_V1") and envs.VLLM_USE_V1:
|
||||||
|
pytest.skip("enable_prompt_embeds is not supported in v1.")
|
||||||
|
|
||||||
if backend == "FLASHINFER" and current_platform.is_rocm():
|
if backend == "FLASHINFER" and current_platform.is_rocm():
|
||||||
pytest.skip("Flashinfer does not support ROCm/HIP.")
|
pytest.skip("Flashinfer does not support ROCm/HIP.")
|
||||||
|
|
||||||
@ -78,14 +99,25 @@ def test_models(
|
|||||||
|
|
||||||
with hf_runner(model, dtype=dtype) as hf_model:
|
with hf_runner(model, dtype=dtype) as hf_model:
|
||||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
if enable_prompt_embeds:
|
||||||
|
with torch.no_grad():
|
||||||
|
prompt_embeds = hf_model.get_prompt_embeddings(
|
||||||
|
example_prompts)
|
||||||
|
|
||||||
with VllmRunner(model,
|
with VllmRunner(model,
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
enforce_eager=enforce_eager,
|
enforce_eager=enforce_eager,
|
||||||
|
enable_prompt_embeds=enable_prompt_embeds,
|
||||||
gpu_memory_utilization=0.7) as vllm_model:
|
gpu_memory_utilization=0.7) as vllm_model:
|
||||||
vllm_outputs = vllm_model.generate_greedy(example_prompts,
|
if enable_prompt_embeds:
|
||||||
max_tokens)
|
vllm_outputs = vllm_model.generate_greedy(
|
||||||
|
prompt_embeds, max_tokens)
|
||||||
|
vllm_outputs = _fix_prompt_embed_outputs(
|
||||||
|
vllm_outputs, hf_model, example_prompts)
|
||||||
|
else:
|
||||||
|
vllm_outputs = vllm_model.generate_greedy(
|
||||||
|
example_prompts, max_tokens)
|
||||||
|
|
||||||
check_outputs_equal(
|
check_outputs_equal(
|
||||||
outputs_0_lst=hf_outputs,
|
outputs_0_lst=hf_outputs,
|
||||||
@ -108,6 +140,7 @@ def test_models(
|
|||||||
("distilbert/distilgpt2", "mp", "FLASHINFER", "A100"),
|
("distilbert/distilgpt2", "mp", "FLASHINFER", "A100"),
|
||||||
("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100"),
|
("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100"),
|
||||||
])
|
])
|
||||||
|
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
|
||||||
def test_models_distributed(
|
def test_models_distributed(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
hf_runner,
|
hf_runner,
|
||||||
@ -117,14 +150,22 @@ def test_models_distributed(
|
|||||||
distributed_executor_backend: str,
|
distributed_executor_backend: str,
|
||||||
attention_backend: str,
|
attention_backend: str,
|
||||||
test_suite: str,
|
test_suite: str,
|
||||||
|
enable_prompt_embeds: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
|
if enable_prompt_embeds and envs.is_set(
|
||||||
|
"VLLM_USE_V1") and envs.VLLM_USE_V1:
|
||||||
|
pytest.skip("enable_prompt_embeds is not supported in v1.")
|
||||||
|
|
||||||
if test_suite != TARGET_TEST_SUITE:
|
if test_suite != TARGET_TEST_SUITE:
|
||||||
pytest.skip(f"Skip test for {test_suite}")
|
pytest.skip(f"Skip test for {test_suite}")
|
||||||
|
|
||||||
with monkeypatch.context() as monkeypatch_context:
|
with monkeypatch.context() as monkeypatch_context:
|
||||||
if model == "meta-llama/Llama-3.2-1B-Instruct" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa
|
if model == "meta-llama/Llama-3.2-1B-Instruct" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa
|
||||||
# test Ray Compiled Graph
|
if enable_prompt_embeds:
|
||||||
|
pytest.skip(
|
||||||
|
"enable_prompt_embeds does not work with ray compiled dag."
|
||||||
|
)
|
||||||
monkeypatch_context.setenv("VLLM_USE_RAY_SPMD_WORKER", "1")
|
monkeypatch_context.setenv("VLLM_USE_RAY_SPMD_WORKER", "1")
|
||||||
monkeypatch_context.setenv("VLLM_USE_RAY_COMPILED_DAG", "1")
|
monkeypatch_context.setenv("VLLM_USE_RAY_COMPILED_DAG", "1")
|
||||||
|
|
||||||
@ -147,12 +188,26 @@ def test_models_distributed(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
tensor_parallel_size=2,
|
tensor_parallel_size=2,
|
||||||
distributed_executor_backend=distributed_executor_backend,
|
distributed_executor_backend=distributed_executor_backend,
|
||||||
|
enable_prompt_embeds=enable_prompt_embeds,
|
||||||
|
gpu_memory_utilization=0.7,
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
vllm_outputs = vllm_model.generate_greedy(example_prompts,
|
if enable_prompt_embeds:
|
||||||
max_tokens)
|
with hf_runner(model, dtype=dtype) as hf_model:
|
||||||
|
with torch.no_grad():
|
||||||
with hf_runner(model, dtype=dtype) as hf_model:
|
prompt_embeds = hf_model.get_prompt_embeddings(
|
||||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
example_prompts)
|
||||||
|
vllm_outputs = vllm_model.generate_greedy(
|
||||||
|
prompt_embeds, max_tokens)
|
||||||
|
vllm_outputs = _fix_prompt_embed_outputs(
|
||||||
|
vllm_outputs, hf_model, example_prompts)
|
||||||
|
hf_outputs = hf_model.generate_greedy(
|
||||||
|
example_prompts, max_tokens)
|
||||||
|
else:
|
||||||
|
vllm_outputs = vllm_model.generate_greedy(
|
||||||
|
example_prompts, max_tokens)
|
||||||
|
with hf_runner(model, dtype=dtype) as hf_model:
|
||||||
|
hf_outputs = hf_model.generate_greedy(
|
||||||
|
example_prompts, max_tokens)
|
||||||
|
|
||||||
check_outputs_equal(
|
check_outputs_equal(
|
||||||
outputs_0_lst=hf_outputs,
|
outputs_0_lst=hf_outputs,
|
||||||
|
|||||||
@ -430,6 +430,15 @@ class HfRunner:
|
|||||||
|
|
||||||
return all_inputs
|
return all_inputs
|
||||||
|
|
||||||
|
def get_prompt_embeddings(self, prompts: list[str]) -> list[torch.Tensor]:
|
||||||
|
all_inputs = self.get_inputs(prompts)
|
||||||
|
embeddings = []
|
||||||
|
for inputs in all_inputs:
|
||||||
|
input_ids = self.wrap_device(inputs)["input_ids"]
|
||||||
|
embedding = self.model.get_input_embeddings()(input_ids).squeeze(0)
|
||||||
|
embeddings.append(embedding)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
def classify(self, prompts: list[str]) -> list[str]:
|
def classify(self, prompts: list[str]) -> list[str]:
|
||||||
# output is final logits
|
# output is final logits
|
||||||
all_inputs = self.get_inputs(prompts)
|
all_inputs = self.get_inputs(prompts)
|
||||||
|
|||||||
@ -112,12 +112,12 @@ class RequestMetrics:
|
|||||||
will include model forward, block/sync across
|
will include model forward, block/sync across
|
||||||
workers, cpu-gpu sync time and sampling time.
|
workers, cpu-gpu sync time and sampling time.
|
||||||
spec_token_acceptance_counts: number of accepted speculative tokens at
|
spec_token_acceptance_counts: number of accepted speculative tokens at
|
||||||
each position; the first token is from
|
each position; the first token is from
|
||||||
the target model and is always accepted;
|
the target model and is always accepted;
|
||||||
e.g., when it's [10, 8, 4, 2] for a req,
|
e.g., when it's [10, 8, 4, 2] for a req,
|
||||||
it means there were 10 forward passes in
|
it means there were 10 forward passes in
|
||||||
total, and there were 8, 4, 2 accepted
|
total, and there were 8, 4, 2 accepted
|
||||||
tokens at 1st, 2nd, 3rd speculation step.
|
tokens at 1st, 2nd, 3rd speculation step.
|
||||||
"""
|
"""
|
||||||
arrival_time: float
|
arrival_time: float
|
||||||
last_token_time: float
|
last_token_time: float
|
||||||
@ -714,9 +714,9 @@ class SequenceGroup:
|
|||||||
trace_headers: OpenTelemetry trace headers.
|
trace_headers: OpenTelemetry trace headers.
|
||||||
prompt_adapter_request: Prompt Adapter request.
|
prompt_adapter_request: Prompt Adapter request.
|
||||||
priority: User-defined priority of the request.
|
priority: User-defined priority of the request.
|
||||||
draft_size: The number of speculative tokens plus one from the target
|
draft_size: The number of speculative tokens plus one from the target
|
||||||
model; equal to max number of tokens a step can generate
|
model; equal to max number of tokens a step can generate
|
||||||
for single-draft speculative decoding but larger than
|
for single-draft speculative decoding but larger than
|
||||||
that for multi-draft SD (currently not supported).
|
that for multi-draft SD (currently not supported).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -1123,7 +1123,7 @@ class SequenceOutput(
|
|||||||
self.output_embed.shape if self.output_embed is not None else None
|
self.output_embed.shape if self.output_embed is not None else None
|
||||||
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
|
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
|
||||||
f"output_token={self.output_token}, "
|
f"output_token={self.output_token}, "
|
||||||
f"output_embed.shape={output_embed_shape}"
|
f"output_embed.shape={output_embed_shape}, "
|
||||||
f"logprobs={self.logprobs})")
|
f"logprobs={self.logprobs})")
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
def __eq__(self, other: object) -> bool:
|
||||||
|
|||||||
@ -23,7 +23,7 @@ from vllm.attention.backends.abstract import AttentionState
|
|||||||
from vllm.attention.backends.utils import CommonAttentionState
|
from vllm.attention.backends.utils import CommonAttentionState
|
||||||
from vllm.config import CompilationLevel, VllmConfig
|
from vllm.config import CompilationLevel, VllmConfig
|
||||||
from vllm.core.scheduler import SchedulerOutputs
|
from vllm.core.scheduler import SchedulerOutputs
|
||||||
from vllm.distributed import get_pp_group
|
from vllm.distributed import broadcast_tensor_dict, get_pp_group
|
||||||
from vllm.distributed.kv_transfer import get_kv_transfer_group
|
from vllm.distributed.kv_transfer import get_kv_transfer_group
|
||||||
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
|
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
|
||||||
graph_capture)
|
graph_capture)
|
||||||
@ -872,7 +872,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
"""
|
"""
|
||||||
# Combine and flatten intermediate data.
|
# Combine and flatten intermediate data.
|
||||||
input_tokens = list[int]()
|
input_tokens = list[int]()
|
||||||
inputs_embeds_lst = list[torch.Tensor]()
|
inputs_embeds_list = list[torch.Tensor]()
|
||||||
token_types = list[int]()
|
token_types = list[int]()
|
||||||
for inter_data in self.inter_data_list:
|
for inter_data in self.inter_data_list:
|
||||||
for cur_input_tokens in inter_data.input_tokens:
|
for cur_input_tokens in inter_data.input_tokens:
|
||||||
@ -880,15 +880,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
for cur_token_types in inter_data.token_types:
|
for cur_token_types in inter_data.token_types:
|
||||||
token_types.extend(cur_token_types)
|
token_types.extend(cur_token_types)
|
||||||
if inter_data.inputs_embeds is not None:
|
if inter_data.inputs_embeds is not None:
|
||||||
inputs_embeds_lst.append(
|
inputs_embeds_list.append(
|
||||||
inter_data.inputs_embeds.to(
|
inter_data.inputs_embeds.to(
|
||||||
dtype=self.runner.model_config.dtype,
|
dtype=self.runner.model_config.dtype,
|
||||||
device=self.runner.device))
|
device=self.runner.device))
|
||||||
inputs_embeds: Optional[torch.Tensor]
|
inputs_embeds: Optional[torch.Tensor]
|
||||||
if len(inputs_embeds_lst) == 0:
|
if len(inputs_embeds_list) == 0:
|
||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
else:
|
else:
|
||||||
inputs_embeds = torch.cat(inputs_embeds_lst, dim=0).to(
|
inputs_embeds = torch.cat(inputs_embeds_list, dim=0).to(
|
||||||
dtype=self.runner.model_config.dtype,
|
dtype=self.runner.model_config.dtype,
|
||||||
device=self.runner.device)
|
device=self.runner.device)
|
||||||
assert len(inputs_embeds) == len(input_tokens)
|
assert len(inputs_embeds) == len(input_tokens)
|
||||||
@ -1893,51 +1893,61 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|||||||
logits = self.model.compute_logits(hidden_or_intermediate_states,
|
logits = self.model.compute_logits(hidden_or_intermediate_states,
|
||||||
model_input.sampling_metadata)
|
model_input.sampling_metadata)
|
||||||
|
|
||||||
|
if self.is_driver_worker:
|
||||||
|
if model_input.async_callback is not None:
|
||||||
|
model_input.async_callback()
|
||||||
|
|
||||||
|
# Sample the next token.
|
||||||
|
assert isinstance(self.sampler, Sampler)
|
||||||
|
orig_include_gpu_probs = self.sampler.include_gpu_probs_tensor
|
||||||
|
if model_input.inputs_embeds is not None:
|
||||||
|
self.sampler.include_gpu_probs_tensor = True
|
||||||
|
|
||||||
|
output: SamplerOutput = self.sampler(
|
||||||
|
logits=logits,
|
||||||
|
sampling_metadata=model_input.sampling_metadata,
|
||||||
|
)
|
||||||
|
if (self.observability_config is not None
|
||||||
|
and self.observability_config.collect_model_forward_time
|
||||||
|
and output is not None):
|
||||||
|
model_forward_end.synchronize()
|
||||||
|
model_forward_time = model_forward_start.elapsed_time(
|
||||||
|
model_forward_end)
|
||||||
|
orig_model_forward_time = 0.0
|
||||||
|
if intermediate_tensors is not None:
|
||||||
|
orig_model_forward_time = intermediate_tensors.tensors.get(
|
||||||
|
"model_forward_time", torch.tensor(0.0)).item()
|
||||||
|
# If there are multiple workers, we are still tracking the
|
||||||
|
# latency from the start time of the driver worker to the end
|
||||||
|
# time of the driver worker. The model forward time will then
|
||||||
|
# end up covering the communication time as well.
|
||||||
|
output.model_forward_time = (orig_model_forward_time +
|
||||||
|
model_forward_time)
|
||||||
|
|
||||||
|
if model_input.inputs_embeds is not None:
|
||||||
|
if self.is_driver_worker:
|
||||||
|
sampled = broadcast_tensor_dict(
|
||||||
|
{"token_ids": output.sampled_token_ids})
|
||||||
|
else:
|
||||||
|
sampled = broadcast_tensor_dict()
|
||||||
|
if sampled["token_ids"] is not None:
|
||||||
|
sampled_token_embeds = self.model.get_input_embeddings(
|
||||||
|
sampled["token_ids"].squeeze(1))
|
||||||
|
if self.is_driver_worker:
|
||||||
|
self.sampler.include_gpu_probs_tensor = \
|
||||||
|
orig_include_gpu_probs
|
||||||
|
|
||||||
|
output.sampled_token_embeds = sampled_token_embeds
|
||||||
|
|
||||||
|
for token_embed, sequence_group_output in zip(
|
||||||
|
output.sampled_token_embeds, output.outputs):
|
||||||
|
assert len(sequence_group_output.samples) == 1
|
||||||
|
sequence_group_output.samples[
|
||||||
|
0].output_embed = token_embed
|
||||||
|
|
||||||
if not self.is_driver_worker:
|
if not self.is_driver_worker:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
if model_input.async_callback is not None:
|
|
||||||
model_input.async_callback()
|
|
||||||
|
|
||||||
# Sample the next token.
|
|
||||||
assert isinstance(self.sampler, Sampler)
|
|
||||||
orig_include_gpu_probs_tensor = self.sampler.include_gpu_probs_tensor
|
|
||||||
if model_input.inputs_embeds is not None:
|
|
||||||
self.sampler.include_gpu_probs_tensor = True
|
|
||||||
|
|
||||||
output: SamplerOutput = self.sampler(
|
|
||||||
logits=logits,
|
|
||||||
sampling_metadata=model_input.sampling_metadata,
|
|
||||||
)
|
|
||||||
if (self.observability_config is not None
|
|
||||||
and self.observability_config.collect_model_forward_time
|
|
||||||
and output is not None):
|
|
||||||
model_forward_end.synchronize()
|
|
||||||
model_forward_time = model_forward_start.elapsed_time(
|
|
||||||
model_forward_end)
|
|
||||||
orig_model_forward_time = 0.0
|
|
||||||
if intermediate_tensors is not None:
|
|
||||||
orig_model_forward_time = intermediate_tensors.tensors.get(
|
|
||||||
"model_forward_time", torch.tensor(0.0)).item()
|
|
||||||
# If there are multiple workers, we are still tracking the latency
|
|
||||||
# from the start time of the driver worker to the end time of the
|
|
||||||
# driver worker. The model forward time will then end up covering
|
|
||||||
# the communication time as well.
|
|
||||||
output.model_forward_time = (orig_model_forward_time +
|
|
||||||
model_forward_time)
|
|
||||||
|
|
||||||
if model_input.inputs_embeds is not None:
|
|
||||||
self.sampler.include_gpu_probs_tensor = \
|
|
||||||
orig_include_gpu_probs_tensor
|
|
||||||
if output.sampled_token_ids is not None:
|
|
||||||
output.sampled_token_embeds = self.model.get_input_embeddings(
|
|
||||||
output.sampled_token_ids.squeeze(1))
|
|
||||||
|
|
||||||
for token_embed, sequence_group_output in zip(
|
|
||||||
output.sampled_token_embeds, output.outputs):
|
|
||||||
assert len(sequence_group_output.samples) == 1
|
|
||||||
sequence_group_output.samples[0].output_embed = token_embed
|
|
||||||
|
|
||||||
if self.return_hidden_states:
|
if self.return_hidden_states:
|
||||||
# we only need to pass hidden states of most recent token
|
# we only need to pass hidden states of most recent token
|
||||||
assert model_input.sampling_metadata is not None
|
assert model_input.sampling_metadata is not None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user