mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-02 10:31:18 +08:00
Export NaNs in logits to scheduler_stats if output is corrupted (#18777)
Signed-off-by: Vlad Mihailescu <vtmihailescu@gmail.com>
This commit is contained in:
parent
7e8977fcd4
commit
2e3e3c86dc
@ -4,6 +4,7 @@
|
|||||||
import random
|
import random
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
from vllm.attention import Attention
|
from vllm.attention import Attention
|
||||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||||
@ -277,6 +278,54 @@ def test_update_states_request_resumed(model_runner):
|
|||||||
assert _is_req_state_block_table_match(model_runner, req_id)
|
assert _is_req_state_block_table_match(model_runner, req_id)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_nans_in_logits(model_runner):
|
||||||
|
req_ids = ("req_0", "req_1")
|
||||||
|
|
||||||
|
scheduler_output = _schedule_new_request(*req_ids)
|
||||||
|
model_runner._update_states(scheduler_output)
|
||||||
|
|
||||||
|
logits = torch.tensor([
|
||||||
|
[1.0, 2.0, 3.0],
|
||||||
|
[3.0, 2.0, 1.0],
|
||||||
|
], device=DEVICE)
|
||||||
|
result = model_runner._get_nans_in_logits(logits)
|
||||||
|
assert result == {"req_0": 0, "req_1": 0}
|
||||||
|
|
||||||
|
logits = torch.tensor([
|
||||||
|
[1.0, float('nan'), 3.0],
|
||||||
|
[4.0, float('nan'), float('nan')],
|
||||||
|
],
|
||||||
|
device=DEVICE)
|
||||||
|
result = model_runner._get_nans_in_logits(logits)
|
||||||
|
assert result == {"req_0": 1, "req_1": 2}
|
||||||
|
|
||||||
|
logits = torch.tensor([
|
||||||
|
[1.0, 2.0, 3.0],
|
||||||
|
[4.0, float('nan'), float('nan')],
|
||||||
|
],
|
||||||
|
device=DEVICE)
|
||||||
|
result = model_runner._get_nans_in_logits(logits)
|
||||||
|
assert result == {"req_0": 0, "req_1": 2}
|
||||||
|
|
||||||
|
result = model_runner._get_nans_in_logits(logits=None)
|
||||||
|
assert result == {"req_0": 0, "req_1": 0}
|
||||||
|
|
||||||
|
logits = torch.tensor([
|
||||||
|
[1.0, float('nan'), 3.0],
|
||||||
|
], device=DEVICE)
|
||||||
|
result = model_runner._get_nans_in_logits(logits)
|
||||||
|
assert result == {'req_0': 1, 'req_1': 0}
|
||||||
|
|
||||||
|
logits = torch.tensor([
|
||||||
|
[float('nan'), float('nan'), 2.0],
|
||||||
|
[1.0, 2.0, 3.0],
|
||||||
|
[float('nan'), 2.0, 3.0],
|
||||||
|
],
|
||||||
|
device=DEVICE)
|
||||||
|
result = model_runner._get_nans_in_logits(logits)
|
||||||
|
assert result == {'req_0': 2, 'req_1': 0}
|
||||||
|
|
||||||
|
|
||||||
def test_update_states_no_changes(model_runner):
|
def test_update_states_no_changes(model_runner):
|
||||||
req_id = "req_0"
|
req_id = "req_0"
|
||||||
|
|
||||||
|
|||||||
@ -130,6 +130,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_SLEEP_WHEN_IDLE: bool = False
|
VLLM_SLEEP_WHEN_IDLE: bool = False
|
||||||
VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16
|
VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16
|
||||||
VLLM_KV_CACHE_LAYOUT: Optional[str] = None
|
VLLM_KV_CACHE_LAYOUT: Optional[str] = None
|
||||||
|
VLLM_COMPUTE_NANS_IN_LOGITS: bool = False
|
||||||
|
|
||||||
|
|
||||||
def get_default_cache_root():
|
def get_default_cache_root():
|
||||||
@ -897,7 +898,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
# leave the layout choice to the backend. Mind that backends may only
|
# leave the layout choice to the backend. Mind that backends may only
|
||||||
# implement and support a subset of all possible layouts.
|
# implement and support a subset of all possible layouts.
|
||||||
"VLLM_KV_CACHE_LAYOUT":
|
"VLLM_KV_CACHE_LAYOUT":
|
||||||
lambda: os.getenv("VLLM_KV_CACHE_LAYOUT", None)
|
lambda: os.getenv("VLLM_KV_CACHE_LAYOUT", None),
|
||||||
|
|
||||||
|
# Enable checking whether the generated logits contain NaNs,
|
||||||
|
# indicating corrupted output. Useful for debugging low level bugs
|
||||||
|
# or bad hardware but it may add compute overhead.
|
||||||
|
"VLLM_COMPUTE_NANS_IN_LOGITS":
|
||||||
|
lambda: bool(int(os.getenv("VLLM_COMPUTE_NANS_IN_LOGITS", "0"))),
|
||||||
}
|
}
|
||||||
|
|
||||||
# --8<-- [end:env-vars-definition]
|
# --8<-- [end:env-vars-definition]
|
||||||
|
|||||||
@ -717,6 +717,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
|
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
|
||||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||||
pooler_outputs = model_runner_output.pooler_output
|
pooler_outputs = model_runner_output.pooler_output
|
||||||
|
num_nans_in_logits = model_runner_output.num_nans_in_logits
|
||||||
|
|
||||||
new_running: list[Request] = []
|
new_running: list[Request] = []
|
||||||
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
|
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
|
||||||
@ -810,6 +811,10 @@ class Scheduler(SchedulerInterface):
|
|||||||
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
|
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
|
||||||
req_id, new_token_ids)
|
req_id, new_token_ids)
|
||||||
|
|
||||||
|
# spec_token_ids comes from the model runner output
|
||||||
|
if num_nans_in_logits is not None and req_id in num_nans_in_logits:
|
||||||
|
request.num_nans_in_logits = num_nans_in_logits[req_id]
|
||||||
|
|
||||||
# Add newly generated spec token ids to the request.
|
# Add newly generated spec token ids to the request.
|
||||||
if spec_token_ids is not None:
|
if spec_token_ids is not None:
|
||||||
if self.structured_output_manager.should_advance(request):
|
if self.structured_output_manager.should_advance(request):
|
||||||
@ -972,6 +977,8 @@ class Scheduler(SchedulerInterface):
|
|||||||
kv_cache_usage=self.kv_cache_manager.usage,
|
kv_cache_usage=self.kv_cache_manager.usage,
|
||||||
prefix_cache_stats=prefix_cache_stats,
|
prefix_cache_stats=prefix_cache_stats,
|
||||||
spec_decoding_stats=spec_decoding_stats,
|
spec_decoding_stats=spec_decoding_stats,
|
||||||
|
num_corrupted_reqs=sum(req.is_output_corrupted
|
||||||
|
for req in self.running),
|
||||||
)
|
)
|
||||||
|
|
||||||
def make_spec_decoding_stats(
|
def make_spec_decoding_stats(
|
||||||
|
|||||||
@ -40,6 +40,8 @@ class SchedulerStats:
|
|||||||
|
|
||||||
spec_decoding_stats: Optional[SpecDecodingStats] = None
|
spec_decoding_stats: Optional[SpecDecodingStats] = None
|
||||||
|
|
||||||
|
num_corrupted_reqs: int = 0
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LoRAStats:
|
class LoRAStats:
|
||||||
|
|||||||
@ -108,6 +108,9 @@ class ModelRunnerOutput:
|
|||||||
finished_sending: Optional[set[str]] = None
|
finished_sending: Optional[set[str]] = None
|
||||||
finished_recving: Optional[set[str]] = None
|
finished_recving: Optional[set[str]] = None
|
||||||
|
|
||||||
|
# req_id -> num_nans_in_logits
|
||||||
|
num_nans_in_logits: Optional[dict[str, int]] = None
|
||||||
|
|
||||||
|
|
||||||
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
|
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
|
||||||
req_id_to_index={},
|
req_id_to_index={},
|
||||||
@ -117,4 +120,5 @@ EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
|
|||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
pooler_output=[],
|
pooler_output=[],
|
||||||
finished_sending=None,
|
finished_sending=None,
|
||||||
finished_recving=None)
|
finished_recving=None,
|
||||||
|
num_nans_in_logits=None)
|
||||||
|
|||||||
@ -97,6 +97,10 @@ class Request:
|
|||||||
# The number of tokens with prefix cache hits.
|
# The number of tokens with prefix cache hits.
|
||||||
self.num_cached_tokens = -1
|
self.num_cached_tokens = -1
|
||||||
|
|
||||||
|
# The number of NaNs in logits. A value greater than 0
|
||||||
|
# indicates that the output is corrupted
|
||||||
|
self.num_nans_in_logits = 0
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
|
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
|
||||||
if request.mm_inputs is not None:
|
if request.mm_inputs is not None:
|
||||||
@ -132,6 +136,10 @@ class Request:
|
|||||||
self._output_token_ids.extend(token_ids)
|
self._output_token_ids.extend(token_ids)
|
||||||
self._all_token_ids.extend(token_ids)
|
self._all_token_ids.extend(token_ids)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_output_corrupted(self) -> bool:
|
||||||
|
return self.num_nans_in_logits > 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_tokens(self) -> int:
|
def num_tokens(self) -> int:
|
||||||
return len(self._all_token_ids)
|
return len(self._all_token_ids)
|
||||||
|
|||||||
@ -1431,6 +1431,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
)
|
)
|
||||||
sampler_output.sampled_token_ids = output_token_ids
|
sampler_output.sampled_token_ids = output_token_ids
|
||||||
|
|
||||||
|
num_nans_in_logits = {}
|
||||||
|
if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
|
||||||
|
num_nans_in_logits = self._get_nans_in_logits(logits)
|
||||||
|
|
||||||
# TODO(woosuk): The following loop can be slow since it iterates over
|
# TODO(woosuk): The following loop can be slow since it iterates over
|
||||||
# the requests one by one. Optimize.
|
# the requests one by one. Optimize.
|
||||||
discard_sampled_tokens_req_indices = []
|
discard_sampled_tokens_req_indices = []
|
||||||
@ -1601,6 +1605,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
pooler_output=[],
|
pooler_output=[],
|
||||||
finished_sending=finished_sending,
|
finished_sending=finished_sending,
|
||||||
finished_recving=finished_recving,
|
finished_recving=finished_recving,
|
||||||
|
num_nans_in_logits=num_nans_in_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
def kv_connector_no_forward(
|
def kv_connector_no_forward(
|
||||||
@ -1826,6 +1831,26 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
return prompt_logprobs_dict
|
return prompt_logprobs_dict
|
||||||
|
|
||||||
|
def _get_nans_in_logits(
|
||||||
|
self,
|
||||||
|
logits: Optional[torch.Tensor],
|
||||||
|
) -> dict[str, int]:
|
||||||
|
try:
|
||||||
|
if logits is None:
|
||||||
|
return {req_id: 0 for req_id in self.input_batch.req_ids}
|
||||||
|
|
||||||
|
num_nans_in_logits = {}
|
||||||
|
num_nans_for_index = logits.isnan().sum(dim=-1).cpu().numpy()
|
||||||
|
for req_id in self.input_batch.req_ids:
|
||||||
|
req_index = self.input_batch.req_id_to_index[req_id]
|
||||||
|
num_nans_in_logits[req_id] = (
|
||||||
|
int(num_nans_for_index[req_index])
|
||||||
|
if num_nans_for_index is not None
|
||||||
|
and req_index < logits.shape[0] else 0)
|
||||||
|
return num_nans_in_logits
|
||||||
|
except IndexError:
|
||||||
|
return {}
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def maybe_randomize_inputs(self, input_ids: torch.Tensor):
|
def maybe_randomize_inputs(self, input_ids: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user