mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:35:01 +08:00
parent
ceca060501
commit
1bed891f72
@ -7,6 +7,7 @@ import torch
|
||||
|
||||
from vllm.v1.outputs import (
|
||||
AsyncModelRunnerOutput,
|
||||
LogprobsTensors,
|
||||
ModelRunnerOutput,
|
||||
SamplerOutput,
|
||||
)
|
||||
@ -46,15 +47,18 @@ class AsyncOutput(AsyncModelRunnerOutput):
|
||||
"cpu", non_blocking=True
|
||||
)
|
||||
if sampler_output.logprobs_tensors is not None:
|
||||
self.logprobs_tensors = (
|
||||
self.logprobs_tensors: LogprobsTensors | None = (
|
||||
sampler_output.logprobs_tensors.to_cpu_nonblocking()
|
||||
)
|
||||
else:
|
||||
self.logprobs_tensors = None
|
||||
self.prompt_logprobs_dict = {}
|
||||
self.prompt_logprobs_dict: dict[str, LogprobsTensors | None] = {}
|
||||
if self.model_runner_output.prompt_logprobs_dict:
|
||||
for k, v in self.model_runner_output.prompt_logprobs_dict.items():
|
||||
self.prompt_logprobs_dict[k] = v.to_cpu_nonblocking()
|
||||
if v is not None:
|
||||
self.prompt_logprobs_dict[k] = v.to_cpu_nonblocking()
|
||||
else:
|
||||
self.prompt_logprobs_dict[k] = None
|
||||
self.copy_event.record(self.copy_stream)
|
||||
|
||||
def get_output(self) -> ModelRunnerOutput:
|
||||
@ -64,12 +68,10 @@ class AsyncOutput(AsyncModelRunnerOutput):
|
||||
# the existing model runner.
|
||||
# Going forward, we should keep the data structures as NumPy arrays
|
||||
# rather than Python lists.
|
||||
sampled_token_ids_np = self.sampled_token_ids.numpy()
|
||||
num_reqs = sampled_token_ids_np.shape[0]
|
||||
sampled_token_ids: list[np.ndarray] = [
|
||||
sampled_token_ids_np[i, : self.num_sampled_tokens[i]]
|
||||
for i in range(num_reqs)
|
||||
]
|
||||
sampled_token_ids: list[list[int]] = self.sampled_token_ids.tolist()
|
||||
num_reqs = len(sampled_token_ids)
|
||||
for i in range(num_reqs):
|
||||
del sampled_token_ids[i][self.num_sampled_tokens[i] :]
|
||||
self.model_runner_output.sampled_token_ids = sampled_token_ids
|
||||
|
||||
if self.logprobs_tensors is not None:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import torch
|
||||
|
||||
@ -13,6 +13,7 @@ from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
AttentionSpec,
|
||||
KVCacheConfig,
|
||||
KVCacheSpec,
|
||||
)
|
||||
@ -22,7 +23,8 @@ from vllm.v1.worker.utils import bind_kv_cache
|
||||
|
||||
def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
attn_layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase)
|
||||
layer_type = cast(type[Any], AttentionLayerBase)
|
||||
attn_layers = get_layers_from_vllm_config(vllm_config, layer_type)
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
# Skip modules that don't need KV cache (eg encoder-only attention)
|
||||
if spec := attn_module.get_kv_cache_spec(vllm_config):
|
||||
@ -35,16 +37,15 @@ def init_attn_backend(
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
attn_backends: dict[str, AttentionBackend] = {}
|
||||
attn_backends: dict[str, type[AttentionBackend]] = {}
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder] = []
|
||||
flashinfer_workspace: torch.Tensor | None = None
|
||||
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
|
||||
layer_names = kv_cache_group_spec.layer_names
|
||||
any_layer_name = next(iter(layer_names))
|
||||
|
||||
attn_layers = get_layers_from_vllm_config(
|
||||
vllm_config, AttentionLayerBase, layer_names
|
||||
)
|
||||
layer_type = cast(type[Any], AttentionLayerBase)
|
||||
attn_layers = get_layers_from_vllm_config(vllm_config, layer_type, layer_names)
|
||||
attn_backend = attn_layers[any_layer_name].get_attn_backend()
|
||||
for layer_name in layer_names:
|
||||
attn_backends[layer_name] = attn_backend
|
||||
@ -93,6 +94,7 @@ def _reshape_kv_cache(
|
||||
kv_caches: dict[str, torch.Tensor] = {}
|
||||
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
|
||||
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
|
||||
assert isinstance(kv_cache_spec, AttentionSpec)
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
raw_tensor = kv_cache_raw_tensors[layer_name]
|
||||
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
|
||||
|
||||
@ -34,8 +34,16 @@ class CudaGraphManager:
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
assert self.compilation_config is not None
|
||||
|
||||
self.cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||
self.cudagraph_sizes = sorted(self.compilation_config.cudagraph_capture_sizes)
|
||||
if self.compilation_config.cudagraph_mode is None:
|
||||
self.cudagraph_mode = CUDAGraphMode.NONE
|
||||
else:
|
||||
self.cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||
if self.compilation_config.cudagraph_capture_sizes is not None:
|
||||
self.cudagraph_sizes = sorted(
|
||||
self.compilation_config.cudagraph_capture_sizes
|
||||
)
|
||||
else:
|
||||
self.cudagraph_sizes = []
|
||||
self.padded_sizes = self._init_padded_sizes()
|
||||
|
||||
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
|
||||
|
||||
@ -329,8 +329,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def update_states(self, scheduler_output: SchedulerOutput) -> None:
|
||||
for req_id in scheduler_output.preempted_req_ids:
|
||||
self.req_states.remove_request(req_id)
|
||||
if scheduler_output.preempted_req_ids is not None:
|
||||
for req_id in scheduler_output.preempted_req_ids:
|
||||
self.req_states.remove_request(req_id)
|
||||
for req_id in scheduler_output.finished_req_ids:
|
||||
self.req_states.remove_request(req_id)
|
||||
|
||||
@ -346,6 +347,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
# Add new requests.
|
||||
for new_req_data in scheduler_output.scheduled_new_reqs:
|
||||
assert new_req_data.prompt_token_ids is not None
|
||||
assert new_req_data.prefill_token_ids is not None
|
||||
assert new_req_data.sampling_params is not None
|
||||
req_id = new_req_data.req_id
|
||||
self.req_states.add_request(
|
||||
req_id=req_id,
|
||||
@ -398,8 +402,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Decode first, then prefill.
|
||||
# batch_idx -> req_id
|
||||
req_ids = sorted(
|
||||
scheduler_output.num_scheduled_tokens,
|
||||
key=scheduler_output.num_scheduled_tokens.get,
|
||||
scheduler_output.num_scheduled_tokens.keys(),
|
||||
key=lambda k: scheduler_output.num_scheduled_tokens[k],
|
||||
)
|
||||
num_scheduled_tokens = np.array(
|
||||
[scheduler_output.num_scheduled_tokens[i] for i in req_ids], dtype=np.int32
|
||||
@ -637,9 +641,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=input_batch.req_ids,
|
||||
req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
|
||||
sampled_token_ids=None,
|
||||
sampled_token_ids=None, # type: ignore
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore
|
||||
pooler_output=[],
|
||||
kv_connector_output=None,
|
||||
num_nans_in_logits=None,
|
||||
|
||||
@ -8,8 +8,8 @@ import triton.language as tl
|
||||
|
||||
from vllm.config.model import LogprobsMode
|
||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.v1.worker.gpu.states import SamplingMetadata
|
||||
|
||||
|
||||
class Sampler:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user