mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-07 18:09:09 +08:00
fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
bc6463ac97
commit
aabfaa08cf
@ -320,27 +320,15 @@ class GPUModelRunner:
|
|||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
input_batch: InputBatch,
|
input_batch: InputBatch,
|
||||||
) -> SamplerOutput:
|
) -> SamplerOutput:
|
||||||
|
pos = input_batch.positions[input_batch.logits_indices]
|
||||||
sampling_metadata = self.req_states.make_sampling_metadata(
|
sampling_metadata = self.req_states.make_sampling_metadata(
|
||||||
input_batch.idx_mapping_np)
|
input_batch.idx_mapping_np, pos)
|
||||||
sampler_output = self.sampler(
|
sampler_output = self.sampler(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
sampling_metadata=sampling_metadata,
|
sampling_metadata=sampling_metadata,
|
||||||
)
|
)
|
||||||
return sampler_output
|
return sampler_output
|
||||||
|
|
||||||
def compute_prompt_logprobs(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
input_batch: InputBatch,
|
|
||||||
):
|
|
||||||
idx_mapping_np = input_batch.idx_mapping_np
|
|
||||||
needs_prompt_logprobs = self.req_states.needs_prompt_logprobs[
|
|
||||||
idx_mapping_np]
|
|
||||||
if not np.any(needs_prompt_logprobs):
|
|
||||||
# Common case.
|
|
||||||
# No request in the batch needs prompt logprobs.
|
|
||||||
return None
|
|
||||||
|
|
||||||
def postprocess(
|
def postprocess(
|
||||||
self,
|
self,
|
||||||
sampler_output: SamplerOutput,
|
sampler_output: SamplerOutput,
|
||||||
@ -387,9 +375,6 @@ class GPUModelRunner:
|
|||||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||||
|
|
||||||
sampler_output = self.sample(logits, input_batch)
|
sampler_output = self.sample(logits, input_batch)
|
||||||
prompt_logprobs = self.compute_prompt_logprobs(hidden_states,
|
|
||||||
input_batch)
|
|
||||||
|
|
||||||
sampled_token_ids_np, num_sampled_tokens = self.postprocess(
|
sampled_token_ids_np, num_sampled_tokens = self.postprocess(
|
||||||
sampler_output, input_batch)
|
sampler_output, input_batch)
|
||||||
req_id_to_index = {
|
req_id_to_index = {
|
||||||
|
|||||||
@ -1,8 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import triton
|
import triton
|
||||||
@ -47,8 +44,8 @@ class Sampler(nn.Module):
|
|||||||
sampled = gumbel_sample(
|
sampled = gumbel_sample(
|
||||||
probs,
|
probs,
|
||||||
sampling_metadata.temperature,
|
sampling_metadata.temperature,
|
||||||
None, # seeds
|
sampling_metadata.seeds,
|
||||||
None, # pos
|
sampling_metadata.pos,
|
||||||
)
|
)
|
||||||
|
|
||||||
logprobs_tensors = None
|
logprobs_tensors = None
|
||||||
@ -163,22 +160,27 @@ def _apply_gumbel_kernel(
|
|||||||
probs_ptr,
|
probs_ptr,
|
||||||
probs_stride,
|
probs_stride,
|
||||||
seeds_ptr,
|
seeds_ptr,
|
||||||
|
pos_ptr,
|
||||||
temp_ptr,
|
temp_ptr,
|
||||||
vocab_size,
|
vocab_size,
|
||||||
BLOCK_SIZE: tl.constexpr,
|
BLOCK_SIZE: tl.constexpr,
|
||||||
EPSILON: tl.constexpr,
|
EPSILON: tl.constexpr,
|
||||||
):
|
):
|
||||||
req_idx = tl.program_id(0)
|
req_idx = tl.program_id(0)
|
||||||
seed = tl.load(seeds_ptr + req_idx)
|
|
||||||
temp = tl.load(temp_ptr + req_idx)
|
temp = tl.load(temp_ptr + req_idx)
|
||||||
|
|
||||||
if temp < EPSILON:
|
if temp < EPSILON:
|
||||||
# Greedy sampling. Don't apply gumbel noise.
|
# Greedy sampling. Don't apply gumbel noise.
|
||||||
return
|
return
|
||||||
|
|
||||||
|
seed = tl.load(seeds_ptr + req_idx)
|
||||||
|
pos = tl.load(pos_ptr + req_idx)
|
||||||
|
gumbel_seed = seed ^ (pos * 0x9E3779B9)
|
||||||
|
gumbel_seed = gumbel_seed & 0xFFFFFFFF
|
||||||
|
|
||||||
block_id = tl.program_id(1)
|
block_id = tl.program_id(1)
|
||||||
r_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
r_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
q = tl.rand(seed, r_offset)
|
q = tl.rand(gumbel_seed, r_offset)
|
||||||
|
|
||||||
# NOTE(woosuk): This logic makes sure q is not 0.
|
# NOTE(woosuk): This logic makes sure q is not 0.
|
||||||
RMAX = 0.9999999403953552
|
RMAX = 0.9999999403953552
|
||||||
@ -201,34 +203,20 @@ def gumbel_sample(
|
|||||||
# fp32[num_reqs]
|
# fp32[num_reqs]
|
||||||
temperature: torch.Tensor,
|
temperature: torch.Tensor,
|
||||||
# int64[num_reqs]
|
# int64[num_reqs]
|
||||||
seeds: Optional[torch.Tensor],
|
seeds: torch.Tensor,
|
||||||
# int64[num_reqs]
|
# int64[num_reqs]
|
||||||
pos: Optional[torch.Tensor],
|
pos: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
num_reqs = probs.shape[0]
|
num_reqs = probs.shape[0]
|
||||||
vocab_size = probs.shape[1]
|
vocab_size = probs.shape[1]
|
||||||
if seeds is not None:
|
|
||||||
# Per-request seed.
|
|
||||||
assert pos is not None
|
|
||||||
gumbel_seeds = seeds ^ (pos * 0x9E3779B9)
|
|
||||||
else:
|
|
||||||
# Global seed.
|
|
||||||
assert pos is None
|
|
||||||
seed_dtype = torch.int64
|
|
||||||
gumbel_seeds = torch.randint(
|
|
||||||
torch.iinfo(seed_dtype).min,
|
|
||||||
torch.iinfo(seed_dtype).max,
|
|
||||||
(num_reqs, ),
|
|
||||||
dtype=seed_dtype,
|
|
||||||
device=probs.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update the probs in-place.
|
# Update the probs in-place.
|
||||||
BLOCK_SIZE = 8192
|
BLOCK_SIZE = 8192
|
||||||
_apply_gumbel_kernel[(num_reqs, triton.cdiv(vocab_size, BLOCK_SIZE))](
|
_apply_gumbel_kernel[(num_reqs, triton.cdiv(vocab_size, BLOCK_SIZE))](
|
||||||
probs,
|
probs,
|
||||||
probs.stride(0),
|
probs.stride(0),
|
||||||
gumbel_seeds,
|
seeds,
|
||||||
|
pos,
|
||||||
temperature,
|
temperature,
|
||||||
vocab_size,
|
vocab_size,
|
||||||
BLOCK_SIZE,
|
BLOCK_SIZE,
|
||||||
|
|||||||
@ -10,6 +10,9 @@ import torch
|
|||||||
|
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
|
_NP_INT64_MIN = np.iinfo(np.int64).min
|
||||||
|
_NP_INT64_MAX = np.iinfo(np.int64).max
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SamplingMetadata:
|
class SamplingMetadata:
|
||||||
@ -19,6 +22,9 @@ class SamplingMetadata:
|
|||||||
top_p: Optional[torch.Tensor]
|
top_p: Optional[torch.Tensor]
|
||||||
top_k: Optional[torch.Tensor]
|
top_k: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
seeds: torch.Tensor
|
||||||
|
pos: torch.Tensor
|
||||||
|
|
||||||
# None means no logprobs, 0 means sampled token logprobs only
|
# None means no logprobs, 0 means sampled token logprobs only
|
||||||
max_num_logprobs: Optional[int]
|
max_num_logprobs: Optional[int]
|
||||||
|
|
||||||
@ -69,6 +75,7 @@ class RequestState:
|
|||||||
self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32)
|
self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32)
|
||||||
# -1 means no logprobs are requested.
|
# -1 means no logprobs are requested.
|
||||||
self.num_logprobs.fill(-1)
|
self.num_logprobs.fill(-1)
|
||||||
|
self.seeds = np.zeros(self.max_num_reqs, dtype=np.int64)
|
||||||
|
|
||||||
self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
|
self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
|
||||||
|
|
||||||
@ -102,6 +109,12 @@ class RequestState:
|
|||||||
top_k = self.vocab_size
|
top_k = self.vocab_size
|
||||||
self.top_k[req_idx] = top_k
|
self.top_k[req_idx] = top_k
|
||||||
|
|
||||||
|
if sampling_params.seed is not None:
|
||||||
|
seed = sampling_params.seed
|
||||||
|
else:
|
||||||
|
seed = np.random.randint(_NP_INT64_MIN, _NP_INT64_MAX)
|
||||||
|
self.seeds[req_idx] = seed
|
||||||
|
|
||||||
if sampling_params.logprobs is not None:
|
if sampling_params.logprobs is not None:
|
||||||
num_logprobs = sampling_params.logprobs
|
num_logprobs = sampling_params.logprobs
|
||||||
else:
|
else:
|
||||||
@ -133,6 +146,7 @@ class RequestState:
|
|||||||
def make_sampling_metadata(
|
def make_sampling_metadata(
|
||||||
self,
|
self,
|
||||||
idx_mapping: np.ndarray,
|
idx_mapping: np.ndarray,
|
||||||
|
pos: torch.Tensor,
|
||||||
) -> SamplingMetadata:
|
) -> SamplingMetadata:
|
||||||
temperature = self.temperature[idx_mapping]
|
temperature = self.temperature[idx_mapping]
|
||||||
temperature = self._copy_np_to_gpu(temperature)
|
temperature = self._copy_np_to_gpu(temperature)
|
||||||
@ -145,6 +159,9 @@ class RequestState:
|
|||||||
no_top_k = np.all(top_k == self.vocab_size)
|
no_top_k = np.all(top_k == self.vocab_size)
|
||||||
top_k = self._copy_np_to_gpu(top_k) if not no_top_k else None
|
top_k = self._copy_np_to_gpu(top_k) if not no_top_k else None
|
||||||
|
|
||||||
|
seeds = self.seeds[idx_mapping]
|
||||||
|
seeds = self._copy_np_to_gpu(seeds)
|
||||||
|
|
||||||
num_logprobs = self.num_logprobs[idx_mapping]
|
num_logprobs = self.num_logprobs[idx_mapping]
|
||||||
max_num_logprobs = np.max(num_logprobs)
|
max_num_logprobs = np.max(num_logprobs)
|
||||||
if max_num_logprobs == -1:
|
if max_num_logprobs == -1:
|
||||||
@ -154,6 +171,8 @@ class RequestState:
|
|||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
|
seeds=seeds,
|
||||||
|
pos=pos,
|
||||||
max_num_logprobs=max_num_logprobs,
|
max_num_logprobs=max_num_logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user