mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-08 13:47:06 +08:00
fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
bc6463ac97
commit
aabfaa08cf
@ -320,27 +320,15 @@ class GPUModelRunner:
|
||||
logits: torch.Tensor,
|
||||
input_batch: InputBatch,
|
||||
) -> SamplerOutput:
|
||||
pos = input_batch.positions[input_batch.logits_indices]
|
||||
sampling_metadata = self.req_states.make_sampling_metadata(
|
||||
input_batch.idx_mapping_np)
|
||||
input_batch.idx_mapping_np, pos)
|
||||
sampler_output = self.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
return sampler_output
|
||||
|
||||
def compute_prompt_logprobs(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_batch: InputBatch,
|
||||
):
|
||||
idx_mapping_np = input_batch.idx_mapping_np
|
||||
needs_prompt_logprobs = self.req_states.needs_prompt_logprobs[
|
||||
idx_mapping_np]
|
||||
if not np.any(needs_prompt_logprobs):
|
||||
# Common case.
|
||||
# No request in the batch needs prompt logprobs.
|
||||
return None
|
||||
|
||||
def postprocess(
|
||||
self,
|
||||
sampler_output: SamplerOutput,
|
||||
@ -387,9 +375,6 @@ class GPUModelRunner:
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
|
||||
sampler_output = self.sample(logits, input_batch)
|
||||
prompt_logprobs = self.compute_prompt_logprobs(hidden_states,
|
||||
input_batch)
|
||||
|
||||
sampled_token_ids_np, num_sampled_tokens = self.postprocess(
|
||||
sampler_output, input_batch)
|
||||
req_id_to_index = {
|
||||
|
||||
@ -1,8 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import triton
|
||||
@ -47,8 +44,8 @@ class Sampler(nn.Module):
|
||||
sampled = gumbel_sample(
|
||||
probs,
|
||||
sampling_metadata.temperature,
|
||||
None, # seeds
|
||||
None, # pos
|
||||
sampling_metadata.seeds,
|
||||
sampling_metadata.pos,
|
||||
)
|
||||
|
||||
logprobs_tensors = None
|
||||
@ -163,22 +160,27 @@ def _apply_gumbel_kernel(
|
||||
probs_ptr,
|
||||
probs_stride,
|
||||
seeds_ptr,
|
||||
pos_ptr,
|
||||
temp_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
EPSILON: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
seed = tl.load(seeds_ptr + req_idx)
|
||||
temp = tl.load(temp_ptr + req_idx)
|
||||
|
||||
if temp < EPSILON:
|
||||
# Greedy sampling. Don't apply gumbel noise.
|
||||
return
|
||||
|
||||
seed = tl.load(seeds_ptr + req_idx)
|
||||
pos = tl.load(pos_ptr + req_idx)
|
||||
gumbel_seed = seed ^ (pos * 0x9E3779B9)
|
||||
gumbel_seed = gumbel_seed & 0xFFFFFFFF
|
||||
|
||||
block_id = tl.program_id(1)
|
||||
r_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
q = tl.rand(seed, r_offset)
|
||||
q = tl.rand(gumbel_seed, r_offset)
|
||||
|
||||
# NOTE(woosuk): This logic makes sure q is not 0.
|
||||
RMAX = 0.9999999403953552
|
||||
@ -201,34 +203,20 @@ def gumbel_sample(
|
||||
# fp32[num_reqs]
|
||||
temperature: torch.Tensor,
|
||||
# int64[num_reqs]
|
||||
seeds: Optional[torch.Tensor],
|
||||
seeds: torch.Tensor,
|
||||
# int64[num_reqs]
|
||||
pos: Optional[torch.Tensor],
|
||||
pos: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
num_reqs = probs.shape[0]
|
||||
vocab_size = probs.shape[1]
|
||||
if seeds is not None:
|
||||
# Per-request seed.
|
||||
assert pos is not None
|
||||
gumbel_seeds = seeds ^ (pos * 0x9E3779B9)
|
||||
else:
|
||||
# Global seed.
|
||||
assert pos is None
|
||||
seed_dtype = torch.int64
|
||||
gumbel_seeds = torch.randint(
|
||||
torch.iinfo(seed_dtype).min,
|
||||
torch.iinfo(seed_dtype).max,
|
||||
(num_reqs, ),
|
||||
dtype=seed_dtype,
|
||||
device=probs.device,
|
||||
)
|
||||
|
||||
# Update the probs in-place.
|
||||
BLOCK_SIZE = 8192
|
||||
_apply_gumbel_kernel[(num_reqs, triton.cdiv(vocab_size, BLOCK_SIZE))](
|
||||
probs,
|
||||
probs.stride(0),
|
||||
gumbel_seeds,
|
||||
seeds,
|
||||
pos,
|
||||
temperature,
|
||||
vocab_size,
|
||||
BLOCK_SIZE,
|
||||
|
||||
@ -10,6 +10,9 @@ import torch
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
_NP_INT64_MIN = np.iinfo(np.int64).min
|
||||
_NP_INT64_MAX = np.iinfo(np.int64).max
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplingMetadata:
|
||||
@ -19,6 +22,9 @@ class SamplingMetadata:
|
||||
top_p: Optional[torch.Tensor]
|
||||
top_k: Optional[torch.Tensor]
|
||||
|
||||
seeds: torch.Tensor
|
||||
pos: torch.Tensor
|
||||
|
||||
# None means no logprobs, 0 means sampled token logprobs only
|
||||
max_num_logprobs: Optional[int]
|
||||
|
||||
@ -69,6 +75,7 @@ class RequestState:
|
||||
self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32)
|
||||
# -1 means no logprobs are requested.
|
||||
self.num_logprobs.fill(-1)
|
||||
self.seeds = np.zeros(self.max_num_reqs, dtype=np.int64)
|
||||
|
||||
self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
|
||||
|
||||
@ -102,6 +109,12 @@ class RequestState:
|
||||
top_k = self.vocab_size
|
||||
self.top_k[req_idx] = top_k
|
||||
|
||||
if sampling_params.seed is not None:
|
||||
seed = sampling_params.seed
|
||||
else:
|
||||
seed = np.random.randint(_NP_INT64_MIN, _NP_INT64_MAX)
|
||||
self.seeds[req_idx] = seed
|
||||
|
||||
if sampling_params.logprobs is not None:
|
||||
num_logprobs = sampling_params.logprobs
|
||||
else:
|
||||
@ -133,6 +146,7 @@ class RequestState:
|
||||
def make_sampling_metadata(
|
||||
self,
|
||||
idx_mapping: np.ndarray,
|
||||
pos: torch.Tensor,
|
||||
) -> SamplingMetadata:
|
||||
temperature = self.temperature[idx_mapping]
|
||||
temperature = self._copy_np_to_gpu(temperature)
|
||||
@ -145,6 +159,9 @@ class RequestState:
|
||||
no_top_k = np.all(top_k == self.vocab_size)
|
||||
top_k = self._copy_np_to_gpu(top_k) if not no_top_k else None
|
||||
|
||||
seeds = self.seeds[idx_mapping]
|
||||
seeds = self._copy_np_to_gpu(seeds)
|
||||
|
||||
num_logprobs = self.num_logprobs[idx_mapping]
|
||||
max_num_logprobs = np.max(num_logprobs)
|
||||
if max_num_logprobs == -1:
|
||||
@ -154,6 +171,8 @@ class RequestState:
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
seeds=seeds,
|
||||
pos=pos,
|
||||
max_num_logprobs=max_num_logprobs,
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user