From 7b4b72e551ff1bd17f35d191cb7bc947e21c6e78 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 24 Aug 2025 18:49:23 -0700 Subject: [PATCH] fix Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu_worker_states.py | 75 +++++++++++++---------------- 1 file changed, 33 insertions(+), 42 deletions(-) diff --git a/vllm/v1/worker/gpu_worker_states.py b/vllm/v1/worker/gpu_worker_states.py index 6b149701632a2..2c276e0d93730 100644 --- a/vllm/v1/worker/gpu_worker_states.py +++ b/vllm/v1/worker/gpu_worker_states.py @@ -44,7 +44,7 @@ class RequestData: ] -class RequestAttribute: +class Param: def __init__( self, @@ -67,11 +67,11 @@ class RequestAttribute: dtype=dtype, device=device) if is_scalar: - assert num_cols == 1 self.cpu.squeeze_(1) self.np = self.cpu.numpy() self.gpu.squeeze_(1) + # TODO(woosuk): Optimize this. self.gpu_buffer = self.cpu.to(device) def mirror_to_gpu(self) -> torch.Tensor: @@ -116,32 +116,49 @@ class RequestState: self.req_data: dict[int, RequestData] = {} # TODO(woosuk): Because the token_ids tensor can be very big, we only # initialize it on CPU memory. - self._add_vector_attr("token_ids", - self.max_model_len, - torch.int32, - cpu_only=True) - self._add_scalar_attr("num_prompt_tokens", torch.int32) - self._add_scalar_attr("num_tokens", torch.int32) - self._add_scalar_attr("num_computed_tokens", torch.int32) + self.token_ids = self._make_param( + num_cols=self.max_model_len, + dtype=torch.int32, + cpu_only=True, + ) + self.num_prompt_tokens = self._make_param(torch.int32) + self.num_tokens = self._make_param(torch.int32) + self.num_computed_tokens = self._make_param(torch.int32) # Sampling-related. - self._add_scalar_attr("temperature", torch.float32) + self.temperature = self._make_param(torch.float32) self.greedy_reqs: set[str] = set() self.random_reqs: set[str] = set() - self._add_scalar_attr("top_p", torch.float32) + self.top_p = self._make_param(torch.float32) self.top_p_reqs: set[str] = set() - self._add_scalar_attr("top_k", torch.int32) + self.top_k = self._make_param(torch.int32) self.top_k_reqs: set[str] = set() - self._add_scalar_attr("frequency_penalties", torch.float32) + self.frequency_penalties = self._make_param(torch.float32) self.frequency_penalties_reqs: set[str] = set() - self._add_scalar_attr("presence_penalties", torch.float32) + self.presence_penalties = self._make_param(torch.float32) self.presence_penalties_reqs: set[str] = set() - self._add_scalar_attr("repetition_penalties", torch.float32) + self.repetition_penalties = self._make_param(torch.float32) self.repetition_penalties_reqs: set[str] = set() # req_idx -> generator self.generators: dict[int, torch.Generator] = {} + def _make_param( + self, + dtype: torch.dtype, + num_cols: int = 1, + cpu_only: bool = False, + ) -> Param: + return Param( + self.max_num_cached_reqs, + num_cols, + self.max_num_reqs if not cpu_only else 0, + dtype, + self.device, + self.pin_memory, + is_scalar=num_cols == 1, + ) + def add_request( self, req_id: str, @@ -266,7 +283,7 @@ class RequestState: no_penalties=no_penalties, # TODO generators={}, - token_ids=self.token_ids.cpu[:batch_size], + token_ids=None, max_num_logprobs=None, allowed_token_ids_mask=None, bad_words_token_ids={}, @@ -277,32 +294,6 @@ class RequestState: def num_cached_reqs(self) -> int: return len(self.req_id_to_index) - def _add_scalar_attr(self, name: str, dtype: torch.dtype): - attr = RequestAttribute(self.max_num_cached_reqs, - 1, - self.max_num_reqs, - dtype, - self.device, - self.pin_memory, - is_scalar=True) - setattr(self, name, attr) - - def _add_vector_attr( - self, - name: str, - max_len: int, - dtype: torch.dtype, - cpu_only: bool = False, - ): - if cpu_only: - num_rows_gpu = 0 - else: - num_rows_gpu = self.max_num_reqs - attr = RequestAttribute(self.max_num_cached_reqs, max_len, - num_rows_gpu, dtype, self.device, - self.pin_memory) - setattr(self, name, attr) - @triton.jit def _make_sampling_metadata_kernel(