mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-08 18:17:05 +08:00
working
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
92f337faeb
commit
cbdb47dc01
@ -16,6 +16,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
|
||||
from vllm.v1.sample.sampler import SamplerOutput
|
||||
from vllm.v1.worker.gpu.attn_utils import (get_kv_cache_spec,
|
||||
init_attn_backend, init_kv_cache)
|
||||
@ -356,15 +357,21 @@ class GPUModelRunner:
|
||||
self,
|
||||
sampler_output: SamplerOutput,
|
||||
input_batch: InputBatch,
|
||||
) -> np.ndarray:
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
# Get the number of sampled tokens.
|
||||
# 0 if chunked-prefilling, 1 if not.
|
||||
is_chunked_prefilling = input_batch.is_chunked_prefilling
|
||||
num_sampled_tokens = (~is_chunked_prefilling).astype(np.int32)
|
||||
# Increment the number of tokens.
|
||||
idx_mapping_np = input_batch.idx_mapping_np
|
||||
self.req_states.num_tokens[idx_mapping_np] += num_sampled_tokens
|
||||
return num_sampled_tokens
|
||||
|
||||
# Update the token IDs and the number of tokens.
|
||||
sampled_token_ids_cpu = sampler_output.sampled_token_ids.cpu()
|
||||
sampled_token_ids_np = sampled_token_ids_cpu.numpy()
|
||||
self.req_states.append_token_ids(
|
||||
input_batch.idx_mapping_np,
|
||||
sampled_token_ids_np,
|
||||
num_sampled_tokens=num_sampled_tokens,
|
||||
)
|
||||
return sampled_token_ids_np, num_sampled_tokens
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
@ -372,17 +379,19 @@ class GPUModelRunner:
|
||||
):
|
||||
self.update_states(scheduler_output)
|
||||
if scheduler_output.total_num_scheduled_tokens == 0:
|
||||
return
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
input_batch = self.prepare_inputs(scheduler_output)
|
||||
num_tokens = input_batch.num_tokens_after_padding
|
||||
|
||||
with set_forward_context(
|
||||
input_batch.attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
):
|
||||
hidden_states = self.model(
|
||||
input_ids=input_batch.input_ids,
|
||||
positions=input_batch.positions,
|
||||
input_ids=input_batch.input_ids[:num_tokens],
|
||||
positions=input_batch.positions[:num_tokens],
|
||||
)
|
||||
|
||||
# Compute logits to sample next tokens.
|
||||
@ -393,5 +402,19 @@ class GPUModelRunner:
|
||||
prompt_logprobs = self.compute_prompt_logprobs(hidden_states,
|
||||
input_batch)
|
||||
|
||||
output = self.postprocess(sampler_output, input_batch)
|
||||
return output
|
||||
sampled_token_ids_np, num_sampled_tokens = self.postprocess(
|
||||
sampler_output, input_batch)
|
||||
req_id_to_index = {
|
||||
req_id: i
|
||||
for i, req_id in enumerate(input_batch.req_ids)
|
||||
}
|
||||
return ModelRunnerOutput(
|
||||
req_ids=input_batch.req_ids,
|
||||
req_id_to_index=req_id_to_index,
|
||||
sampled_token_ids=sampled_token_ids_np.tolist(),
|
||||
logprobs=sampler_output.logprobs_tensors,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
kv_connector_output=None,
|
||||
num_nans_in_logits=None,
|
||||
)
|
||||
|
||||
@ -3,6 +3,8 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import numba
|
||||
import numba.types as types
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@ -160,3 +162,47 @@ class RequestState:
|
||||
if self.pin_memory:
|
||||
cpu_tensor = cpu_tensor.pin_memory()
|
||||
return cpu_tensor.to(device=self.device, non_blocking=True)
|
||||
|
||||
def append_token_ids(
|
||||
self,
|
||||
req_indices: np.ndarray,
|
||||
sampled_ids: np.ndarray,
|
||||
num_sampled_tokens: np.ndarray,
|
||||
) -> None:
|
||||
_append_token_ids(
|
||||
req_indices,
|
||||
sampled_ids,
|
||||
num_sampled_tokens,
|
||||
self.token_ids,
|
||||
self.num_tokens,
|
||||
)
|
||||
|
||||
|
||||
@numba.jit(
|
||||
[
|
||||
types.none(
|
||||
types.int32[:],
|
||||
types.int64[:, :],
|
||||
types.int32[:],
|
||||
types.int32[:, :],
|
||||
types.int32[:],
|
||||
)
|
||||
],
|
||||
nopython=True,
|
||||
cache=True,
|
||||
)
|
||||
def _append_token_ids(
|
||||
req_indices: np.ndarray,
|
||||
sampled_ids: np.ndarray,
|
||||
num_sampled_tokens: np.ndarray,
|
||||
token_ids: np.ndarray,
|
||||
num_tokens: np.ndarray,
|
||||
) -> None:
|
||||
num_reqs = num_sampled_tokens.shape[0]
|
||||
for i in range(num_reqs):
|
||||
req_idx = req_indices[i]
|
||||
n = num_sampled_tokens[i]
|
||||
start_idx = num_tokens[req_idx]
|
||||
end_idx = start_idx + n
|
||||
token_ids[req_idx, start_idx:end_idx] = sampled_ids[i, :n]
|
||||
num_tokens[req_idx] = end_idx
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user