mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-19 19:17:09 +08:00
fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
da9cd26c78
commit
7b4b72e551
@ -44,7 +44,7 @@ class RequestData:
|
||||
]
|
||||
|
||||
|
||||
class RequestAttribute:
|
||||
class Param:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -67,11 +67,11 @@ class RequestAttribute:
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
if is_scalar:
|
||||
assert num_cols == 1
|
||||
self.cpu.squeeze_(1)
|
||||
self.np = self.cpu.numpy()
|
||||
self.gpu.squeeze_(1)
|
||||
|
||||
# TODO(woosuk): Optimize this.
|
||||
self.gpu_buffer = self.cpu.to(device)
|
||||
|
||||
def mirror_to_gpu(self) -> torch.Tensor:
|
||||
@ -116,32 +116,49 @@ class RequestState:
|
||||
self.req_data: dict[int, RequestData] = {}
|
||||
# TODO(woosuk): Because the token_ids tensor can be very big, we only
|
||||
# initialize it on CPU memory.
|
||||
self._add_vector_attr("token_ids",
|
||||
self.max_model_len,
|
||||
torch.int32,
|
||||
cpu_only=True)
|
||||
self._add_scalar_attr("num_prompt_tokens", torch.int32)
|
||||
self._add_scalar_attr("num_tokens", torch.int32)
|
||||
self._add_scalar_attr("num_computed_tokens", torch.int32)
|
||||
self.token_ids = self._make_param(
|
||||
num_cols=self.max_model_len,
|
||||
dtype=torch.int32,
|
||||
cpu_only=True,
|
||||
)
|
||||
self.num_prompt_tokens = self._make_param(torch.int32)
|
||||
self.num_tokens = self._make_param(torch.int32)
|
||||
self.num_computed_tokens = self._make_param(torch.int32)
|
||||
|
||||
# Sampling-related.
|
||||
self._add_scalar_attr("temperature", torch.float32)
|
||||
self.temperature = self._make_param(torch.float32)
|
||||
self.greedy_reqs: set[str] = set()
|
||||
self.random_reqs: set[str] = set()
|
||||
self._add_scalar_attr("top_p", torch.float32)
|
||||
self.top_p = self._make_param(torch.float32)
|
||||
self.top_p_reqs: set[str] = set()
|
||||
self._add_scalar_attr("top_k", torch.int32)
|
||||
self.top_k = self._make_param(torch.int32)
|
||||
self.top_k_reqs: set[str] = set()
|
||||
self._add_scalar_attr("frequency_penalties", torch.float32)
|
||||
self.frequency_penalties = self._make_param(torch.float32)
|
||||
self.frequency_penalties_reqs: set[str] = set()
|
||||
self._add_scalar_attr("presence_penalties", torch.float32)
|
||||
self.presence_penalties = self._make_param(torch.float32)
|
||||
self.presence_penalties_reqs: set[str] = set()
|
||||
self._add_scalar_attr("repetition_penalties", torch.float32)
|
||||
self.repetition_penalties = self._make_param(torch.float32)
|
||||
self.repetition_penalties_reqs: set[str] = set()
|
||||
|
||||
# req_idx -> generator
|
||||
self.generators: dict[int, torch.Generator] = {}
|
||||
|
||||
def _make_param(
|
||||
self,
|
||||
dtype: torch.dtype,
|
||||
num_cols: int = 1,
|
||||
cpu_only: bool = False,
|
||||
) -> Param:
|
||||
return Param(
|
||||
self.max_num_cached_reqs,
|
||||
num_cols,
|
||||
self.max_num_reqs if not cpu_only else 0,
|
||||
dtype,
|
||||
self.device,
|
||||
self.pin_memory,
|
||||
is_scalar=num_cols == 1,
|
||||
)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
req_id: str,
|
||||
@ -266,7 +283,7 @@ class RequestState:
|
||||
no_penalties=no_penalties,
|
||||
# TODO
|
||||
generators={},
|
||||
token_ids=self.token_ids.cpu[:batch_size],
|
||||
token_ids=None,
|
||||
max_num_logprobs=None,
|
||||
allowed_token_ids_mask=None,
|
||||
bad_words_token_ids={},
|
||||
@ -277,32 +294,6 @@ class RequestState:
|
||||
def num_cached_reqs(self) -> int:
|
||||
return len(self.req_id_to_index)
|
||||
|
||||
def _add_scalar_attr(self, name: str, dtype: torch.dtype):
|
||||
attr = RequestAttribute(self.max_num_cached_reqs,
|
||||
1,
|
||||
self.max_num_reqs,
|
||||
dtype,
|
||||
self.device,
|
||||
self.pin_memory,
|
||||
is_scalar=True)
|
||||
setattr(self, name, attr)
|
||||
|
||||
def _add_vector_attr(
|
||||
self,
|
||||
name: str,
|
||||
max_len: int,
|
||||
dtype: torch.dtype,
|
||||
cpu_only: bool = False,
|
||||
):
|
||||
if cpu_only:
|
||||
num_rows_gpu = 0
|
||||
else:
|
||||
num_rows_gpu = self.max_num_reqs
|
||||
attr = RequestAttribute(self.max_num_cached_reqs, max_len,
|
||||
num_rows_gpu, dtype, self.device,
|
||||
self.pin_memory)
|
||||
setattr(self, name, attr)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _make_sampling_metadata_kernel(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user