diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index abf14a8fb6250..583a88d8e6ec6 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -4,6 +4,7 @@ import random import pytest +import torch from vllm.attention import Attention from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, @@ -277,6 +278,54 @@ def test_update_states_request_resumed(model_runner): assert _is_req_state_block_table_match(model_runner, req_id) +def test_get_nans_in_logits(model_runner): + req_ids = ("req_0", "req_1") + + scheduler_output = _schedule_new_request(*req_ids) + model_runner._update_states(scheduler_output) + + logits = torch.tensor([ + [1.0, 2.0, 3.0], + [3.0, 2.0, 1.0], + ], device=DEVICE) + result = model_runner._get_nans_in_logits(logits) + assert result == {"req_0": 0, "req_1": 0} + + logits = torch.tensor([ + [1.0, float('nan'), 3.0], + [4.0, float('nan'), float('nan')], + ], + device=DEVICE) + result = model_runner._get_nans_in_logits(logits) + assert result == {"req_0": 1, "req_1": 2} + + logits = torch.tensor([ + [1.0, 2.0, 3.0], + [4.0, float('nan'), float('nan')], + ], + device=DEVICE) + result = model_runner._get_nans_in_logits(logits) + assert result == {"req_0": 0, "req_1": 2} + + result = model_runner._get_nans_in_logits(logits=None) + assert result == {"req_0": 0, "req_1": 0} + + logits = torch.tensor([ + [1.0, float('nan'), 3.0], + ], device=DEVICE) + result = model_runner._get_nans_in_logits(logits) + assert result == {'req_0': 1, 'req_1': 0} + + logits = torch.tensor([ + [float('nan'), float('nan'), 2.0], + [1.0, 2.0, 3.0], + [float('nan'), 2.0, 3.0], + ], + device=DEVICE) + result = model_runner._get_nans_in_logits(logits) + assert result == {'req_0': 2, 'req_1': 0} + + def test_update_states_no_changes(model_runner): req_id = "req_0" diff --git a/vllm/envs.py b/vllm/envs.py index c7604d6dfeb84..b1030997f25ad 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -130,6 +130,7 @@ if TYPE_CHECKING: VLLM_SLEEP_WHEN_IDLE: bool = False VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16 VLLM_KV_CACHE_LAYOUT: Optional[str] = None + VLLM_COMPUTE_NANS_IN_LOGITS: bool = False def get_default_cache_root(): @@ -897,7 +898,13 @@ environment_variables: dict[str, Callable[[], Any]] = { # leave the layout choice to the backend. Mind that backends may only # implement and support a subset of all possible layouts. "VLLM_KV_CACHE_LAYOUT": - lambda: os.getenv("VLLM_KV_CACHE_LAYOUT", None) + lambda: os.getenv("VLLM_KV_CACHE_LAYOUT", None), + + # Enable checking whether the generated logits contain NaNs, + # indicating corrupted output. Useful for debugging low level bugs + # or bad hardware but it may add compute overhead. + "VLLM_COMPUTE_NANS_IN_LOGITS": + lambda: bool(int(os.getenv("VLLM_COMPUTE_NANS_IN_LOGITS", "0"))), } # --8<-- [end:env-vars-definition] diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 16e76defdf721..0958366e0aca7 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -717,6 +717,7 @@ class Scheduler(SchedulerInterface): prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict num_scheduled_tokens = scheduler_output.num_scheduled_tokens pooler_outputs = model_runner_output.pooler_output + num_nans_in_logits = model_runner_output.num_nans_in_logits new_running: list[Request] = [] outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) @@ -810,6 +811,10 @@ class Scheduler(SchedulerInterface): request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] req_id, new_token_ids) + # spec_token_ids comes from the model runner output + if num_nans_in_logits is not None and req_id in num_nans_in_logits: + request.num_nans_in_logits = num_nans_in_logits[req_id] + # Add newly generated spec token ids to the request. if spec_token_ids is not None: if self.structured_output_manager.should_advance(request): @@ -972,6 +977,8 @@ class Scheduler(SchedulerInterface): kv_cache_usage=self.kv_cache_manager.usage, prefix_cache_stats=prefix_cache_stats, spec_decoding_stats=spec_decoding_stats, + num_corrupted_reqs=sum(req.is_output_corrupted + for req in self.running), ) def make_spec_decoding_stats( diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 716f40fffb282..1eb10ccb6c493 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -40,6 +40,8 @@ class SchedulerStats: spec_decoding_stats: Optional[SpecDecodingStats] = None + num_corrupted_reqs: int = 0 + @dataclass class LoRAStats: diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 2234843293cc6..f78623f571b2d 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -108,6 +108,9 @@ class ModelRunnerOutput: finished_sending: Optional[set[str]] = None finished_recving: Optional[set[str]] = None + # req_id -> num_nans_in_logits + num_nans_in_logits: Optional[dict[str, int]] = None + EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], req_id_to_index={}, @@ -117,4 +120,5 @@ EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], prompt_logprobs_dict={}, pooler_output=[], finished_sending=None, - finished_recving=None) + finished_recving=None, + num_nans_in_logits=None) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index e3f3a418755c3..4632884419aed 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -97,6 +97,10 @@ class Request: # The number of tokens with prefix cache hits. self.num_cached_tokens = -1 + # The number of NaNs in logits. A value greater than 0 + # indicates that the output is corrupted + self.num_nans_in_logits = 0 + @classmethod def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": if request.mm_inputs is not None: @@ -132,6 +136,10 @@ class Request: self._output_token_ids.extend(token_ids) self._all_token_ids.extend(token_ids) + @property + def is_output_corrupted(self) -> bool: + return self.num_nans_in_logits > 0 + @property def num_tokens(self) -> int: return len(self._all_token_ids) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f96fb64342c9f..3303660061183 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1431,6 +1431,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): ) sampler_output.sampled_token_ids = output_token_ids + num_nans_in_logits = {} + if envs.VLLM_COMPUTE_NANS_IN_LOGITS: + num_nans_in_logits = self._get_nans_in_logits(logits) + # TODO(woosuk): The following loop can be slow since it iterates over # the requests one by one. Optimize. discard_sampled_tokens_req_indices = [] @@ -1601,6 +1605,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): pooler_output=[], finished_sending=finished_sending, finished_recving=finished_recving, + num_nans_in_logits=num_nans_in_logits, ) def kv_connector_no_forward( @@ -1826,6 +1831,26 @@ class GPUModelRunner(LoRAModelRunnerMixin): return prompt_logprobs_dict + def _get_nans_in_logits( + self, + logits: Optional[torch.Tensor], + ) -> dict[str, int]: + try: + if logits is None: + return {req_id: 0 for req_id in self.input_batch.req_ids} + + num_nans_in_logits = {} + num_nans_for_index = logits.isnan().sum(dim=-1).cpu().numpy() + for req_id in self.input_batch.req_ids: + req_index = self.input_batch.req_id_to_index[req_id] + num_nans_in_logits[req_id] = ( + int(num_nans_for_index[req_index]) + if num_nans_for_index is not None + and req_index < logits.shape[0] else 0) + return num_nans_in_logits + except IndexError: + return {} + @contextmanager def maybe_randomize_inputs(self, input_ids: torch.Tensor): """