mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 16:14:37 +08:00
[Model Runner V2] Add sample/ directory and reorganize files (#29719)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
39e63dec7c
commit
6afc0ffaf6
@ -47,13 +47,18 @@ from vllm.v1.worker.gpu.input_batch import (
|
||||
prepare_pos_seq_lens,
|
||||
prepare_prefill_inputs,
|
||||
)
|
||||
from vllm.v1.worker.gpu.sampler import Sampler, compute_prompt_logprobs
|
||||
from vllm.v1.worker.gpu.sample.logprob import compute_prompt_logprobs
|
||||
from vllm.v1.worker.gpu.sample.metadata import (
|
||||
SamplingMetadata,
|
||||
expand_sampling_metadata,
|
||||
)
|
||||
from vllm.v1.worker.gpu.sample.sampler import Sampler
|
||||
from vllm.v1.worker.gpu.spec_decode import init_speculator
|
||||
from vllm.v1.worker.gpu.spec_decode.rejection_sample import (
|
||||
get_num_rejected,
|
||||
rejection_sample,
|
||||
)
|
||||
from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
from vllm.v1.worker.gpu.structured_outputs import apply_grammar_bitmask
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
@ -890,8 +895,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
input_batch.idx_mapping, input_batch.idx_mapping_np, pos
|
||||
)
|
||||
if input_batch.num_draft_tokens > 0:
|
||||
sampling_metadata = self.req_states.expand_sampling_metadata(
|
||||
sampling_metadata, input_batch.cu_num_logits
|
||||
sampling_metadata = expand_sampling_metadata(
|
||||
sampling_metadata,
|
||||
input_batch.cu_num_logits,
|
||||
max_expand_len=self.num_speculative_steps + 1,
|
||||
)
|
||||
|
||||
if self.lora_config:
|
||||
|
||||
0
vllm/v1/worker/gpu/sample/__init__.py
Normal file
0
vllm/v1/worker/gpu/sample/__init__.py
Normal file
100
vllm/v1/worker/gpu/sample/gumbel.py
Normal file
100
vllm/v1/worker/gpu/sample/gumbel.py
Normal file
@ -0,0 +1,100 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _gumbel_sample_kernel(
|
||||
local_argmax_ptr,
|
||||
local_argmax_stride,
|
||||
local_max_ptr,
|
||||
local_max_stride,
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
seeds_ptr,
|
||||
pos_ptr,
|
||||
temp_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
APPLY_TEMPERATURE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
block_idx = tl.program_id(1)
|
||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(
|
||||
logits_ptr + req_idx * logits_stride + block,
|
||||
mask=mask,
|
||||
other=float("-inf"),
|
||||
)
|
||||
logits = logits.to(tl.float32)
|
||||
|
||||
temp = tl.load(temp_ptr + req_idx).to(tl.float32)
|
||||
if temp != 0.0:
|
||||
# Calculate the seed for gumbel noise.
|
||||
seed = tl.load(seeds_ptr + req_idx)
|
||||
pos = tl.load(pos_ptr + req_idx)
|
||||
gumbel_seed = tl.randint(seed, pos)
|
||||
|
||||
# Generate gumbel noise.
|
||||
r = tl.rand(gumbel_seed, block).to(tl.float64)
|
||||
gumbel_noise = -tl.log(-tl.log(r + 1e-20) + 1e-20)
|
||||
gumbel_noise = gumbel_noise.to(tl.float32)
|
||||
|
||||
# Apply temperature.
|
||||
if APPLY_TEMPERATURE:
|
||||
# NOTE(woosuk): Use div_rn to match the behavior of torch.
|
||||
logits = tl.div_rn(logits, temp)
|
||||
|
||||
# Apply gumbel noise.
|
||||
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
|
||||
|
||||
idx = tl.argmax(logits, axis=0)
|
||||
token_id = block_idx * BLOCK_SIZE + idx
|
||||
value = tl.max(logits, axis=0)
|
||||
tl.store(local_argmax_ptr + req_idx * local_argmax_stride + block_idx, token_id)
|
||||
tl.store(local_max_ptr + req_idx * local_max_stride + block_idx, value)
|
||||
|
||||
|
||||
def gumbel_sample(
|
||||
logits: torch.Tensor, # [num_reqs, vocab_size]
|
||||
temperature: torch.Tensor, # [num_reqs]
|
||||
seed: torch.Tensor, # [num_reqs]
|
||||
pos: torch.Tensor, # [num_reqs]
|
||||
apply_temperature: bool,
|
||||
) -> torch.Tensor:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
BLOCK_SIZE = 1024
|
||||
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
||||
local_argmax = torch.empty(
|
||||
num_reqs,
|
||||
num_blocks,
|
||||
dtype=torch.int64,
|
||||
device=logits.device,
|
||||
)
|
||||
local_max = torch.empty(
|
||||
num_reqs,
|
||||
num_blocks,
|
||||
dtype=torch.float32,
|
||||
device=logits.device,
|
||||
)
|
||||
_gumbel_sample_kernel[(num_reqs, num_blocks)](
|
||||
local_argmax,
|
||||
local_argmax.stride(0),
|
||||
local_max,
|
||||
local_max.stride(0),
|
||||
logits,
|
||||
logits.stride(0),
|
||||
seed,
|
||||
pos,
|
||||
temperature,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
APPLY_TEMPERATURE=apply_temperature,
|
||||
)
|
||||
# NOTE(woosuk): Use int64 for later indexing.
|
||||
max_block_idx = local_max.argmax(dim=-1, keepdim=True)
|
||||
sampled = local_argmax.gather(dim=-1, index=max_block_idx).view(-1)
|
||||
return sampled
|
||||
167
vllm/v1/worker/gpu/sample/logprob.py
Normal file
167
vllm/v1/worker/gpu/sample/logprob.py
Normal file
@ -0,0 +1,167 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _topk_log_softmax_kernel(
|
||||
output_ptr,
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
topk_ids_ptr,
|
||||
topk,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
PADDED_TOPK: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
row_ptr = logits_ptr + req_idx * logits_stride
|
||||
|
||||
max_val = float("-inf")
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
|
||||
max_val = tl.max(tl.maximum(logits, max_val))
|
||||
max_val = max_val.to(tl.float32) # type: ignore
|
||||
|
||||
se = 0.0
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0)
|
||||
# NOTE(woosuk): Make sure that logits and all following operations use FP32.
|
||||
logits = logits.to(tl.float32)
|
||||
e = tl.exp(logits - max_val)
|
||||
e = tl.where(block < vocab_size, e, 0.0)
|
||||
se += tl.sum(e)
|
||||
lse = tl.log(se)
|
||||
|
||||
k_offset = tl.arange(0, PADDED_TOPK)
|
||||
k_mask = k_offset < topk
|
||||
topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask, other=0)
|
||||
|
||||
logits = tl.load(row_ptr + topk_ids, mask=k_mask)
|
||||
logits = logits.to(tl.float32)
|
||||
o = logits - max_val - lse
|
||||
tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _ranks_kernel(
|
||||
output_ptr,
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
token_ids_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
row_ptr = logits_ptr + req_idx * logits_stride
|
||||
|
||||
token_id = tl.load(token_ids_ptr + req_idx)
|
||||
x = tl.load(row_ptr + token_id)
|
||||
|
||||
n = 0
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
|
||||
n += tl.sum((logits > x).to(tl.int32))
|
||||
tl.store(output_ptr + req_idx, n)
|
||||
|
||||
|
||||
def compute_token_logprobs(
|
||||
logits: torch.Tensor,
|
||||
token_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
batch_size = logits.shape[0]
|
||||
vocab_size = logits.shape[1]
|
||||
token_ids = token_ids.to(torch.int64)
|
||||
num_logprobs = token_ids.shape[1]
|
||||
logprobs = torch.empty(
|
||||
batch_size,
|
||||
num_logprobs,
|
||||
dtype=torch.float32,
|
||||
device=logits.device,
|
||||
)
|
||||
_topk_log_softmax_kernel[(batch_size,)](
|
||||
logprobs,
|
||||
logits,
|
||||
logits.stride(0),
|
||||
token_ids,
|
||||
num_logprobs,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=1024, # type: ignore
|
||||
PADDED_TOPK=triton.next_power_of_2(num_logprobs),
|
||||
)
|
||||
return logprobs
|
||||
|
||||
|
||||
def compute_topk_logprobs(
|
||||
logits: torch.Tensor,
|
||||
num_logprobs: int,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
) -> LogprobsTensors:
|
||||
assert num_logprobs >= 0
|
||||
batch_size, vocab_size = logits.shape
|
||||
if num_logprobs == 0:
|
||||
logprob_token_ids = sampled_token_ids.unsqueeze(-1)
|
||||
else:
|
||||
topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices
|
||||
logprob_token_ids = torch.cat(
|
||||
(sampled_token_ids.unsqueeze(-1), topk_indices), dim=1
|
||||
)
|
||||
|
||||
# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
|
||||
# logprobs tensor. Instead, we only compute and return the logprobs of
|
||||
# the topk + 1 tokens.
|
||||
logprobs = compute_token_logprobs(logits, logprob_token_ids)
|
||||
token_ranks = torch.empty(
|
||||
batch_size,
|
||||
dtype=torch.int64,
|
||||
device=logits.device,
|
||||
)
|
||||
_ranks_kernel[(batch_size,)](
|
||||
token_ranks,
|
||||
logits,
|
||||
logits.stride(0),
|
||||
sampled_token_ids,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=8192, # type: ignore
|
||||
)
|
||||
return LogprobsTensors(
|
||||
logprob_token_ids=logprob_token_ids,
|
||||
logprobs=logprobs,
|
||||
selected_token_ranks=token_ranks,
|
||||
)
|
||||
|
||||
|
||||
def compute_prompt_logprobs(
|
||||
prompt_token_ids: torch.Tensor,
|
||||
prompt_hidden_states: torch.Tensor,
|
||||
logits_fn: Callable[[torch.Tensor], torch.Tensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Since materializing the full prompt logits can take too much memory,
|
||||
# we compute it in chunks.
|
||||
CHUNK_SIZE = 1024
|
||||
logprobs = []
|
||||
ranks = []
|
||||
prompt_token_ids = prompt_token_ids.to(torch.int64)
|
||||
for start_idx in range(0, prompt_token_ids.shape[0], CHUNK_SIZE):
|
||||
end_idx = start_idx + CHUNK_SIZE
|
||||
# NOTE(woosuk): logits_fn can be slow because it involves all-gather.
|
||||
prompt_logits = logits_fn(prompt_hidden_states[start_idx:end_idx])
|
||||
prompt_logprobs = compute_topk_logprobs(
|
||||
prompt_logits,
|
||||
0, # num_logprobs
|
||||
prompt_token_ids[start_idx:end_idx],
|
||||
)
|
||||
logprobs.append(prompt_logprobs.logprobs)
|
||||
ranks.append(prompt_logprobs.selected_token_ranks)
|
||||
|
||||
logprobs = torch.cat(logprobs, dim=0) if len(logprobs) > 1 else logprobs[0]
|
||||
ranks = torch.cat(ranks, dim=0) if len(ranks) > 1 else ranks[0]
|
||||
return logprobs, ranks
|
||||
179
vllm/v1/worker/gpu/sample/metadata.py
Normal file
179
vllm/v1/worker/gpu/sample/metadata.py
Normal file
@ -0,0 +1,179 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplingMetadata:
|
||||
temperature: torch.Tensor
|
||||
|
||||
top_p: torch.Tensor | None
|
||||
top_k: torch.Tensor | None
|
||||
|
||||
repetition_penalty: torch.Tensor
|
||||
frequency_penalty: torch.Tensor
|
||||
presence_penalty: torch.Tensor
|
||||
|
||||
seeds: torch.Tensor
|
||||
pos: torch.Tensor
|
||||
|
||||
# None means no logprobs, 0 means sampled token logprobs only
|
||||
max_num_logprobs: int | None
|
||||
|
||||
# For penalties
|
||||
idx_mapping: torch.Tensor
|
||||
prompt_bin_counts: torch.Tensor
|
||||
output_bin_counts: torch.Tensor
|
||||
|
||||
@classmethod
|
||||
def make_dummy(
|
||||
cls,
|
||||
num_reqs: int,
|
||||
device: torch.device,
|
||||
) -> "SamplingMetadata":
|
||||
assert num_reqs > 0
|
||||
temperature = torch.zeros(num_reqs, dtype=torch.float32, device=device)
|
||||
temperature[0] = 0.5
|
||||
# TODO(woosuk): Use top-p and top-k for dummy sampler.
|
||||
# Currently, they are disabled because of memory usage.
|
||||
# top_p = torch.full((num_reqs,), 0.95, dtype=torch.float32, device=device)
|
||||
# top_k = torch.full((num_reqs,), 20, dtype=torch.int32, device=device)
|
||||
top_p = None
|
||||
top_k = None
|
||||
# NOTE(woosuk): We must set penalties to their default values to make sure
|
||||
# the penalties kernel does not touch the placeholder bin_counts tensors.
|
||||
repetition_penalty = torch.ones(num_reqs, dtype=torch.float32, device=device)
|
||||
frequency_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device)
|
||||
presence_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device)
|
||||
seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device)
|
||||
pos = torch.zeros(num_reqs, dtype=torch.int64, device=device)
|
||||
max_num_logprobs = 20
|
||||
|
||||
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
|
||||
# NOTE(woosuk): These are placeholder tensors to avoid None checks in the
|
||||
# penalties kernel. We use 2 instead of 1 as vocab_size to avoid Triton
|
||||
# specialization and re-compilation at runtime.
|
||||
prompt_bin_counts = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device)
|
||||
output_bin_counts = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device)
|
||||
|
||||
return cls(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
repetition_penalty=repetition_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
seeds=seeds,
|
||||
pos=pos,
|
||||
max_num_logprobs=max_num_logprobs,
|
||||
idx_mapping=idx_mapping,
|
||||
prompt_bin_counts=prompt_bin_counts,
|
||||
output_bin_counts=output_bin_counts,
|
||||
)
|
||||
|
||||
|
||||
# NOTE(woosuk): Re-compilation can happen at runtime since top_p and top_k can be None.
|
||||
@triton.jit
|
||||
def _expand_sampling_metadata_kernel(
|
||||
temp_ptr,
|
||||
expanded_temp_ptr,
|
||||
top_p_ptr,
|
||||
expanded_top_p_ptr,
|
||||
top_k_ptr,
|
||||
expanded_top_k_ptr,
|
||||
rep_penalty_ptr,
|
||||
expanded_rep_penalty_ptr,
|
||||
freq_penalty_ptr,
|
||||
expanded_freq_penalty_ptr,
|
||||
pres_penalty_ptr,
|
||||
expanded_pres_penalty_ptr,
|
||||
seeds_ptr,
|
||||
expanded_seeds_ptr,
|
||||
cu_num_logits_ptr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
start_idx = tl.load(cu_num_logits_ptr + req_idx)
|
||||
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
|
||||
num_tokens = end_idx - start_idx
|
||||
|
||||
block = tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < num_tokens
|
||||
|
||||
temp = tl.load(temp_ptr + req_idx)
|
||||
tl.store(expanded_temp_ptr + start_idx + block, temp, mask=mask)
|
||||
|
||||
if top_p_ptr is not None:
|
||||
top_p = tl.load(top_p_ptr + req_idx)
|
||||
tl.store(expanded_top_p_ptr + start_idx + block, top_p, mask=mask)
|
||||
|
||||
if top_k_ptr is not None:
|
||||
top_k = tl.load(top_k_ptr + req_idx)
|
||||
tl.store(expanded_top_k_ptr + start_idx + block, top_k, mask=mask)
|
||||
|
||||
rep_penalty = tl.load(rep_penalty_ptr + req_idx)
|
||||
tl.store(expanded_rep_penalty_ptr + start_idx + block, rep_penalty, mask=mask)
|
||||
|
||||
freq_penalty = tl.load(freq_penalty_ptr + req_idx)
|
||||
tl.store(expanded_freq_penalty_ptr + start_idx + block, freq_penalty, mask=mask)
|
||||
|
||||
pres_penalty = tl.load(pres_penalty_ptr + req_idx)
|
||||
tl.store(expanded_pres_penalty_ptr + start_idx + block, pres_penalty, mask=mask)
|
||||
|
||||
seed = tl.load(seeds_ptr + req_idx)
|
||||
tl.store(expanded_seeds_ptr + start_idx + block, seed, mask=mask)
|
||||
|
||||
|
||||
def expand_sampling_metadata(
|
||||
sampling_metadata: SamplingMetadata,
|
||||
cu_num_logits: torch.Tensor,
|
||||
max_expand_len: int,
|
||||
) -> SamplingMetadata:
|
||||
total_num_logits = sampling_metadata.pos.shape[0]
|
||||
create_empty = lambda x: x.new_empty(total_num_logits) if x is not None else None
|
||||
expanded_temp = create_empty(sampling_metadata.temperature)
|
||||
expanded_top_p = create_empty(sampling_metadata.top_p)
|
||||
expanded_top_k = create_empty(sampling_metadata.top_k)
|
||||
expanded_repetition_penalty = create_empty(sampling_metadata.repetition_penalty)
|
||||
expanded_frequency_penalty = create_empty(sampling_metadata.frequency_penalty)
|
||||
expanded_presence_penalty = create_empty(sampling_metadata.presence_penalty)
|
||||
expanded_seeds = create_empty(sampling_metadata.seeds)
|
||||
|
||||
num_reqs = cu_num_logits.shape[0] - 1
|
||||
_expand_sampling_metadata_kernel[(num_reqs,)](
|
||||
sampling_metadata.temperature,
|
||||
expanded_temp,
|
||||
sampling_metadata.top_p,
|
||||
expanded_top_p,
|
||||
sampling_metadata.top_k,
|
||||
expanded_top_k,
|
||||
sampling_metadata.repetition_penalty,
|
||||
expanded_repetition_penalty,
|
||||
sampling_metadata.frequency_penalty,
|
||||
expanded_frequency_penalty,
|
||||
sampling_metadata.presence_penalty,
|
||||
expanded_presence_penalty,
|
||||
sampling_metadata.seeds,
|
||||
expanded_seeds,
|
||||
cu_num_logits,
|
||||
BLOCK_SIZE=triton.next_power_of_2(max_expand_len),
|
||||
)
|
||||
return SamplingMetadata(
|
||||
temperature=expanded_temp,
|
||||
top_p=expanded_top_p,
|
||||
top_k=expanded_top_k,
|
||||
seeds=expanded_seeds,
|
||||
repetition_penalty=expanded_repetition_penalty,
|
||||
frequency_penalty=expanded_frequency_penalty,
|
||||
presence_penalty=expanded_presence_penalty,
|
||||
pos=sampling_metadata.pos,
|
||||
max_num_logprobs=sampling_metadata.max_num_logprobs,
|
||||
# TODO(woosuk): Support penalties with spec decoding.
|
||||
idx_mapping=sampling_metadata.idx_mapping,
|
||||
prompt_bin_counts=sampling_metadata.prompt_bin_counts,
|
||||
output_bin_counts=sampling_metadata.output_bin_counts,
|
||||
)
|
||||
@ -3,7 +3,7 @@
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.worker.gpu.states import SamplingMetadata
|
||||
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
|
||||
|
||||
|
||||
@triton.jit
|
||||
@ -83,3 +83,49 @@ def apply_penalties(logits: torch.Tensor, sampling_metadata: SamplingMetadata) -
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit(do_not_specialize=["prefill_len", "prompt_len"])
|
||||
def _bincount_kernel(
|
||||
prefill_token_ids_ptr,
|
||||
prefill_len,
|
||||
prompt_len,
|
||||
prompt_bin_counts_ptr,
|
||||
output_bin_counts_ptr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
block_idx = tl.program_id(0)
|
||||
if block_idx * BLOCK_SIZE >= prefill_len:
|
||||
return
|
||||
|
||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
if block_idx * BLOCK_SIZE < prompt_len:
|
||||
mask = block < prompt_len
|
||||
prefill_tokens = tl.load(prefill_token_ids_ptr + block, mask=mask)
|
||||
tl.atomic_add(prompt_bin_counts_ptr + prefill_tokens, 1, mask=mask)
|
||||
if (block_idx + 1) * BLOCK_SIZE >= prompt_len:
|
||||
mask = block < prefill_len
|
||||
mask &= block >= prompt_len
|
||||
prefill_tokens = tl.load(prefill_token_ids_ptr + block, mask=mask)
|
||||
tl.atomic_add(output_bin_counts_ptr + prefill_tokens, 1, mask=mask)
|
||||
|
||||
|
||||
def bincount(
|
||||
prefill_token_ids: torch.Tensor,
|
||||
prefill_len: int,
|
||||
prompt_len: int,
|
||||
prompt_bin_counts: torch.Tensor,
|
||||
output_bin_counts: torch.Tensor,
|
||||
) -> None:
|
||||
prompt_bin_counts.zero_()
|
||||
output_bin_counts.zero_()
|
||||
BLOCK_SIZE = 1024
|
||||
num_blocks = triton.cdiv(prefill_len, BLOCK_SIZE)
|
||||
_bincount_kernel[(num_blocks,)](
|
||||
prefill_token_ids,
|
||||
prefill_len,
|
||||
prompt_len,
|
||||
prompt_bin_counts,
|
||||
output_bin_counts,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
79
vllm/v1/worker/gpu/sample/sampler.py
Normal file
79
vllm/v1/worker/gpu/sample/sampler.py
Normal file
@ -0,0 +1,79 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config.model import LogprobsMode
|
||||
from vllm.v1.outputs import SamplerOutput
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
|
||||
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
|
||||
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.gpu.sample.penalties import apply_penalties
|
||||
|
||||
|
||||
class Sampler:
|
||||
def __init__(
|
||||
self,
|
||||
logprobs_mode: LogprobsMode = "raw_logprobs",
|
||||
):
|
||||
if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]:
|
||||
raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}")
|
||||
self.logprobs_mode = logprobs_mode
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
if sampling_metadata.max_num_logprobs is not None:
|
||||
if self.logprobs_mode == "processed_logprobs":
|
||||
sampled, logits = self.sample(
|
||||
logits, sampling_metadata, return_logits=True
|
||||
)
|
||||
else:
|
||||
assert self.logprobs_mode == "raw_logprobs"
|
||||
sampled, _ = self.sample(logits, sampling_metadata, return_logits=False)
|
||||
|
||||
logprobs_tensors = compute_topk_logprobs(
|
||||
logits,
|
||||
sampling_metadata.max_num_logprobs,
|
||||
sampled,
|
||||
)
|
||||
else:
|
||||
sampled, _ = self.sample(logits, sampling_metadata, return_logits=False)
|
||||
logprobs_tensors = None
|
||||
|
||||
# These are GPU tensors.
|
||||
sampler_output = SamplerOutput(
|
||||
# The sampled tokens are expanded to 2D tensor with shape
|
||||
# [num_requests, 1], where each row represents one generated
|
||||
# token per request.
|
||||
sampled_token_ids=sampled.view(-1, 1),
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
)
|
||||
return sampler_output
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
return_logits: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
is_greedy = sampling_metadata.temperature == 0
|
||||
temp = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
|
||||
logits = logits / temp.view(-1, 1)
|
||||
logits = apply_top_k_top_p(
|
||||
logits, sampling_metadata.top_k, sampling_metadata.top_p
|
||||
)
|
||||
# Apply penalties in place.
|
||||
apply_penalties(logits, sampling_metadata)
|
||||
|
||||
sampled = gumbel_sample(
|
||||
logits,
|
||||
sampling_metadata.temperature,
|
||||
sampling_metadata.seeds,
|
||||
sampling_metadata.pos,
|
||||
apply_temperature=False,
|
||||
)
|
||||
return sampled, logits if return_logits else None
|
||||
@ -1,333 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config.model import LogprobsMode
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.v1.worker.gpu.penalties import apply_penalties
|
||||
from vllm.v1.worker.gpu.states import SamplingMetadata
|
||||
|
||||
|
||||
class Sampler:
|
||||
def __init__(
|
||||
self,
|
||||
logprobs_mode: LogprobsMode = "raw_logprobs",
|
||||
):
|
||||
if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]:
|
||||
raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}")
|
||||
self.logprobs_mode = logprobs_mode
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
if sampling_metadata.max_num_logprobs is not None:
|
||||
if self.logprobs_mode == "processed_logprobs":
|
||||
sampled, logits = self.sample(
|
||||
logits, sampling_metadata, return_logits=True
|
||||
)
|
||||
else:
|
||||
assert self.logprobs_mode == "raw_logprobs"
|
||||
sampled, _ = self.sample(logits, sampling_metadata, return_logits=False)
|
||||
|
||||
logprobs_tensors = compute_topk_logprobs(
|
||||
logits,
|
||||
sampling_metadata.max_num_logprobs,
|
||||
sampled,
|
||||
)
|
||||
else:
|
||||
sampled, _ = self.sample(logits, sampling_metadata, return_logits=False)
|
||||
logprobs_tensors = None
|
||||
|
||||
# These are GPU tensors.
|
||||
sampler_output = SamplerOutput(
|
||||
# The sampled tokens are expanded to 2D tensor with shape
|
||||
# [num_requests, 1], where each row represents one generated
|
||||
# token per request.
|
||||
sampled_token_ids=sampled.view(-1, 1),
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
)
|
||||
return sampler_output
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
return_logits: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
is_greedy = sampling_metadata.temperature == 0
|
||||
temp = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
|
||||
logits = logits / temp.view(-1, 1)
|
||||
logits = apply_top_k_top_p(
|
||||
logits, sampling_metadata.top_k, sampling_metadata.top_p
|
||||
)
|
||||
# Apply penalties in place.
|
||||
apply_penalties(logits, sampling_metadata)
|
||||
|
||||
sampled = gumbel_sample(
|
||||
logits,
|
||||
sampling_metadata.temperature,
|
||||
sampling_metadata.seeds,
|
||||
sampling_metadata.pos,
|
||||
apply_temperature=False,
|
||||
)
|
||||
return sampled, logits if return_logits else None
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _gumbel_sample_kernel(
|
||||
local_argmax_ptr,
|
||||
local_argmax_stride,
|
||||
local_max_ptr,
|
||||
local_max_stride,
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
seeds_ptr,
|
||||
pos_ptr,
|
||||
temp_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
APPLY_TEMPERATURE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
block_idx = tl.program_id(1)
|
||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(
|
||||
logits_ptr + req_idx * logits_stride + block,
|
||||
mask=mask,
|
||||
other=float("-inf"),
|
||||
)
|
||||
logits = logits.to(tl.float32)
|
||||
|
||||
temp = tl.load(temp_ptr + req_idx).to(tl.float32)
|
||||
if temp != 0.0:
|
||||
# Calculate the seed for gumbel noise.
|
||||
seed = tl.load(seeds_ptr + req_idx)
|
||||
pos = tl.load(pos_ptr + req_idx)
|
||||
gumbel_seed = tl.randint(seed, pos)
|
||||
|
||||
# Generate gumbel noise.
|
||||
r = tl.rand(gumbel_seed, block).to(tl.float64)
|
||||
gumbel_noise = -tl.log(-tl.log(r + 1e-20) + 1e-20)
|
||||
gumbel_noise = gumbel_noise.to(tl.float32)
|
||||
|
||||
# Apply temperature.
|
||||
if APPLY_TEMPERATURE:
|
||||
# NOTE(woosuk): Use div_rn to match the behavior of torch.
|
||||
logits = tl.div_rn(logits, temp)
|
||||
|
||||
# Apply gumbel noise.
|
||||
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
|
||||
|
||||
idx = tl.argmax(logits, axis=0)
|
||||
token_id = block_idx * BLOCK_SIZE + idx
|
||||
value = tl.max(logits, axis=0)
|
||||
tl.store(local_argmax_ptr + req_idx * local_argmax_stride + block_idx, token_id)
|
||||
tl.store(local_max_ptr + req_idx * local_max_stride + block_idx, value)
|
||||
|
||||
|
||||
def gumbel_sample(
|
||||
logits: torch.Tensor, # [num_reqs, vocab_size]
|
||||
temperature: torch.Tensor, # [num_reqs]
|
||||
seed: torch.Tensor, # [num_reqs]
|
||||
pos: torch.Tensor, # [num_reqs]
|
||||
apply_temperature: bool,
|
||||
) -> torch.Tensor:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
BLOCK_SIZE = 1024
|
||||
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
||||
local_argmax = torch.empty(
|
||||
num_reqs,
|
||||
num_blocks,
|
||||
dtype=torch.int64,
|
||||
device=logits.device,
|
||||
)
|
||||
local_max = torch.empty(
|
||||
num_reqs,
|
||||
num_blocks,
|
||||
dtype=torch.float32,
|
||||
device=logits.device,
|
||||
)
|
||||
_gumbel_sample_kernel[(num_reqs, num_blocks)](
|
||||
local_argmax,
|
||||
local_argmax.stride(0),
|
||||
local_max,
|
||||
local_max.stride(0),
|
||||
logits,
|
||||
logits.stride(0),
|
||||
seed,
|
||||
pos,
|
||||
temperature,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
APPLY_TEMPERATURE=apply_temperature,
|
||||
)
|
||||
# NOTE(woosuk): Use int64 for later indexing.
|
||||
max_block_idx = local_max.argmax(dim=-1, keepdim=True)
|
||||
sampled = local_argmax.gather(dim=-1, index=max_block_idx).view(-1)
|
||||
return sampled
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _topk_log_softmax_kernel(
|
||||
output_ptr,
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
topk_ids_ptr,
|
||||
topk,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
PADDED_TOPK: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
row_ptr = logits_ptr + req_idx * logits_stride
|
||||
|
||||
max_val = float("-inf")
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
|
||||
max_val = tl.max(tl.maximum(logits, max_val))
|
||||
max_val = max_val.to(tl.float32) # type: ignore
|
||||
|
||||
se = 0.0
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0)
|
||||
# NOTE(woosuk): Make sure that logits and all following operations use FP32.
|
||||
logits = logits.to(tl.float32)
|
||||
e = tl.exp(logits - max_val)
|
||||
e = tl.where(block < vocab_size, e, 0.0)
|
||||
se += tl.sum(e)
|
||||
lse = tl.log(se)
|
||||
|
||||
k_offset = tl.arange(0, PADDED_TOPK)
|
||||
k_mask = k_offset < topk
|
||||
topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask, other=0)
|
||||
|
||||
logits = tl.load(row_ptr + topk_ids, mask=k_mask)
|
||||
logits = logits.to(tl.float32)
|
||||
o = logits - max_val - lse
|
||||
tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _ranks_kernel(
|
||||
output_ptr,
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
token_ids_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
row_ptr = logits_ptr + req_idx * logits_stride
|
||||
|
||||
token_id = tl.load(token_ids_ptr + req_idx)
|
||||
x = tl.load(row_ptr + token_id)
|
||||
|
||||
n = 0
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
|
||||
n += tl.sum((logits > x).to(tl.int32))
|
||||
tl.store(output_ptr + req_idx, n)
|
||||
|
||||
|
||||
def compute_token_logprobs(
|
||||
logits: torch.Tensor,
|
||||
token_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
batch_size = logits.shape[0]
|
||||
vocab_size = logits.shape[1]
|
||||
token_ids = token_ids.to(torch.int64)
|
||||
num_logprobs = token_ids.shape[1]
|
||||
logprobs = torch.empty(
|
||||
batch_size,
|
||||
num_logprobs,
|
||||
dtype=torch.float32,
|
||||
device=logits.device,
|
||||
)
|
||||
_topk_log_softmax_kernel[(batch_size,)](
|
||||
logprobs,
|
||||
logits,
|
||||
logits.stride(0),
|
||||
token_ids,
|
||||
num_logprobs,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=1024, # type: ignore
|
||||
PADDED_TOPK=triton.next_power_of_2(num_logprobs),
|
||||
)
|
||||
return logprobs
|
||||
|
||||
|
||||
def compute_topk_logprobs(
|
||||
logits: torch.Tensor,
|
||||
num_logprobs: int,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
) -> LogprobsTensors:
|
||||
assert num_logprobs >= 0
|
||||
batch_size, vocab_size = logits.shape
|
||||
if num_logprobs == 0:
|
||||
logprob_token_ids = sampled_token_ids.unsqueeze(-1)
|
||||
else:
|
||||
topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices
|
||||
logprob_token_ids = torch.cat(
|
||||
(sampled_token_ids.unsqueeze(-1), topk_indices), dim=1
|
||||
)
|
||||
|
||||
# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
|
||||
# logprobs tensor. Instead, we only compute and return the logprobs of
|
||||
# the topk + 1 tokens.
|
||||
logprobs = compute_token_logprobs(logits, logprob_token_ids)
|
||||
token_ranks = torch.empty(
|
||||
batch_size,
|
||||
dtype=torch.int64,
|
||||
device=logits.device,
|
||||
)
|
||||
_ranks_kernel[(batch_size,)](
|
||||
token_ranks,
|
||||
logits,
|
||||
logits.stride(0),
|
||||
sampled_token_ids,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=8192, # type: ignore
|
||||
)
|
||||
return LogprobsTensors(
|
||||
logprob_token_ids=logprob_token_ids,
|
||||
logprobs=logprobs,
|
||||
selected_token_ranks=token_ranks,
|
||||
)
|
||||
|
||||
|
||||
def compute_prompt_logprobs(
|
||||
prompt_token_ids: torch.Tensor,
|
||||
prompt_hidden_states: torch.Tensor,
|
||||
logits_fn: Callable[[torch.Tensor], torch.Tensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Since materializing the full prompt logits can take too much memory,
|
||||
# we compute it in chunks.
|
||||
CHUNK_SIZE = 1024
|
||||
logprobs = []
|
||||
ranks = []
|
||||
prompt_token_ids = prompt_token_ids.to(torch.int64)
|
||||
for start_idx in range(0, prompt_token_ids.shape[0], CHUNK_SIZE):
|
||||
end_idx = start_idx + CHUNK_SIZE
|
||||
# NOTE(woosuk): logits_fn can be slow because it involves all-gather.
|
||||
prompt_logits = logits_fn(prompt_hidden_states[start_idx:end_idx])
|
||||
prompt_logprobs = compute_topk_logprobs(
|
||||
prompt_logits,
|
||||
0, # num_logprobs
|
||||
prompt_token_ids[start_idx:end_idx],
|
||||
)
|
||||
logprobs.append(prompt_logprobs.logprobs)
|
||||
ranks.append(prompt_logprobs.selected_token_ranks)
|
||||
|
||||
logprobs = torch.cat(logprobs, dim=0) if len(logprobs) > 1 else logprobs[0]
|
||||
ranks = torch.cat(ranks, dim=0) if len(ranks) > 1 else ranks[0]
|
||||
return logprobs, ranks
|
||||
@ -18,9 +18,9 @@ from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
|
||||
from vllm.v1.worker.gpu.sampler import gumbel_sample
|
||||
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
|
||||
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.gpu.spec_decode.eagle_cudagraph import EagleCudaGraphManager
|
||||
from vllm.v1.worker.gpu.states import SamplingMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -7,86 +7,18 @@ import torch
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.platform_utils import is_uva_available
|
||||
from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.gpu.sample.penalties import bincount
|
||||
|
||||
_NP_INT64_MIN = np.iinfo(np.int64).min
|
||||
_NP_INT64_MAX = np.iinfo(np.int64).max
|
||||
NO_LORA_ID = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplingMetadata:
|
||||
temperature: torch.Tensor
|
||||
|
||||
top_p: torch.Tensor | None
|
||||
top_k: torch.Tensor | None
|
||||
|
||||
repetition_penalty: torch.Tensor
|
||||
frequency_penalty: torch.Tensor
|
||||
presence_penalty: torch.Tensor
|
||||
|
||||
seeds: torch.Tensor
|
||||
pos: torch.Tensor
|
||||
|
||||
# None means no logprobs, 0 means sampled token logprobs only
|
||||
max_num_logprobs: int | None
|
||||
|
||||
# For penalties
|
||||
idx_mapping: torch.Tensor
|
||||
prompt_bin_counts: torch.Tensor
|
||||
output_bin_counts: torch.Tensor
|
||||
|
||||
@classmethod
|
||||
def make_dummy(
|
||||
cls,
|
||||
num_reqs: int,
|
||||
device: torch.device,
|
||||
) -> "SamplingMetadata":
|
||||
assert num_reqs > 0
|
||||
temperature = torch.zeros(num_reqs, dtype=torch.float32, device=device)
|
||||
temperature[0] = 0.5
|
||||
# TODO(woosuk): Use top-p and top-k for dummy sampler.
|
||||
# Currently, they are disabled because of memory usage.
|
||||
# top_p = torch.full((num_reqs,), 0.95, dtype=torch.float32, device=device)
|
||||
# top_k = torch.full((num_reqs,), 20, dtype=torch.int32, device=device)
|
||||
top_p = None
|
||||
top_k = None
|
||||
# NOTE(woosuk): We must set penalties to their default values to make sure
|
||||
# the penalties kernel does not touch the placeholder bin_counts tensors.
|
||||
repetition_penalty = torch.ones(num_reqs, dtype=torch.float32, device=device)
|
||||
frequency_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device)
|
||||
presence_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device)
|
||||
seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device)
|
||||
pos = torch.zeros(num_reqs, dtype=torch.int64, device=device)
|
||||
max_num_logprobs = 20
|
||||
|
||||
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
|
||||
# NOTE(woosuk): These are placeholder tensors to avoid None checks in the
|
||||
# penalties kernel. We use 2 instead of 1 as vocab_size to avoid Triton
|
||||
# specialization and re-compilation at runtime.
|
||||
prompt_bin_counts = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device)
|
||||
output_bin_counts = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device)
|
||||
|
||||
return cls(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
repetition_penalty=repetition_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
seeds=seeds,
|
||||
pos=pos,
|
||||
max_num_logprobs=max_num_logprobs,
|
||||
idx_mapping=idx_mapping,
|
||||
prompt_bin_counts=prompt_bin_counts,
|
||||
output_bin_counts=output_bin_counts,
|
||||
)
|
||||
|
||||
|
||||
class RequestState:
|
||||
def __init__(
|
||||
self,
|
||||
@ -311,17 +243,6 @@ class RequestState:
|
||||
output_bin_counts=self.output_bin_counts,
|
||||
)
|
||||
|
||||
def expand_sampling_metadata(
|
||||
self,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
cu_num_logits: torch.Tensor,
|
||||
) -> SamplingMetadata:
|
||||
# For draft tokens, we need to expand the sampling param tensors as
|
||||
# each request samples multiple tokens in each step.
|
||||
return expand_sampling_metadata(
|
||||
sampling_metadata, cu_num_logits, self.num_speculative_steps
|
||||
)
|
||||
|
||||
def make_lora_inputs(
|
||||
self,
|
||||
req_ids: list[str],
|
||||
@ -376,158 +297,9 @@ class UvaBuffer:
|
||||
self.gpu = get_cuda_view_from_cpu_tensor(self.cpu)
|
||||
|
||||
|
||||
# NOTE(woosuk): Re-compilation can happen at runtime since top_p and top_k can be None.
|
||||
@triton.jit
|
||||
def _expand_sampling_metadata_kernel(
|
||||
temp_ptr,
|
||||
expanded_temp_ptr,
|
||||
top_p_ptr,
|
||||
expanded_top_p_ptr,
|
||||
top_k_ptr,
|
||||
expanded_top_k_ptr,
|
||||
rep_penalty_ptr,
|
||||
expanded_rep_penalty_ptr,
|
||||
freq_penalty_ptr,
|
||||
expanded_freq_penalty_ptr,
|
||||
pres_penalty_ptr,
|
||||
expanded_pres_penalty_ptr,
|
||||
seeds_ptr,
|
||||
expanded_seeds_ptr,
|
||||
cu_num_logits_ptr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
start_idx = tl.load(cu_num_logits_ptr + req_idx)
|
||||
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
|
||||
num_tokens = end_idx - start_idx
|
||||
|
||||
block = tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < num_tokens
|
||||
|
||||
temp = tl.load(temp_ptr + req_idx)
|
||||
tl.store(expanded_temp_ptr + start_idx + block, temp, mask=mask)
|
||||
|
||||
if top_p_ptr is not None:
|
||||
top_p = tl.load(top_p_ptr + req_idx)
|
||||
tl.store(expanded_top_p_ptr + start_idx + block, top_p, mask=mask)
|
||||
|
||||
if top_k_ptr is not None:
|
||||
top_k = tl.load(top_k_ptr + req_idx)
|
||||
tl.store(expanded_top_k_ptr + start_idx + block, top_k, mask=mask)
|
||||
|
||||
rep_penalty = tl.load(rep_penalty_ptr + req_idx)
|
||||
tl.store(expanded_rep_penalty_ptr + start_idx + block, rep_penalty, mask=mask)
|
||||
|
||||
freq_penalty = tl.load(freq_penalty_ptr + req_idx)
|
||||
tl.store(expanded_freq_penalty_ptr + start_idx + block, freq_penalty, mask=mask)
|
||||
|
||||
pres_penalty = tl.load(pres_penalty_ptr + req_idx)
|
||||
tl.store(expanded_pres_penalty_ptr + start_idx + block, pres_penalty, mask=mask)
|
||||
|
||||
seed = tl.load(seeds_ptr + req_idx)
|
||||
tl.store(expanded_seeds_ptr + start_idx + block, seed, mask=mask)
|
||||
|
||||
|
||||
def expand_sampling_metadata(
|
||||
sampling_metadata: SamplingMetadata,
|
||||
cu_num_logits: torch.Tensor,
|
||||
num_speculative_steps: int,
|
||||
) -> SamplingMetadata:
|
||||
total_num_logits = sampling_metadata.pos.shape[0]
|
||||
create_empty = lambda x: x.new_empty(total_num_logits) if x is not None else None
|
||||
expanded_temp = create_empty(sampling_metadata.temperature)
|
||||
expanded_top_p = create_empty(sampling_metadata.top_p)
|
||||
expanded_top_k = create_empty(sampling_metadata.top_k)
|
||||
expanded_repetition_penalty = create_empty(sampling_metadata.repetition_penalty)
|
||||
expanded_frequency_penalty = create_empty(sampling_metadata.frequency_penalty)
|
||||
expanded_presence_penalty = create_empty(sampling_metadata.presence_penalty)
|
||||
expanded_seeds = create_empty(sampling_metadata.seeds)
|
||||
|
||||
num_reqs = cu_num_logits.shape[0] - 1
|
||||
_expand_sampling_metadata_kernel[(num_reqs,)](
|
||||
sampling_metadata.temperature,
|
||||
expanded_temp,
|
||||
sampling_metadata.top_p,
|
||||
expanded_top_p,
|
||||
sampling_metadata.top_k,
|
||||
expanded_top_k,
|
||||
sampling_metadata.repetition_penalty,
|
||||
expanded_repetition_penalty,
|
||||
sampling_metadata.frequency_penalty,
|
||||
expanded_frequency_penalty,
|
||||
sampling_metadata.presence_penalty,
|
||||
expanded_presence_penalty,
|
||||
sampling_metadata.seeds,
|
||||
expanded_seeds,
|
||||
cu_num_logits,
|
||||
BLOCK_SIZE=triton.next_power_of_2(num_speculative_steps + 1),
|
||||
)
|
||||
return SamplingMetadata(
|
||||
temperature=expanded_temp,
|
||||
top_p=expanded_top_p,
|
||||
top_k=expanded_top_k,
|
||||
seeds=expanded_seeds,
|
||||
repetition_penalty=expanded_repetition_penalty,
|
||||
frequency_penalty=expanded_frequency_penalty,
|
||||
presence_penalty=expanded_presence_penalty,
|
||||
pos=sampling_metadata.pos,
|
||||
max_num_logprobs=sampling_metadata.max_num_logprobs,
|
||||
# TODO(woosuk): Support penalties with spec decoding.
|
||||
idx_mapping=sampling_metadata.idx_mapping,
|
||||
prompt_bin_counts=sampling_metadata.prompt_bin_counts,
|
||||
output_bin_counts=sampling_metadata.output_bin_counts,
|
||||
)
|
||||
|
||||
|
||||
def use_penalty(sampling_params: SamplingParams) -> bool:
|
||||
return (
|
||||
sampling_params.repetition_penalty != 1.0
|
||||
or sampling_params.frequency_penalty != 0.0
|
||||
or sampling_params.presence_penalty != 0.0
|
||||
)
|
||||
|
||||
|
||||
@triton.jit(do_not_specialize=["prefill_len", "prompt_len"])
|
||||
def _bincount_kernel(
|
||||
prefill_token_ids_ptr,
|
||||
prefill_len,
|
||||
prompt_len,
|
||||
prompt_bin_counts_ptr,
|
||||
output_bin_counts_ptr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
block_idx = tl.program_id(0)
|
||||
if block_idx * BLOCK_SIZE >= prefill_len:
|
||||
return
|
||||
|
||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
if block_idx * BLOCK_SIZE < prompt_len:
|
||||
mask = block < prompt_len
|
||||
prefill_tokens = tl.load(prefill_token_ids_ptr + block, mask=mask)
|
||||
tl.atomic_add(prompt_bin_counts_ptr + prefill_tokens, 1, mask=mask)
|
||||
if (block_idx + 1) * BLOCK_SIZE >= prompt_len:
|
||||
mask = block < prefill_len
|
||||
mask &= block >= prompt_len
|
||||
prefill_tokens = tl.load(prefill_token_ids_ptr + block, mask=mask)
|
||||
tl.atomic_add(output_bin_counts_ptr + prefill_tokens, 1, mask=mask)
|
||||
|
||||
|
||||
def bincount(
|
||||
prefill_token_ids: torch.Tensor,
|
||||
prefill_len: int,
|
||||
prompt_len: int,
|
||||
prompt_bin_counts: torch.Tensor,
|
||||
output_bin_counts: torch.Tensor,
|
||||
) -> None:
|
||||
prompt_bin_counts.zero_()
|
||||
output_bin_counts.zero_()
|
||||
BLOCK_SIZE = 1024
|
||||
num_blocks = triton.cdiv(prefill_len, BLOCK_SIZE)
|
||||
_bincount_kernel[(num_blocks,)](
|
||||
prefill_token_ids,
|
||||
prefill_len,
|
||||
prompt_len,
|
||||
prompt_bin_counts,
|
||||
output_bin_counts,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user