Refactor Worker & InputMetadata (#1843)

This commit is contained in:
Woosuk Kwon 2023-11-29 22:16:37 -08:00 committed by GitHub
parent c782195662
commit 27feead2f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 668 additions and 443 deletions

View File

@ -161,6 +161,12 @@ class ModelConfig:
"must be divisible by pipeline parallel size "
f"({pipeline_parallel_size}).")
def get_sliding_window(self) -> Optional[int]:
return getattr(self.hf_config, "sliding_window", None)
def get_vocab_size(self) -> int:
return self.hf_config.vocab_size
def get_hidden_size(self) -> int:
return self.hf_config.hidden_size

View File

@ -201,9 +201,10 @@ class EngineArgs:
self.dtype, self.seed, self.revision,
self.tokenizer_revision, self.max_model_len,
self.quantization)
cache_config = CacheConfig(
self.block_size, self.gpu_memory_utilization, self.swap_space,
getattr(model_config.hf_config, 'sliding_window', None))
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space,
model_config.get_sliding_window())
parallel_config = ParallelConfig(self.pipeline_parallel_size,
self.tensor_parallel_size,
self.worker_use_ray,

View File

@ -88,8 +88,6 @@ class LLMEngine:
self.model_config = model_config
self.cache_config = cache_config
assert self.cache_config.sliding_window == getattr(
self.model_config.hf_config, "sliding_window", None)
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.log_stats = log_stats

View File

@ -1,9 +1,11 @@
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
__all__ = [
"InputMetadata",
"get_model",
"SamplingMetadata",
"set_random_seed",
]

View File

@ -1,91 +1,42 @@
from typing import Dict, List, Optional, Tuple
from typing import List, Optional
import torch
from xformers.ops import AttentionBias
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SequenceData
class InputMetadata:
"""Metadata for input sequences. Used for PagedAttention.
"""Metadata for input sequences. Used in PagedAttention.
Args:
seq_groups: List of (seq_ids, sampling_params).
seq_data: Seq_id -> SequenceData.
prompt_lens: Lengths of prompts.
slot_mapping: The address to write the new KV to of each token.
context_lens: the length of attention context for each generation token.
max_context_len: The maximum context length.
context_lens: the length of attention context for each sequence.
block_tables: The block tables. (Seq id -> list of physical block)
"""
def __init__(
self,
seq_groups: List[Tuple[List[int], SamplingParams]],
seq_data: Dict[int, SequenceData],
prompt_lens: List[int],
slot_mapping: torch.Tensor,
context_lens: torch.Tensor,
max_context_len: int,
block_tables: torch.Tensor,
selected_token_indices: torch.Tensor,
categorized_sample_indices: Dict[SamplingType, torch.Tensor],
sliding_window: Optional[int] = None,
max_context_len: Optional[int],
context_lens: Optional[torch.Tensor],
block_tables: Optional[torch.Tensor],
) -> None:
self.seq_groups = seq_groups
self.seq_data = seq_data
self.prompt_lens = prompt_lens
self.max_context_len = max_context_len
self.slot_mapping = slot_mapping
self.context_lens = context_lens
self.max_context_len = max_context_len
self.block_tables = block_tables
self.selected_token_indices = selected_token_indices
self.categorized_sample_indices = categorized_sample_indices
self.max_prompt_len = max(prompt_lens) if prompt_lens else 0
self.to_cache = None
if sliding_window is not None:
# We need to keep the positions of sliding windows within
# the key / value tables, this is helpful to know which
# elements we need to cache.
to_cache, start_idx = [], 0
for prompt_len in self.prompt_lens:
to_cache.extend(
range(
start_idx + max(0, prompt_len - sliding_window),
start_idx + prompt_len,
))
start_idx += self.max_prompt_len
to_cache.extend(range(start_idx, slot_mapping.shape[0]))
self.to_cache = torch.tensor(to_cache,
dtype=torch.int32,
device=self.slot_mapping.device)
self.num_prompts = len(prompt_lens)
self.num_prompt_tokens = self.num_prompts * self.max_prompt_len
self.num_generation_tokens = context_lens.shape[0]
if block_tables.numel() > 0:
self.max_num_blocks_per_seq = block_tables.shape[1]
else:
self.max_num_blocks_per_seq = 0
assert block_tables.shape[0] == self.num_generation_tokens
self.is_prompt = len(prompt_lens) > 0
# Set during the execution of the first attention op.
self.attn_bias: Optional[AttentionBias] = None
# FIXME(woosuk): This is a hack.
self.attn_bias = None
def __repr__(self) -> str:
# Print only useful metadata.
return (
f'InputMetadata('
f'num_prompt_tokens={self.num_prompt_tokens}, '
f'num_prompts={self.num_prompts}, '
f'prompt_lens={self.prompt_lens}, '
f'num_generation_tokens={self.num_generation_tokens}, '
f'context_lens={self.context_lens}, '
f'max_context_len={self.max_context_len}), '
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
f'block_tables={self.block_tables}, '
f'selected_token_indices={self.selected_token_indices}, '
f'categorized_sample_indices={self.categorized_sample_indices}, '
f'slot_mapping={self.slot_mapping})')
return ("InputMetadata("
f"prompt_lens={self.prompt_lens}, "
f"max_context_len={self.max_context_len}, "
f"slot_mapping={self.slot_mapping}, "
f"context_lens={self.context_lens}, "
f"block_tables={self.block_tables})")

View File

@ -101,23 +101,15 @@ class PagedAttention(nn.Module):
# vectors will not be cached. This happens during the initial memory
# profiling run.
if key_cache is not None and value_cache is not None:
key_to_cache = key
value_to_cache = value
if input_metadata.to_cache is not None:
key_to_cache = key_to_cache[input_metadata.to_cache]
value_to_cache = value_to_cache[input_metadata.to_cache]
slot_mapping = slot_mapping[input_metadata.to_cache]
cache_ops.reshape_and_cache(
key_to_cache,
value_to_cache,
key,
value,
key_cache,
value_cache,
slot_mapping,
)
is_prompt = len(input_metadata.prompt_lens) > 0
if is_prompt:
if input_metadata.is_prompt:
# Prompt run.
if self.num_kv_heads != self.num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,

View File

@ -4,9 +4,9 @@ from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_gather)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput,
SequenceData, SequenceGroupOutput, SequenceOutput)
@ -37,29 +37,30 @@ class Sampler(nn.Module):
self,
embedding: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
sampling_metadata: SamplingMetadata,
embedding_bias: Optional[torch.Tensor] = None,
) -> SamplerOutput:
# Get the hidden states that we use for sampling.
hidden_states = _prune_hidden_states(hidden_states, input_metadata)
hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
# Get the logits for the next tokens.
logits = _get_logits(hidden_states, embedding, embedding_bias,
self.vocab_size)
# Apply logits processors (if any).
logits = _apply_logits_processors(logits, input_metadata)
logits = _apply_logits_processors(logits, sampling_metadata)
# Apply presence and frequency penalties.
presence_penalties, frequency_penalties, repetition_penalties = (
_get_penalties(input_metadata))
_get_penalties(sampling_metadata))
assert len(presence_penalties) == logits.shape[0]
assert len(frequency_penalties) == logits.shape[0]
assert len(repetition_penalties) == logits.shape[0]
logits = _apply_penalties(logits, input_metadata, presence_penalties,
frequency_penalties, repetition_penalties)
logits = _apply_penalties(logits, sampling_metadata,
presence_penalties, frequency_penalties,
repetition_penalties)
# Apply temperature scaling.
temperatures = _get_temperatures(input_metadata)
temperatures = _get_temperatures(sampling_metadata)
assert len(temperatures) == logits.shape[0]
if any(t != 1.0 for t in temperatures):
t = torch.tensor(temperatures,
@ -70,7 +71,7 @@ class Sampler(nn.Module):
# Apply top-p and top-k truncation.
top_ps, top_ks, min_ps = _get_top_p_top_k_min_p(
input_metadata, self.vocab_size)
sampling_metadata, self.vocab_size)
assert len(top_ps) == len(top_ks) == logits.shape[0]
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
do_top_k = any(k != self.vocab_size for k in top_ks)
@ -89,11 +90,11 @@ class Sampler(nn.Module):
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
# Sample the next tokens.
sample_results = _sample(probs, logprobs, input_metadata)
sample_results = _sample(probs, logprobs, sampling_metadata)
# Get the logprobs query results.
prompt_logprobs, sample_logprobs = _get_logprobs(
logprobs, input_metadata, sample_results)
return _build_sampler_output(sample_results, input_metadata,
logprobs, sampling_metadata, sample_results)
return _build_sampler_output(sample_results, sampling_metadata,
prompt_logprobs, sample_logprobs)
@ -112,29 +113,30 @@ def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
def _prune_hidden_states(
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
return hidden_states.index_select(0, input_metadata.selected_token_indices)
return hidden_states.index_select(0,
sampling_metadata.selected_token_indices)
def _get_penalties(
input_metadata: InputMetadata
sampling_metadata: SamplingMetadata
) -> Tuple[List[float], List[float], List[float]]:
# Collect the presence and frequency penalties.
presence_penalties: List[float] = []
frequency_penalties: List[float] = []
repetition_penalties: List[float] = []
for i, seq_group in enumerate(input_metadata.seq_groups):
for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group
p = sampling_params.presence_penalty
f = sampling_params.frequency_penalty
r = sampling_params.repetition_penalty
if (i < input_metadata.num_prompts
if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
# NOTE: We do not apply presence and frequency penalties for the
# prompt token positions where we don't sample new tokens.
prompt_len = input_metadata.prompt_lens[i]
prompt_len = sampling_metadata.prompt_lens[i]
presence_penalties += [0] * (prompt_len - 1)
frequency_penalties += [0] * (prompt_len - 1)
repetition_penalties += [1] * (prompt_len - 1)
@ -145,21 +147,21 @@ def _get_penalties(
def _get_prompt_and_output_tokens(
input_metadata: InputMetadata
sampling_metadata: SamplingMetadata,
) -> Tuple[List[List[int]], List[List[int]]]:
prompt_tokens: List[List[int]] = []
output_tokens: List[List[int]] = []
for i, seq_group in enumerate(input_metadata.seq_groups):
for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group
if (i < input_metadata.num_prompts
if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
# NOTE: prompt token positions do not need output tokens to
# compute penalties.
prompt_len = input_metadata.prompt_lens[i]
prompt_len = sampling_metadata.prompt_lens[i]
prompt_tokens.extend([] for _ in range(prompt_len - 1))
output_tokens.extend([] for _ in range(prompt_len - 1))
for seq_id in seq_ids:
seq_data = input_metadata.seq_data[seq_id]
seq_data = sampling_metadata.seq_data[seq_id]
prompt_tokens.append(seq_data.prompt_token_ids)
output_tokens.append(seq_data.output_token_ids)
return prompt_tokens, output_tokens
@ -191,17 +193,19 @@ def _get_bin_counts_and_mask(
return bin_counts, mask
def _apply_logits_processors(logits: torch.Tensor,
input_metadata: InputMetadata) -> torch.Tensor:
def _apply_logits_processors(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
logits_row_idx = 0
found_logits_processors = False
for seq_ids, sampling_params in input_metadata.seq_groups:
for seq_ids, sampling_params in sampling_metadata.seq_groups:
logits_processors = sampling_params.logits_processors
if logits_processors:
found_logits_processors = True
for seq_id in seq_ids:
logits_row = logits[logits_row_idx]
token_ids = input_metadata.seq_data[seq_id].output_token_ids
token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
for logits_processor in logits_processors:
logits_row = logits_processor(token_ids, logits_row)
logits[logits_row_idx] = logits_row
@ -215,7 +219,7 @@ def _apply_logits_processors(logits: torch.Tensor,
def _apply_penalties(
logits: torch.Tensor,
input_metadata: InputMetadata,
sampling_metadata: SamplingMetadata,
presence_penalties: List[float],
frequency_penalties: List[float],
repetition_penalties: List[float],
@ -234,7 +238,7 @@ def _apply_penalties(
return logits
prompt_tokens, output_tokens = (
_get_prompt_and_output_tokens(input_metadata))
_get_prompt_and_output_tokens(sampling_metadata))
assert len(prompt_tokens) == logits.shape[0]
assert len(output_tokens) == logits.shape[0]
@ -265,10 +269,10 @@ def _apply_penalties(
return logits
def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
def _get_temperatures(sampling_metadata: SamplingMetadata) -> List[float]:
# Collect the temperatures for the logits.
temperatures: List[float] = []
for i, seq_group in enumerate(input_metadata.seq_groups):
for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group
temperature = sampling_params.temperature
if temperature < _SAMPLING_EPS:
@ -276,22 +280,22 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
# (i.e., greedy sampling or beam search).
# Set the temperature to 1 to avoid division by zero.
temperature = 1.0
if (i < input_metadata.num_prompts
if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
prompt_len = input_metadata.prompt_lens[i]
prompt_len = sampling_metadata.prompt_lens[i]
temperatures += [temperature] * (prompt_len - 1)
temperatures += [temperature] * len(seq_ids)
return temperatures
def _get_top_p_top_k_min_p(
input_metadata: InputMetadata,
sampling_metadata: SamplingMetadata,
vocab_size: int,
) -> Tuple[List[float], List[int], List[float]]:
top_ps: List[float] = []
top_ks: List[int] = []
min_ps: List[float] = []
for i, seq_group in enumerate(input_metadata.seq_groups):
for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group
top_p = sampling_params.top_p
min_p = sampling_params.min_p
@ -299,9 +303,9 @@ def _get_top_p_top_k_min_p(
top_k = min(sampling_params.top_k, vocab_size)
# k=-1 means no truncation.
top_k = vocab_size if top_k == -1 else top_k
if (i < input_metadata.num_prompts
if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
prompt_len = input_metadata.prompt_lens[i]
prompt_len = sampling_metadata.prompt_lens[i]
top_ps += [top_p] * (prompt_len - 1)
top_ks += [top_k] * (prompt_len - 1)
min_ps += [min_p] * (prompt_len - 1)
@ -471,11 +475,11 @@ def _beam_search_sample(
def _sample(
probs: torch.Tensor,
logprobs: torch.Tensor,
input_metadata: InputMetadata,
sampling_metadata: SamplingMetadata,
) -> List[Tuple[List[int], List[int]]]:
categorized_seq_group_ids = {t: [] for t in SamplingType}
categorized_sample_indices = input_metadata.categorized_sample_indices
for i, seq_group in enumerate(input_metadata.seq_groups):
categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups):
_, sampling_params = seq_group
sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i)
@ -483,8 +487,8 @@ def _sample(
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
for sampling_type in SamplingType:
seq_group_ids = categorized_seq_group_ids[sampling_type]
seq_groups = [input_metadata.seq_groups[i] for i in seq_group_ids]
is_prompts = [i < input_metadata.num_prompts for i in seq_group_ids]
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
sample_indices = categorized_sample_indices[sampling_type]
num_tokens = len(sample_indices)
if num_tokens == 0:
@ -499,21 +503,22 @@ def _sample(
elif sampling_type == SamplingType.BEAM:
category_logprobs = logprobs[sample_indices]
sample_results = _beam_search_sample(seq_groups, is_prompts,
input_metadata.seq_data,
sampling_metadata.seq_data,
category_logprobs)
else:
raise ValueError(f"Unsupported sampling type: {sampling_type}")
sample_results_dict.update(zip(seq_group_ids, sample_results))
sample_results = [
sample_results_dict[i] for i in range(len(input_metadata.seq_groups))
sample_results_dict[i]
for i in range(len(sampling_metadata.seq_groups))
]
return sample_results
def _get_logprobs(
logprobs: torch.Tensor,
input_metadata: InputMetadata,
sampling_metadata: SamplingMetadata,
sample_results: List[Tuple[List[int], List[int]]],
) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[
int, float]]]]:
@ -523,16 +528,16 @@ def _get_logprobs(
largest_num_logprobs = 0
sample_idx = 0
for i, (seq_group, sample_result) in enumerate(
zip(input_metadata.seq_groups, sample_results)):
zip(sampling_metadata.seq_groups, sample_results)):
seq_ids, sampling_params = seq_group
next_token_ids, parent_ids = sample_result
num_parent_seqs = len(seq_ids)
if (i < input_metadata.num_prompts
if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
largest_num_logprobs = max(largest_num_logprobs,
sampling_params.prompt_logprobs)
prompt_len = input_metadata.prompt_lens[i]
prompt_tokens = input_metadata.seq_data[
prompt_len = sampling_metadata.prompt_lens[i]
prompt_tokens = sampling_metadata.seq_data[
seq_ids[0]].prompt_token_ids
batched_logprobs_query_seq_indices.extend(
sample_idx + j for j in range(prompt_len - 1))
@ -570,16 +575,16 @@ def _get_logprobs(
sample_idx = 0
query_result_idx = 0
for i, (seq_group, sample_result) in enumerate(
zip(input_metadata.seq_groups, sample_results)):
zip(sampling_metadata.seq_groups, sample_results)):
seq_ids, sampling_params = seq_group
next_token_ids, parent_ids = sample_result
# Prompt logprobs
if (i < input_metadata.num_prompts
if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
num_logprobs = sampling_params.prompt_logprobs
prompt_len = input_metadata.prompt_lens[i]
prompt_tokens = input_metadata.seq_data[
prompt_len = sampling_metadata.prompt_lens[i]
prompt_tokens = sampling_metadata.seq_data[
seq_ids[0]].prompt_token_ids
group_prompt_logprobs: PromptLogprobs = [None]
for token_id in prompt_tokens[1:]:
@ -625,13 +630,13 @@ def _get_logprobs(
def _build_sampler_output(
sample_results: List[Tuple[List[int], List[int]]],
input_metadata: InputMetadata,
sampling_metadata: SamplingMetadata,
prompt_logprobs: List[Optional[PromptLogprobs]],
sample_logprobs: List[SampleLogprobs],
) -> SamplerOutput:
sampler_output = []
for (seq_group, sample_result, group_prompt_logprobs,
group_sample_logprobs) in zip(input_metadata.seq_groups,
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
sample_results, prompt_logprobs,
sample_logprobs):
seq_ids, _ = seq_group

View File

@ -39,6 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@ -296,11 +297,18 @@ class AquilaForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -38,6 +38,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@ -311,11 +312,18 @@ class BaiChuanBaseForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@ -288,11 +289,18 @@ class BloomForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata)
sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -22,6 +22,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@ -350,11 +351,18 @@ class ChatGLMForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata)
sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -41,6 +41,7 @@ from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@ -389,7 +390,7 @@ class FalconForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.transformer(
input_ids,
positions,
@ -397,9 +398,15 @@ class FalconForCausalLM(nn.Module):
input_metadata,
cache_events,
)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@ -232,11 +233,18 @@ class GPT2LMHeadModel(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata)
sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -36,6 +36,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@ -251,11 +252,18 @@ class GPTBigCodeForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata)
sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@ -238,11 +239,18 @@ class GPTJForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata, self.lm_head.bias)
sampling_metadata, self.lm_head.bias)
return next_tokens
def load_weights(self,

View File

@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@ -251,11 +252,18 @@ class GPTNeoXForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.embed_out.weight, hidden_states,
input_metadata)
sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -19,6 +19,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@ -250,11 +251,18 @@ class InternLMForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -41,6 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@ -289,11 +290,18 @@ class LlamaForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -41,6 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@ -285,11 +286,18 @@ class MistralForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -18,6 +18,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@ -256,11 +257,18 @@ class MPTForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata)
sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -36,6 +36,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@ -308,11 +309,18 @@ class OPTForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata)
sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -54,6 +54,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@ -210,28 +211,6 @@ class PhiLayer(nn.Module):
return hidden_states
class PhiCausalLMHead(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.ln = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_epsilon)
self.linear = ParallelLMHead(config.vocab_size,
config.hidden_size,
bias=True)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
):
hidden_states = self.ln(hidden_states)
next_tokens = self.sampler(self.linear.weight, hidden_states,
input_metadata, self.linear.bias)
return next_tokens
class PhiModel(nn.Module):
def __init__(self,
@ -253,7 +232,7 @@ class PhiModel(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.embd(input_ids)
for i in range(self.config.num_hidden_layers):
cache_event = None if cache_events is None else cache_events[i]
@ -268,6 +247,17 @@ class PhiModel(nn.Module):
return hidden_states
class PhiCausalLMHead(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.ln = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_epsilon)
self.linear = ParallelLMHead(config.vocab_size,
config.hidden_size,
bias=True)
class PhiForCausalLM(nn.Module):
def __init__(self,
@ -279,6 +269,7 @@ class PhiForCausalLM(nn.Module):
self.transformer = PhiModel(config, linear_method)
self.lm_head = PhiCausalLMHead(config)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
@ -287,11 +278,21 @@ class PhiForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
lm_logits = self.lm_head(hidden_states, input_metadata)
return lm_logits
hidden_states = self.lm_head.ln(hidden_states)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
head = self.lm_head.linear
next_tokens = self.sampler(head.weight, hidden_states,
sampling_metadata, head.bias)
return next_tokens
def load_weights(self,
model_name_or_path: str,

View File

@ -23,6 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@ -246,11 +247,18 @@ class QWenLMHeadModel(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -41,6 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@ -284,11 +285,18 @@ class YiForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
sampling_metadata)
return next_tokens
def load_weights(self,

View File

@ -0,0 +1,43 @@
from typing import Dict, List, Tuple
import torch
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SequenceData
class SamplingMetadata:
"""Metadata for input sequences. Used in sampler.
Args:
seq_groups: List of (seq_ids, sampling_params).
seq_data: Seq_id -> SequenceData.
prompt_lens: Lengths of prompts.
selected_token_indices: Token indices selected for sampling.
categorized_sample_indices: SamplingType -> token indicies to sample.
"""
def __init__(
self,
seq_groups: List[Tuple[List[int], SamplingParams]],
seq_data: Dict[int, SequenceData],
prompt_lens: List[int],
selected_token_indices: torch.Tensor,
categorized_sample_indices: Dict[SamplingType, torch.Tensor],
) -> None:
self.seq_groups = seq_groups
self.seq_data = seq_data
self.prompt_lens = prompt_lens
self.selected_token_indices = selected_token_indices
self.categorized_sample_indices = categorized_sample_indices
self.num_prompts = len(prompt_lens)
def __repr__(self) -> str:
return (
"SamplingMetadata("
f"seq_groups={self.seq_groups}, "
f"seq_data={self.seq_data}, "
f"prompt_lens={self.prompt_lens}, "
f"selected_token_indices={self.selected_token_indices}, "
f"categorized_sample_indices={self.categorized_sample_indices})")

334
vllm/worker/model_runner.py Normal file
View File

@ -0,0 +1,334 @@
from typing import Dict, List, Optional, Tuple
import torch
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
from vllm.logger import init_logger
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
logger = init_logger(__name__)
_PAD_SLOT_ID = -1
class ModelRunner:
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
):
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.sliding_window = model_config.get_sliding_window()
self.model = None
self.block_size = None # Set after initial profiling.
def load_model(self) -> None:
self.model = get_model(self.model_config)
def set_block_size(self, block_size: int) -> None:
self.block_size = block_size
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = []
prompt_lens: List[int] = []
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
assert len(seq_ids) == 1
seq_id = seq_ids[0]
seq_data = seq_group_metadata.seq_data[seq_id]
prompt_tokens = seq_data.get_token_ids()
prompt_len = len(prompt_tokens)
prompt_lens.append(prompt_len)
input_tokens.append(prompt_tokens)
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions.append(list(range(prompt_len)))
if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping.
slot_mapping.append([_PAD_SLOT_ID] * prompt_len)
continue
# Compute the slot mapping.
slot_mapping.append([])
block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0, prompt_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and
# block size is 4, the first two tokens are masked and the slot
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx = 0
if self.sliding_window is not None:
start_idx = max(0, prompt_len - self.sliding_window)
for i in range(prompt_len):
if i < start_idx:
slot_mapping[-1].append(_PAD_SLOT_ID)
continue
block_number = block_table[i // self.block_size]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping[-1].append(slot)
max_prompt_len = max(prompt_lens)
input_tokens = _make_tensor_with_pad(input_tokens,
max_prompt_len,
pad=0,
dtype=torch.long)
input_positions = _make_tensor_with_pad(input_positions,
max_prompt_len,
pad=0,
dtype=torch.long)
slot_mapping = _make_tensor_with_pad(slot_mapping,
max_prompt_len,
pad=_PAD_SLOT_ID,
dtype=torch.long)
input_metadata = InputMetadata(
prompt_lens=prompt_lens,
slot_mapping=slot_mapping,
max_context_len=None,
context_lens=None,
block_tables=None,
)
return input_tokens, input_positions, input_metadata
def _prepare_decode(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = []
context_lens: List[int] = []
block_tables: List[List[int]] = []
for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id()
input_tokens.append([generation_token])
context_len = seq_data.get_len()
if self.sliding_window is not None:
context_len = min(context_len, self.sliding_window)
context_lens.append(context_len)
position = context_len - 1
input_positions.append([position])
block_table = seq_group_metadata.block_tables[seq_id]
block_number = block_table[position // self.block_size]
block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append([slot])
if self.sliding_window is not None:
sliding_window_blocks = (self.sliding_window //
self.block_size)
block_table = block_table[-sliding_window_blocks:]
block_tables.append(block_table)
input_tokens = _make_tensor_with_pad(input_tokens,
max_len=1,
pad=0,
dtype=torch.long)
input_positions = _make_tensor_with_pad(input_positions,
max_len=1,
pad=0,
dtype=torch.long)
slot_mapping = _make_tensor_with_pad(slot_mapping,
max_len=1,
pad=_PAD_SLOT_ID,
dtype=torch.long)
max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens,
dtype=torch.int,
device="cuda")
max_block_table_len = max([len(t) for t in block_tables])
block_tables = _make_tensor_with_pad(block_tables,
max_len=max_block_table_len,
pad=0,
dtype=torch.int)
input_metadata = InputMetadata(
prompt_lens=[],
slot_mapping=slot_mapping,
max_context_len=max_context_len,
context_lens=context_lens,
block_tables=block_tables,
)
return input_tokens, input_positions, input_metadata
def _prepare_sample(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int],
) -> SamplingMetadata:
seq_groups: List[Tuple[List[int], SamplingParams]] = []
selected_token_indices: List[int] = []
selected_token_start_idx = 0
categorized_sample_indices = {t: [] for t in SamplingType}
categorized_sample_indices_start_idx = 0
max_prompt_len = max(prompt_lens) if prompt_lens else 1
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
seq_groups.append((seq_ids, sampling_params))
if seq_group_metadata.is_prompt:
assert len(seq_ids) == 1
prompt_len = prompt_lens[i]
if sampling_params.prompt_logprobs is not None:
# NOTE: prompt token positions do not need sample, skip
categorized_sample_indices_start_idx += prompt_len - 1
categorized_sample_indices[
sampling_params.sampling_type].append(
categorized_sample_indices_start_idx)
categorized_sample_indices_start_idx += 1
if sampling_params.prompt_logprobs is not None:
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + prompt_len - 1))
selected_token_indices.append(selected_token_start_idx +
prompt_len - 1)
selected_token_start_idx += max_prompt_len
else:
num_seqs = len(seq_ids)
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + num_seqs))
selected_token_start_idx += num_seqs
categorized_sample_indices[
sampling_params.sampling_type].extend(
range(categorized_sample_indices_start_idx,
categorized_sample_indices_start_idx + num_seqs))
categorized_sample_indices_start_idx += num_seqs
selected_token_indices = torch.tensor(selected_token_indices,
dtype=torch.long,
device="cuda")
categorized_sample_indices = {
t: torch.tensor(seq_ids, dtype=torch.int, device="cuda")
for t, seq_ids in categorized_sample_indices.items()
}
seq_data: Dict[int, SequenceData] = {}
for seq_group_metadata in seq_group_metadata_list:
seq_data.update(seq_group_metadata.seq_data)
sampling_metadata = SamplingMetadata(
seq_groups=seq_groups,
seq_data=seq_data,
prompt_lens=prompt_lens,
selected_token_indices=selected_token_indices,
categorized_sample_indices=categorized_sample_indices,
)
return sampling_metadata
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
cache_events: Optional[List[torch.cuda.Event]] = None,
) -> SamplerOutput:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
# Prepare input tensors.
is_prompt = seq_group_metadata_list[0].is_prompt
if is_prompt:
inputs = self._prepare_prompt(seq_group_metadata_list)
input_tokens, input_positions, input_metadata = inputs
else:
inputs = self._prepare_decode(seq_group_metadata_list)
input_tokens, input_positions, input_metadata = inputs
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
input_metadata.prompt_lens)
# Execute the model.
hidden_states = self.model(
input_ids=input_tokens,
positions=input_positions,
kv_caches=kv_caches,
input_metadata=input_metadata,
cache_events=cache_events,
)
# Sample the next token.
output = self.model.sample(
hidden_states=hidden_states,
sampling_metadata=sampling_metadata,
)
return output
@torch.inference_mode()
def profile_run(self) -> None:
# Enable top-k sampling to reflect the accurate memory usage.
vocab_size = self.model_config.get_vocab_size()
sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs
# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
seqs: List[SequenceGroupMetadata] = []
for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs))
seq_data = SequenceData([0] * seq_len)
seq = SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=True,
seq_data={group_id: seq_data},
sampling_params=sampling_params,
block_tables=None,
)
seqs.append(seq)
# Run the model with the dummy inputs.
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [(None, None)] * num_layers
self.execute_model(seqs, kv_caches)
return
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
assert len(x) <= max_len
return x + [pad] * (max_len - len(x))
def _make_tensor_with_pad(
x: List[List[int]],
max_len: int,
pad: int,
dtype: torch.dtype,
) -> torch.Tensor:
padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
return torch.tensor(padded_x, dtype=dtype, device="cuda")

View File

@ -1,18 +1,18 @@
"""A GPU worker class."""
import os
from typing import Dict, List, Tuple, Optional
from typing import Dict, List, Optional, Tuple
import torch
import torch.distributed
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
from vllm.model_executor import get_model, InputMetadata, set_random_seed
from vllm.model_executor import set_random_seed
from vllm.model_executor.parallel_utils.parallel_state import (
initialize_model_parallel)
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner
from vllm.utils import get_gpu_memory
@ -38,11 +38,11 @@ class Worker:
self.rank = rank
self.distributed_init_method = distributed_init_method
self.model_runner = ModelRunner(model_config, parallel_config,
scheduler_config)
# Uninitialized cache engine. Will be initialized by
# self.init_cache_engine().
self.cache_config = None
self.block_size = None
self.sliding_window = None
self.cache_engine = None
self.cache_events = None
self.gpu_cache = None
@ -69,7 +69,7 @@ class Worker:
set_random_seed(self.model_config.seed)
def load_model(self):
self.model = get_model(self.model_config)
self.model_runner.load_model()
@torch.inference_mode()
def profile_num_available_blocks(
@ -83,40 +83,9 @@ class Worker:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
# Enable top-k sampling to reflect the accurate memory usage.
vocab_size = self.model.config.vocab_size
sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs
seqs = []
for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs))
seq_data = SequenceData([0] * seq_len)
seq = SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=True,
seq_data={group_id: seq_data},
sampling_params=sampling_params,
block_tables=None,
)
seqs.append(seq)
input_tokens, input_positions, input_metadata = self._prepare_inputs(
seqs)
# Execute the model.
num_layers = self.model_config.get_num_layers(self.parallel_config)
self.model(
input_ids=input_tokens,
positions=input_positions,
kv_caches=[(None, None)] * num_layers,
input_metadata=input_metadata,
cache_events=None,
)
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
self.model_runner.profile_run()
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
@ -140,197 +109,11 @@ class Worker:
def init_cache_engine(self, cache_config: CacheConfig) -> None:
self.cache_config = cache_config
self.block_size = cache_config.block_size
self.sliding_window = cache_config.sliding_window
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
self.parallel_config)
self.cache_events = self.cache_engine.events
self.gpu_cache = self.cache_engine.gpu_cache
def _prepare_inputs(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
seq_groups: List[Tuple[List[int], SamplingParams]] = []
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = []
selected_token_indices: List[int] = []
selected_token_start_idx = 0
categorized_sample_indices = {t: [] for t in SamplingType}
categorized_sample_indices_start_idx = 0
# Add prompt tokens.
prompt_lens: List[int] = []
for seq_group_metadata in seq_group_metadata_list:
if not seq_group_metadata.is_prompt:
continue
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
seq_groups.append((seq_ids, sampling_params))
# Use any sequence in the group.
seq_id = seq_ids[0]
seq_data = seq_group_metadata.seq_data[seq_id]
prompt_tokens = seq_data.get_token_ids()
prompt_len = len(prompt_tokens)
prompt_lens.append(prompt_len)
if sampling_params.prompt_logprobs is not None:
# NOTE: prompt token positions do not need sample, skip
categorized_sample_indices_start_idx += prompt_len - 1
categorized_sample_indices[sampling_params.sampling_type].append(
categorized_sample_indices_start_idx)
categorized_sample_indices_start_idx += 1
input_tokens.append(prompt_tokens)
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions.append(list(range(prompt_len)))
if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping.
slot_mapping.append([0] * prompt_len)
continue
# Compute the slot mapping.
slot_mapping.append([])
block_table = seq_group_metadata.block_tables[seq_id]
for i in range(prompt_len):
block_number = block_table[i // self.block_size]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping[-1].append(slot)
# Add generation tokens.
max_context_len = 0
max_num_blocks_per_seq = 0
context_lens: List[int] = []
generation_block_tables: List[List[int]] = []
max_seq_len = max(prompt_lens) if prompt_lens else 1
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
if seq_group_metadata.is_prompt:
# We need to do this in this loop as we need to know max_seq_len
assert len(
seq_ids) == 1, "Prompt input should have only one seq."
sampling_params = seq_group_metadata.sampling_params
assert len(prompt_lens) == len(seq_group_metadata_list)
prompt_len = prompt_lens[i]
if sampling_params.prompt_logprobs is not None:
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + prompt_len - 1))
selected_token_indices.append(selected_token_start_idx +
prompt_len - 1)
selected_token_start_idx += max_seq_len
continue
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
seq_groups.append((seq_ids, sampling_params))
num_seqs = len(seq_ids)
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + num_seqs))
selected_token_start_idx += num_seqs
categorized_sample_indices[sampling_params.sampling_type].extend(
range(categorized_sample_indices_start_idx,
categorized_sample_indices_start_idx + num_seqs))
categorized_sample_indices_start_idx += num_seqs
for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id()
input_tokens.append([generation_token])
context_len = seq_data.get_len()
position = context_len - 1
if self.sliding_window is not None:
context_len = min(context_len, self.sliding_window)
input_positions.append([position])
block_table = seq_group_metadata.block_tables[seq_id]
max_context_len = max(max_context_len, context_len)
max_num_blocks_per_seq = max(max_num_blocks_per_seq,
len(block_table))
context_lens.append(context_len)
block_number = block_table[position // self.block_size]
block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append([slot])
if self.sliding_window is not None:
sliding_window_blocks = (self.sliding_window //
self.block_size)
block_table = block_table[-sliding_window_blocks:]
generation_block_tables.append(block_table)
padded_input_tokens = [
_pad_to_max(tokens, max_seq_len, pad=0) for tokens in input_tokens
]
padded_input_positions = [
_pad_to_max(positions, max_seq_len, pad=0)
for positions in input_positions
]
padded_slot_mapping = [
_pad_to_max(mapping, max_seq_len, pad=-1)
for mapping in slot_mapping
]
padded_block_tables = [
_pad_to_max(block_table, max_num_blocks_per_seq, pad=0)
for block_table in generation_block_tables
]
# Convert to tensors.
tokens_tensor = torch.tensor(padded_input_tokens,
dtype=torch.long,
device="cuda")
positions_tensor = torch.tensor(padded_input_positions,
dtype=torch.long,
device="cuda")
slot_mapping_tensor = torch.tensor(padded_slot_mapping,
dtype=torch.long,
device="cuda")
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device="cuda")
selected_token_indices = torch.tensor(selected_token_indices,
dtype=torch.long,
device="cuda")
categorized_sample_indices = {
t: torch.tensor(seq_ids, dtype=torch.int, device="cuda")
for t, seq_ids in categorized_sample_indices.items()
}
block_tables_tensor = torch.tensor(padded_block_tables,
dtype=torch.int,
device="cuda")
seq_data: Dict[int, SequenceData] = {}
for seq_group_metadata in seq_group_metadata_list:
seq_data.update(seq_group_metadata.seq_data)
input_metadata = InputMetadata(
seq_groups=seq_groups,
seq_data=seq_data,
prompt_lens=prompt_lens,
slot_mapping=slot_mapping_tensor,
context_lens=context_lens_tensor,
max_context_len=max_context_len,
block_tables=block_tables_tensor,
selected_token_indices=selected_token_indices,
categorized_sample_indices=categorized_sample_indices,
sliding_window=self.sliding_window,
)
return tokens_tensor, positions_tensor, input_metadata
self.model_runner.set_block_size(self.cache_engine.block_size)
@torch.inference_mode()
def execute_model(
@ -361,18 +144,8 @@ class Worker:
event.wait()
return {}
# Prepare input tensors.
input_tokens, input_positions, input_metadata = self._prepare_inputs(
seq_group_metadata_list)
# Execute the model.
output = self.model(
input_ids=input_tokens,
positions=input_positions,
kv_caches=self.gpu_cache,
input_metadata=input_metadata,
cache_events=cache_events,
)
output = self.model_runner.execute_model(seq_group_metadata_list,
self.gpu_cache, cache_events)
return output
@ -407,14 +180,6 @@ def _init_distributed_environment(
parallel_config.pipeline_parallel_size)
def _pad_to_alignment(x: List[int], multiple_of: int, pad: int) -> List[int]:
return x + [pad] * ((-len(x)) % multiple_of)
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
return x + [pad] * (max_len - len(x))
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype.
if torch_dtype == torch.bfloat16: