mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 00:37:08 +08:00
fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
633f9f006d
commit
9a6fcca030
@ -8,11 +8,9 @@ import numpy as np
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import (MultiModalKwargsItem,
|
||||
MultiModalKwargsItems, PlaceholderRange)
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors
|
||||
@ -37,15 +35,6 @@ class RequestData:
|
||||
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
|
||||
# Temporary back-compatibility for plugins that define model runner
|
||||
@property
|
||||
@deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
|
||||
"removed in v0.13. Please use `mm_kwargs` instead.")
|
||||
def mm_inputs(self) -> list[MultiModalKwargsItems]:
|
||||
return [
|
||||
MultiModalKwargsItems.from_seq([item]) for item in self.mm_kwargs
|
||||
]
|
||||
|
||||
|
||||
class RequestState:
|
||||
|
||||
@ -81,26 +70,33 @@ class RequestState:
|
||||
self.req_data: dict[int, RequestData] = {}
|
||||
# TODO(woosuk): Because the token_ids tensor can be very big, we only
|
||||
# initialize it on CPU memory.
|
||||
self.token_ids = self._make_param(
|
||||
num_cols=self.max_model_len,
|
||||
dtype=torch.int32,
|
||||
cpu_only=True,
|
||||
self.token_ids = np.zeros(
|
||||
(self.max_num_cached_reqs, self.max_model_len),
|
||||
dtype=np.int32,
|
||||
)
|
||||
self.num_prompt_tokens = np.zeros(self.max_num_cached_reqs, dtype=np.int32)
|
||||
self.num_prompt_tokens = np.zeros(self.max_num_cached_reqs,
|
||||
dtype=np.int32)
|
||||
self.num_tokens = np.zeros(self.max_num_cached_reqs, dtype=np.int32)
|
||||
self.num_computed_tokens = np.zeros(self.max_num_cached_reqs, dtype=np.int32)
|
||||
self.num_computed_tokens = np.zeros(self.max_num_cached_reqs,
|
||||
dtype=np.int32)
|
||||
|
||||
# Last sampled token ids.
|
||||
self.last_sampled_token = torch.zeros(
|
||||
self.max_num_cached_reqs,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.temperature = np.zeros(self.max_num_cached_reqs, dtype=np.float32)
|
||||
self.greedy_req_indices: set[int] = set()
|
||||
self.top_p = np.zeros(self.max_num_cached_reqs, dtype=np.float32)
|
||||
self.top_p_req_indices: set[int] = set()
|
||||
self.top_k = np.zeros(self.max_num_cached_reqs, dtype=np.int32)
|
||||
self.top_k_req_indices: set[int] = set()
|
||||
|
||||
self.frequency_penalties = np.zeros(self.max_num_cached_reqs, dtype=np.float32)
|
||||
self.presence_penalties = np.zeros(self.max_num_cached_reqs, dtype=np.float32)
|
||||
self.repetition_penalties = np.zeros(self.max_num_cached_reqs, dtype=np.float32)
|
||||
self.penalty_req_indices: set[int] = set()
|
||||
self.frequency_penalties = np.zeros(self.max_num_cached_reqs,
|
||||
dtype=np.float32)
|
||||
self.presence_penalties = np.zeros(self.max_num_cached_reqs,
|
||||
dtype=np.float32)
|
||||
self.repetition_penalties = np.zeros(self.max_num_cached_reqs,
|
||||
dtype=np.float32)
|
||||
|
||||
self.generators: dict[int, torch.Generator] = {}
|
||||
|
||||
@ -179,9 +175,9 @@ class RequestState:
|
||||
frequency_penalties = self.frequency_penalties[idx_mapping]
|
||||
presence_penalties = self.presence_penalties[idx_mapping]
|
||||
repetition_penalties = self.repetition_penalties[idx_mapping]
|
||||
no_penalties = (np.all(frequency_penalties == 0.0) and
|
||||
np.all(presence_penalties == 0.0) and
|
||||
np.all(repetition_penalties == 1.0))
|
||||
no_penalties = (np.all(frequency_penalties == 0.0)
|
||||
and np.all(presence_penalties == 0.0)
|
||||
and np.all(repetition_penalties == 1.0))
|
||||
if no_penalties:
|
||||
frequency_penalties = None
|
||||
presence_penalties = None
|
||||
@ -194,8 +190,7 @@ class RequestState:
|
||||
if self.generators:
|
||||
generators = {
|
||||
req_idx: self.generators[req_idx]
|
||||
for req_idx in idx_mapping
|
||||
if req_idx in self.generators
|
||||
for req_idx in idx_mapping if req_idx in self.generators
|
||||
}
|
||||
else:
|
||||
generators = {}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user