mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 08:45:01 +08:00
Make _prepare_sample non-blocking and use pinned memory for input buffers (#2207)
This commit is contained in:
parent
ba4f826738
commit
31bff69151
@ -10,6 +10,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
|
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
|
||||||
from vllm.sampling_params import SamplingParams, SamplingType
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||||
|
from vllm.utils import in_wsl
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -52,6 +53,8 @@ class ModelRunner:
|
|||||||
# The shape of the cached block table will be
|
# The shape of the cached block table will be
|
||||||
# (max batch size to capture, max context len to capture / block size).
|
# (max batch size to capture, max context len to capture / block size).
|
||||||
self.graph_block_tables = None # Set after initial profiling.
|
self.graph_block_tables = None # Set after initial profiling.
|
||||||
|
# cache in_wsl result
|
||||||
|
self.in_wsl = in_wsl()
|
||||||
|
|
||||||
def load_model(self) -> None:
|
def load_model(self) -> None:
|
||||||
self.model = get_model(self.model_config)
|
self.model = get_model(self.model_config)
|
||||||
@ -203,24 +206,29 @@ class ModelRunner:
|
|||||||
# When using CUDA graph, we don't need to make the tensors on the GPU
|
# When using CUDA graph, we don't need to make the tensors on the GPU
|
||||||
# because they will be eventually copied to the designated GPU buffer.
|
# because they will be eventually copied to the designated GPU buffer.
|
||||||
device = "cpu" if use_captured_graph else "cuda"
|
device = "cpu" if use_captured_graph else "cuda"
|
||||||
|
pin_memory = use_captured_graph and not self.in_wsl
|
||||||
input_tokens = _make_tensor_with_pad(input_tokens,
|
input_tokens = _make_tensor_with_pad(input_tokens,
|
||||||
max_len=1,
|
max_len=1,
|
||||||
pad=0,
|
pad=0,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=device)
|
device=device,
|
||||||
|
pin_memory=pin_memory)
|
||||||
input_positions = _make_tensor_with_pad(input_positions,
|
input_positions = _make_tensor_with_pad(input_positions,
|
||||||
max_len=1,
|
max_len=1,
|
||||||
pad=0,
|
pad=0,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=device)
|
device=device,
|
||||||
|
pin_memory=pin_memory)
|
||||||
slot_mapping = _make_tensor_with_pad(slot_mapping,
|
slot_mapping = _make_tensor_with_pad(slot_mapping,
|
||||||
max_len=1,
|
max_len=1,
|
||||||
pad=_PAD_SLOT_ID,
|
pad=_PAD_SLOT_ID,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=device)
|
device=device,
|
||||||
|
pin_memory=pin_memory)
|
||||||
context_lens = torch.tensor(context_lens,
|
context_lens = torch.tensor(context_lens,
|
||||||
dtype=torch.int,
|
dtype=torch.int,
|
||||||
device=device)
|
device=device,
|
||||||
|
pin_memory=pin_memory)
|
||||||
|
|
||||||
if use_captured_graph:
|
if use_captured_graph:
|
||||||
# The shape of graph_block_tables is
|
# The shape of graph_block_tables is
|
||||||
@ -229,7 +237,7 @@ class ModelRunner:
|
|||||||
for i, block_table in enumerate(block_tables):
|
for i, block_table in enumerate(block_tables):
|
||||||
if block_table:
|
if block_table:
|
||||||
input_block_tables[i, :len(block_table)] = block_table
|
input_block_tables[i, :len(block_table)] = block_table
|
||||||
block_tables = torch.from_numpy(input_block_tables).to(device)
|
block_tables = torch.tensor(input_block_tables, device=device)
|
||||||
else:
|
else:
|
||||||
block_tables = _make_tensor_with_pad(
|
block_tables = _make_tensor_with_pad(
|
||||||
block_tables,
|
block_tables,
|
||||||
@ -297,11 +305,11 @@ class ModelRunner:
|
|||||||
categorized_sample_indices_start_idx + num_seqs))
|
categorized_sample_indices_start_idx + num_seqs))
|
||||||
categorized_sample_indices_start_idx += num_seqs
|
categorized_sample_indices_start_idx += num_seqs
|
||||||
|
|
||||||
selected_token_indices = torch.tensor(selected_token_indices,
|
selected_token_indices = _async_h2d(selected_token_indices,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device="cuda")
|
pin_memory=not self.in_wsl)
|
||||||
categorized_sample_indices = {
|
categorized_sample_indices = {
|
||||||
t: torch.tensor(seq_ids, dtype=torch.int, device="cuda")
|
t: _async_h2d(seq_ids, dtype=torch.int, pin_memory=not self.in_wsl)
|
||||||
for t, seq_ids in categorized_sample_indices.items()
|
for t, seq_ids in categorized_sample_indices.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -334,8 +342,6 @@ class ModelRunner:
|
|||||||
else:
|
else:
|
||||||
inputs = self._prepare_decode(seq_group_metadata_list)
|
inputs = self._prepare_decode(seq_group_metadata_list)
|
||||||
input_tokens, input_positions, input_metadata = inputs
|
input_tokens, input_positions, input_metadata = inputs
|
||||||
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
|
|
||||||
input_metadata.prompt_lens)
|
|
||||||
|
|
||||||
# Execute the model.
|
# Execute the model.
|
||||||
if input_metadata.use_cuda_graph:
|
if input_metadata.use_cuda_graph:
|
||||||
@ -350,6 +356,9 @@ class ModelRunner:
|
|||||||
input_metadata=input_metadata,
|
input_metadata=input_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
|
||||||
|
input_metadata.prompt_lens)
|
||||||
|
|
||||||
# Sample the next token.
|
# Sample the next token.
|
||||||
output = self.model.sample(
|
output = self.model.sample(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
@ -502,11 +511,14 @@ class CUDAGraphRunner:
|
|||||||
del kv_caches
|
del kv_caches
|
||||||
|
|
||||||
# Copy the input tensors to the input buffers.
|
# Copy the input tensors to the input buffers.
|
||||||
self.input_buffers["input_ids"].copy_(input_ids)
|
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
|
||||||
self.input_buffers["positions"].copy_(positions)
|
self.input_buffers["positions"].copy_(positions, non_blocking=True)
|
||||||
self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping)
|
self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping,
|
||||||
self.input_buffers["context_lens"].copy_(input_metadata.context_lens)
|
non_blocking=True)
|
||||||
self.input_buffers["block_tables"].copy_(input_metadata.block_tables)
|
self.input_buffers["context_lens"].copy_(input_metadata.context_lens,
|
||||||
|
non_blocking=True)
|
||||||
|
self.input_buffers["block_tables"].copy_(input_metadata.block_tables,
|
||||||
|
non_blocking=True)
|
||||||
|
|
||||||
# Run the graph.
|
# Run the graph.
|
||||||
self.graph.replay()
|
self.graph.replay()
|
||||||
@ -529,9 +541,13 @@ def _make_tensor_with_pad(
|
|||||||
pad: int,
|
pad: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: Union[str, torch.device] = "cuda",
|
device: Union[str, torch.device] = "cuda",
|
||||||
|
pin_memory: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
|
padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
|
||||||
return torch.tensor(padded_x, dtype=dtype, device=device)
|
return torch.tensor(padded_x,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
pin_memory=pin_memory and str(device) == "cpu")
|
||||||
|
|
||||||
|
|
||||||
def _get_graph_batch_size(batch_size: int) -> int:
|
def _get_graph_batch_size(batch_size: int) -> int:
|
||||||
@ -541,3 +557,8 @@ def _get_graph_batch_size(batch_size: int) -> int:
|
|||||||
return 4
|
return 4
|
||||||
else:
|
else:
|
||||||
return (batch_size + 7) // 8 * 8
|
return (batch_size + 7) // 8 * 8
|
||||||
|
|
||||||
|
|
||||||
|
def _async_h2d(data: list, dtype, pin_memory):
|
||||||
|
t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory)
|
||||||
|
return t.to(device="cuda", non_blocking=True)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user