From 6afc0ffaf61b49599dec25c30b28596a390ad2d9 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 29 Nov 2025 00:41:01 -0800 Subject: [PATCH] [Model Runner V2] Add sample/ directory and reorganize files (#29719) Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu/model_runner.py | 15 +- vllm/v1/worker/gpu/sample/__init__.py | 0 vllm/v1/worker/gpu/sample/gumbel.py | 100 ++++++ vllm/v1/worker/gpu/sample/logprob.py | 167 ++++++++++ vllm/v1/worker/gpu/sample/metadata.py | 179 ++++++++++ vllm/v1/worker/gpu/{ => sample}/penalties.py | 48 ++- vllm/v1/worker/gpu/sample/sampler.py | 79 +++++ vllm/v1/worker/gpu/sampler.py | 333 ------------------- vllm/v1/worker/gpu/spec_decode/eagle.py | 4 +- vllm/v1/worker/gpu/states.py | 232 +------------ 10 files changed, 587 insertions(+), 570 deletions(-) create mode 100644 vllm/v1/worker/gpu/sample/__init__.py create mode 100644 vllm/v1/worker/gpu/sample/gumbel.py create mode 100644 vllm/v1/worker/gpu/sample/logprob.py create mode 100644 vllm/v1/worker/gpu/sample/metadata.py rename vllm/v1/worker/gpu/{ => sample}/penalties.py (66%) create mode 100644 vllm/v1/worker/gpu/sample/sampler.py delete mode 100644 vllm/v1/worker/gpu/sampler.py diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 8414ca53b8748..fdb930c4dcd79 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -47,13 +47,18 @@ from vllm.v1.worker.gpu.input_batch import ( prepare_pos_seq_lens, prepare_prefill_inputs, ) -from vllm.v1.worker.gpu.sampler import Sampler, compute_prompt_logprobs +from vllm.v1.worker.gpu.sample.logprob import compute_prompt_logprobs +from vllm.v1.worker.gpu.sample.metadata import ( + SamplingMetadata, + expand_sampling_metadata, +) +from vllm.v1.worker.gpu.sample.sampler import Sampler from vllm.v1.worker.gpu.spec_decode import init_speculator from vllm.v1.worker.gpu.spec_decode.rejection_sample import ( get_num_rejected, rejection_sample, ) -from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata +from vllm.v1.worker.gpu.states import RequestState from vllm.v1.worker.gpu.structured_outputs import apply_grammar_bitmask from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin @@ -890,8 +895,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): input_batch.idx_mapping, input_batch.idx_mapping_np, pos ) if input_batch.num_draft_tokens > 0: - sampling_metadata = self.req_states.expand_sampling_metadata( - sampling_metadata, input_batch.cu_num_logits + sampling_metadata = expand_sampling_metadata( + sampling_metadata, + input_batch.cu_num_logits, + max_expand_len=self.num_speculative_steps + 1, ) if self.lora_config: diff --git a/vllm/v1/worker/gpu/sample/__init__.py b/vllm/v1/worker/gpu/sample/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/v1/worker/gpu/sample/gumbel.py b/vllm/v1/worker/gpu/sample/gumbel.py new file mode 100644 index 0000000000000..3e0d72e56939b --- /dev/null +++ b/vllm/v1/worker/gpu/sample/gumbel.py @@ -0,0 +1,100 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +from vllm.triton_utils import tl, triton + + +@triton.jit +def _gumbel_sample_kernel( + local_argmax_ptr, + local_argmax_stride, + local_max_ptr, + local_max_stride, + logits_ptr, + logits_stride, + seeds_ptr, + pos_ptr, + temp_ptr, + vocab_size, + BLOCK_SIZE: tl.constexpr, + APPLY_TEMPERATURE: tl.constexpr, +): + req_idx = tl.program_id(0) + block_idx = tl.program_id(1) + block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = block < vocab_size + logits = tl.load( + logits_ptr + req_idx * logits_stride + block, + mask=mask, + other=float("-inf"), + ) + logits = logits.to(tl.float32) + + temp = tl.load(temp_ptr + req_idx).to(tl.float32) + if temp != 0.0: + # Calculate the seed for gumbel noise. + seed = tl.load(seeds_ptr + req_idx) + pos = tl.load(pos_ptr + req_idx) + gumbel_seed = tl.randint(seed, pos) + + # Generate gumbel noise. + r = tl.rand(gumbel_seed, block).to(tl.float64) + gumbel_noise = -tl.log(-tl.log(r + 1e-20) + 1e-20) + gumbel_noise = gumbel_noise.to(tl.float32) + + # Apply temperature. + if APPLY_TEMPERATURE: + # NOTE(woosuk): Use div_rn to match the behavior of torch. + logits = tl.div_rn(logits, temp) + + # Apply gumbel noise. + logits = tl.where(mask, logits + gumbel_noise, float("-inf")) + + idx = tl.argmax(logits, axis=0) + token_id = block_idx * BLOCK_SIZE + idx + value = tl.max(logits, axis=0) + tl.store(local_argmax_ptr + req_idx * local_argmax_stride + block_idx, token_id) + tl.store(local_max_ptr + req_idx * local_max_stride + block_idx, value) + + +def gumbel_sample( + logits: torch.Tensor, # [num_reqs, vocab_size] + temperature: torch.Tensor, # [num_reqs] + seed: torch.Tensor, # [num_reqs] + pos: torch.Tensor, # [num_reqs] + apply_temperature: bool, +) -> torch.Tensor: + num_reqs, vocab_size = logits.shape + BLOCK_SIZE = 1024 + num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE) + local_argmax = torch.empty( + num_reqs, + num_blocks, + dtype=torch.int64, + device=logits.device, + ) + local_max = torch.empty( + num_reqs, + num_blocks, + dtype=torch.float32, + device=logits.device, + ) + _gumbel_sample_kernel[(num_reqs, num_blocks)]( + local_argmax, + local_argmax.stride(0), + local_max, + local_max.stride(0), + logits, + logits.stride(0), + seed, + pos, + temperature, + vocab_size, + BLOCK_SIZE=BLOCK_SIZE, + APPLY_TEMPERATURE=apply_temperature, + ) + # NOTE(woosuk): Use int64 for later indexing. + max_block_idx = local_max.argmax(dim=-1, keepdim=True) + sampled = local_argmax.gather(dim=-1, index=max_block_idx).view(-1) + return sampled diff --git a/vllm/v1/worker/gpu/sample/logprob.py b/vllm/v1/worker/gpu/sample/logprob.py new file mode 100644 index 0000000000000..25448b387b310 --- /dev/null +++ b/vllm/v1/worker/gpu/sample/logprob.py @@ -0,0 +1,167 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable + +import torch + +from vllm.triton_utils import tl, triton +from vllm.v1.outputs import LogprobsTensors + + +@triton.jit +def _topk_log_softmax_kernel( + output_ptr, + logits_ptr, + logits_stride, + topk_ids_ptr, + topk, + vocab_size, + BLOCK_SIZE: tl.constexpr, + PADDED_TOPK: tl.constexpr, +): + req_idx = tl.program_id(0) + row_ptr = logits_ptr + req_idx * logits_stride + + max_val = float("-inf") + for i in range(0, vocab_size, BLOCK_SIZE): + block = i + tl.arange(0, BLOCK_SIZE) + logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf")) + max_val = tl.max(tl.maximum(logits, max_val)) + max_val = max_val.to(tl.float32) # type: ignore + + se = 0.0 + for i in range(0, vocab_size, BLOCK_SIZE): + block = i + tl.arange(0, BLOCK_SIZE) + logits = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0) + # NOTE(woosuk): Make sure that logits and all following operations use FP32. + logits = logits.to(tl.float32) + e = tl.exp(logits - max_val) + e = tl.where(block < vocab_size, e, 0.0) + se += tl.sum(e) + lse = tl.log(se) + + k_offset = tl.arange(0, PADDED_TOPK) + k_mask = k_offset < topk + topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask, other=0) + + logits = tl.load(row_ptr + topk_ids, mask=k_mask) + logits = logits.to(tl.float32) + o = logits - max_val - lse + tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask) + + +@triton.jit +def _ranks_kernel( + output_ptr, + logits_ptr, + logits_stride, + token_ids_ptr, + vocab_size, + BLOCK_SIZE: tl.constexpr, +): + req_idx = tl.program_id(0) + row_ptr = logits_ptr + req_idx * logits_stride + + token_id = tl.load(token_ids_ptr + req_idx) + x = tl.load(row_ptr + token_id) + + n = 0 + for i in range(0, vocab_size, BLOCK_SIZE): + block = i + tl.arange(0, BLOCK_SIZE) + logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf")) + n += tl.sum((logits > x).to(tl.int32)) + tl.store(output_ptr + req_idx, n) + + +def compute_token_logprobs( + logits: torch.Tensor, + token_ids: torch.Tensor, +) -> torch.Tensor: + batch_size = logits.shape[0] + vocab_size = logits.shape[1] + token_ids = token_ids.to(torch.int64) + num_logprobs = token_ids.shape[1] + logprobs = torch.empty( + batch_size, + num_logprobs, + dtype=torch.float32, + device=logits.device, + ) + _topk_log_softmax_kernel[(batch_size,)]( + logprobs, + logits, + logits.stride(0), + token_ids, + num_logprobs, + vocab_size, + BLOCK_SIZE=1024, # type: ignore + PADDED_TOPK=triton.next_power_of_2(num_logprobs), + ) + return logprobs + + +def compute_topk_logprobs( + logits: torch.Tensor, + num_logprobs: int, + sampled_token_ids: torch.Tensor, +) -> LogprobsTensors: + assert num_logprobs >= 0 + batch_size, vocab_size = logits.shape + if num_logprobs == 0: + logprob_token_ids = sampled_token_ids.unsqueeze(-1) + else: + topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices + logprob_token_ids = torch.cat( + (sampled_token_ids.unsqueeze(-1), topk_indices), dim=1 + ) + + # NOTE(woosuk): Here, to save GPU memory, we do not materialize the full + # logprobs tensor. Instead, we only compute and return the logprobs of + # the topk + 1 tokens. + logprobs = compute_token_logprobs(logits, logprob_token_ids) + token_ranks = torch.empty( + batch_size, + dtype=torch.int64, + device=logits.device, + ) + _ranks_kernel[(batch_size,)]( + token_ranks, + logits, + logits.stride(0), + sampled_token_ids, + vocab_size, + BLOCK_SIZE=8192, # type: ignore + ) + return LogprobsTensors( + logprob_token_ids=logprob_token_ids, + logprobs=logprobs, + selected_token_ranks=token_ranks, + ) + + +def compute_prompt_logprobs( + prompt_token_ids: torch.Tensor, + prompt_hidden_states: torch.Tensor, + logits_fn: Callable[[torch.Tensor], torch.Tensor], +) -> tuple[torch.Tensor, torch.Tensor]: + # Since materializing the full prompt logits can take too much memory, + # we compute it in chunks. + CHUNK_SIZE = 1024 + logprobs = [] + ranks = [] + prompt_token_ids = prompt_token_ids.to(torch.int64) + for start_idx in range(0, prompt_token_ids.shape[0], CHUNK_SIZE): + end_idx = start_idx + CHUNK_SIZE + # NOTE(woosuk): logits_fn can be slow because it involves all-gather. + prompt_logits = logits_fn(prompt_hidden_states[start_idx:end_idx]) + prompt_logprobs = compute_topk_logprobs( + prompt_logits, + 0, # num_logprobs + prompt_token_ids[start_idx:end_idx], + ) + logprobs.append(prompt_logprobs.logprobs) + ranks.append(prompt_logprobs.selected_token_ranks) + + logprobs = torch.cat(logprobs, dim=0) if len(logprobs) > 1 else logprobs[0] + ranks = torch.cat(ranks, dim=0) if len(ranks) > 1 else ranks[0] + return logprobs, ranks diff --git a/vllm/v1/worker/gpu/sample/metadata.py b/vllm/v1/worker/gpu/sample/metadata.py new file mode 100644 index 0000000000000..666649fd0eebc --- /dev/null +++ b/vllm/v1/worker/gpu/sample/metadata.py @@ -0,0 +1,179 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass + +import torch + +from vllm.triton_utils import tl, triton + + +@dataclass +class SamplingMetadata: + temperature: torch.Tensor + + top_p: torch.Tensor | None + top_k: torch.Tensor | None + + repetition_penalty: torch.Tensor + frequency_penalty: torch.Tensor + presence_penalty: torch.Tensor + + seeds: torch.Tensor + pos: torch.Tensor + + # None means no logprobs, 0 means sampled token logprobs only + max_num_logprobs: int | None + + # For penalties + idx_mapping: torch.Tensor + prompt_bin_counts: torch.Tensor + output_bin_counts: torch.Tensor + + @classmethod + def make_dummy( + cls, + num_reqs: int, + device: torch.device, + ) -> "SamplingMetadata": + assert num_reqs > 0 + temperature = torch.zeros(num_reqs, dtype=torch.float32, device=device) + temperature[0] = 0.5 + # TODO(woosuk): Use top-p and top-k for dummy sampler. + # Currently, they are disabled because of memory usage. + # top_p = torch.full((num_reqs,), 0.95, dtype=torch.float32, device=device) + # top_k = torch.full((num_reqs,), 20, dtype=torch.int32, device=device) + top_p = None + top_k = None + # NOTE(woosuk): We must set penalties to their default values to make sure + # the penalties kernel does not touch the placeholder bin_counts tensors. + repetition_penalty = torch.ones(num_reqs, dtype=torch.float32, device=device) + frequency_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device) + presence_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device) + seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device) + pos = torch.zeros(num_reqs, dtype=torch.int64, device=device) + max_num_logprobs = 20 + + idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device) + # NOTE(woosuk): These are placeholder tensors to avoid None checks in the + # penalties kernel. We use 2 instead of 1 as vocab_size to avoid Triton + # specialization and re-compilation at runtime. + prompt_bin_counts = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device) + output_bin_counts = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device) + + return cls( + temperature=temperature, + top_p=top_p, + top_k=top_k, + repetition_penalty=repetition_penalty, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + seeds=seeds, + pos=pos, + max_num_logprobs=max_num_logprobs, + idx_mapping=idx_mapping, + prompt_bin_counts=prompt_bin_counts, + output_bin_counts=output_bin_counts, + ) + + +# NOTE(woosuk): Re-compilation can happen at runtime since top_p and top_k can be None. +@triton.jit +def _expand_sampling_metadata_kernel( + temp_ptr, + expanded_temp_ptr, + top_p_ptr, + expanded_top_p_ptr, + top_k_ptr, + expanded_top_k_ptr, + rep_penalty_ptr, + expanded_rep_penalty_ptr, + freq_penalty_ptr, + expanded_freq_penalty_ptr, + pres_penalty_ptr, + expanded_pres_penalty_ptr, + seeds_ptr, + expanded_seeds_ptr, + cu_num_logits_ptr, + BLOCK_SIZE: tl.constexpr, +): + req_idx = tl.program_id(0) + start_idx = tl.load(cu_num_logits_ptr + req_idx) + end_idx = tl.load(cu_num_logits_ptr + req_idx + 1) + num_tokens = end_idx - start_idx + + block = tl.arange(0, BLOCK_SIZE) + mask = block < num_tokens + + temp = tl.load(temp_ptr + req_idx) + tl.store(expanded_temp_ptr + start_idx + block, temp, mask=mask) + + if top_p_ptr is not None: + top_p = tl.load(top_p_ptr + req_idx) + tl.store(expanded_top_p_ptr + start_idx + block, top_p, mask=mask) + + if top_k_ptr is not None: + top_k = tl.load(top_k_ptr + req_idx) + tl.store(expanded_top_k_ptr + start_idx + block, top_k, mask=mask) + + rep_penalty = tl.load(rep_penalty_ptr + req_idx) + tl.store(expanded_rep_penalty_ptr + start_idx + block, rep_penalty, mask=mask) + + freq_penalty = tl.load(freq_penalty_ptr + req_idx) + tl.store(expanded_freq_penalty_ptr + start_idx + block, freq_penalty, mask=mask) + + pres_penalty = tl.load(pres_penalty_ptr + req_idx) + tl.store(expanded_pres_penalty_ptr + start_idx + block, pres_penalty, mask=mask) + + seed = tl.load(seeds_ptr + req_idx) + tl.store(expanded_seeds_ptr + start_idx + block, seed, mask=mask) + + +def expand_sampling_metadata( + sampling_metadata: SamplingMetadata, + cu_num_logits: torch.Tensor, + max_expand_len: int, +) -> SamplingMetadata: + total_num_logits = sampling_metadata.pos.shape[0] + create_empty = lambda x: x.new_empty(total_num_logits) if x is not None else None + expanded_temp = create_empty(sampling_metadata.temperature) + expanded_top_p = create_empty(sampling_metadata.top_p) + expanded_top_k = create_empty(sampling_metadata.top_k) + expanded_repetition_penalty = create_empty(sampling_metadata.repetition_penalty) + expanded_frequency_penalty = create_empty(sampling_metadata.frequency_penalty) + expanded_presence_penalty = create_empty(sampling_metadata.presence_penalty) + expanded_seeds = create_empty(sampling_metadata.seeds) + + num_reqs = cu_num_logits.shape[0] - 1 + _expand_sampling_metadata_kernel[(num_reqs,)]( + sampling_metadata.temperature, + expanded_temp, + sampling_metadata.top_p, + expanded_top_p, + sampling_metadata.top_k, + expanded_top_k, + sampling_metadata.repetition_penalty, + expanded_repetition_penalty, + sampling_metadata.frequency_penalty, + expanded_frequency_penalty, + sampling_metadata.presence_penalty, + expanded_presence_penalty, + sampling_metadata.seeds, + expanded_seeds, + cu_num_logits, + BLOCK_SIZE=triton.next_power_of_2(max_expand_len), + ) + return SamplingMetadata( + temperature=expanded_temp, + top_p=expanded_top_p, + top_k=expanded_top_k, + seeds=expanded_seeds, + repetition_penalty=expanded_repetition_penalty, + frequency_penalty=expanded_frequency_penalty, + presence_penalty=expanded_presence_penalty, + pos=sampling_metadata.pos, + max_num_logprobs=sampling_metadata.max_num_logprobs, + # TODO(woosuk): Support penalties with spec decoding. + idx_mapping=sampling_metadata.idx_mapping, + prompt_bin_counts=sampling_metadata.prompt_bin_counts, + output_bin_counts=sampling_metadata.output_bin_counts, + ) diff --git a/vllm/v1/worker/gpu/penalties.py b/vllm/v1/worker/gpu/sample/penalties.py similarity index 66% rename from vllm/v1/worker/gpu/penalties.py rename to vllm/v1/worker/gpu/sample/penalties.py index f87ee01718cd6..1607a75fd56ba 100644 --- a/vllm/v1/worker/gpu/penalties.py +++ b/vllm/v1/worker/gpu/sample/penalties.py @@ -3,7 +3,7 @@ import torch from vllm.triton_utils import tl, triton -from vllm.v1.worker.gpu.states import SamplingMetadata +from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata @triton.jit @@ -83,3 +83,49 @@ def apply_penalties(logits: torch.Tensor, sampling_metadata: SamplingMetadata) - vocab_size, BLOCK_SIZE=BLOCK_SIZE, ) + + +@triton.jit(do_not_specialize=["prefill_len", "prompt_len"]) +def _bincount_kernel( + prefill_token_ids_ptr, + prefill_len, + prompt_len, + prompt_bin_counts_ptr, + output_bin_counts_ptr, + BLOCK_SIZE: tl.constexpr, +): + block_idx = tl.program_id(0) + if block_idx * BLOCK_SIZE >= prefill_len: + return + + block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + if block_idx * BLOCK_SIZE < prompt_len: + mask = block < prompt_len + prefill_tokens = tl.load(prefill_token_ids_ptr + block, mask=mask) + tl.atomic_add(prompt_bin_counts_ptr + prefill_tokens, 1, mask=mask) + if (block_idx + 1) * BLOCK_SIZE >= prompt_len: + mask = block < prefill_len + mask &= block >= prompt_len + prefill_tokens = tl.load(prefill_token_ids_ptr + block, mask=mask) + tl.atomic_add(output_bin_counts_ptr + prefill_tokens, 1, mask=mask) + + +def bincount( + prefill_token_ids: torch.Tensor, + prefill_len: int, + prompt_len: int, + prompt_bin_counts: torch.Tensor, + output_bin_counts: torch.Tensor, +) -> None: + prompt_bin_counts.zero_() + output_bin_counts.zero_() + BLOCK_SIZE = 1024 + num_blocks = triton.cdiv(prefill_len, BLOCK_SIZE) + _bincount_kernel[(num_blocks,)]( + prefill_token_ids, + prefill_len, + prompt_len, + prompt_bin_counts, + output_bin_counts, + BLOCK_SIZE=BLOCK_SIZE, + ) diff --git a/vllm/v1/worker/gpu/sample/sampler.py b/vllm/v1/worker/gpu/sample/sampler.py new file mode 100644 index 0000000000000..4e7c85a021cfb --- /dev/null +++ b/vllm/v1/worker/gpu/sample/sampler.py @@ -0,0 +1,79 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.config.model import LogprobsMode +from vllm.v1.outputs import SamplerOutput +from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p +from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample +from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs +from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata +from vllm.v1.worker.gpu.sample.penalties import apply_penalties + + +class Sampler: + def __init__( + self, + logprobs_mode: LogprobsMode = "raw_logprobs", + ): + if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]: + raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}") + self.logprobs_mode = logprobs_mode + + def __call__( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput: + if sampling_metadata.max_num_logprobs is not None: + if self.logprobs_mode == "processed_logprobs": + sampled, logits = self.sample( + logits, sampling_metadata, return_logits=True + ) + else: + assert self.logprobs_mode == "raw_logprobs" + sampled, _ = self.sample(logits, sampling_metadata, return_logits=False) + + logprobs_tensors = compute_topk_logprobs( + logits, + sampling_metadata.max_num_logprobs, + sampled, + ) + else: + sampled, _ = self.sample(logits, sampling_metadata, return_logits=False) + logprobs_tensors = None + + # These are GPU tensors. + sampler_output = SamplerOutput( + # The sampled tokens are expanded to 2D tensor with shape + # [num_requests, 1], where each row represents one generated + # token per request. + sampled_token_ids=sampled.view(-1, 1), + logprobs_tensors=logprobs_tensors, + ) + return sampler_output + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + return_logits: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + is_greedy = sampling_metadata.temperature == 0 + temp = torch.where(is_greedy, 1.0, sampling_metadata.temperature) + logits = logits / temp.view(-1, 1) + logits = apply_top_k_top_p( + logits, sampling_metadata.top_k, sampling_metadata.top_p + ) + # Apply penalties in place. + apply_penalties(logits, sampling_metadata) + + sampled = gumbel_sample( + logits, + sampling_metadata.temperature, + sampling_metadata.seeds, + sampling_metadata.pos, + apply_temperature=False, + ) + return sampled, logits if return_logits else None diff --git a/vllm/v1/worker/gpu/sampler.py b/vllm/v1/worker/gpu/sampler.py deleted file mode 100644 index 6e0d6150a9669..0000000000000 --- a/vllm/v1/worker/gpu/sampler.py +++ /dev/null @@ -1,333 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable - -import torch - -from vllm.config.model import LogprobsMode -from vllm.triton_utils import tl, triton -from vllm.v1.outputs import LogprobsTensors, SamplerOutput -from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p -from vllm.v1.worker.gpu.penalties import apply_penalties -from vllm.v1.worker.gpu.states import SamplingMetadata - - -class Sampler: - def __init__( - self, - logprobs_mode: LogprobsMode = "raw_logprobs", - ): - if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]: - raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}") - self.logprobs_mode = logprobs_mode - - def __call__( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> SamplerOutput: - if sampling_metadata.max_num_logprobs is not None: - if self.logprobs_mode == "processed_logprobs": - sampled, logits = self.sample( - logits, sampling_metadata, return_logits=True - ) - else: - assert self.logprobs_mode == "raw_logprobs" - sampled, _ = self.sample(logits, sampling_metadata, return_logits=False) - - logprobs_tensors = compute_topk_logprobs( - logits, - sampling_metadata.max_num_logprobs, - sampled, - ) - else: - sampled, _ = self.sample(logits, sampling_metadata, return_logits=False) - logprobs_tensors = None - - # These are GPU tensors. - sampler_output = SamplerOutput( - # The sampled tokens are expanded to 2D tensor with shape - # [num_requests, 1], where each row represents one generated - # token per request. - sampled_token_ids=sampled.view(-1, 1), - logprobs_tensors=logprobs_tensors, - ) - return sampler_output - - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - return_logits: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - is_greedy = sampling_metadata.temperature == 0 - temp = torch.where(is_greedy, 1.0, sampling_metadata.temperature) - logits = logits / temp.view(-1, 1) - logits = apply_top_k_top_p( - logits, sampling_metadata.top_k, sampling_metadata.top_p - ) - # Apply penalties in place. - apply_penalties(logits, sampling_metadata) - - sampled = gumbel_sample( - logits, - sampling_metadata.temperature, - sampling_metadata.seeds, - sampling_metadata.pos, - apply_temperature=False, - ) - return sampled, logits if return_logits else None - - -@triton.jit -def _gumbel_sample_kernel( - local_argmax_ptr, - local_argmax_stride, - local_max_ptr, - local_max_stride, - logits_ptr, - logits_stride, - seeds_ptr, - pos_ptr, - temp_ptr, - vocab_size, - BLOCK_SIZE: tl.constexpr, - APPLY_TEMPERATURE: tl.constexpr, -): - req_idx = tl.program_id(0) - block_idx = tl.program_id(1) - block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = block < vocab_size - logits = tl.load( - logits_ptr + req_idx * logits_stride + block, - mask=mask, - other=float("-inf"), - ) - logits = logits.to(tl.float32) - - temp = tl.load(temp_ptr + req_idx).to(tl.float32) - if temp != 0.0: - # Calculate the seed for gumbel noise. - seed = tl.load(seeds_ptr + req_idx) - pos = tl.load(pos_ptr + req_idx) - gumbel_seed = tl.randint(seed, pos) - - # Generate gumbel noise. - r = tl.rand(gumbel_seed, block).to(tl.float64) - gumbel_noise = -tl.log(-tl.log(r + 1e-20) + 1e-20) - gumbel_noise = gumbel_noise.to(tl.float32) - - # Apply temperature. - if APPLY_TEMPERATURE: - # NOTE(woosuk): Use div_rn to match the behavior of torch. - logits = tl.div_rn(logits, temp) - - # Apply gumbel noise. - logits = tl.where(mask, logits + gumbel_noise, float("-inf")) - - idx = tl.argmax(logits, axis=0) - token_id = block_idx * BLOCK_SIZE + idx - value = tl.max(logits, axis=0) - tl.store(local_argmax_ptr + req_idx * local_argmax_stride + block_idx, token_id) - tl.store(local_max_ptr + req_idx * local_max_stride + block_idx, value) - - -def gumbel_sample( - logits: torch.Tensor, # [num_reqs, vocab_size] - temperature: torch.Tensor, # [num_reqs] - seed: torch.Tensor, # [num_reqs] - pos: torch.Tensor, # [num_reqs] - apply_temperature: bool, -) -> torch.Tensor: - num_reqs, vocab_size = logits.shape - BLOCK_SIZE = 1024 - num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE) - local_argmax = torch.empty( - num_reqs, - num_blocks, - dtype=torch.int64, - device=logits.device, - ) - local_max = torch.empty( - num_reqs, - num_blocks, - dtype=torch.float32, - device=logits.device, - ) - _gumbel_sample_kernel[(num_reqs, num_blocks)]( - local_argmax, - local_argmax.stride(0), - local_max, - local_max.stride(0), - logits, - logits.stride(0), - seed, - pos, - temperature, - vocab_size, - BLOCK_SIZE=BLOCK_SIZE, - APPLY_TEMPERATURE=apply_temperature, - ) - # NOTE(woosuk): Use int64 for later indexing. - max_block_idx = local_max.argmax(dim=-1, keepdim=True) - sampled = local_argmax.gather(dim=-1, index=max_block_idx).view(-1) - return sampled - - -@triton.jit -def _topk_log_softmax_kernel( - output_ptr, - logits_ptr, - logits_stride, - topk_ids_ptr, - topk, - vocab_size, - BLOCK_SIZE: tl.constexpr, - PADDED_TOPK: tl.constexpr, -): - req_idx = tl.program_id(0) - row_ptr = logits_ptr + req_idx * logits_stride - - max_val = float("-inf") - for i in range(0, vocab_size, BLOCK_SIZE): - block = i + tl.arange(0, BLOCK_SIZE) - logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf")) - max_val = tl.max(tl.maximum(logits, max_val)) - max_val = max_val.to(tl.float32) # type: ignore - - se = 0.0 - for i in range(0, vocab_size, BLOCK_SIZE): - block = i + tl.arange(0, BLOCK_SIZE) - logits = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0) - # NOTE(woosuk): Make sure that logits and all following operations use FP32. - logits = logits.to(tl.float32) - e = tl.exp(logits - max_val) - e = tl.where(block < vocab_size, e, 0.0) - se += tl.sum(e) - lse = tl.log(se) - - k_offset = tl.arange(0, PADDED_TOPK) - k_mask = k_offset < topk - topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask, other=0) - - logits = tl.load(row_ptr + topk_ids, mask=k_mask) - logits = logits.to(tl.float32) - o = logits - max_val - lse - tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask) - - -@triton.jit -def _ranks_kernel( - output_ptr, - logits_ptr, - logits_stride, - token_ids_ptr, - vocab_size, - BLOCK_SIZE: tl.constexpr, -): - req_idx = tl.program_id(0) - row_ptr = logits_ptr + req_idx * logits_stride - - token_id = tl.load(token_ids_ptr + req_idx) - x = tl.load(row_ptr + token_id) - - n = 0 - for i in range(0, vocab_size, BLOCK_SIZE): - block = i + tl.arange(0, BLOCK_SIZE) - logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf")) - n += tl.sum((logits > x).to(tl.int32)) - tl.store(output_ptr + req_idx, n) - - -def compute_token_logprobs( - logits: torch.Tensor, - token_ids: torch.Tensor, -) -> torch.Tensor: - batch_size = logits.shape[0] - vocab_size = logits.shape[1] - token_ids = token_ids.to(torch.int64) - num_logprobs = token_ids.shape[1] - logprobs = torch.empty( - batch_size, - num_logprobs, - dtype=torch.float32, - device=logits.device, - ) - _topk_log_softmax_kernel[(batch_size,)]( - logprobs, - logits, - logits.stride(0), - token_ids, - num_logprobs, - vocab_size, - BLOCK_SIZE=1024, # type: ignore - PADDED_TOPK=triton.next_power_of_2(num_logprobs), - ) - return logprobs - - -def compute_topk_logprobs( - logits: torch.Tensor, - num_logprobs: int, - sampled_token_ids: torch.Tensor, -) -> LogprobsTensors: - assert num_logprobs >= 0 - batch_size, vocab_size = logits.shape - if num_logprobs == 0: - logprob_token_ids = sampled_token_ids.unsqueeze(-1) - else: - topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices - logprob_token_ids = torch.cat( - (sampled_token_ids.unsqueeze(-1), topk_indices), dim=1 - ) - - # NOTE(woosuk): Here, to save GPU memory, we do not materialize the full - # logprobs tensor. Instead, we only compute and return the logprobs of - # the topk + 1 tokens. - logprobs = compute_token_logprobs(logits, logprob_token_ids) - token_ranks = torch.empty( - batch_size, - dtype=torch.int64, - device=logits.device, - ) - _ranks_kernel[(batch_size,)]( - token_ranks, - logits, - logits.stride(0), - sampled_token_ids, - vocab_size, - BLOCK_SIZE=8192, # type: ignore - ) - return LogprobsTensors( - logprob_token_ids=logprob_token_ids, - logprobs=logprobs, - selected_token_ranks=token_ranks, - ) - - -def compute_prompt_logprobs( - prompt_token_ids: torch.Tensor, - prompt_hidden_states: torch.Tensor, - logits_fn: Callable[[torch.Tensor], torch.Tensor], -) -> tuple[torch.Tensor, torch.Tensor]: - # Since materializing the full prompt logits can take too much memory, - # we compute it in chunks. - CHUNK_SIZE = 1024 - logprobs = [] - ranks = [] - prompt_token_ids = prompt_token_ids.to(torch.int64) - for start_idx in range(0, prompt_token_ids.shape[0], CHUNK_SIZE): - end_idx = start_idx + CHUNK_SIZE - # NOTE(woosuk): logits_fn can be slow because it involves all-gather. - prompt_logits = logits_fn(prompt_hidden_states[start_idx:end_idx]) - prompt_logprobs = compute_topk_logprobs( - prompt_logits, - 0, # num_logprobs - prompt_token_ids[start_idx:end_idx], - ) - logprobs.append(prompt_logprobs.logprobs) - ranks.append(prompt_logprobs.selected_token_ranks) - - logprobs = torch.cat(logprobs, dim=0) if len(logprobs) > 1 else logprobs[0] - ranks = torch.cat(ranks, dim=0) if len(ranks) > 1 else ranks[0] - return logprobs, ranks diff --git a/vllm/v1/worker/gpu/spec_decode/eagle.py b/vllm/v1/worker/gpu/spec_decode/eagle.py index 580d67246dfa1..a2d0550326f3b 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle.py @@ -18,9 +18,9 @@ from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.worker.gpu.attn_utils import build_attn_metadata from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers -from vllm.v1.worker.gpu.sampler import gumbel_sample +from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample +from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata from vllm.v1.worker.gpu.spec_decode.eagle_cudagraph import EagleCudaGraphManager -from vllm.v1.worker.gpu.states import SamplingMetadata logger = init_logger(__name__) diff --git a/vllm/v1/worker/gpu/states.py b/vllm/v1/worker/gpu/states.py index 44b076fa4c2ae..c3428faab0a31 100644 --- a/vllm/v1/worker/gpu/states.py +++ b/vllm/v1/worker/gpu/states.py @@ -7,86 +7,18 @@ import torch from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams -from vllm.triton_utils import tl, triton from vllm.utils.platform_utils import is_uva_available from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor from vllm.v1.outputs import LogprobsTensors from vllm.v1.utils import CpuGpuBuffer +from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata +from vllm.v1.worker.gpu.sample.penalties import bincount _NP_INT64_MIN = np.iinfo(np.int64).min _NP_INT64_MAX = np.iinfo(np.int64).max NO_LORA_ID = 0 -@dataclass -class SamplingMetadata: - temperature: torch.Tensor - - top_p: torch.Tensor | None - top_k: torch.Tensor | None - - repetition_penalty: torch.Tensor - frequency_penalty: torch.Tensor - presence_penalty: torch.Tensor - - seeds: torch.Tensor - pos: torch.Tensor - - # None means no logprobs, 0 means sampled token logprobs only - max_num_logprobs: int | None - - # For penalties - idx_mapping: torch.Tensor - prompt_bin_counts: torch.Tensor - output_bin_counts: torch.Tensor - - @classmethod - def make_dummy( - cls, - num_reqs: int, - device: torch.device, - ) -> "SamplingMetadata": - assert num_reqs > 0 - temperature = torch.zeros(num_reqs, dtype=torch.float32, device=device) - temperature[0] = 0.5 - # TODO(woosuk): Use top-p and top-k for dummy sampler. - # Currently, they are disabled because of memory usage. - # top_p = torch.full((num_reqs,), 0.95, dtype=torch.float32, device=device) - # top_k = torch.full((num_reqs,), 20, dtype=torch.int32, device=device) - top_p = None - top_k = None - # NOTE(woosuk): We must set penalties to their default values to make sure - # the penalties kernel does not touch the placeholder bin_counts tensors. - repetition_penalty = torch.ones(num_reqs, dtype=torch.float32, device=device) - frequency_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device) - presence_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device) - seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device) - pos = torch.zeros(num_reqs, dtype=torch.int64, device=device) - max_num_logprobs = 20 - - idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device) - # NOTE(woosuk): These are placeholder tensors to avoid None checks in the - # penalties kernel. We use 2 instead of 1 as vocab_size to avoid Triton - # specialization and re-compilation at runtime. - prompt_bin_counts = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device) - output_bin_counts = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device) - - return cls( - temperature=temperature, - top_p=top_p, - top_k=top_k, - repetition_penalty=repetition_penalty, - frequency_penalty=frequency_penalty, - presence_penalty=presence_penalty, - seeds=seeds, - pos=pos, - max_num_logprobs=max_num_logprobs, - idx_mapping=idx_mapping, - prompt_bin_counts=prompt_bin_counts, - output_bin_counts=output_bin_counts, - ) - - class RequestState: def __init__( self, @@ -311,17 +243,6 @@ class RequestState: output_bin_counts=self.output_bin_counts, ) - def expand_sampling_metadata( - self, - sampling_metadata: SamplingMetadata, - cu_num_logits: torch.Tensor, - ) -> SamplingMetadata: - # For draft tokens, we need to expand the sampling param tensors as - # each request samples multiple tokens in each step. - return expand_sampling_metadata( - sampling_metadata, cu_num_logits, self.num_speculative_steps - ) - def make_lora_inputs( self, req_ids: list[str], @@ -376,158 +297,9 @@ class UvaBuffer: self.gpu = get_cuda_view_from_cpu_tensor(self.cpu) -# NOTE(woosuk): Re-compilation can happen at runtime since top_p and top_k can be None. -@triton.jit -def _expand_sampling_metadata_kernel( - temp_ptr, - expanded_temp_ptr, - top_p_ptr, - expanded_top_p_ptr, - top_k_ptr, - expanded_top_k_ptr, - rep_penalty_ptr, - expanded_rep_penalty_ptr, - freq_penalty_ptr, - expanded_freq_penalty_ptr, - pres_penalty_ptr, - expanded_pres_penalty_ptr, - seeds_ptr, - expanded_seeds_ptr, - cu_num_logits_ptr, - BLOCK_SIZE: tl.constexpr, -): - req_idx = tl.program_id(0) - start_idx = tl.load(cu_num_logits_ptr + req_idx) - end_idx = tl.load(cu_num_logits_ptr + req_idx + 1) - num_tokens = end_idx - start_idx - - block = tl.arange(0, BLOCK_SIZE) - mask = block < num_tokens - - temp = tl.load(temp_ptr + req_idx) - tl.store(expanded_temp_ptr + start_idx + block, temp, mask=mask) - - if top_p_ptr is not None: - top_p = tl.load(top_p_ptr + req_idx) - tl.store(expanded_top_p_ptr + start_idx + block, top_p, mask=mask) - - if top_k_ptr is not None: - top_k = tl.load(top_k_ptr + req_idx) - tl.store(expanded_top_k_ptr + start_idx + block, top_k, mask=mask) - - rep_penalty = tl.load(rep_penalty_ptr + req_idx) - tl.store(expanded_rep_penalty_ptr + start_idx + block, rep_penalty, mask=mask) - - freq_penalty = tl.load(freq_penalty_ptr + req_idx) - tl.store(expanded_freq_penalty_ptr + start_idx + block, freq_penalty, mask=mask) - - pres_penalty = tl.load(pres_penalty_ptr + req_idx) - tl.store(expanded_pres_penalty_ptr + start_idx + block, pres_penalty, mask=mask) - - seed = tl.load(seeds_ptr + req_idx) - tl.store(expanded_seeds_ptr + start_idx + block, seed, mask=mask) - - -def expand_sampling_metadata( - sampling_metadata: SamplingMetadata, - cu_num_logits: torch.Tensor, - num_speculative_steps: int, -) -> SamplingMetadata: - total_num_logits = sampling_metadata.pos.shape[0] - create_empty = lambda x: x.new_empty(total_num_logits) if x is not None else None - expanded_temp = create_empty(sampling_metadata.temperature) - expanded_top_p = create_empty(sampling_metadata.top_p) - expanded_top_k = create_empty(sampling_metadata.top_k) - expanded_repetition_penalty = create_empty(sampling_metadata.repetition_penalty) - expanded_frequency_penalty = create_empty(sampling_metadata.frequency_penalty) - expanded_presence_penalty = create_empty(sampling_metadata.presence_penalty) - expanded_seeds = create_empty(sampling_metadata.seeds) - - num_reqs = cu_num_logits.shape[0] - 1 - _expand_sampling_metadata_kernel[(num_reqs,)]( - sampling_metadata.temperature, - expanded_temp, - sampling_metadata.top_p, - expanded_top_p, - sampling_metadata.top_k, - expanded_top_k, - sampling_metadata.repetition_penalty, - expanded_repetition_penalty, - sampling_metadata.frequency_penalty, - expanded_frequency_penalty, - sampling_metadata.presence_penalty, - expanded_presence_penalty, - sampling_metadata.seeds, - expanded_seeds, - cu_num_logits, - BLOCK_SIZE=triton.next_power_of_2(num_speculative_steps + 1), - ) - return SamplingMetadata( - temperature=expanded_temp, - top_p=expanded_top_p, - top_k=expanded_top_k, - seeds=expanded_seeds, - repetition_penalty=expanded_repetition_penalty, - frequency_penalty=expanded_frequency_penalty, - presence_penalty=expanded_presence_penalty, - pos=sampling_metadata.pos, - max_num_logprobs=sampling_metadata.max_num_logprobs, - # TODO(woosuk): Support penalties with spec decoding. - idx_mapping=sampling_metadata.idx_mapping, - prompt_bin_counts=sampling_metadata.prompt_bin_counts, - output_bin_counts=sampling_metadata.output_bin_counts, - ) - - def use_penalty(sampling_params: SamplingParams) -> bool: return ( sampling_params.repetition_penalty != 1.0 or sampling_params.frequency_penalty != 0.0 or sampling_params.presence_penalty != 0.0 ) - - -@triton.jit(do_not_specialize=["prefill_len", "prompt_len"]) -def _bincount_kernel( - prefill_token_ids_ptr, - prefill_len, - prompt_len, - prompt_bin_counts_ptr, - output_bin_counts_ptr, - BLOCK_SIZE: tl.constexpr, -): - block_idx = tl.program_id(0) - if block_idx * BLOCK_SIZE >= prefill_len: - return - - block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - if block_idx * BLOCK_SIZE < prompt_len: - mask = block < prompt_len - prefill_tokens = tl.load(prefill_token_ids_ptr + block, mask=mask) - tl.atomic_add(prompt_bin_counts_ptr + prefill_tokens, 1, mask=mask) - if (block_idx + 1) * BLOCK_SIZE >= prompt_len: - mask = block < prefill_len - mask &= block >= prompt_len - prefill_tokens = tl.load(prefill_token_ids_ptr + block, mask=mask) - tl.atomic_add(output_bin_counts_ptr + prefill_tokens, 1, mask=mask) - - -def bincount( - prefill_token_ids: torch.Tensor, - prefill_len: int, - prompt_len: int, - prompt_bin_counts: torch.Tensor, - output_bin_counts: torch.Tensor, -) -> None: - prompt_bin_counts.zero_() - output_bin_counts.zero_() - BLOCK_SIZE = 1024 - num_blocks = triton.cdiv(prefill_len, BLOCK_SIZE) - _bincount_kernel[(num_blocks,)]( - prefill_token_ids, - prefill_len, - prompt_len, - prompt_bin_counts, - output_bin_counts, - BLOCK_SIZE=BLOCK_SIZE, - )