mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 00:55:01 +08:00
[Model Runner V2] Minor cleanup for build_attn_metadata (#29576)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
0aeb698b77
commit
ee80aee1ca
@ -18,7 +18,6 @@ from vllm.v1.kv_cache_interface import (
|
|||||||
KVCacheConfig,
|
KVCacheConfig,
|
||||||
KVCacheSpec,
|
KVCacheSpec,
|
||||||
)
|
)
|
||||||
from vllm.v1.utils import CpuGpuBuffer
|
|
||||||
from vllm.v1.worker.utils import bind_kv_cache
|
from vllm.v1.worker.utils import bind_kv_cache
|
||||||
|
|
||||||
|
|
||||||
@ -145,7 +144,8 @@ def build_attn_metadata(
|
|||||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||||
num_reqs: int,
|
num_reqs: int,
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
query_start_loc: CpuGpuBuffer,
|
query_start_loc_gpu: torch.Tensor,
|
||||||
|
query_start_loc_cpu: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_np: np.ndarray,
|
seq_lens_np: np.ndarray,
|
||||||
num_computed_tokens_cpu: torch.Tensor | None,
|
num_computed_tokens_cpu: torch.Tensor | None,
|
||||||
@ -153,9 +153,7 @@ def build_attn_metadata(
|
|||||||
slot_mappings: torch.Tensor,
|
slot_mappings: torch.Tensor,
|
||||||
kv_cache_config: KVCacheConfig,
|
kv_cache_config: KVCacheConfig,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
|
max_query_len = int(query_start_loc_cpu.max())
|
||||||
query_start_loc_cpu = query_start_loc.cpu[: num_reqs + 1]
|
|
||||||
max_query_len = int(query_start_loc.np[: num_reqs + 1].max())
|
|
||||||
seq_lens = seq_lens[:num_reqs]
|
seq_lens = seq_lens[:num_reqs]
|
||||||
seq_lens_cpu = torch.from_numpy(seq_lens_np)
|
seq_lens_cpu = torch.from_numpy(seq_lens_np)
|
||||||
max_seq_len = int(seq_lens_np.max())
|
max_seq_len = int(seq_lens_np.max())
|
||||||
|
|||||||
@ -120,7 +120,8 @@ class CudaGraphManager:
|
|||||||
attn_metadata_builders=attn_metadata_builders,
|
attn_metadata_builders=attn_metadata_builders,
|
||||||
num_reqs=batch_size,
|
num_reqs=batch_size,
|
||||||
num_tokens=batch_size,
|
num_tokens=batch_size,
|
||||||
query_start_loc=input_buffers.query_start_loc,
|
query_start_loc_gpu=input_buffers.query_start_loc.gpu[: batch_size + 1],
|
||||||
|
query_start_loc_cpu=input_buffers.query_start_loc.cpu[: batch_size + 1],
|
||||||
seq_lens=input_buffers.seq_lens,
|
seq_lens=input_buffers.seq_lens,
|
||||||
seq_lens_np=np.full(batch_size, self.max_model_len, dtype=np.int32),
|
seq_lens_np=np.full(batch_size, self.max_model_len, dtype=np.int32),
|
||||||
num_computed_tokens_cpu=None, # FIXME
|
num_computed_tokens_cpu=None, # FIXME
|
||||||
|
|||||||
@ -226,11 +226,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
num_computed_tokens = torch.zeros(
|
num_computed_tokens = torch.zeros(
|
||||||
input_batch.num_reqs, dtype=torch.int32, device=self.device
|
input_batch.num_reqs, dtype=torch.int32, device=self.device
|
||||||
)
|
)
|
||||||
|
query_start_loc = self.input_buffers.query_start_loc
|
||||||
|
query_start_loc_gpu = query_start_loc.gpu[: input_batch.num_reqs + 1]
|
||||||
|
query_start_loc_cpu = query_start_loc.cpu[: input_batch.num_reqs + 1]
|
||||||
attn_metadata = build_attn_metadata(
|
attn_metadata = build_attn_metadata(
|
||||||
attn_metadata_builders=self.attn_metadata_builders,
|
attn_metadata_builders=self.attn_metadata_builders,
|
||||||
num_reqs=input_batch.num_reqs,
|
num_reqs=input_batch.num_reqs,
|
||||||
num_tokens=input_batch.num_tokens,
|
num_tokens=input_batch.num_tokens,
|
||||||
query_start_loc=self.input_buffers.query_start_loc,
|
query_start_loc_gpu=query_start_loc_gpu,
|
||||||
|
query_start_loc_cpu=query_start_loc_cpu,
|
||||||
seq_lens=self.input_buffers.seq_lens,
|
seq_lens=self.input_buffers.seq_lens,
|
||||||
seq_lens_np=input_batch.seq_lens_np,
|
seq_lens_np=input_batch.seq_lens_np,
|
||||||
num_computed_tokens_cpu=num_computed_tokens,
|
num_computed_tokens_cpu=num_computed_tokens,
|
||||||
@ -515,6 +519,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.input_buffers.query_start_loc.np[num_reqs + 1 :] = num_tokens
|
self.input_buffers.query_start_loc.np[num_reqs + 1 :] = num_tokens
|
||||||
self.input_buffers.query_start_loc.copy_to_gpu()
|
self.input_buffers.query_start_loc.copy_to_gpu()
|
||||||
query_start_loc_gpu = self.input_buffers.query_start_loc.gpu[: num_reqs + 1]
|
query_start_loc_gpu = self.input_buffers.query_start_loc.gpu[: num_reqs + 1]
|
||||||
|
query_start_loc_cpu = self.input_buffers.query_start_loc.cpu[: num_reqs + 1]
|
||||||
query_start_loc_np = self.input_buffers.query_start_loc.np[: num_reqs + 1]
|
query_start_loc_np = self.input_buffers.query_start_loc.np[: num_reqs + 1]
|
||||||
|
|
||||||
# Copy prefill tokens from CPU to GPU.
|
# Copy prefill tokens from CPU to GPU.
|
||||||
@ -572,7 +577,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
attn_metadata_builders=self.attn_metadata_builders,
|
attn_metadata_builders=self.attn_metadata_builders,
|
||||||
num_reqs=num_reqs,
|
num_reqs=num_reqs,
|
||||||
num_tokens=num_tokens,
|
num_tokens=num_tokens,
|
||||||
query_start_loc=self.input_buffers.query_start_loc,
|
query_start_loc_gpu=query_start_loc_gpu,
|
||||||
|
query_start_loc_cpu=query_start_loc_cpu,
|
||||||
seq_lens=self.input_buffers.seq_lens,
|
seq_lens=self.input_buffers.seq_lens,
|
||||||
seq_lens_np=seq_lens_np,
|
seq_lens_np=seq_lens_np,
|
||||||
num_computed_tokens_cpu=num_computed_tokens,
|
num_computed_tokens_cpu=num_computed_tokens,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user