[Model Runner V2] Use packed mask for prompt bin counts (#29756)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-11-30 14:15:42 -08:00 committed by GitHub
parent 21c2627934
commit ec38a7368d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 35 additions and 25 deletions

View File

@ -26,7 +26,7 @@ class SamplingMetadata:
# For penalties # For penalties
idx_mapping: torch.Tensor idx_mapping: torch.Tensor
prompt_bin_counts: torch.Tensor prompt_bin_mask: torch.Tensor
output_bin_counts: torch.Tensor output_bin_counts: torch.Tensor
@classmethod @classmethod
@ -57,7 +57,7 @@ class SamplingMetadata:
# NOTE(woosuk): These are placeholder tensors to avoid None checks in the # NOTE(woosuk): These are placeholder tensors to avoid None checks in the
# penalties kernel. We use 2 instead of 1 as vocab_size to avoid Triton # penalties kernel. We use 2 instead of 1 as vocab_size to avoid Triton
# specialization and re-compilation at runtime. # specialization and re-compilation at runtime.
prompt_bin_counts = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device) prompt_bin_mask = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device)
output_bin_counts = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device) output_bin_counts = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device)
return cls( return cls(
@ -71,7 +71,7 @@ class SamplingMetadata:
pos=pos, pos=pos,
max_num_logprobs=max_num_logprobs, max_num_logprobs=max_num_logprobs,
idx_mapping=idx_mapping, idx_mapping=idx_mapping,
prompt_bin_counts=prompt_bin_counts, prompt_bin_mask=prompt_bin_mask,
output_bin_counts=output_bin_counts, output_bin_counts=output_bin_counts,
) )
@ -174,6 +174,6 @@ def expand_sampling_metadata(
max_num_logprobs=sampling_metadata.max_num_logprobs, max_num_logprobs=sampling_metadata.max_num_logprobs,
# TODO(woosuk): Support penalties with spec decoding. # TODO(woosuk): Support penalties with spec decoding.
idx_mapping=sampling_metadata.idx_mapping, idx_mapping=sampling_metadata.idx_mapping,
prompt_bin_counts=sampling_metadata.prompt_bin_counts, prompt_bin_mask=sampling_metadata.prompt_bin_mask,
output_bin_counts=sampling_metadata.output_bin_counts, output_bin_counts=sampling_metadata.output_bin_counts,
) )

View File

@ -15,8 +15,8 @@ def _penalties_and_temperature_kernel(
presence_penalty_ptr, presence_penalty_ptr,
temperature_ptr, temperature_ptr,
idx_mapping_ptr, idx_mapping_ptr,
prompt_bin_counts_ptr, prompt_bin_mask_ptr,
prompt_bin_counts_stride, prompt_bin_mask_stride,
output_bin_counts_ptr, output_bin_counts_ptr,
output_bin_counts_stride, output_bin_counts_stride,
vocab_size, vocab_size,
@ -54,13 +54,16 @@ def _penalties_and_temperature_kernel(
# Apply repetition penalties. # Apply repetition penalties.
if use_rep_penalty: if use_rep_penalty:
prompt_bin_counts = tl.load( packed_block = block_idx * BLOCK_SIZE // 32 + tl.arange(0, BLOCK_SIZE // 32)
prompt_bin_counts_ptr packed_mask = tl.load(
+ req_state_idx * prompt_bin_counts_stride prompt_bin_mask_ptr
+ block, + req_state_idx * prompt_bin_mask_stride
mask=mask, + packed_block,
mask=packed_block < tl.cdiv(vocab_size, 32),
) )
prompt_bin_mask = prompt_bin_counts > 0 prompt_bin_mask = (packed_mask[:, None] >> (tl.arange(0, 32)[None, :])) & 1
prompt_bin_mask = prompt_bin_mask.reshape(BLOCK_SIZE)
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op. # If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
scale = tl.where(prompt_bin_mask | output_bin_mask, rep_penalty, 1.0) scale = tl.where(prompt_bin_mask | output_bin_mask, rep_penalty, 1.0)
# If logits are positive, divide by penalty, otherwise multiply by penalty. # If logits are positive, divide by penalty, otherwise multiply by penalty.
@ -93,8 +96,8 @@ def apply_penalties_and_temperature(
sampling_metadata.presence_penalty, sampling_metadata.presence_penalty,
sampling_metadata.temperature, sampling_metadata.temperature,
sampling_metadata.idx_mapping, sampling_metadata.idx_mapping,
sampling_metadata.prompt_bin_counts, sampling_metadata.prompt_bin_mask,
sampling_metadata.prompt_bin_counts.stride(0), sampling_metadata.prompt_bin_mask.stride(0),
sampling_metadata.output_bin_counts, sampling_metadata.output_bin_counts,
sampling_metadata.output_bin_counts.stride(0), sampling_metadata.output_bin_counts.stride(0),
vocab_size, vocab_size,
@ -107,7 +110,7 @@ def _bincount_kernel(
prefill_token_ids_ptr, prefill_token_ids_ptr,
prefill_len, prefill_len,
prompt_len, prompt_len,
prompt_bin_counts_ptr, prompt_bin_mask_ptr,
output_bin_counts_ptr, output_bin_counts_ptr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
@ -119,7 +122,10 @@ def _bincount_kernel(
if block_idx * BLOCK_SIZE < prompt_len: if block_idx * BLOCK_SIZE < prompt_len:
mask = block < prompt_len mask = block < prompt_len
prefill_tokens = tl.load(prefill_token_ids_ptr + block, mask=mask) prefill_tokens = tl.load(prefill_token_ids_ptr + block, mask=mask)
tl.atomic_add(prompt_bin_counts_ptr + prefill_tokens, 1, mask=mask) idx = prefill_tokens // 32
bit_idx = prefill_tokens % 32
bit = tl.full((BLOCK_SIZE,), 1, tl.int32) << bit_idx
tl.atomic_or(prompt_bin_mask_ptr + idx, bit, mask=mask)
if (block_idx + 1) * BLOCK_SIZE >= prompt_len: if (block_idx + 1) * BLOCK_SIZE >= prompt_len:
mask = block < prefill_len mask = block < prefill_len
mask &= block >= prompt_len mask &= block >= prompt_len
@ -131,10 +137,10 @@ def bincount(
prefill_token_ids: torch.Tensor, prefill_token_ids: torch.Tensor,
prefill_len: int, prefill_len: int,
prompt_len: int, prompt_len: int,
prompt_bin_counts: torch.Tensor, prompt_bin_mask: torch.Tensor,
output_bin_counts: torch.Tensor, output_bin_counts: torch.Tensor,
) -> None: ) -> None:
prompt_bin_counts.zero_() prompt_bin_mask.zero_()
output_bin_counts.zero_() output_bin_counts.zero_()
BLOCK_SIZE = 1024 BLOCK_SIZE = 1024
num_blocks = triton.cdiv(prefill_len, BLOCK_SIZE) num_blocks = triton.cdiv(prefill_len, BLOCK_SIZE)
@ -142,7 +148,7 @@ def bincount(
prefill_token_ids, prefill_token_ids,
prefill_len, prefill_len,
prompt_len, prompt_len,
prompt_bin_counts, prompt_bin_mask,
output_bin_counts, output_bin_counts,
BLOCK_SIZE=BLOCK_SIZE, BLOCK_SIZE=BLOCK_SIZE,
) )

View File

@ -7,6 +7,7 @@ import torch
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import is_uva_available from vllm.utils.platform_utils import is_uva_available
from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor
from vllm.v1.outputs import LogprobsTensors from vllm.v1.outputs import LogprobsTensors
@ -97,11 +98,14 @@ class RequestState:
self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool) self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
# Statistics for penalties. # Statistics for penalties.
# TODO(woosuk): These tensors are rarely used but can be extremely large. self.prompt_bin_mask = torch.zeros(
# Optimize the memory usage. self.max_num_reqs,
self.prompt_bin_counts = torch.zeros( cdiv(self.vocab_size, 32),
self.max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device dtype=torch.int32,
device=self.device,
) )
# TODO(woosuk): This tensor is rarely used but can be extremely large.
# Optimize the memory usage.
self.output_bin_counts = torch.zeros( self.output_bin_counts = torch.zeros(
self.max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device self.max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device
) )
@ -167,7 +171,7 @@ class RequestState:
self.prefill_token_ids.gpu[req_idx], self.prefill_token_ids.gpu[req_idx],
prefill_len, prefill_len,
prompt_len, prompt_len,
self.prompt_bin_counts[req_idx], self.prompt_bin_mask[req_idx],
self.output_bin_counts[req_idx], self.output_bin_counts[req_idx],
) )
@ -239,7 +243,7 @@ class RequestState:
pos=pos, pos=pos,
max_num_logprobs=max_num_logprobs, max_num_logprobs=max_num_logprobs,
idx_mapping=idx_mapping, idx_mapping=idx_mapping,
prompt_bin_counts=self.prompt_bin_counts, prompt_bin_mask=self.prompt_bin_mask,
output_bin_counts=self.output_bin_counts, output_bin_counts=self.output_bin_counts,
) )