mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 16:14:37 +08:00
[Model Runner V2] Support penalties using bin counts (#29703)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
ea3370b428
commit
1dcafb3dea
@ -341,6 +341,8 @@ def _post_update_kernel(
|
||||
idx_mapping_ptr,
|
||||
num_computed_tokens_ptr,
|
||||
last_sampled_tokens_ptr,
|
||||
output_bin_counts_ptr,
|
||||
output_bin_counts_stride,
|
||||
sampled_tokens_ptr,
|
||||
sampled_tokens_stride,
|
||||
num_sampled_ptr,
|
||||
@ -357,6 +359,15 @@ def _post_update_kernel(
|
||||
)
|
||||
tl.store(last_sampled_tokens_ptr + req_state_idx, token_id)
|
||||
|
||||
for i in range(num_sampled):
|
||||
token_id = tl.load(sampled_tokens_ptr + req_id * sampled_tokens_stride + i)
|
||||
token_ptr = (
|
||||
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + token_id
|
||||
)
|
||||
count = tl.load(token_ptr)
|
||||
count += 1
|
||||
tl.store(token_ptr, count)
|
||||
|
||||
query_start = tl.load(query_start_loc_ptr + req_id)
|
||||
query_end = tl.load(query_start_loc_ptr + req_id + 1)
|
||||
query_len = query_end - query_start
|
||||
@ -374,6 +385,8 @@ def post_update(
|
||||
num_computed_tokens: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
last_sampled_tokens: torch.Tensor,
|
||||
# [max_num_reqs, vocab_size]
|
||||
output_bin_counts: torch.Tensor,
|
||||
# [num_reqs, num_speculative_steps + 1]
|
||||
sampled_tokens: torch.Tensor,
|
||||
# [num_reqs]
|
||||
@ -388,6 +401,8 @@ def post_update(
|
||||
idx_mapping,
|
||||
num_computed_tokens,
|
||||
last_sampled_tokens,
|
||||
output_bin_counts,
|
||||
output_bin_counts.stride(0),
|
||||
sampled_tokens,
|
||||
sampled_tokens.stride(0),
|
||||
num_sampled,
|
||||
|
||||
@ -512,7 +512,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
idx_mapping_np,
|
||||
num_scheduled_tokens,
|
||||
query_start_loc_np,
|
||||
self.req_states.prefill_token_ids,
|
||||
self.req_states.prefill_token_ids.np,
|
||||
self.req_states.num_computed_prefill_tokens,
|
||||
self.input_buffers.input_ids.np,
|
||||
)
|
||||
@ -681,7 +681,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Handle chunked prompts.
|
||||
pos_after_step = computed_prefill + input_batch.num_scheduled_tokens
|
||||
is_prompt_chunked = pos_after_step < prompt_lens
|
||||
prefill_token_ids = self.req_states.prefill_token_ids
|
||||
prefill_token_ids = self.req_states.prefill_token_ids.np
|
||||
query_start_loc = self.input_buffers.query_start_loc.np
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
if not needs_prompt_logprobs[i]:
|
||||
@ -756,6 +756,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
input_batch.idx_mapping,
|
||||
self.req_states.num_computed_tokens,
|
||||
self.req_states.last_sampled_tokens,
|
||||
self.req_states.output_bin_counts,
|
||||
sampled_tokens,
|
||||
num_sampled,
|
||||
num_rejected,
|
||||
@ -785,7 +786,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
idx_mapping_np = input_batch.idx_mapping_np
|
||||
with async_barrier(self.spec_decode_event):
|
||||
self.input_buffers.next_prefill_tokens.np[:num_reqs] = (
|
||||
self.req_states.prefill_token_ids[
|
||||
self.req_states.prefill_token_ids.np[
|
||||
idx_mapping_np,
|
||||
self.req_states.num_computed_prefill_tokens[idx_mapping_np],
|
||||
]
|
||||
@ -896,7 +897,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# barrier to avoid race conditions.
|
||||
pos = input_batch.positions[input_batch.logits_indices]
|
||||
sampling_metadata = self.req_states.make_sampling_metadata(
|
||||
input_batch.idx_mapping_np, pos
|
||||
input_batch.idx_mapping, input_batch.idx_mapping_np, pos
|
||||
)
|
||||
if input_batch.num_draft_tokens > 0:
|
||||
sampling_metadata = self.req_states.expand_sampling_metadata(
|
||||
|
||||
85
vllm/v1/worker/gpu/penalties.py
Normal file
85
vllm/v1/worker/gpu/penalties.py
Normal file
@ -0,0 +1,85 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.worker.gpu.states import SamplingMetadata
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _penalties_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
repetition_penalty_ptr,
|
||||
frequency_penalty_ptr,
|
||||
presence_penalty_ptr,
|
||||
idx_mapping_ptr,
|
||||
prompt_bin_counts_ptr,
|
||||
prompt_bin_counts_stride,
|
||||
output_bin_counts_ptr,
|
||||
output_bin_counts_stride,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
rep_penalty = tl.load(repetition_penalty_ptr + batch_idx)
|
||||
freq_penalty = tl.load(frequency_penalty_ptr + batch_idx)
|
||||
pres_penalty = tl.load(presence_penalty_ptr + batch_idx)
|
||||
|
||||
use_rep_penalty = rep_penalty != 1.0
|
||||
use_freq_penalty = freq_penalty != 0.0
|
||||
use_pres_penalty = pres_penalty != 0.0
|
||||
if not (use_rep_penalty or use_freq_penalty or use_pres_penalty):
|
||||
# No penalties to apply. Early return.
|
||||
return
|
||||
|
||||
block_idx = tl.program_id(1)
|
||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask)
|
||||
logits = logits.to(tl.float32)
|
||||
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
output_bin_counts = tl.load(
|
||||
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
# Apply repetition penalties.
|
||||
if use_rep_penalty:
|
||||
prompt_bin_counts = tl.load(
|
||||
prompt_bin_counts_ptr + req_state_idx * prompt_bin_counts_stride + block,
|
||||
mask=mask,
|
||||
)
|
||||
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
|
||||
scale = tl.where((prompt_bin_counts + output_bin_counts) > 0, rep_penalty, 1.0)
|
||||
# If logits are positive, divide by penalty, otherwise multiply by penalty.
|
||||
scale = tl.where(logits > 0, 1.0 / scale, scale)
|
||||
logits *= scale
|
||||
|
||||
# Apply frequency penalties.
|
||||
logits -= freq_penalty * output_bin_counts
|
||||
# Apply presence penalties.
|
||||
logits -= pres_penalty * (output_bin_counts > 0)
|
||||
# Store back to logits.
|
||||
tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask)
|
||||
|
||||
|
||||
def apply_penalties(logits: torch.Tensor, sampling_metadata: SamplingMetadata) -> None:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
BLOCK_SIZE = 8192
|
||||
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
||||
_penalties_kernel[(num_reqs, num_blocks)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
sampling_metadata.repetition_penalty,
|
||||
sampling_metadata.frequency_penalty,
|
||||
sampling_metadata.presence_penalty,
|
||||
sampling_metadata.idx_mapping,
|
||||
sampling_metadata.prompt_bin_counts,
|
||||
sampling_metadata.prompt_bin_counts.stride(0),
|
||||
sampling_metadata.output_bin_counts,
|
||||
sampling_metadata.output_bin_counts.stride(0),
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
@ -8,6 +8,7 @@ from vllm.config.model import LogprobsMode
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.v1.worker.gpu.penalties import apply_penalties
|
||||
from vllm.v1.worker.gpu.states import SamplingMetadata
|
||||
|
||||
|
||||
@ -65,6 +66,8 @@ class Sampler:
|
||||
logits = apply_top_k_top_p(
|
||||
logits, sampling_metadata.top_k, sampling_metadata.top_p
|
||||
)
|
||||
# Apply penalties in place.
|
||||
apply_penalties(logits, sampling_metadata)
|
||||
|
||||
sampled = gumbel_sample(
|
||||
logits,
|
||||
|
||||
@ -8,6 +8,8 @@ import torch
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.platform_utils import is_uva_available
|
||||
from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
|
||||
@ -23,12 +25,21 @@ class SamplingMetadata:
|
||||
top_p: torch.Tensor | None
|
||||
top_k: torch.Tensor | None
|
||||
|
||||
repetition_penalty: torch.Tensor
|
||||
frequency_penalty: torch.Tensor
|
||||
presence_penalty: torch.Tensor
|
||||
|
||||
seeds: torch.Tensor
|
||||
pos: torch.Tensor
|
||||
|
||||
# None means no logprobs, 0 means sampled token logprobs only
|
||||
max_num_logprobs: int | None
|
||||
|
||||
# For penalties
|
||||
idx_mapping: torch.Tensor
|
||||
prompt_bin_counts: torch.Tensor
|
||||
output_bin_counts: torch.Tensor
|
||||
|
||||
@classmethod
|
||||
def make_dummy(
|
||||
cls,
|
||||
@ -44,17 +55,35 @@ class SamplingMetadata:
|
||||
# top_k = torch.full((num_reqs,), 20, dtype=torch.int32, device=device)
|
||||
top_p = None
|
||||
top_k = None
|
||||
# NOTE(woosuk): We must set penalties to their default values to make sure
|
||||
# the penalties kernel does not touch the placeholder bin_counts tensors.
|
||||
repetition_penalty = torch.ones(num_reqs, dtype=torch.float32, device=device)
|
||||
frequency_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device)
|
||||
presence_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device)
|
||||
seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device)
|
||||
pos = torch.zeros(num_reqs, dtype=torch.int64, device=device)
|
||||
max_num_logprobs = 20
|
||||
|
||||
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
|
||||
# NOTE(woosuk): These are placeholder tensors to avoid None checks in the
|
||||
# penalties kernel. We use 2 instead of 1 as vocab_size to avoid Triton
|
||||
# specialization and re-compilation at runtime.
|
||||
prompt_bin_counts = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device)
|
||||
output_bin_counts = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device)
|
||||
|
||||
return cls(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
repetition_penalty=repetition_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
seeds=seeds,
|
||||
pos=pos,
|
||||
max_num_logprobs=max_num_logprobs,
|
||||
idx_mapping=idx_mapping,
|
||||
prompt_bin_counts=prompt_bin_counts,
|
||||
output_bin_counts=output_bin_counts,
|
||||
)
|
||||
|
||||
|
||||
@ -83,9 +112,10 @@ class RequestState:
|
||||
self.extra_data: dict[str, ExtraData] = {}
|
||||
|
||||
self.prompt_len = np.zeros(self.max_num_reqs, dtype=np.int32)
|
||||
self.prefill_token_ids = np.zeros(
|
||||
(self.max_num_reqs, self.max_model_len),
|
||||
dtype=np.int32,
|
||||
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
|
||||
# depending on the configured max_num_reqs and max_model_len.
|
||||
self.prefill_token_ids = UvaBuffer(
|
||||
self.max_num_reqs, self.max_model_len, dtype=torch.int32
|
||||
)
|
||||
self.prefill_len = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
|
||||
|
||||
@ -119,6 +149,9 @@ class RequestState:
|
||||
self.temperature = self._make_param(self.max_num_reqs, torch.float32)
|
||||
self.top_p = self._make_param(self.max_num_reqs, torch.float32)
|
||||
self.top_k = self._make_param(self.max_num_reqs, torch.int32)
|
||||
self.repetition_penalty = self._make_param(self.max_num_reqs, torch.float32)
|
||||
self.frequency_penalty = self._make_param(self.max_num_reqs, torch.float32)
|
||||
self.presence_penalty = self._make_param(self.max_num_reqs, torch.float32)
|
||||
self.seeds = self._make_param(self.max_num_reqs, torch.int64)
|
||||
|
||||
self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32)
|
||||
@ -126,6 +159,16 @@ class RequestState:
|
||||
self.num_logprobs.fill(-1)
|
||||
self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
|
||||
|
||||
# Statistics for penalties.
|
||||
# TODO(woosuk): These tensors are rarely used but can be extremely large.
|
||||
# Optimize the memory usage.
|
||||
self.prompt_bin_counts = torch.zeros(
|
||||
self.max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self.output_bin_counts = torch.zeros(
|
||||
self.max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
def _make_param(self, size: int, dtype: torch.dtype) -> "Param":
|
||||
return Param(size, dtype=dtype, device=self.device, pin_memory=self.pin_memory)
|
||||
|
||||
@ -159,7 +202,7 @@ class RequestState:
|
||||
f"prefill_len {prefill_len} < prompt_len {prompt_len}"
|
||||
)
|
||||
self.prefill_len.np[req_idx] = prefill_len
|
||||
self.prefill_token_ids[req_idx, :prefill_len] = prefill_token_ids
|
||||
self.prefill_token_ids.np[req_idx, :prefill_len] = prefill_token_ids
|
||||
|
||||
self.num_computed_prefill_tokens[req_idx] = num_computed_tokens
|
||||
# FIXME(woosuk): This triggers a GPU operation whenever adding a new request.
|
||||
@ -178,6 +221,18 @@ class RequestState:
|
||||
else:
|
||||
top_k = self.vocab_size
|
||||
self.top_k.np[req_idx] = top_k
|
||||
self.repetition_penalty.np[req_idx] = sampling_params.repetition_penalty
|
||||
self.frequency_penalty.np[req_idx] = sampling_params.frequency_penalty
|
||||
self.presence_penalty.np[req_idx] = sampling_params.presence_penalty
|
||||
|
||||
if use_penalty(sampling_params):
|
||||
bincount(
|
||||
self.prefill_token_ids.gpu[req_idx],
|
||||
prefill_len,
|
||||
prompt_len,
|
||||
self.prompt_bin_counts[req_idx],
|
||||
self.output_bin_counts[req_idx],
|
||||
)
|
||||
|
||||
if sampling_params.seed is not None:
|
||||
seed = sampling_params.seed
|
||||
@ -206,24 +261,32 @@ class RequestState:
|
||||
|
||||
def make_sampling_metadata(
|
||||
self,
|
||||
idx_mapping: np.ndarray,
|
||||
idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
pos: torch.Tensor,
|
||||
) -> SamplingMetadata:
|
||||
temperature = self.temperature.np[idx_mapping]
|
||||
temperature = self.temperature.np[idx_mapping_np]
|
||||
temperature = self.temperature.copy_np_to_gpu(temperature)
|
||||
|
||||
top_p = self.top_p.np[idx_mapping]
|
||||
top_p = self.top_p.np[idx_mapping_np]
|
||||
no_top_p = np.all(top_p == 1.0)
|
||||
top_p = self.top_p.copy_np_to_gpu(top_p) if not no_top_p else None
|
||||
|
||||
top_k = self.top_k.np[idx_mapping]
|
||||
top_k = self.top_k.np[idx_mapping_np]
|
||||
no_top_k = np.all(top_k == self.vocab_size)
|
||||
top_k = self.top_k.copy_np_to_gpu(top_k) if not no_top_k else None
|
||||
|
||||
seeds = self.seeds.np[idx_mapping]
|
||||
rep_penalty = self.repetition_penalty.np[idx_mapping_np]
|
||||
rep_penalty = self.repetition_penalty.copy_np_to_gpu(rep_penalty)
|
||||
freq_penalty = self.frequency_penalty.np[idx_mapping_np]
|
||||
freq_penalty = self.frequency_penalty.copy_np_to_gpu(freq_penalty)
|
||||
pres_penalty = self.presence_penalty.np[idx_mapping_np]
|
||||
pres_penalty = self.presence_penalty.copy_np_to_gpu(pres_penalty)
|
||||
|
||||
seeds = self.seeds.np[idx_mapping_np]
|
||||
seeds = self.seeds.copy_np_to_gpu(seeds)
|
||||
|
||||
num_logprobs = self.num_logprobs[idx_mapping]
|
||||
num_logprobs = self.num_logprobs[idx_mapping_np]
|
||||
max_num_logprobs: int | None = int(np.max(num_logprobs))
|
||||
if max_num_logprobs == -1:
|
||||
max_num_logprobs = None
|
||||
@ -232,9 +295,15 @@ class RequestState:
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
repetition_penalty=rep_penalty,
|
||||
frequency_penalty=freq_penalty,
|
||||
presence_penalty=pres_penalty,
|
||||
seeds=seeds,
|
||||
pos=pos,
|
||||
max_num_logprobs=max_num_logprobs,
|
||||
idx_mapping=idx_mapping,
|
||||
prompt_bin_counts=self.prompt_bin_counts,
|
||||
output_bin_counts=self.output_bin_counts,
|
||||
)
|
||||
|
||||
def expand_sampling_metadata(
|
||||
@ -294,6 +363,14 @@ class ExtraData:
|
||||
in_progress_prompt_logprobs: list[LogprobsTensors] = field(default_factory=list)
|
||||
|
||||
|
||||
class UvaBuffer:
|
||||
def __init__(self, *size: int | torch.SymInt, dtype: torch.dtype):
|
||||
assert is_uva_available()
|
||||
self.cpu = torch.zeros(*size, dtype=dtype, device="cpu", pin_memory=True)
|
||||
self.np = self.cpu.numpy()
|
||||
self.gpu = get_cuda_view_from_cpu_tensor(self.cpu)
|
||||
|
||||
|
||||
# NOTE(woosuk): Re-compilation can happen at runtime since top_p and top_k can be None.
|
||||
@triton.jit
|
||||
def _expand_sampling_metadata_kernel(
|
||||
@ -304,6 +381,12 @@ def _expand_sampling_metadata_kernel(
|
||||
top_k_ptr,
|
||||
expanded_top_k_ptr,
|
||||
seeds_ptr,
|
||||
rep_penalty_ptr,
|
||||
expanded_rep_penalty_ptr,
|
||||
freq_penalty_ptr,
|
||||
expanded_freq_penalty_ptr,
|
||||
pres_penalty_ptr,
|
||||
expanded_pres_penalty_ptr,
|
||||
expanded_seeds_ptr,
|
||||
cu_num_logits_ptr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
@ -327,6 +410,15 @@ def _expand_sampling_metadata_kernel(
|
||||
top_k = tl.load(top_k_ptr + req_idx)
|
||||
tl.store(expanded_top_k_ptr + start_idx + block, top_k, mask=mask)
|
||||
|
||||
rep_penalty = tl.load(rep_penalty_ptr + req_idx)
|
||||
tl.store(expanded_rep_penalty_ptr + start_idx + block, rep_penalty, mask=mask)
|
||||
|
||||
freq_penalty = tl.load(freq_penalty_ptr + req_idx)
|
||||
tl.store(expanded_freq_penalty_ptr + start_idx + block, freq_penalty, mask=mask)
|
||||
|
||||
pres_penalty = tl.load(pres_penalty_ptr + req_idx)
|
||||
tl.store(expanded_pres_penalty_ptr + start_idx + block, pres_penalty, mask=mask)
|
||||
|
||||
seed = tl.load(seeds_ptr + req_idx)
|
||||
tl.store(expanded_seeds_ptr + start_idx + block, seed, mask=mask)
|
||||
|
||||
@ -341,6 +433,9 @@ def expand_sampling_metadata(
|
||||
expanded_temp = create_empty(sampling_metadata.temperature)
|
||||
expanded_top_p = create_empty(sampling_metadata.top_p)
|
||||
expanded_top_k = create_empty(sampling_metadata.top_k)
|
||||
expanded_repetition_penalty = create_empty(sampling_metadata.repetition_penalty)
|
||||
expanded_frequency_penalty = create_empty(sampling_metadata.frequency_penalty)
|
||||
expanded_presence_penalty = create_empty(sampling_metadata.presence_penalty)
|
||||
expanded_seeds = create_empty(sampling_metadata.seeds)
|
||||
|
||||
num_reqs = cu_num_logits.shape[0] - 1
|
||||
@ -351,6 +446,12 @@ def expand_sampling_metadata(
|
||||
expanded_top_p,
|
||||
sampling_metadata.top_k,
|
||||
expanded_top_k,
|
||||
sampling_metadata.repetition_penalty,
|
||||
expanded_repetition_penalty,
|
||||
sampling_metadata.frequency_penalty,
|
||||
expanded_frequency_penalty,
|
||||
sampling_metadata.presence_penalty,
|
||||
expanded_presence_penalty,
|
||||
sampling_metadata.seeds,
|
||||
expanded_seeds,
|
||||
cu_num_logits,
|
||||
@ -361,6 +462,67 @@ def expand_sampling_metadata(
|
||||
top_p=expanded_top_p,
|
||||
top_k=expanded_top_k,
|
||||
seeds=expanded_seeds,
|
||||
repetition_penalty=expanded_repetition_penalty,
|
||||
frequency_penalty=expanded_frequency_penalty,
|
||||
presence_penalty=expanded_presence_penalty,
|
||||
pos=sampling_metadata.pos,
|
||||
max_num_logprobs=sampling_metadata.max_num_logprobs,
|
||||
# TODO(woosuk): Support penalties with spec decoding.
|
||||
idx_mapping=sampling_metadata.idx_mapping,
|
||||
prompt_bin_counts=sampling_metadata.prompt_bin_counts,
|
||||
output_bin_counts=sampling_metadata.output_bin_counts,
|
||||
)
|
||||
|
||||
|
||||
def use_penalty(sampling_params: SamplingParams) -> bool:
|
||||
return (
|
||||
sampling_params.repetition_penalty != 1.0
|
||||
or sampling_params.frequency_penalty != 0.0
|
||||
or sampling_params.presence_penalty != 0.0
|
||||
)
|
||||
|
||||
|
||||
@triton.jit(do_not_specialize=["prefill_len", "prompt_len"])
|
||||
def _bincount_kernel(
|
||||
prefill_token_ids_ptr,
|
||||
prefill_len,
|
||||
prompt_len,
|
||||
prompt_bin_counts_ptr,
|
||||
output_bin_counts_ptr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
block_idx = tl.program_id(0)
|
||||
if block_idx * BLOCK_SIZE >= prefill_len:
|
||||
return
|
||||
|
||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
if block_idx * BLOCK_SIZE < prompt_len:
|
||||
mask = block < prompt_len
|
||||
prefill_tokens = tl.load(prefill_token_ids_ptr + block, mask=mask)
|
||||
tl.atomic_add(prompt_bin_counts_ptr + prefill_tokens, 1, mask=mask)
|
||||
if (block_idx + 1) * BLOCK_SIZE >= prompt_len:
|
||||
mask = block < prefill_len
|
||||
mask &= block >= prompt_len
|
||||
prefill_tokens = tl.load(prefill_token_ids_ptr + block, mask=mask)
|
||||
tl.atomic_add(output_bin_counts_ptr + prefill_tokens, 1, mask=mask)
|
||||
|
||||
|
||||
def bincount(
|
||||
prefill_token_ids: torch.Tensor,
|
||||
prefill_len: int,
|
||||
prompt_len: int,
|
||||
prompt_bin_counts: torch.Tensor,
|
||||
output_bin_counts: torch.Tensor,
|
||||
) -> None:
|
||||
prompt_bin_counts.zero_()
|
||||
output_bin_counts.zero_()
|
||||
BLOCK_SIZE = 1024
|
||||
num_blocks = triton.cdiv(prefill_len, BLOCK_SIZE)
|
||||
_bincount_kernel[(num_blocks,)](
|
||||
prefill_token_ids,
|
||||
prefill_len,
|
||||
prompt_len,
|
||||
prompt_bin_counts,
|
||||
output_bin_counts,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user