mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 13:25:01 +08:00
[MISC] Use non-blocking transfer in prepare_input (#7172)
This commit is contained in:
parent
89b8db6bb2
commit
ef527be06c
@ -13,7 +13,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx,
|
||||
is_block_tables_empty)
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
||||
@ -310,7 +310,8 @@ class FlashAttentionMetadataBuilder(
|
||||
for i, block_table in enumerate(self.block_tables):
|
||||
if block_table:
|
||||
input_block_tables[i, :len(block_table)] = block_table
|
||||
block_tables = torch.tensor(input_block_tables, device=device)
|
||||
block_tables = torch.from_numpy(input_block_tables).to(
|
||||
device=device, non_blocking=True)
|
||||
else:
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.block_tables,
|
||||
@ -320,15 +321,15 @@ class FlashAttentionMetadataBuilder(
|
||||
)
|
||||
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
|
||||
|
||||
context_lens_tensor = torch.tensor(self.context_lens,
|
||||
dtype=torch.int,
|
||||
device=device)
|
||||
seq_lens_tensor = torch.tensor(seq_lens,
|
||||
dtype=torch.int,
|
||||
device=device)
|
||||
query_lens_tensor = torch.tensor(query_lens,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
assert device is not None
|
||||
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
|
||||
device, self.runner.pin_memory)
|
||||
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
||||
self.runner.pin_memory)
|
||||
query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
|
||||
self.runner.pin_memory)
|
||||
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
|
||||
device, self.runner.pin_memory)
|
||||
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
@ -344,10 +345,6 @@ class FlashAttentionMetadataBuilder(
|
||||
dtype=query_start_loc.dtype,
|
||||
out=query_start_loc[1:])
|
||||
|
||||
slot_mapping_tensor = torch.tensor(self.slot_mapping,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
return FlashAttentionMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
|
||||
@ -21,7 +21,8 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx,
|
||||
is_block_tables_empty)
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.utils import get_kv_cache_torch_dtype, make_tensor_with_pad
|
||||
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
|
||||
make_tensor_with_pad)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
||||
@ -356,7 +357,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
for i, block_table in enumerate(self.block_tables):
|
||||
if block_table:
|
||||
input_block_tables[i, :len(block_table)] = block_table
|
||||
block_tables = torch.tensor(input_block_tables, device=device)
|
||||
block_tables = torch.from_numpy(input_block_tables).to(
|
||||
device, non_blocking=True)
|
||||
|
||||
last_paged_kv_indptr = self.paged_kv_indptr[-1]
|
||||
self.paged_kv_indptr.extend([last_paged_kv_indptr] *
|
||||
@ -371,12 +373,13 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
)
|
||||
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
|
||||
|
||||
seq_lens_tensor = torch.tensor(seq_lens,
|
||||
dtype=torch.int,
|
||||
device=device)
|
||||
query_lens_tensor = torch.tensor(query_lens,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
assert device is not None
|
||||
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
||||
self.runner.pin_memory)
|
||||
query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
|
||||
self.runner.pin_memory)
|
||||
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
|
||||
device, self.runner.pin_memory)
|
||||
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
@ -392,10 +395,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
dtype=query_start_loc.dtype,
|
||||
out=query_start_loc[1:])
|
||||
|
||||
slot_mapping_tensor = torch.tensor(self.slot_mapping,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
if len(self.paged_kv_indptr) > 0:
|
||||
paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
|
||||
device="cpu",
|
||||
|
||||
@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Dict, List, Type, TypeVar, Union
|
||||
import torch
|
||||
|
||||
from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
|
||||
# Error string(s) for encoder/decoder
|
||||
# unsupported attention scenarios
|
||||
@ -181,7 +181,8 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
||||
for i, block_table in enumerate(self.block_tables):
|
||||
if block_table:
|
||||
input_block_tables[i, :len(block_table)] = block_table
|
||||
block_tables = torch.tensor(input_block_tables, device=device)
|
||||
block_tables = torch.from_numpy(input_block_tables).to(
|
||||
device, non_blocking=True)
|
||||
else:
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.block_tables,
|
||||
@ -191,15 +192,15 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
||||
)
|
||||
assert max_query_len > 0, "query_lens: {}".format(query_lens)
|
||||
|
||||
context_lens_tensor = torch.tensor(self.context_lens,
|
||||
dtype=torch.int,
|
||||
device=device)
|
||||
seq_lens_tensor = torch.tensor(seq_lens,
|
||||
dtype=torch.int,
|
||||
device=device)
|
||||
query_lens_tensor = torch.tensor(query_lens,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
assert device is not None
|
||||
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
|
||||
device, self.runner.pin_memory)
|
||||
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
||||
self.runner.pin_memory)
|
||||
query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
|
||||
self.runner.pin_memory)
|
||||
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
|
||||
device, self.runner.pin_memory)
|
||||
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
@ -215,10 +216,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
||||
dtype=query_start_loc.dtype,
|
||||
out=query_start_loc[1:])
|
||||
|
||||
slot_mapping_tensor = torch.tensor(self.slot_mapping,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
return self._metadata_cls( # type: ignore
|
||||
num_prefills=self.num_prefills,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
|
||||
@ -50,7 +50,7 @@ from vllm.prompt_adapter.worker_manager import (
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (IntermediateTensors, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.utils import (CudaMemoryProfiler, flatten_2d_lists,
|
||||
from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, flatten_2d_lists,
|
||||
get_kv_cache_torch_dtype, is_hip,
|
||||
is_pin_memory_available)
|
||||
from vllm.worker.model_runner_base import (
|
||||
@ -549,12 +549,13 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
# Tokens and positions.
|
||||
input_tokens.extend([0] * cuda_graph_pad_size)
|
||||
input_positions.extend([0] * cuda_graph_pad_size)
|
||||
input_tokens_tensor = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.runner.device)
|
||||
input_positions_tensor = torch.tensor(input_positions,
|
||||
dtype=torch.long,
|
||||
device=self.runner.device)
|
||||
assert self.runner.device is not None
|
||||
input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long,
|
||||
self.runner.device,
|
||||
self.runner.pin_memory)
|
||||
input_positions_tensor = async_tensor_h2d(input_positions, torch.long,
|
||||
self.runner.device,
|
||||
self.runner.pin_memory)
|
||||
|
||||
# Sequence and query lengths.
|
||||
seq_lens.extend([1] * cuda_graph_pad_size)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user