From d471b2aff09028f9c62e861f760a74fd8f99081d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 9 Dec 2025 10:00:49 -0800 Subject: [PATCH] [Model Runner V2] Support num NaNs in logits (#30187) Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu/async_utils.py | 41 +++++++++++++------------ vllm/v1/worker/gpu/metrics/__init__.py | 0 vllm/v1/worker/gpu/metrics/logits.py | 42 ++++++++++++++++++++++++++ vllm/v1/worker/gpu/model_runner.py | 2 +- vllm/v1/worker/gpu/sample/min_p.py | 4 +-- vllm/v1/worker/gpu/sample/output.py | 14 +++++++++ vllm/v1/worker/gpu/sample/sampler.py | 12 ++++++-- 7 files changed, 89 insertions(+), 26 deletions(-) create mode 100644 vllm/v1/worker/gpu/metrics/__init__.py create mode 100644 vllm/v1/worker/gpu/metrics/logits.py create mode 100644 vllm/v1/worker/gpu/sample/output.py diff --git a/vllm/v1/worker/gpu/async_utils.py b/vllm/v1/worker/gpu/async_utils.py index f6bc607c1ae67..a2e3decad0486 100644 --- a/vllm/v1/worker/gpu/async_utils.py +++ b/vllm/v1/worker/gpu/async_utils.py @@ -2,14 +2,15 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from contextlib import contextmanager +import numpy as np import torch from vllm.v1.outputs import ( AsyncModelRunnerOutput, LogprobsTensors, ModelRunnerOutput, - SamplerOutput, ) +from vllm.v1.worker.gpu.sample.output import SamplerOutput class AsyncOutput(AsyncModelRunnerOutput): @@ -34,29 +35,18 @@ class AsyncOutput(AsyncModelRunnerOutput): with torch.cuda.stream(self.copy_stream): self.copy_stream.wait_stream(default_stream) - # NOTE(woosuk): We must ensure that CPU tensors are not freed - # before the device-to-host copy is fully completed. For instance, - # operations like - # self.sampled_token_np = ...to("cpu", non_blocking=True).numpy() - # are unsafe because the underlying CPU tensor can be prematurely freed and - # reused by other tensors before the asynchronous copy finishes, potentially - # causing race conditions. To prevent this, we delay freeing by holding - # references until the copy event signals completion. - # Likewise, we also need to keep the reference to the GPU tensors. - # This is done by keeping the reference to sampler_output and - # model_runner_output. - self.sampled_token_ids = sampler_output.sampled_token_ids.to( - "cpu", non_blocking=True - ) + self.sampled_token_ids = async_copy_to_np(sampler_output.sampled_token_ids) if sampler_output.logprobs_tensors is not None: self.logprobs_tensors: LogprobsTensors | None = ( sampler_output.logprobs_tensors.to_cpu_nonblocking() ) else: self.logprobs_tensors = None - self.num_sampled_tokens_cpu = num_sampled_tokens.to( - "cpu", non_blocking=True - ) + if sampler_output.num_nans is not None: + self.num_nans = async_copy_to_np(sampler_output.num_nans) + else: + self.num_nans = None + self.num_sampled_tokens_np = async_copy_to_np(num_sampled_tokens) self.prompt_logprobs_dict: dict[str, LogprobsTensors | None] = {} if self.model_runner_output.prompt_logprobs_dict: for k, v in self.model_runner_output.prompt_logprobs_dict.items(): @@ -68,7 +58,6 @@ class AsyncOutput(AsyncModelRunnerOutput): def get_output(self) -> ModelRunnerOutput: self.copy_event.synchronize() - num_sampled_tokens_np = self.num_sampled_tokens_cpu.numpy() # NOTE(woosuk): The following code is to ensure compatibility with # the existing model runner. @@ -76,10 +65,18 @@ class AsyncOutput(AsyncModelRunnerOutput): # rather than Python lists. sampled_token_ids: list[list[int]] = self.sampled_token_ids.tolist() num_reqs = len(sampled_token_ids) + num_sampled_tokens = self.num_sampled_tokens_np.tolist() for i in range(num_reqs): - del sampled_token_ids[i][num_sampled_tokens_np[i] :] + del sampled_token_ids[i][num_sampled_tokens[i] :] self.model_runner_output.sampled_token_ids = sampled_token_ids + if self.num_nans is not None: + num_nans = self.num_nans.tolist() + self.model_runner_output.num_nans_in_logits = { + req_id: num_nans[i] + for i, req_id in enumerate(self.model_runner_output.req_ids) + } + if self.logprobs_tensors is not None: self.model_runner_output.logprobs = self.logprobs_tensors.tolists() self.model_runner_output.prompt_logprobs_dict = self.prompt_logprobs_dict @@ -95,3 +92,7 @@ def async_barrier(event: torch.cuda.Event | None): finally: if event is not None: event.record() + + +def async_copy_to_np(x: torch.Tensor) -> np.ndarray: + return x.to("cpu", non_blocking=True).numpy() diff --git a/vllm/v1/worker/gpu/metrics/__init__.py b/vllm/v1/worker/gpu/metrics/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/v1/worker/gpu/metrics/logits.py b/vllm/v1/worker/gpu/metrics/logits.py new file mode 100644 index 0000000000000..fd7b30beaa1f8 --- /dev/null +++ b/vllm/v1/worker/gpu/metrics/logits.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch +from torch._inductor.runtime.triton_helpers import libdevice + +from vllm.triton_utils import tl, triton + + +@triton.jit +def _num_nans_kernel( + logits_ptr, + logits_stride, + num_nans_ptr, + vocab_size, + BLOCK_SIZE: tl.constexpr, +): + req_idx = tl.program_id(0) + num_nans = 0 + for i in range(0, vocab_size, BLOCK_SIZE): + block = i + tl.arange(0, BLOCK_SIZE) + mask = block < vocab_size + logits = tl.load( + logits_ptr + req_idx * logits_stride + block, mask=mask, other=0 + ) + logits = logits.to(tl.float32) + is_nan = libdevice.isnan(logits).to(tl.int1) + num_nans += tl.sum(is_nan).to(tl.int32) + tl.store(num_nans_ptr + req_idx, num_nans) + + +def get_num_nans(logits: torch.Tensor) -> torch.Tensor: + num_reqs, vocab_size = logits.shape + BLOCK_SIZE = 8192 + num_nans = torch.empty(num_reqs, dtype=torch.int32, device=logits.device) + _num_nans_kernel[(num_reqs,)]( + logits, + logits.stride(0), + num_nans, + vocab_size, + BLOCK_SIZE=BLOCK_SIZE, + ) + return num_nans diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 464f7b7bd3532..9f4c6edfb6aa9 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -25,7 +25,6 @@ from vllm.v1.outputs import ( LogprobsTensors, ModelRunnerOutput, ) -from vllm.v1.sample.sampler import SamplerOutput from vllm.v1.worker.gpu.async_utils import AsyncOutput, async_barrier from vllm.v1.worker.gpu.attn_utils import ( build_attn_metadata, @@ -53,6 +52,7 @@ from vllm.v1.worker.gpu.sample.metadata import ( SamplingMetadata, expand_sampling_metadata, ) +from vllm.v1.worker.gpu.sample.output import SamplerOutput from vllm.v1.worker.gpu.sample.sampler import Sampler from vllm.v1.worker.gpu.spec_decode import init_speculator from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample diff --git a/vllm/v1/worker/gpu/sample/min_p.py b/vllm/v1/worker/gpu/sample/min_p.py index 0638818006f50..c98a42cb2b1bb 100644 --- a/vllm/v1/worker/gpu/sample/min_p.py +++ b/vllm/v1/worker/gpu/sample/min_p.py @@ -39,9 +39,7 @@ def _min_p_kernel( tl.store(logits_ptr + req_idx * logits_stride + block, logits, mask=mask) -def apply_min_p(logits: torch.Tensor, min_p: torch.Tensor | None) -> None: - if min_p is None: - return +def apply_min_p(logits: torch.Tensor, min_p: torch.Tensor) -> None: num_reqs, vocab_size = logits.shape BLOCK_SIZE = 1024 _min_p_kernel[(num_reqs,)]( diff --git a/vllm/v1/worker/gpu/sample/output.py b/vllm/v1/worker/gpu/sample/output.py new file mode 100644 index 0000000000000..13e8cf1d6c1ec --- /dev/null +++ b/vllm/v1/worker/gpu/sample/output.py @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass + +import torch + +from vllm.v1.outputs import LogprobsTensors + + +@dataclass +class SamplerOutput: + sampled_token_ids: torch.Tensor + logprobs_tensors: LogprobsTensors | None + num_nans: torch.Tensor | None diff --git a/vllm/v1/worker/gpu/sample/sampler.py b/vllm/v1/worker/gpu/sample/sampler.py index 9a4224d8fddef..84a3e18671b2c 100644 --- a/vllm/v1/worker/gpu/sample/sampler.py +++ b/vllm/v1/worker/gpu/sample/sampler.py @@ -3,13 +3,15 @@ import torch +import vllm.envs as envs from vllm.config.model import LogprobsMode -from vllm.v1.outputs import SamplerOutput from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p +from vllm.v1.worker.gpu.metrics.logits import get_num_nans from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata from vllm.v1.worker.gpu.sample.min_p import apply_min_p +from vllm.v1.worker.gpu.sample.output import SamplerOutput from vllm.v1.worker.gpu.sample.penalties import apply_penalties_and_temperature @@ -21,12 +23,16 @@ class Sampler: if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]: raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}") self.logprobs_mode = logprobs_mode + self.compute_nans = envs.VLLM_COMPUTE_NANS_IN_LOGITS # False by default. def __call__( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> SamplerOutput: + # NOTE(woosuk): We intentionally compute num_nans before sampling to make clear + # that num_nans is computed before applying penalties and temperature. + num_nans = get_num_nans(logits) if self.compute_nans else None sampled, processed_logits = self.sample(logits, sampling_metadata) if sampling_metadata.max_num_logprobs is not None: logits = ( @@ -49,6 +55,7 @@ class Sampler: # token per request. sampled_token_ids=sampled.view(-1, 1), logprobs_tensors=logprobs_tensors, + num_nans=num_nans, ) return sampler_output @@ -63,7 +70,8 @@ class Sampler: # Apply penalties and temperature in place. apply_penalties_and_temperature(logits, sampling_metadata) # Apply min_p in place. - apply_min_p(logits, sampling_metadata.min_p) + if sampling_metadata.min_p is not None: + apply_min_p(logits, sampling_metadata.min_p) # Apply top_k and/or top_p. This might return a new tensor. logits = apply_top_k_top_p( logits, sampling_metadata.top_k, sampling_metadata.top_p