mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:25:01 +08:00
Fix vllm:prompt_tokens_total metric calculation (#2869)
This commit is contained in:
parent
86fd8bb0ac
commit
e433c115bc
@ -13,12 +13,10 @@ _TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
|
||||
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
|
||||
|
||||
|
||||
def _read_prompts(filename: str) -> str:
|
||||
prompts = []
|
||||
def _read_prompts(filename: str) -> List[str]:
|
||||
with open(filename, "r") as f:
|
||||
prompt = f.readline()
|
||||
prompts.append(prompt)
|
||||
return prompts
|
||||
prompts = f.readlines()
|
||||
return prompts
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -165,6 +163,7 @@ class VllmRunner:
|
||||
model_name: str,
|
||||
tokenizer_name: Optional[str] = None,
|
||||
dtype: str = "half",
|
||||
disable_log_stats: bool = True,
|
||||
tensor_parallel_size: int = 1,
|
||||
) -> None:
|
||||
self.model = LLM(
|
||||
@ -173,6 +172,7 @@ class VllmRunner:
|
||||
trust_remote_code=True,
|
||||
dtype=dtype,
|
||||
swap_space=0,
|
||||
disable_log_stats=disable_log_stats,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
)
|
||||
|
||||
|
||||
33
tests/metrics/test_metrics.py
Normal file
33
tests/metrics/test_metrics.py
Normal file
@ -0,0 +1,33 @@
|
||||
import pytest
|
||||
import vllm.engine.metrics
|
||||
|
||||
MODELS = [
|
||||
"facebook/opt-125m",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
def test_metrics(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
vllm_model = vllm_runner(model, dtype=dtype, disable_log_stats=False)
|
||||
tokenizer = vllm_model.model.get_tokenizer()
|
||||
prompt_token_counts = [len(tokenizer.encode(p)) for p in example_prompts]
|
||||
# This test needs at least 2 prompts in a batch of different lengths to verify their token count is correct despite padding.
|
||||
assert len(example_prompts) > 1, "at least 2 prompts are required"
|
||||
assert prompt_token_counts[0] != prompt_token_counts[1], (
|
||||
"prompts of different lengths are required")
|
||||
vllm_prompt_token_count = sum(prompt_token_counts)
|
||||
|
||||
_ = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
metric_count = vllm.engine.metrics.counter_prompt_tokens.get_value({})
|
||||
|
||||
assert vllm_prompt_token_count == metric_count, (
|
||||
f"prompt token count: {vllm_prompt_token_count!r}\nmetric: {metric_count!r}"
|
||||
)
|
||||
@ -867,7 +867,9 @@ class LLMEngine:
|
||||
|
||||
# Number of Tokens.
|
||||
if prompt_run:
|
||||
num_prompt_tokens = scheduler_outputs.num_batched_tokens
|
||||
num_prompt_tokens = sum(
|
||||
len(seq_group.prompt_token_ids)
|
||||
for seq_group in scheduler_outputs.scheduled_seq_groups)
|
||||
else:
|
||||
num_generation_tokens = scheduler_outputs.num_batched_tokens
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user