vllm/vllm/v1/worker/gpu/model_runner.py
Woosuk Kwon d30c0d50a6 refactor
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-19 07:17:53 +00:00

459 lines
17 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
import time
from copy import deepcopy
from typing import Any, Optional
import numpy as np
import torch
from vllm.config import VllmConfig
from vllm.distributed import get_tp_group
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_loader
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, is_pin_memory_available)
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
from vllm.v1.sample.sampler import SamplerOutput
from vllm.v1.worker.gpu.async_utils import AsyncOutput
from vllm.v1.worker.gpu.attn_utils import (get_kv_cache_spec,
init_attn_backend, init_kv_cache)
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.dist_utils import (all_gather_sampler_output,
evenly_split)
from vllm.v1.worker.gpu.input_batch import (InputBatch, InputBuffers,
prepare_inputs)
from vllm.v1.worker.gpu.sampler import Sampler
from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata
logger = init_logger(__name__)
class GPUModelRunner:
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.compilation_config = vllm_config.compilation_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.speculative_config = vllm_config.speculative_config
self.observability_config = vllm_config.observability_config
self.device = device
self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype
self.kv_cache_dtype = self.dtype
if self.cache_config.cache_dtype != "auto":
# Quantized KV cache.
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
self.cache_config.cache_dtype]
self.is_pooling_model = False
self.vocab_size = self.model_config.get_vocab_size()
self.max_model_len = self.model_config.max_model_len
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.use_async_scheduling = self.scheduler_config.async_scheduling
assert self.use_async_scheduling
self.output_copy_stream = torch.cuda.Stream()
self.req_states = RequestState(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
vocab_size=self.vocab_size,
device=self.device,
pin_memory=self.pin_memory,
)
self.input_buffers = InputBuffers(
max_num_reqs=self.max_num_reqs,
max_num_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
)
self.sampler = Sampler()
def get_supported_tasks(self) -> tuple[str]:
return ("generate", )
def load_model(self, *args, **kwargs) -> None:
time_before_load = time.perf_counter()
with DeviceMemoryProfiler() as m:
model_loader = get_model_loader(self.vllm_config.load_config)
logger.info("Loading model from scratch...")
self.model = model_loader.load_model(
vllm_config=self.vllm_config,
model_config=self.vllm_config.model_config,
)
time_after_load = time.perf_counter()
self.model_memory_usage = m.consumed_memory
logger.info("Model loading took %.4f GiB and %.6f seconds",
m.consumed_memory / GiB_bytes,
time_after_load - time_before_load)
def get_kv_cache_spec(self):
return get_kv_cache_spec(self.vllm_config, self.kv_cache_dtype)
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
kv_cache_config = deepcopy(kv_cache_config)
self.kv_cache_config = kv_cache_config
block_sizes = kv_cache_config.block_sizes
self.block_tables = BlockTables(
block_sizes=block_sizes,
max_num_reqs=self.max_num_reqs,
max_num_batched_tokens=self.max_num_tokens,
max_model_len=self.max_model_len,
device=self.device,
pin_memory=self.pin_memory,
)
self.attn_backends, self.attn_metadata_builders = init_attn_backend(
self.kv_cache_config,
self.vllm_config,
self.device,
)
self.kv_caches: list[torch.Tensor] = []
init_kv_cache(
self.kv_caches,
self.compilation_config.static_forward_context,
self.kv_cache_config,
self.attn_backends,
self.device,
)
def _dummy_run(
self,
num_tokens: int,
*args,
input_batch: Optional[InputBatch] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
if input_batch is None:
input_batch = InputBatch.make_dummy(
num_reqs=min(num_tokens, self.max_num_reqs),
num_tokens=num_tokens,
device=self.device,
)
with set_forward_context(
input_batch.attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
):
hidden_states = self.model(
input_ids=input_batch.input_ids[:num_tokens],
positions=input_batch.positions[:num_tokens],
)
sample_hidden_states = hidden_states[input_batch.logits_indices]
return hidden_states, sample_hidden_states
def _dummy_sampler_run(
self,
hidden_states: torch.Tensor,
) -> None:
num_reqs = hidden_states.shape[0]
sampling_metadata = SamplingMetadata.make_dummy(
num_reqs=num_reqs,
device=self.device,
)
logits = self.model.compute_logits(hidden_states, None)
self.sampler(logits, sampling_metadata)
def profile_run(self) -> None:
input_batch = InputBatch.make_dummy(
num_reqs=self.max_num_reqs,
num_tokens=self.max_num_tokens,
device=self.device,
)
hidden_states, sample_hidden_states = self._dummy_run(
self.max_num_tokens,
input_batch=input_batch,
)
self._dummy_sampler_run(sample_hidden_states)
torch.cuda.synchronize()
del hidden_states, sample_hidden_states
gc.collect()
def update_states(self, scheduler_output: SchedulerOutput) -> None:
# for req_id in scheduler_output.preempted_req_ids:
# self.req_states.remove_request(req_id)
for req_id in scheduler_output.finished_req_ids:
self.req_states.remove_request(req_id)
# TODO(woosuk): Change SchedulerOutput.
req_indices: list[int] = []
cu_num_new_blocks = tuple(
[0] for _ in range(self.block_tables.num_kv_cache_groups))
new_block_ids = tuple(
[] for _ in range(self.block_tables.num_kv_cache_groups))
overwrite: list[bool] = []
# Add new requests to the cached states.
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
self.req_states.add_request(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
sampling_params=new_req_data.sampling_params,
)
req_index = self.req_states.req_id_to_index[req_id]
req_indices.append(req_index)
for i, block_ids in enumerate(new_req_data.block_ids):
x = cu_num_new_blocks[i][-1]
cu_num_new_blocks[i].append(x + len(block_ids))
new_block_ids[i].extend(block_ids)
overwrite.append(True)
# Update the states of the running/resumed requests.
cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids):
req_index = self.req_states.req_id_to_index[req_id]
req_new_block_ids = cached_reqs.new_block_ids[i]
if req_new_block_ids is not None:
req_indices.append(req_index)
for group_id, block_ids in enumerate(req_new_block_ids):
x = cu_num_new_blocks[group_id][-1]
cu_num_new_blocks[group_id].append(x + len(block_ids))
new_block_ids[group_id].extend(block_ids)
overwrite.append(False)
self.req_states.num_computed_tokens[req_index] = (
cached_reqs.num_computed_tokens[i])
if req_indices:
self.block_tables.append_block_ids(
req_indices=req_indices,
cu_num_new_blocks=cu_num_new_blocks,
new_block_ids=new_block_ids,
overwrite=overwrite,
)
def prepare_inputs(self, scheduler_output: SchedulerOutput) -> InputBatch:
num_tokens = scheduler_output.total_num_scheduled_tokens
assert num_tokens > 0
num_reqs = len(scheduler_output.num_scheduled_tokens)
# Decode first, then prefill.
# batch_idx -> req_id
req_ids = sorted(scheduler_output.num_scheduled_tokens,
key=scheduler_output.num_scheduled_tokens.get)
num_scheduled_tokens = np.array(
[scheduler_output.num_scheduled_tokens[i] for i in req_ids],
dtype=np.int32)
# TODO(woosuk): Support CUDA graphs.
num_tokens_after_padding = num_tokens
idx_mapping_list = [
self.req_states.req_id_to_index[req_id] for req_id in req_ids
]
idx_mapping = self.input_buffers.idx_mapping
idx_mapping.np[:num_reqs] = idx_mapping_list
idx_mapping_np = idx_mapping.np[:num_reqs]
idx_mapping = idx_mapping.copy_to_gpu(num_reqs)
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables = self.block_tables.gather_block_tables(idx_mapping)
input_ids = self.input_buffers.input_ids
positions = self.input_buffers.positions
query_start_loc = self.input_buffers.query_start_loc
seq_lens = self.input_buffers.seq_lens
prepare_inputs(
idx_mapping_np,
self.req_states.token_ids,
self.req_states.num_computed_tokens,
num_scheduled_tokens,
input_ids.np,
positions.np,
query_start_loc.np,
seq_lens.np,
)
input_ids.copy_to_gpu(num_tokens)
positions.copy_to_gpu(num_tokens)
# NOTE(woosuk): We should copy the whole query_start_loc and seq_lens
# tensors from CPU to GPU, because they may include paddings needed
# for full CUDA graph mode.
query_start_loc.copy_to_gpu()
query_start_loc_cpu = query_start_loc.cpu[:num_reqs + 1]
query_start_loc_gpu = query_start_loc.gpu[:num_reqs + 1]
max_query_len = int(num_scheduled_tokens.max())
seq_lens.copy_to_gpu()
seq_lens_cpu = seq_lens.cpu[:num_reqs]
seq_lens_np = seq_lens.np[:num_reqs]
max_seq_len = int(seq_lens_np.max())
seq_lens_gpu = seq_lens.gpu[:num_reqs]
num_computed_tokens_np = self.req_states.num_computed_tokens[
idx_mapping_np]
num_computed_tokens_cpu = torch.from_numpy(num_computed_tokens_np)
is_chunked_prefilling = (seq_lens_np
< self.req_states.num_tokens[idx_mapping_np])
# Slot mappings: [num_kv_cache_groups, num_tokens]
slot_mappings = self.block_tables.compute_slot_mappings(
query_start_loc_gpu, positions.gpu[:num_tokens])
logits_indices = query_start_loc_gpu[1:] - 1
num_logits_indices = logits_indices.size(0)
# Layer name -> attention metadata.
attn_metadata: dict[str, Any] = {}
for i, kv_cache_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
block_table = block_tables[i]
slot_mapping = slot_mappings[i]
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc_gpu,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens_gpu,
seq_lens_cpu=seq_lens_cpu,
num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
block_table_tensor=block_table,
slot_mapping=slot_mapping,
logits_indices_padded=None,
num_logits_indices=num_logits_indices,
causal=True,
encoder_seq_lens=None,
)
attn_metadata_builder = self.attn_metadata_builders[i]
metadata = attn_metadata_builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
for layer_name in kv_cache_spec.layer_names:
attn_metadata[layer_name] = metadata
return InputBatch(
req_ids=req_ids,
num_reqs=num_reqs,
idx_mapping=idx_mapping,
idx_mapping_np=idx_mapping_np,
num_scheduled_tokens=num_scheduled_tokens,
num_tokens=num_tokens,
num_tokens_after_padding=num_tokens_after_padding,
is_chunked_prefilling=is_chunked_prefilling,
input_ids=input_ids.gpu,
positions=positions.gpu,
attn_metadata=attn_metadata,
logits_indices=logits_indices,
)
def sample(
self,
hidden_states: torch.Tensor,
input_batch: InputBatch,
) -> SamplerOutput:
sample_hidden_states = hidden_states[input_batch.logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
pos = input_batch.positions[input_batch.logits_indices]
idx_mapping_np = input_batch.idx_mapping_np
num_reqs = logits.shape[0]
# When the batch size is large enough, use DP sampler.
tp_group = get_tp_group()
tp_size = tp_group.world_size
n = (num_reqs + tp_size - 1) // tp_size
use_dp_sampler = tp_size > 1 and n > 32 # TODO(woosuk): Tune.
if use_dp_sampler:
# NOTE(woosuk): Make sure that no rank gets zero requests.
tp_rank = tp_group.rank
start, end = evenly_split(num_reqs, tp_size, tp_rank)
logits = logits[start:end]
pos = pos[start:end]
idx_mapping_np = idx_mapping_np[start:end]
sampling_metadata = self.req_states.make_sampling_metadata(
idx_mapping_np, pos)
sampler_output = self.sampler(
logits=logits,
sampling_metadata=sampling_metadata,
)
if use_dp_sampler:
# All-gather the outputs.
sampler_output = all_gather_sampler_output(
sampler_output,
num_reqs,
tp_size,
)
return sampler_output
def postprocess(
self,
sampler_output: SamplerOutput,
input_batch: InputBatch,
) -> AsyncOutput:
# Get the number of sampled tokens.
# 0 if chunked-prefilling, 1 if not.
is_chunked_prefilling = input_batch.is_chunked_prefilling
num_sampled_tokens = (~is_chunked_prefilling).astype(np.int32)
model_runner_output = ModelRunnerOutput(
req_ids=input_batch.req_ids,
sampled_token_ids=None,
num_sampled_tokens=num_sampled_tokens,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
kv_connector_output=None,
num_nans_in_logits=None,
)
return AsyncOutput(
model_runner_output=model_runner_output,
sampler_output=sampler_output,
copy_stream=self.output_copy_stream,
)
def execute_model(
self,
scheduler_output: SchedulerOutput,
) -> AsyncOutput:
self.update_states(scheduler_output)
if scheduler_output.total_num_scheduled_tokens == 0:
return EMPTY_MODEL_RUNNER_OUTPUT
input_batch = self.prepare_inputs(scheduler_output)
num_tokens = input_batch.num_tokens_after_padding
with set_forward_context(
input_batch.attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
):
hidden_states = self.model(
input_ids=input_batch.input_ids[:num_tokens],
positions=input_batch.positions[:num_tokens],
)
sampler_output = self.sample(hidden_states, input_batch)
return self.postprocess(sampler_output, input_batch)