mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-16 03:17:02 +08:00
wip
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
5f95309a6d
commit
787e59629c
3205
vllm/v1/worker/gpu_model_runner copy.py
Normal file
3205
vllm/v1/worker/gpu_model_runner copy.py
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -47,78 +47,6 @@ class RequestData:
|
||||
]
|
||||
|
||||
|
||||
class SamplingStates:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_cached_reqs: int,
|
||||
vocab_size: int,
|
||||
device: torch.device,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_cached_reqs = max_num_cached_reqs
|
||||
self.vocab_size = vocab_size
|
||||
self.device = device
|
||||
|
||||
self.temperature = self._make_param(torch.float32)
|
||||
self.greedy_req_indices: set[int] = set()
|
||||
self.top_p = self._make_param(torch.float32)
|
||||
self.top_p_req_indices: set[int] = set()
|
||||
self.top_k = self._make_param(torch.int32)
|
||||
self.top_k_req_indices: set[int] = set()
|
||||
|
||||
self.frequency_penalties = self._make_param(torch.float32)
|
||||
self.presence_penalties = self._make_param(torch.float32)
|
||||
self.repetition_penalties = self._make_param(torch.float32)
|
||||
self.penalty_req_indices: set[int] = set()
|
||||
|
||||
self.generators: dict[int, torch.Generator] = {}
|
||||
|
||||
def _make_param(self, dtype: torch.dtype) -> torch.Tensor:
|
||||
return torch.zeros(self.max_num_reqs, dtype=dtype, device=self.device)
|
||||
|
||||
def add_requests(
|
||||
self,
|
||||
req_indices: list[int],
|
||||
sampling_params: list[SamplingParams],
|
||||
) -> None:
|
||||
num_reqs = len(req_indices)
|
||||
for i in range(num_reqs):
|
||||
req_idx = req_indices[i]
|
||||
sampling_param = sampling_params[i]
|
||||
|
||||
temp = sampling_param.temperature
|
||||
if temp == 0.0:
|
||||
self.greedy_req_indices.add(req_idx)
|
||||
|
||||
top_p = sampling_param.top_p
|
||||
if top_p < 1.0:
|
||||
self.top_p_req_indices.add(req_idx)
|
||||
top_k = sampling_param.top_k
|
||||
if 0 < top_k < self.vocab_size:
|
||||
self.top_k_req_indices.add(req_idx)
|
||||
else:
|
||||
top_k = self.vocab_size
|
||||
|
||||
if sampling_param.frequency_penalty != 0.0 or sampling_param.presence_penalty != 0.0 or sampling_param.repetition_penalty != 1.0:
|
||||
self.penalty_req_indices.add(req_idx)
|
||||
|
||||
if sampling_param.sampling_type == SamplingType.RANDOM_SEED:
|
||||
generator = torch.Generator(device=self.device)
|
||||
generator.manual_seed(sampling_param.seed)
|
||||
self.generators[req_idx] = generator
|
||||
|
||||
def remove_request(self, req_idx: int) -> None:
|
||||
self.greedy_req_indices.discard(req_idx)
|
||||
self.top_p_req_indices.discard(req_idx)
|
||||
self.top_k_req_indices.discard(req_idx)
|
||||
self.penalty_req_indices.discard(req_idx)
|
||||
self.generators.pop(req_idx, None)
|
||||
|
||||
|
||||
class RequestState:
|
||||
|
||||
def __init__(
|
||||
@ -130,7 +58,6 @@ class RequestState:
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
vocab_size: int,
|
||||
block_sizes: list[int], # The block_size of each kv cache group
|
||||
logitsprocs: Optional[LogitsProcessors] = None,
|
||||
is_spec_decode: bool = False,
|
||||
is_pooling_model: bool = False,
|
||||
@ -144,7 +71,6 @@ class RequestState:
|
||||
self.vocab_size = vocab_size
|
||||
self.is_spec_decode = is_spec_decode
|
||||
self.pooling_params = None
|
||||
self.block_sizes = block_sizes
|
||||
self.num_prompt_logprobs: dict[int, int] = {}
|
||||
|
||||
self.req_id_to_index: dict[str, int] = {}
|
||||
@ -160,36 +86,23 @@ class RequestState:
|
||||
dtype=torch.int32,
|
||||
cpu_only=True,
|
||||
)
|
||||
self.num_prompt_tokens = self._make_param(torch.int32)
|
||||
self.num_tokens = self._make_param(torch.int32)
|
||||
self.num_computed_tokens = self._make_param(torch.int32)
|
||||
self.num_prompt_tokens = np.zeros(self.max_num_cached_reqs, dtype=np.int32)
|
||||
self.num_tokens = np.zeros(self.max_num_cached_reqs, dtype=np.int32)
|
||||
self.num_computed_tokens = np.zeros(self.max_num_cached_reqs, dtype=np.int32)
|
||||
|
||||
self.sampling_states = SamplingStates(
|
||||
max_num_reqs=max_num_reqs,
|
||||
max_model_len=max_model_len,
|
||||
max_num_cached_reqs=max_num_cached_reqs,
|
||||
device=device,
|
||||
)
|
||||
self.temperature = np.zeros(self.max_num_cached_reqs, dtype=np.float32)
|
||||
self.greedy_req_indices: set[int] = set()
|
||||
self.top_p = np.zeros(self.max_num_cached_reqs, dtype=np.float32)
|
||||
self.top_p_req_indices: set[int] = set()
|
||||
self.top_k = np.zeros(self.max_num_cached_reqs, dtype=np.int32)
|
||||
self.top_k_req_indices: set[int] = set()
|
||||
|
||||
def _make_param(
|
||||
self,
|
||||
dtype: torch.dtype,
|
||||
num_cols: int = 1,
|
||||
cpu_only: bool = False,
|
||||
) -> Param:
|
||||
return Param(
|
||||
self.max_num_cached_reqs,
|
||||
num_cols,
|
||||
self.max_num_reqs if not cpu_only else 0,
|
||||
dtype,
|
||||
self.device,
|
||||
self.pin_memory,
|
||||
is_scalar=num_cols == 1,
|
||||
)
|
||||
self.frequency_penalties = np.zeros(self.max_num_cached_reqs, dtype=np.float32)
|
||||
self.presence_penalties = np.zeros(self.max_num_cached_reqs, dtype=np.float32)
|
||||
self.repetition_penalties = np.zeros(self.max_num_cached_reqs, dtype=np.float32)
|
||||
self.penalty_req_indices: set[int] = set()
|
||||
|
||||
@property
|
||||
def num_cached_reqs(self) -> int:
|
||||
return len(self.req_id_to_index)
|
||||
self.generators: dict[int, torch.Generator] = {}
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
@ -204,46 +117,31 @@ class RequestState:
|
||||
self.index_to_req_id[req_idx] = req_id
|
||||
|
||||
prompt_len = len(prompt_token_ids)
|
||||
self.num_prompt_tokens.np[req_idx] = prompt_len
|
||||
self.num_tokens.np[req_idx] = prompt_len
|
||||
self.token_ids.np[req_idx, :prompt_len] = prompt_token_ids
|
||||
self.num_computed_tokens.np[req_idx] = num_computed_tokens
|
||||
self.num_prompt_tokens[req_idx] = prompt_len
|
||||
self.num_tokens[req_idx] = prompt_len
|
||||
self.token_ids[req_idx, :prompt_len] = prompt_token_ids
|
||||
self.num_computed_tokens[req_idx] = num_computed_tokens
|
||||
|
||||
self.temperature.np[req_idx] = sampling_params.temperature
|
||||
if sampling_params.sampling_type == SamplingType.GREEDY:
|
||||
# NOTE: Be careful about division by zero.
|
||||
self.greedy_reqs.add(req_id)
|
||||
elif sampling_params.sampling_type == SamplingType.RANDOM:
|
||||
self.random_reqs.add(req_id)
|
||||
|
||||
self.top_p.np[req_idx] = sampling_params.top_p
|
||||
if sampling_params.top_p < 1.0:
|
||||
self.top_p_reqs.add(req_id)
|
||||
|
||||
top_k = sampling_params.top_k
|
||||
if 0 < top_k < self.vocab_size:
|
||||
self.top_k_reqs.add(req_id)
|
||||
self.temperature[req_idx] = sampling_params.temperature
|
||||
self.top_p[req_idx] = sampling_params.top_p
|
||||
if 0 < sampling_params.top_k < self.vocab_size:
|
||||
top_k = sampling_params.top_k
|
||||
else:
|
||||
top_k = self.vocab_size
|
||||
self.top_k.np[req_idx] = top_k
|
||||
|
||||
self.frequency_penalties.np[
|
||||
req_idx] = sampling_params.frequency_penalty
|
||||
if sampling_params.frequency_penalty != 0.0:
|
||||
self.frequency_penalties_reqs.add(req_id)
|
||||
self.presence_penalties.np[req_idx] = sampling_params.presence_penalty
|
||||
if sampling_params.presence_penalty != 0.0:
|
||||
self.presence_penalties_reqs.add(req_id)
|
||||
self.repetition_penalties.np[
|
||||
req_idx] = sampling_params.repetition_penalty
|
||||
if sampling_params.repetition_penalty != 1.0:
|
||||
self.repetition_penalties_reqs.add(req_id)
|
||||
self.top_k[req_idx] = top_k
|
||||
self.frequency_penalties[req_idx] = sampling_params.frequency_penalty
|
||||
self.presence_penalties[req_idx] = sampling_params.presence_penalty
|
||||
self.repetition_penalties[req_idx] = sampling_params.repetition_penalty
|
||||
|
||||
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
||||
generator = torch.Generator(device=self.device)
|
||||
generator.manual_seed(sampling_params.seed)
|
||||
self.generators[req_idx] = generator
|
||||
|
||||
@property
|
||||
def num_cached_reqs(self) -> int:
|
||||
return len(self.req_id_to_index)
|
||||
|
||||
def append_token_ids(
|
||||
self,
|
||||
req_idx: int,
|
||||
@ -262,65 +160,57 @@ class RequestState:
|
||||
self.index_to_req_id.pop(req_idx, None)
|
||||
self.free_indices.append(req_idx)
|
||||
|
||||
self.greedy_reqs.discard(req_id)
|
||||
self.random_reqs.discard(req_id)
|
||||
self.top_p_reqs.discard(req_id)
|
||||
self.top_k_reqs.discard(req_id)
|
||||
self.frequency_penalties_reqs.discard(req_id)
|
||||
self.presence_penalties_reqs.discard(req_id)
|
||||
self.repetition_penalties_reqs.discard(req_id)
|
||||
self.generators.pop(req_idx, None)
|
||||
|
||||
def make_sampling_metadata(
|
||||
self,
|
||||
batch_idx_to_req_idx: torch.Tensor,
|
||||
idx_mapping: np.ndarray,
|
||||
) -> SamplingMetadata:
|
||||
batch_size = batch_idx_to_req_idx.shape[0]
|
||||
if self.top_p_reqs:
|
||||
top_p_buffer = self.top_p.mirror_to_gpu()
|
||||
top_p = self.top_p.gpu
|
||||
temperature = self.temperature[idx_mapping]
|
||||
all_greedy = np.all(temperature == 0.0)
|
||||
all_random = np.all(temperature != 0.0)
|
||||
temperature = self._copy_np_to_gpu(temperature)
|
||||
|
||||
top_p = self.top_p[idx_mapping]
|
||||
no_top_p = np.all(top_p == 1.0)
|
||||
top_p = self._copy_np_to_gpu(top_p) if not no_top_p else None
|
||||
top_k = self.top_k[idx_mapping]
|
||||
no_top_k = np.all(top_k == self.vocab_size)
|
||||
top_k = self._copy_np_to_gpu(top_k) if not no_top_k else None
|
||||
|
||||
frequency_penalties = self.frequency_penalties[idx_mapping]
|
||||
presence_penalties = self.presence_penalties[idx_mapping]
|
||||
repetition_penalties = self.repetition_penalties[idx_mapping]
|
||||
no_penalties = (np.all(frequency_penalties == 0.0) and
|
||||
np.all(presence_penalties == 0.0) and
|
||||
np.all(repetition_penalties == 1.0))
|
||||
if no_penalties:
|
||||
frequency_penalties = None
|
||||
presence_penalties = None
|
||||
repetition_penalties = None
|
||||
else:
|
||||
top_p_buffer = self.top_p.gpu_buffer
|
||||
top_p = None
|
||||
if self.top_k_reqs:
|
||||
top_k_buffer = self.top_k.mirror_to_gpu()
|
||||
top_k = self.top_k.gpu
|
||||
frequency_penalties = self._copy_np_to_gpu(frequency_penalties)
|
||||
presence_penalties = self._copy_np_to_gpu(presence_penalties)
|
||||
repetition_penalties = self._copy_np_to_gpu(repetition_penalties)
|
||||
|
||||
if self.generators:
|
||||
generators = {
|
||||
req_idx: self.generators[req_idx]
|
||||
for req_idx in idx_mapping
|
||||
if req_idx in self.generators
|
||||
}
|
||||
else:
|
||||
top_k_buffer = self.top_k.gpu_buffer
|
||||
top_k = None
|
||||
# TODO(woosuk): Use UVA to optimize CPU -> GPU copy.
|
||||
_make_sampling_metadata_kernel[(batch_size, )](
|
||||
batch_idx_to_req_idx,
|
||||
self.temperature.mirror_to_gpu(),
|
||||
self.temperature.gpu,
|
||||
top_p_buffer,
|
||||
self.top_p.gpu,
|
||||
top_k_buffer,
|
||||
self.top_k.gpu,
|
||||
self.frequency_penalties.mirror_to_gpu(),
|
||||
self.frequency_penalties.gpu,
|
||||
self.presence_penalties.mirror_to_gpu(),
|
||||
self.presence_penalties.gpu,
|
||||
self.repetition_penalties.mirror_to_gpu(),
|
||||
self.repetition_penalties.gpu,
|
||||
num_warps=1,
|
||||
num_stages=1,
|
||||
)
|
||||
no_penalties = not (self.frequency_penalties_reqs
|
||||
or self.presence_penalties_reqs
|
||||
or self.repetition_penalties_reqs)
|
||||
generators = {}
|
||||
|
||||
return SamplingMetadata(
|
||||
temperature=self.temperature.gpu[:batch_size],
|
||||
all_greedy=not self.random_reqs,
|
||||
all_random=not self.greedy_reqs,
|
||||
temperature=temperature,
|
||||
all_greedy=all_greedy,
|
||||
all_random=all_random,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
frequency_penalties=self.frequency_penalties.gpu[:batch_size],
|
||||
presence_penalties=self.presence_penalties.gpu[:batch_size],
|
||||
repetition_penalties=self.repetition_penalties.gpu[:batch_size],
|
||||
frequency_penalties=frequency_penalties,
|
||||
presence_penalties=presence_penalties,
|
||||
repetition_penalties=repetition_penalties,
|
||||
no_penalties=no_penalties,
|
||||
# TODO
|
||||
generators={},
|
||||
generators=generators,
|
||||
token_ids=None,
|
||||
num_tokens=None,
|
||||
num_prompt_tokens=None,
|
||||
@ -330,6 +220,10 @@ class RequestState:
|
||||
logitsprocs=None,
|
||||
)
|
||||
|
||||
def _copy_np_to_gpu(self, src: np.ndarray) -> torch.Tensor:
|
||||
cpu_tensor = torch.from_numpy(src)
|
||||
return cpu_tensor.to(device=self.device, non_blocking=True)
|
||||
|
||||
def make_spec_decode_metadata(
|
||||
self,
|
||||
query_start_loc: torch.Tensor,
|
||||
@ -369,44 +263,6 @@ class RequestState:
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _make_sampling_metadata_kernel(
|
||||
batch_idx_to_req_idx, # [batch_size]
|
||||
src_temperature,
|
||||
dst_temperature,
|
||||
src_top_p,
|
||||
dst_top_p,
|
||||
src_top_k,
|
||||
dst_top_k,
|
||||
src_frequency_penalties,
|
||||
dst_frequency_penalties,
|
||||
src_presence_penalties,
|
||||
dst_presence_penalties,
|
||||
src_repetition_penalties,
|
||||
dst_repetition_penalties,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
|
||||
|
||||
temperature = tl.load(src_temperature + req_idx)
|
||||
tl.store(dst_temperature + batch_idx, temperature)
|
||||
|
||||
top_p = tl.load(src_top_p + req_idx)
|
||||
tl.store(dst_top_p + batch_idx, top_p)
|
||||
|
||||
top_k = tl.load(src_top_k + req_idx)
|
||||
tl.store(dst_top_k + batch_idx, top_k)
|
||||
|
||||
frequency_penalties = tl.load(src_frequency_penalties + req_idx)
|
||||
tl.store(dst_frequency_penalties + batch_idx, frequency_penalties)
|
||||
|
||||
presence_penalties = tl.load(src_presence_penalties + req_idx)
|
||||
tl.store(dst_presence_penalties + batch_idx, presence_penalties)
|
||||
|
||||
repetition_penalties = tl.load(src_repetition_penalties + req_idx)
|
||||
tl.store(dst_repetition_penalties + batch_idx, repetition_penalties)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _prepare_spec_decode_kernel(
|
||||
query_start_loc, # [B + 1]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user