mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-08 13:47:06 +08:00
fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
aabfaa08cf
commit
330058f9b8
@ -69,16 +69,22 @@ class RequestState:
|
||||
)
|
||||
|
||||
# Sampling parameters.
|
||||
self.temperature = np.zeros(self.max_num_reqs, dtype=np.float32)
|
||||
self.top_p = np.zeros(self.max_num_reqs, dtype=np.float32)
|
||||
self.top_k = np.zeros(self.max_num_reqs, dtype=np.int32)
|
||||
self.temperature = self._make_buffer(self.max_num_reqs, torch.float32)
|
||||
self.top_p = self._make_buffer(self.max_num_reqs, torch.float32)
|
||||
self.top_k = self._make_buffer(self.max_num_reqs, torch.int32)
|
||||
self.seeds = self._make_buffer(self.max_num_reqs, torch.int64)
|
||||
|
||||
self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32)
|
||||
# -1 means no logprobs are requested.
|
||||
self.num_logprobs.fill(-1)
|
||||
self.seeds = np.zeros(self.max_num_reqs, dtype=np.int64)
|
||||
|
||||
self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
|
||||
|
||||
def _make_buffer(self, size, dtype: torch.dtype) -> "Buffer":
|
||||
return Buffer(size,
|
||||
dtype=dtype,
|
||||
pin_memory=self.pin_memory,
|
||||
device=self.device)
|
||||
|
||||
@property
|
||||
def num_reqs(self) -> int:
|
||||
return len(self.req_id_to_index)
|
||||
@ -101,19 +107,19 @@ class RequestState:
|
||||
self.token_ids[req_idx, :prompt_len] = prompt_token_ids
|
||||
self.num_computed_tokens[req_idx] = num_computed_tokens
|
||||
|
||||
self.temperature[req_idx] = sampling_params.temperature
|
||||
self.top_p[req_idx] = sampling_params.top_p
|
||||
self.temperature.np[req_idx] = sampling_params.temperature
|
||||
self.top_p.np[req_idx] = sampling_params.top_p
|
||||
if 0 < sampling_params.top_k < self.vocab_size:
|
||||
top_k = sampling_params.top_k
|
||||
else:
|
||||
top_k = self.vocab_size
|
||||
self.top_k[req_idx] = top_k
|
||||
self.top_k.np[req_idx] = top_k
|
||||
|
||||
if sampling_params.seed is not None:
|
||||
seed = sampling_params.seed
|
||||
else:
|
||||
seed = np.random.randint(_NP_INT64_MIN, _NP_INT64_MAX)
|
||||
self.seeds[req_idx] = seed
|
||||
self.seeds.np[req_idx] = seed
|
||||
|
||||
if sampling_params.logprobs is not None:
|
||||
num_logprobs = sampling_params.logprobs
|
||||
@ -148,19 +154,19 @@ class RequestState:
|
||||
idx_mapping: np.ndarray,
|
||||
pos: torch.Tensor,
|
||||
) -> SamplingMetadata:
|
||||
temperature = self.temperature[idx_mapping]
|
||||
temperature = self._copy_np_to_gpu(temperature)
|
||||
temperature = self.temperature.np[idx_mapping]
|
||||
temperature = self.temperature.copy_np_to_gpu(temperature)
|
||||
|
||||
top_p = self.top_p[idx_mapping]
|
||||
top_p = self.top_p.np[idx_mapping]
|
||||
no_top_p = np.all(top_p == 1.0)
|
||||
top_p = self._copy_np_to_gpu(top_p) if not no_top_p else None
|
||||
top_p = self.top_p.copy_np_to_gpu(top_p) if not no_top_p else None
|
||||
|
||||
top_k = self.top_k[idx_mapping]
|
||||
top_k = self.top_k.np[idx_mapping]
|
||||
no_top_k = np.all(top_k == self.vocab_size)
|
||||
top_k = self._copy_np_to_gpu(top_k) if not no_top_k else None
|
||||
top_k = self.top_k.copy_np_to_gpu(top_k) if not no_top_k else None
|
||||
|
||||
seeds = self.seeds[idx_mapping]
|
||||
seeds = self._copy_np_to_gpu(seeds)
|
||||
seeds = self.seeds.np[idx_mapping]
|
||||
seeds = self.seeds.copy_np_to_gpu(seeds)
|
||||
|
||||
num_logprobs = self.num_logprobs[idx_mapping]
|
||||
max_num_logprobs = np.max(num_logprobs)
|
||||
@ -176,12 +182,6 @@ class RequestState:
|
||||
max_num_logprobs=max_num_logprobs,
|
||||
)
|
||||
|
||||
def _copy_np_to_gpu(self, src: np.ndarray) -> torch.Tensor:
|
||||
cpu_tensor = torch.from_numpy(src)
|
||||
if self.pin_memory:
|
||||
cpu_tensor = cpu_tensor.pin_memory()
|
||||
return cpu_tensor.to(device=self.device, non_blocking=True)
|
||||
|
||||
def append_token_ids(
|
||||
self,
|
||||
req_indices: np.ndarray,
|
||||
@ -225,3 +225,29 @@ def _append_token_ids(
|
||||
end_idx = start_idx + n
|
||||
token_ids[req_idx, start_idx:end_idx] = sampled_ids[i, :n]
|
||||
num_tokens[req_idx] = end_idx
|
||||
|
||||
|
||||
class Buffer:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
dtype: torch.dtype,
|
||||
pin_memory: bool,
|
||||
device: torch.device,
|
||||
):
|
||||
# NOTE(woosuk): Unlike CpuGpuBuffer, the Numpy array and CPU tensor
|
||||
# in this class do not share the same storage.
|
||||
self.np = np.zeros(*args, dtype=dtype)
|
||||
self.cpu = torch.zeros(
|
||||
*args,
|
||||
dtype=dtype,
|
||||
pin_memory=pin_memory,
|
||||
device=device,
|
||||
)
|
||||
self.gpu = self.cpu.to(device)
|
||||
|
||||
def copy_np_to_gpu(self, x: np.ndarray) -> torch.Tensor:
|
||||
n = x.shape[0]
|
||||
self.cpu[:n] = x
|
||||
return self.gpu[:n].copy_(self.cpu[:n], non_blocking=True)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user