mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 07:15:01 +08:00
[torch.compile] rework test plans (#9866)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
37a4947dcd
commit
566cd27797
@ -1,3 +1,4 @@
|
|||||||
|
import dataclasses
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -8,33 +9,109 @@ from vllm.utils import cuda_device_count_stateless
|
|||||||
from ..utils import compare_all_settings
|
from ..utils import compare_all_settings
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class TestSetting:
|
||||||
|
model: str
|
||||||
|
model_args: List[str]
|
||||||
|
pp_size: int
|
||||||
|
tp_size: int
|
||||||
|
attn_backend: str
|
||||||
|
method: str
|
||||||
|
fullgraph: bool
|
||||||
|
|
||||||
|
|
||||||
|
# representative settings for testing
|
||||||
|
test_settings = [
|
||||||
|
# basic llama model
|
||||||
|
TestSetting(
|
||||||
|
model="meta-llama/Llama-3.2-1B",
|
||||||
|
model_args=[],
|
||||||
|
pp_size=2,
|
||||||
|
tp_size=2,
|
||||||
|
attn_backend="FLASHINFER",
|
||||||
|
method="generate",
|
||||||
|
fullgraph=True,
|
||||||
|
),
|
||||||
|
# llama model with quantization
|
||||||
|
TestSetting(
|
||||||
|
model="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
|
||||||
|
model_args=["--quantization", "gptq"],
|
||||||
|
pp_size=1,
|
||||||
|
tp_size=1,
|
||||||
|
attn_backend="FLASH_ATTN",
|
||||||
|
method="generate",
|
||||||
|
fullgraph=True,
|
||||||
|
),
|
||||||
|
# MoE model
|
||||||
|
TestSetting(
|
||||||
|
model="ibm/PowerMoE-3b",
|
||||||
|
model_args=[],
|
||||||
|
pp_size=1,
|
||||||
|
tp_size=2,
|
||||||
|
attn_backend="FLASH_ATTN",
|
||||||
|
method="generate",
|
||||||
|
fullgraph=True,
|
||||||
|
),
|
||||||
|
# embedding model
|
||||||
|
TestSetting(
|
||||||
|
model="BAAI/bge-multilingual-gemma2",
|
||||||
|
model_args=["--task", "embedding"],
|
||||||
|
pp_size=1,
|
||||||
|
tp_size=1,
|
||||||
|
attn_backend="FLASHINFER",
|
||||||
|
method="encode",
|
||||||
|
fullgraph=True,
|
||||||
|
),
|
||||||
|
# vision language model
|
||||||
|
TestSetting(
|
||||||
|
model="microsoft/Phi-3.5-vision-instruct",
|
||||||
|
model_args=["--trust-remote-code", "--max-model-len", "2048"],
|
||||||
|
pp_size=2,
|
||||||
|
tp_size=1,
|
||||||
|
attn_backend="FLASH_ATTN",
|
||||||
|
method="generate_with_image",
|
||||||
|
fullgraph=False,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
# we cannot afford testing the full Catesian product
|
# we cannot afford testing the full Catesian product
|
||||||
# of all models and all levels
|
# of all models and all levels
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("test_setting", test_settings)
|
||||||
"model, model_args, pp_size, tp_size, attn_backend, method, fullgraph",
|
def test_compile_correctness(test_setting: TestSetting):
|
||||||
[
|
|
||||||
("meta-llama/Llama-3.2-1B", [], 2, 2, "FLASHINFER", "generate", True),
|
|
||||||
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples",
|
|
||||||
["--quantization", "compressed-tensors"
|
|
||||||
], 1, 1, "FLASH_ATTN", "generate", True),
|
|
||||||
("ibm/PowerMoE-3b", [], 1, 2, "FLASH_ATTN", "generate", True),
|
|
||||||
# TODO: add multi-modality test for llava
|
|
||||||
("llava-hf/llava-1.5-7b-hf", [], 2, 1, "FLASHINFER", "generate", False)
|
|
||||||
])
|
|
||||||
def test_compile_correctness(model, model_args, pp_size, tp_size, attn_backend,
|
|
||||||
method, fullgraph):
|
|
||||||
# this test is run under multiple suits, with different GPUs.
|
# this test is run under multiple suits, with different GPUs.
|
||||||
# make sure we only run the test with correct CUDA devices.
|
# make sure we only run the test with correct CUDA devices.
|
||||||
# don't use "<", as it will duplicate the tests.
|
# don't use "<", as it will duplicate the tests.
|
||||||
|
model = test_setting.model
|
||||||
|
model_args = test_setting.model_args
|
||||||
|
pp_size = test_setting.pp_size
|
||||||
|
tp_size = test_setting.tp_size
|
||||||
|
attn_backend = test_setting.attn_backend
|
||||||
|
method = test_setting.method
|
||||||
|
fullgraph = test_setting.fullgraph
|
||||||
if cuda_device_count_stateless() != pp_size * tp_size:
|
if cuda_device_count_stateless() != pp_size * tp_size:
|
||||||
pytest.skip("Not correct CUDA devices for the test.")
|
pytest.skip("Not correct CUDA devices for the test.")
|
||||||
import os
|
import os
|
||||||
os.environ["VLLM_ATTENTION_BACKEND"] = attn_backend
|
os.environ["VLLM_ATTENTION_BACKEND"] = attn_backend
|
||||||
all_args = [["--enforce-eager"] + model_args + ["-pp", str(pp_size)] +
|
final_args = ["--enforce-eager"] + model_args + ["-pp", str(pp_size)] + \
|
||||||
["-tp", str(tp_size)]] * 3
|
["-tp", str(tp_size)]
|
||||||
# don't test VLLM_TORCH_COMPILE_LEVEL == 3 case
|
|
||||||
# inductor will change the output, so we cannot compare them.
|
|
||||||
all_envs: List[Optional[Dict[str, str]]] = []
|
all_envs: List[Optional[Dict[str, str]]] = []
|
||||||
|
|
||||||
|
for level in [
|
||||||
|
CompilationLevel.NO_COMPILATION,
|
||||||
|
CompilationLevel.PIECEWISE,
|
||||||
|
]:
|
||||||
|
all_envs.append({"VLLM_TORCH_COMPILE_LEVEL": str(level)})
|
||||||
|
|
||||||
|
# inductor will change the output, so we only compare if the output
|
||||||
|
# is close, not exactly the same.
|
||||||
|
compare_all_settings(
|
||||||
|
model, [final_args] * 2,
|
||||||
|
all_envs,
|
||||||
|
method=method if method != "generate" else "generate_close")
|
||||||
|
all_envs.clear()
|
||||||
|
|
||||||
for level in [
|
for level in [
|
||||||
CompilationLevel.NO_COMPILATION,
|
CompilationLevel.NO_COMPILATION,
|
||||||
CompilationLevel.DYNAMO_AS_IS,
|
CompilationLevel.DYNAMO_AS_IS,
|
||||||
@ -46,4 +123,4 @@ def test_compile_correctness(model, model_args, pp_size, tp_size, attn_backend,
|
|||||||
all_envs[-1][
|
all_envs[-1][
|
||||||
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "0" # type: ignore
|
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "0" # type: ignore
|
||||||
|
|
||||||
compare_all_settings(model, all_args, all_envs, method=method)
|
compare_all_settings(model, [final_args] * 3, all_envs, method=method)
|
||||||
|
|||||||
124
tests/utils.py
124
tests/utils.py
@ -1,4 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import copy
|
||||||
import functools
|
import functools
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
@ -8,13 +9,14 @@ import time
|
|||||||
import warnings
|
import warnings
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union
|
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
|
import torch
|
||||||
from openai.types.completion import Completion
|
from openai.types.completion import Completion
|
||||||
from typing_extensions import ParamSpec, assert_never
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from tests.models.utils import TextTextLogprobs
|
from tests.models.utils import TextTextLogprobs
|
||||||
@ -272,6 +274,31 @@ def _test_completion(
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _test_completion_close(
|
||||||
|
client: openai.OpenAI,
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
):
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# test with text prompt
|
||||||
|
completion = client.completions.create(model=model,
|
||||||
|
prompt=prompt,
|
||||||
|
max_tokens=1,
|
||||||
|
logprobs=5,
|
||||||
|
temperature=0.0)
|
||||||
|
|
||||||
|
logporbs = completion.choices[0].logprobs.top_logprobs[0]
|
||||||
|
logporbs = {k: round(v, 2) for k, v in logporbs.items()}
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"test": "completion_close",
|
||||||
|
"logprobs": logporbs,
|
||||||
|
})
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
def _test_embeddings(
|
def _test_embeddings(
|
||||||
client: openai.OpenAI,
|
client: openai.OpenAI,
|
||||||
model: str,
|
model: str,
|
||||||
@ -295,13 +322,81 @@ def _test_embeddings(
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _test_image_text(
|
||||||
|
client: openai.OpenAI,
|
||||||
|
model_name: str,
|
||||||
|
image_url: str,
|
||||||
|
):
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# test pure text input
|
||||||
|
messages = [{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "How do you feel today?"
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}]
|
||||||
|
|
||||||
|
chat_completion = client.chat.completions.create(model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
temperature=0.0,
|
||||||
|
max_tokens=1,
|
||||||
|
logprobs=True,
|
||||||
|
top_logprobs=5)
|
||||||
|
top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs
|
||||||
|
|
||||||
|
for x in top_logprobs:
|
||||||
|
x.logprob = round(x.logprob, 2)
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"test": "pure_text",
|
||||||
|
"logprobs": top_logprobs,
|
||||||
|
})
|
||||||
|
|
||||||
|
messages = [{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": image_url
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What's in this image?"
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}]
|
||||||
|
|
||||||
|
chat_completion = client.chat.completions.create(model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
temperature=0.0,
|
||||||
|
max_tokens=1,
|
||||||
|
logprobs=True,
|
||||||
|
top_logprobs=5)
|
||||||
|
top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"test": "text_image",
|
||||||
|
"logprobs": top_logprobs,
|
||||||
|
})
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
def compare_two_settings(model: str,
|
def compare_two_settings(model: str,
|
||||||
arg1: List[str],
|
arg1: List[str],
|
||||||
arg2: List[str],
|
arg2: List[str],
|
||||||
env1: Optional[Dict[str, str]] = None,
|
env1: Optional[Dict[str, str]] = None,
|
||||||
env2: Optional[Dict[str, str]] = None,
|
env2: Optional[Dict[str, str]] = None,
|
||||||
*,
|
*,
|
||||||
method: Literal["generate", "encode"] = "generate",
|
method: str = "generate",
|
||||||
max_wait_seconds: Optional[float] = None) -> None:
|
max_wait_seconds: Optional[float] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Launch API server with two different sets of arguments/environments
|
Launch API server with two different sets of arguments/environments
|
||||||
@ -328,7 +423,7 @@ def compare_all_settings(model: str,
|
|||||||
all_args: List[List[str]],
|
all_args: List[List[str]],
|
||||||
all_envs: List[Optional[Dict[str, str]]],
|
all_envs: List[Optional[Dict[str, str]]],
|
||||||
*,
|
*,
|
||||||
method: Literal["generate", "encode"] = "generate",
|
method: str = "generate",
|
||||||
max_wait_seconds: Optional[float] = None) -> None:
|
max_wait_seconds: Optional[float] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Launch API server with several different sets of arguments/environments
|
Launch API server with several different sets of arguments/environments
|
||||||
@ -397,10 +492,17 @@ def compare_all_settings(model: str,
|
|||||||
|
|
||||||
if method == "generate":
|
if method == "generate":
|
||||||
results += _test_completion(client, model, prompt, token_ids)
|
results += _test_completion(client, model, prompt, token_ids)
|
||||||
|
elif method == "generate_close":
|
||||||
|
results += _test_completion_close(client, model, prompt)
|
||||||
|
elif method == "generate_with_image":
|
||||||
|
results += _test_image_text(
|
||||||
|
client, model,
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png"
|
||||||
|
)
|
||||||
elif method == "encode":
|
elif method == "encode":
|
||||||
results += _test_embeddings(client, model, prompt)
|
results += _test_embeddings(client, model, prompt)
|
||||||
else:
|
else:
|
||||||
assert_never(method)
|
raise ValueError(f"Unknown method: {method}")
|
||||||
|
|
||||||
if i > 0:
|
if i > 0:
|
||||||
# if any setting fails, raise an error early
|
# if any setting fails, raise an error early
|
||||||
@ -410,6 +512,18 @@ def compare_all_settings(model: str,
|
|||||||
compare_envs = all_envs[i]
|
compare_envs = all_envs[i]
|
||||||
for ref_result, compare_result in zip(ref_results,
|
for ref_result, compare_result in zip(ref_results,
|
||||||
compare_results):
|
compare_results):
|
||||||
|
ref_result = copy.deepcopy(ref_result)
|
||||||
|
compare_result = copy.deepcopy(compare_result)
|
||||||
|
if "embedding" in ref_result and method == "encode":
|
||||||
|
ref_embedding = torch.tensor(ref_result["embedding"])
|
||||||
|
compare_embedding = torch.tensor(
|
||||||
|
compare_result["embedding"])
|
||||||
|
mse = ((ref_embedding - compare_embedding)**2).mean()
|
||||||
|
assert mse < 1e-6, (
|
||||||
|
f"Embedding for {model=} are not the same.\n"
|
||||||
|
f"mse={mse}\n")
|
||||||
|
del ref_result["embedding"]
|
||||||
|
del compare_result["embedding"]
|
||||||
assert ref_result == compare_result, (
|
assert ref_result == compare_result, (
|
||||||
f"Results for {model=} are not the same.\n"
|
f"Results for {model=} are not the same.\n"
|
||||||
f"{ref_args=} {ref_envs=}\n"
|
f"{ref_args=} {ref_envs=}\n"
|
||||||
|
|||||||
@ -493,13 +493,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
:class:`LlavaImageInputs`
|
:class:`LlavaImageInputs`
|
||||||
"""
|
"""
|
||||||
if intermediate_tensors is not None:
|
if intermediate_tensors is not None:
|
||||||
input_ids = None
|
|
||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
else:
|
else:
|
||||||
# always pass the input via `inputs_embeds`
|
|
||||||
# to make sure the computation graph is consistent
|
|
||||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
|
|
||||||
if image_input is not None:
|
if image_input is not None:
|
||||||
vision_embeddings = self._process_image_input(image_input)
|
vision_embeddings = self._process_image_input(image_input)
|
||||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||||
@ -511,7 +507,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
else:
|
else:
|
||||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||||
input_ids)
|
input_ids)
|
||||||
input_ids = None
|
|
||||||
|
# always pass the input via `inputs_embeds`
|
||||||
|
# to make sure the computation graph is consistent
|
||||||
|
# for `torch.compile` integration
|
||||||
|
input_ids = None
|
||||||
|
|
||||||
hidden_states = self.language_model.model(input_ids,
|
hidden_states = self.language_model.model(input_ids,
|
||||||
positions,
|
positions,
|
||||||
|
|||||||
@ -679,7 +679,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
**kwargs: object):
|
**kwargs: object):
|
||||||
if intermediate_tensors is not None:
|
if intermediate_tensors is not None:
|
||||||
input_ids = None
|
|
||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
else:
|
else:
|
||||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
@ -690,9 +689,14 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
inputs_embeds = merge_multimodal_embeddings(
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
input_ids, inputs_embeds, vision_embeddings,
|
input_ids, inputs_embeds, vision_embeddings,
|
||||||
self.image_token_id)
|
self.image_token_id)
|
||||||
input_ids = None
|
|
||||||
else:
|
else:
|
||||||
inputs_embeds = None
|
inputs_embeds = self.language_model.model.embed_tokens(
|
||||||
|
input_ids)
|
||||||
|
|
||||||
|
# always pass the input via `inputs_embeds`
|
||||||
|
# to make sure the computation graph is consistent
|
||||||
|
# for `torch.compile` integration
|
||||||
|
input_ids = None
|
||||||
|
|
||||||
hidden_states = self.language_model.model(input_ids,
|
hidden_states = self.language_model.model(input_ids,
|
||||||
positions,
|
positions,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user