mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 05:05:01 +08:00
[Bugfix][TPU][V1] Fix recompilation (#15553)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
46450b8d33
commit
4098b72210
@ -32,7 +32,9 @@ docker run --privileged --net host --shm-size=16G -it \
|
|||||||
&& echo TEST_5 \
|
&& echo TEST_5 \
|
||||||
&& python3 /workspace/vllm/examples/offline_inference/tpu.py \
|
&& python3 /workspace/vllm/examples/offline_inference/tpu.py \
|
||||||
&& echo TEST_6 \
|
&& echo TEST_6 \
|
||||||
&& pytest -s -v /workspace/vllm/tests/tpu/worker/test_tpu_model_runner.py" \
|
&& pytest -s -v /workspace/vllm/tests/tpu/worker/test_tpu_model_runner.py \
|
||||||
|
&& echo TEST_7 \
|
||||||
|
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py" \
|
||||||
|
|
||||||
|
|
||||||
# TODO: This test fails because it uses RANDOM_SEED sampling
|
# TODO: This test fails because it uses RANDOM_SEED sampling
|
||||||
|
|||||||
@ -1,7 +1,4 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
import tempfile
|
|
||||||
from time import time
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm import LLM, envs
|
from vllm import LLM, envs
|
||||||
@ -15,60 +12,6 @@ if not envs.VLLM_USE_V1:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_name", ["D4nt3/Qwen2.5-two-layers"])
|
|
||||||
@pytest.mark.skipif(not current_platform.is_tpu(),
|
|
||||||
reason="This test needs a TPU")
|
|
||||||
def test_sampler_compilation(model_name: str, monkeypatch):
|
|
||||||
"""
|
|
||||||
Check that no recompilation happens despite changing sampling parameters.
|
|
||||||
We can't read XLA metrics from the engine process, hence we measure time.
|
|
||||||
"""
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
|
||||||
monkeypatch.setenv("VLLM_XLA_CACHE_PATH", temp_dir)
|
|
||||||
# Compiling model init may still take some time, enforce_eager to skip.
|
|
||||||
llm = LLM(model_name,
|
|
||||||
enforce_eager=True,
|
|
||||||
max_num_seqs=16,
|
|
||||||
max_model_len=1024,
|
|
||||||
gpu_memory_utilization=0.5)
|
|
||||||
prompts = [
|
|
||||||
"A robot may not injure a human being",
|
|
||||||
"It is only with the heart that one can see rightly;",
|
|
||||||
]
|
|
||||||
# First inference should be slow
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
temperature=0.7,
|
|
||||||
# top_p=0.6, # TODO too slow!
|
|
||||||
top_k=10,
|
|
||||||
min_p=0.2,
|
|
||||||
max_tokens=16)
|
|
||||||
s = time()
|
|
||||||
_ = llm.generate(prompts, sampling_params)
|
|
||||||
run1 = time() - s
|
|
||||||
|
|
||||||
# Second request with different params, but for which we
|
|
||||||
# compiled for in previous eager iteration.
|
|
||||||
sampling_params = SamplingParams(temperature=0.1,
|
|
||||||
top_k=12,
|
|
||||||
min_p=0.8,
|
|
||||||
max_tokens=24)
|
|
||||||
s = time()
|
|
||||||
_ = llm.generate(prompts, sampling_params)
|
|
||||||
run2 = time() - s
|
|
||||||
# Much faster after compiling
|
|
||||||
assert run1 * 0.1 > run2
|
|
||||||
print("TIMES", run1, run2)
|
|
||||||
|
|
||||||
# Third request with min_p set to "None". It will not trigger
|
|
||||||
# recompilation as a default 0 value will be used.
|
|
||||||
sampling_params = SamplingParams(max_tokens=24, temperature=0.0)
|
|
||||||
s = time()
|
|
||||||
_ = llm.generate(prompts, sampling_params)
|
|
||||||
run3 = time() - s
|
|
||||||
assert run1 * 0.1 > run3
|
|
||||||
print("TIMES", run1, run3)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
|
@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
|
||||||
@pytest.mark.skipif(not current_platform.is_tpu(),
|
@pytest.mark.skipif(not current_platform.is_tpu(),
|
||||||
reason="This test needs a TPU")
|
reason="This test needs a TPU")
|
||||||
@ -77,13 +20,11 @@ def test_sampler_different(model_name: str):
|
|||||||
Test significantly different sampling params to assert the model produces
|
Test significantly different sampling params to assert the model produces
|
||||||
different results.
|
different results.
|
||||||
"""
|
"""
|
||||||
llm = LLM(
|
llm = LLM(model_name,
|
||||||
model_name,
|
enforce_eager=False,
|
||||||
enforce_eager=True,
|
|
||||||
max_num_seqs=1,
|
max_num_seqs=1,
|
||||||
max_model_len=64,
|
max_model_len=512,
|
||||||
# TODO: setting to 0.5 or it will go OOM
|
max_num_batched_tokens=512)
|
||||||
gpu_memory_utilization=0.5)
|
|
||||||
prompts = [
|
prompts = [
|
||||||
"Write a short story about a robot that dreams for the first time."
|
"Write a short story about a robot that dreams for the first time."
|
||||||
]
|
]
|
||||||
|
|||||||
@ -88,6 +88,7 @@ class TPUSupportedSamplingMetadata:
|
|||||||
# Copy slice from CPU to corresponding TPU pre-allocated tensor.
|
# Copy slice from CPU to corresponding TPU pre-allocated tensor.
|
||||||
# Pad value is the default one.
|
# Pad value is the default one.
|
||||||
cpu_tensor[num_reqs:padded_num_reqs] = fill_val
|
cpu_tensor[num_reqs:padded_num_reqs] = fill_val
|
||||||
|
# Subtle compilation: len(tpu_tensor) must be >= `padded_num_reqs`
|
||||||
tpu_tensor[:padded_num_reqs] = cpu_tensor[:padded_num_reqs]
|
tpu_tensor[:padded_num_reqs] = cpu_tensor[:padded_num_reqs]
|
||||||
|
|
||||||
# NOTE NickLucche The sync CPU-TPU graph we produce here must be
|
# NOTE NickLucche The sync CPU-TPU graph we produce here must be
|
||||||
@ -101,13 +102,6 @@ class TPUSupportedSamplingMetadata:
|
|||||||
copy_slice(input_batch.min_p_cpu_tensor, input_batch.min_p,
|
copy_slice(input_batch.min_p_cpu_tensor, input_batch.min_p,
|
||||||
DEFAULT_SAMPLING_PARAMS["min_p"])
|
DEFAULT_SAMPLING_PARAMS["min_p"])
|
||||||
|
|
||||||
# copy_slice(input_batch.frequency_penalties_cpu_tensor,
|
|
||||||
# input_batch.frequency_penalties)
|
|
||||||
# copy_slice(input_batch.presence_penalties_cpu_tensor,
|
|
||||||
# input_batch.presence_penalties)
|
|
||||||
# copy_slice(input_batch.repetition_penalties_cpu_tensor,
|
|
||||||
# input_batch.repetition_penalties)
|
|
||||||
|
|
||||||
xm.mark_step()
|
xm.mark_step()
|
||||||
xm.wait_device_ops()
|
xm.wait_device_ops()
|
||||||
|
|
||||||
|
|||||||
@ -88,6 +88,8 @@ class TPUModelRunner:
|
|||||||
self.max_model_len = model_config.max_model_len
|
self.max_model_len = model_config.max_model_len
|
||||||
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
|
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
|
||||||
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
||||||
|
# InputBatch needs to work with sampling tensors greater than padding
|
||||||
|
# to avoid dynamic shapes. Also, avoid suboptimal alignment.
|
||||||
self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS)
|
self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS)
|
||||||
|
|
||||||
# Model-related.
|
# Model-related.
|
||||||
@ -788,6 +790,7 @@ class TPUModelRunner:
|
|||||||
dummy_hidden = torch.randn((num_tokens, hsize),
|
dummy_hidden = torch.randn((num_tokens, hsize),
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.bfloat16)
|
dtype=torch.bfloat16)
|
||||||
|
# Compile for [8, 16, .., 128,.., `self.max_num_reqs`]
|
||||||
while True:
|
while True:
|
||||||
indices = torch.zeros(
|
indices = torch.zeros(
|
||||||
num_reqs_to_sample,
|
num_reqs_to_sample,
|
||||||
@ -804,7 +807,9 @@ class TPUModelRunner:
|
|||||||
out = out.cpu()
|
out = out.cpu()
|
||||||
if num_reqs_to_sample >= self.max_num_reqs:
|
if num_reqs_to_sample >= self.max_num_reqs:
|
||||||
break
|
break
|
||||||
num_reqs_to_sample *= 2
|
# Make sure to compile the `max_num_reqs` upper-limit case
|
||||||
|
num_reqs_to_sample = _get_padded_num_reqs_with_upper_limit(
|
||||||
|
num_reqs_to_sample + 1, self.max_num_reqs)
|
||||||
xm.wait_device_ops()
|
xm.wait_device_ops()
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
logger.info("Compilation finished in in %.2f [secs].", end - start)
|
logger.info("Compilation finished in in %.2f [secs].", end - start)
|
||||||
@ -897,7 +902,6 @@ class ModelWrapperV1(nn.Module):
|
|||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
# @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
|
|
||||||
def sample_from_hidden(
|
def sample_from_hidden(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user