mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-04 01:42:16 +08:00
fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
da9cd26c78
commit
7b4b72e551
@ -44,7 +44,7 @@ class RequestData:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class RequestAttribute:
|
class Param:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -67,11 +67,11 @@ class RequestAttribute:
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device)
|
device=device)
|
||||||
if is_scalar:
|
if is_scalar:
|
||||||
assert num_cols == 1
|
|
||||||
self.cpu.squeeze_(1)
|
self.cpu.squeeze_(1)
|
||||||
self.np = self.cpu.numpy()
|
self.np = self.cpu.numpy()
|
||||||
self.gpu.squeeze_(1)
|
self.gpu.squeeze_(1)
|
||||||
|
|
||||||
|
# TODO(woosuk): Optimize this.
|
||||||
self.gpu_buffer = self.cpu.to(device)
|
self.gpu_buffer = self.cpu.to(device)
|
||||||
|
|
||||||
def mirror_to_gpu(self) -> torch.Tensor:
|
def mirror_to_gpu(self) -> torch.Tensor:
|
||||||
@ -116,32 +116,49 @@ class RequestState:
|
|||||||
self.req_data: dict[int, RequestData] = {}
|
self.req_data: dict[int, RequestData] = {}
|
||||||
# TODO(woosuk): Because the token_ids tensor can be very big, we only
|
# TODO(woosuk): Because the token_ids tensor can be very big, we only
|
||||||
# initialize it on CPU memory.
|
# initialize it on CPU memory.
|
||||||
self._add_vector_attr("token_ids",
|
self.token_ids = self._make_param(
|
||||||
self.max_model_len,
|
num_cols=self.max_model_len,
|
||||||
torch.int32,
|
dtype=torch.int32,
|
||||||
cpu_only=True)
|
cpu_only=True,
|
||||||
self._add_scalar_attr("num_prompt_tokens", torch.int32)
|
)
|
||||||
self._add_scalar_attr("num_tokens", torch.int32)
|
self.num_prompt_tokens = self._make_param(torch.int32)
|
||||||
self._add_scalar_attr("num_computed_tokens", torch.int32)
|
self.num_tokens = self._make_param(torch.int32)
|
||||||
|
self.num_computed_tokens = self._make_param(torch.int32)
|
||||||
|
|
||||||
# Sampling-related.
|
# Sampling-related.
|
||||||
self._add_scalar_attr("temperature", torch.float32)
|
self.temperature = self._make_param(torch.float32)
|
||||||
self.greedy_reqs: set[str] = set()
|
self.greedy_reqs: set[str] = set()
|
||||||
self.random_reqs: set[str] = set()
|
self.random_reqs: set[str] = set()
|
||||||
self._add_scalar_attr("top_p", torch.float32)
|
self.top_p = self._make_param(torch.float32)
|
||||||
self.top_p_reqs: set[str] = set()
|
self.top_p_reqs: set[str] = set()
|
||||||
self._add_scalar_attr("top_k", torch.int32)
|
self.top_k = self._make_param(torch.int32)
|
||||||
self.top_k_reqs: set[str] = set()
|
self.top_k_reqs: set[str] = set()
|
||||||
self._add_scalar_attr("frequency_penalties", torch.float32)
|
self.frequency_penalties = self._make_param(torch.float32)
|
||||||
self.frequency_penalties_reqs: set[str] = set()
|
self.frequency_penalties_reqs: set[str] = set()
|
||||||
self._add_scalar_attr("presence_penalties", torch.float32)
|
self.presence_penalties = self._make_param(torch.float32)
|
||||||
self.presence_penalties_reqs: set[str] = set()
|
self.presence_penalties_reqs: set[str] = set()
|
||||||
self._add_scalar_attr("repetition_penalties", torch.float32)
|
self.repetition_penalties = self._make_param(torch.float32)
|
||||||
self.repetition_penalties_reqs: set[str] = set()
|
self.repetition_penalties_reqs: set[str] = set()
|
||||||
|
|
||||||
# req_idx -> generator
|
# req_idx -> generator
|
||||||
self.generators: dict[int, torch.Generator] = {}
|
self.generators: dict[int, torch.Generator] = {}
|
||||||
|
|
||||||
|
def _make_param(
|
||||||
|
self,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
num_cols: int = 1,
|
||||||
|
cpu_only: bool = False,
|
||||||
|
) -> Param:
|
||||||
|
return Param(
|
||||||
|
self.max_num_cached_reqs,
|
||||||
|
num_cols,
|
||||||
|
self.max_num_reqs if not cpu_only else 0,
|
||||||
|
dtype,
|
||||||
|
self.device,
|
||||||
|
self.pin_memory,
|
||||||
|
is_scalar=num_cols == 1,
|
||||||
|
)
|
||||||
|
|
||||||
def add_request(
|
def add_request(
|
||||||
self,
|
self,
|
||||||
req_id: str,
|
req_id: str,
|
||||||
@ -266,7 +283,7 @@ class RequestState:
|
|||||||
no_penalties=no_penalties,
|
no_penalties=no_penalties,
|
||||||
# TODO
|
# TODO
|
||||||
generators={},
|
generators={},
|
||||||
token_ids=self.token_ids.cpu[:batch_size],
|
token_ids=None,
|
||||||
max_num_logprobs=None,
|
max_num_logprobs=None,
|
||||||
allowed_token_ids_mask=None,
|
allowed_token_ids_mask=None,
|
||||||
bad_words_token_ids={},
|
bad_words_token_ids={},
|
||||||
@ -277,32 +294,6 @@ class RequestState:
|
|||||||
def num_cached_reqs(self) -> int:
|
def num_cached_reqs(self) -> int:
|
||||||
return len(self.req_id_to_index)
|
return len(self.req_id_to_index)
|
||||||
|
|
||||||
def _add_scalar_attr(self, name: str, dtype: torch.dtype):
|
|
||||||
attr = RequestAttribute(self.max_num_cached_reqs,
|
|
||||||
1,
|
|
||||||
self.max_num_reqs,
|
|
||||||
dtype,
|
|
||||||
self.device,
|
|
||||||
self.pin_memory,
|
|
||||||
is_scalar=True)
|
|
||||||
setattr(self, name, attr)
|
|
||||||
|
|
||||||
def _add_vector_attr(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
max_len: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
cpu_only: bool = False,
|
|
||||||
):
|
|
||||||
if cpu_only:
|
|
||||||
num_rows_gpu = 0
|
|
||||||
else:
|
|
||||||
num_rows_gpu = self.max_num_reqs
|
|
||||||
attr = RequestAttribute(self.max_num_cached_reqs, max_len,
|
|
||||||
num_rows_gpu, dtype, self.device,
|
|
||||||
self.pin_memory)
|
|
||||||
setattr(self, name, attr)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _make_sampling_metadata_kernel(
|
def _make_sampling_metadata_kernel(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user