Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-08-24 18:49:23 -07:00
parent da9cd26c78
commit 7b4b72e551

View File

@ -44,7 +44,7 @@ class RequestData:
]
class RequestAttribute:
class Param:
def __init__(
self,
@ -67,11 +67,11 @@ class RequestAttribute:
dtype=dtype,
device=device)
if is_scalar:
assert num_cols == 1
self.cpu.squeeze_(1)
self.np = self.cpu.numpy()
self.gpu.squeeze_(1)
# TODO(woosuk): Optimize this.
self.gpu_buffer = self.cpu.to(device)
def mirror_to_gpu(self) -> torch.Tensor:
@ -116,32 +116,49 @@ class RequestState:
self.req_data: dict[int, RequestData] = {}
# TODO(woosuk): Because the token_ids tensor can be very big, we only
# initialize it on CPU memory.
self._add_vector_attr("token_ids",
self.max_model_len,
torch.int32,
cpu_only=True)
self._add_scalar_attr("num_prompt_tokens", torch.int32)
self._add_scalar_attr("num_tokens", torch.int32)
self._add_scalar_attr("num_computed_tokens", torch.int32)
self.token_ids = self._make_param(
num_cols=self.max_model_len,
dtype=torch.int32,
cpu_only=True,
)
self.num_prompt_tokens = self._make_param(torch.int32)
self.num_tokens = self._make_param(torch.int32)
self.num_computed_tokens = self._make_param(torch.int32)
# Sampling-related.
self._add_scalar_attr("temperature", torch.float32)
self.temperature = self._make_param(torch.float32)
self.greedy_reqs: set[str] = set()
self.random_reqs: set[str] = set()
self._add_scalar_attr("top_p", torch.float32)
self.top_p = self._make_param(torch.float32)
self.top_p_reqs: set[str] = set()
self._add_scalar_attr("top_k", torch.int32)
self.top_k = self._make_param(torch.int32)
self.top_k_reqs: set[str] = set()
self._add_scalar_attr("frequency_penalties", torch.float32)
self.frequency_penalties = self._make_param(torch.float32)
self.frequency_penalties_reqs: set[str] = set()
self._add_scalar_attr("presence_penalties", torch.float32)
self.presence_penalties = self._make_param(torch.float32)
self.presence_penalties_reqs: set[str] = set()
self._add_scalar_attr("repetition_penalties", torch.float32)
self.repetition_penalties = self._make_param(torch.float32)
self.repetition_penalties_reqs: set[str] = set()
# req_idx -> generator
self.generators: dict[int, torch.Generator] = {}
def _make_param(
self,
dtype: torch.dtype,
num_cols: int = 1,
cpu_only: bool = False,
) -> Param:
return Param(
self.max_num_cached_reqs,
num_cols,
self.max_num_reqs if not cpu_only else 0,
dtype,
self.device,
self.pin_memory,
is_scalar=num_cols == 1,
)
def add_request(
self,
req_id: str,
@ -266,7 +283,7 @@ class RequestState:
no_penalties=no_penalties,
# TODO
generators={},
token_ids=self.token_ids.cpu[:batch_size],
token_ids=None,
max_num_logprobs=None,
allowed_token_ids_mask=None,
bad_words_token_ids={},
@ -277,32 +294,6 @@ class RequestState:
def num_cached_reqs(self) -> int:
return len(self.req_id_to_index)
def _add_scalar_attr(self, name: str, dtype: torch.dtype):
attr = RequestAttribute(self.max_num_cached_reqs,
1,
self.max_num_reqs,
dtype,
self.device,
self.pin_memory,
is_scalar=True)
setattr(self, name, attr)
def _add_vector_attr(
self,
name: str,
max_len: int,
dtype: torch.dtype,
cpu_only: bool = False,
):
if cpu_only:
num_rows_gpu = 0
else:
num_rows_gpu = self.max_num_reqs
attr = RequestAttribute(self.max_num_cached_reqs, max_len,
num_rows_gpu, dtype, self.device,
self.pin_memory)
setattr(self, name, attr)
@triton.jit
def _make_sampling_metadata_kernel(