mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 10:45:45 +08:00
Refactor Worker & InputMetadata (#1843)
This commit is contained in:
parent
c782195662
commit
27feead2f8
@ -161,6 +161,12 @@ class ModelConfig:
|
||||
"must be divisible by pipeline parallel size "
|
||||
f"({pipeline_parallel_size}).")
|
||||
|
||||
def get_sliding_window(self) -> Optional[int]:
|
||||
return getattr(self.hf_config, "sliding_window", None)
|
||||
|
||||
def get_vocab_size(self) -> int:
|
||||
return self.hf_config.vocab_size
|
||||
|
||||
def get_hidden_size(self) -> int:
|
||||
return self.hf_config.hidden_size
|
||||
|
||||
|
||||
@ -201,9 +201,10 @@ class EngineArgs:
|
||||
self.dtype, self.seed, self.revision,
|
||||
self.tokenizer_revision, self.max_model_len,
|
||||
self.quantization)
|
||||
cache_config = CacheConfig(
|
||||
self.block_size, self.gpu_memory_utilization, self.swap_space,
|
||||
getattr(model_config.hf_config, 'sliding_window', None))
|
||||
cache_config = CacheConfig(self.block_size,
|
||||
self.gpu_memory_utilization,
|
||||
self.swap_space,
|
||||
model_config.get_sliding_window())
|
||||
parallel_config = ParallelConfig(self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size,
|
||||
self.worker_use_ray,
|
||||
|
||||
@ -88,8 +88,6 @@ class LLMEngine:
|
||||
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
assert self.cache_config.sliding_window == getattr(
|
||||
self.model_config.hf_config, "sliding_window", None)
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.log_stats = log_stats
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
|
||||
__all__ = [
|
||||
"InputMetadata",
|
||||
"get_model",
|
||||
"SamplingMetadata",
|
||||
"set_random_seed",
|
||||
]
|
||||
|
||||
@ -1,91 +1,42 @@
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from xformers.ops import AttentionBias
|
||||
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
|
||||
class InputMetadata:
|
||||
"""Metadata for input sequences. Used for PagedAttention.
|
||||
"""Metadata for input sequences. Used in PagedAttention.
|
||||
|
||||
Args:
|
||||
seq_groups: List of (seq_ids, sampling_params).
|
||||
seq_data: Seq_id -> SequenceData.
|
||||
prompt_lens: Lengths of prompts.
|
||||
slot_mapping: The address to write the new KV to of each token.
|
||||
context_lens: the length of attention context for each generation token.
|
||||
max_context_len: The maximum context length.
|
||||
context_lens: the length of attention context for each sequence.
|
||||
block_tables: The block tables. (Seq id -> list of physical block)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
seq_groups: List[Tuple[List[int], SamplingParams]],
|
||||
seq_data: Dict[int, SequenceData],
|
||||
prompt_lens: List[int],
|
||||
slot_mapping: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
max_context_len: int,
|
||||
block_tables: torch.Tensor,
|
||||
selected_token_indices: torch.Tensor,
|
||||
categorized_sample_indices: Dict[SamplingType, torch.Tensor],
|
||||
sliding_window: Optional[int] = None,
|
||||
max_context_len: Optional[int],
|
||||
context_lens: Optional[torch.Tensor],
|
||||
block_tables: Optional[torch.Tensor],
|
||||
) -> None:
|
||||
self.seq_groups = seq_groups
|
||||
self.seq_data = seq_data
|
||||
self.prompt_lens = prompt_lens
|
||||
self.max_context_len = max_context_len
|
||||
self.slot_mapping = slot_mapping
|
||||
self.context_lens = context_lens
|
||||
self.max_context_len = max_context_len
|
||||
self.block_tables = block_tables
|
||||
self.selected_token_indices = selected_token_indices
|
||||
self.categorized_sample_indices = categorized_sample_indices
|
||||
|
||||
self.max_prompt_len = max(prompt_lens) if prompt_lens else 0
|
||||
self.to_cache = None
|
||||
if sliding_window is not None:
|
||||
# We need to keep the positions of sliding windows within
|
||||
# the key / value tables, this is helpful to know which
|
||||
# elements we need to cache.
|
||||
to_cache, start_idx = [], 0
|
||||
for prompt_len in self.prompt_lens:
|
||||
to_cache.extend(
|
||||
range(
|
||||
start_idx + max(0, prompt_len - sliding_window),
|
||||
start_idx + prompt_len,
|
||||
))
|
||||
start_idx += self.max_prompt_len
|
||||
to_cache.extend(range(start_idx, slot_mapping.shape[0]))
|
||||
self.to_cache = torch.tensor(to_cache,
|
||||
dtype=torch.int32,
|
||||
device=self.slot_mapping.device)
|
||||
|
||||
self.num_prompts = len(prompt_lens)
|
||||
self.num_prompt_tokens = self.num_prompts * self.max_prompt_len
|
||||
self.num_generation_tokens = context_lens.shape[0]
|
||||
if block_tables.numel() > 0:
|
||||
self.max_num_blocks_per_seq = block_tables.shape[1]
|
||||
else:
|
||||
self.max_num_blocks_per_seq = 0
|
||||
assert block_tables.shape[0] == self.num_generation_tokens
|
||||
|
||||
self.is_prompt = len(prompt_lens) > 0
|
||||
# Set during the execution of the first attention op.
|
||||
self.attn_bias: Optional[AttentionBias] = None
|
||||
# FIXME(woosuk): This is a hack.
|
||||
self.attn_bias = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# Print only useful metadata.
|
||||
return (
|
||||
f'InputMetadata('
|
||||
f'num_prompt_tokens={self.num_prompt_tokens}, '
|
||||
f'num_prompts={self.num_prompts}, '
|
||||
f'prompt_lens={self.prompt_lens}, '
|
||||
f'num_generation_tokens={self.num_generation_tokens}, '
|
||||
f'context_lens={self.context_lens}, '
|
||||
f'max_context_len={self.max_context_len}), '
|
||||
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
|
||||
f'block_tables={self.block_tables}, '
|
||||
f'selected_token_indices={self.selected_token_indices}, '
|
||||
f'categorized_sample_indices={self.categorized_sample_indices}, '
|
||||
f'slot_mapping={self.slot_mapping})')
|
||||
return ("InputMetadata("
|
||||
f"prompt_lens={self.prompt_lens}, "
|
||||
f"max_context_len={self.max_context_len}, "
|
||||
f"slot_mapping={self.slot_mapping}, "
|
||||
f"context_lens={self.context_lens}, "
|
||||
f"block_tables={self.block_tables})")
|
||||
|
||||
@ -101,23 +101,15 @@ class PagedAttention(nn.Module):
|
||||
# vectors will not be cached. This happens during the initial memory
|
||||
# profiling run.
|
||||
if key_cache is not None and value_cache is not None:
|
||||
key_to_cache = key
|
||||
value_to_cache = value
|
||||
if input_metadata.to_cache is not None:
|
||||
key_to_cache = key_to_cache[input_metadata.to_cache]
|
||||
value_to_cache = value_to_cache[input_metadata.to_cache]
|
||||
slot_mapping = slot_mapping[input_metadata.to_cache]
|
||||
|
||||
cache_ops.reshape_and_cache(
|
||||
key_to_cache,
|
||||
value_to_cache,
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
)
|
||||
|
||||
is_prompt = len(input_metadata.prompt_lens) > 0
|
||||
if is_prompt:
|
||||
if input_metadata.is_prompt:
|
||||
# Prompt run.
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
|
||||
|
||||
@ -4,9 +4,9 @@ from typing import Dict, List, Optional, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
tensor_model_parallel_all_gather)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput,
|
||||
SequenceData, SequenceGroupOutput, SequenceOutput)
|
||||
@ -37,29 +37,30 @@ class Sampler(nn.Module):
|
||||
self,
|
||||
embedding: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
embedding_bias: Optional[torch.Tensor] = None,
|
||||
) -> SamplerOutput:
|
||||
# Get the hidden states that we use for sampling.
|
||||
hidden_states = _prune_hidden_states(hidden_states, input_metadata)
|
||||
hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
|
||||
|
||||
# Get the logits for the next tokens.
|
||||
logits = _get_logits(hidden_states, embedding, embedding_bias,
|
||||
self.vocab_size)
|
||||
|
||||
# Apply logits processors (if any).
|
||||
logits = _apply_logits_processors(logits, input_metadata)
|
||||
logits = _apply_logits_processors(logits, sampling_metadata)
|
||||
# Apply presence and frequency penalties.
|
||||
presence_penalties, frequency_penalties, repetition_penalties = (
|
||||
_get_penalties(input_metadata))
|
||||
_get_penalties(sampling_metadata))
|
||||
assert len(presence_penalties) == logits.shape[0]
|
||||
assert len(frequency_penalties) == logits.shape[0]
|
||||
assert len(repetition_penalties) == logits.shape[0]
|
||||
logits = _apply_penalties(logits, input_metadata, presence_penalties,
|
||||
frequency_penalties, repetition_penalties)
|
||||
logits = _apply_penalties(logits, sampling_metadata,
|
||||
presence_penalties, frequency_penalties,
|
||||
repetition_penalties)
|
||||
|
||||
# Apply temperature scaling.
|
||||
temperatures = _get_temperatures(input_metadata)
|
||||
temperatures = _get_temperatures(sampling_metadata)
|
||||
assert len(temperatures) == logits.shape[0]
|
||||
if any(t != 1.0 for t in temperatures):
|
||||
t = torch.tensor(temperatures,
|
||||
@ -70,7 +71,7 @@ class Sampler(nn.Module):
|
||||
|
||||
# Apply top-p and top-k truncation.
|
||||
top_ps, top_ks, min_ps = _get_top_p_top_k_min_p(
|
||||
input_metadata, self.vocab_size)
|
||||
sampling_metadata, self.vocab_size)
|
||||
assert len(top_ps) == len(top_ks) == logits.shape[0]
|
||||
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
|
||||
do_top_k = any(k != self.vocab_size for k in top_ks)
|
||||
@ -89,11 +90,11 @@ class Sampler(nn.Module):
|
||||
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
||||
|
||||
# Sample the next tokens.
|
||||
sample_results = _sample(probs, logprobs, input_metadata)
|
||||
sample_results = _sample(probs, logprobs, sampling_metadata)
|
||||
# Get the logprobs query results.
|
||||
prompt_logprobs, sample_logprobs = _get_logprobs(
|
||||
logprobs, input_metadata, sample_results)
|
||||
return _build_sampler_output(sample_results, input_metadata,
|
||||
logprobs, sampling_metadata, sample_results)
|
||||
return _build_sampler_output(sample_results, sampling_metadata,
|
||||
prompt_logprobs, sample_logprobs)
|
||||
|
||||
|
||||
@ -112,29 +113,30 @@ def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
|
||||
|
||||
def _prune_hidden_states(
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
return hidden_states.index_select(0, input_metadata.selected_token_indices)
|
||||
return hidden_states.index_select(0,
|
||||
sampling_metadata.selected_token_indices)
|
||||
|
||||
|
||||
def _get_penalties(
|
||||
input_metadata: InputMetadata
|
||||
sampling_metadata: SamplingMetadata
|
||||
) -> Tuple[List[float], List[float], List[float]]:
|
||||
# Collect the presence and frequency penalties.
|
||||
presence_penalties: List[float] = []
|
||||
frequency_penalties: List[float] = []
|
||||
repetition_penalties: List[float] = []
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||
seq_ids, sampling_params = seq_group
|
||||
p = sampling_params.presence_penalty
|
||||
f = sampling_params.frequency_penalty
|
||||
r = sampling_params.repetition_penalty
|
||||
if (i < input_metadata.num_prompts
|
||||
if (i < sampling_metadata.num_prompts
|
||||
and sampling_params.prompt_logprobs is not None):
|
||||
# NOTE: We do not apply presence and frequency penalties for the
|
||||
# prompt token positions where we don't sample new tokens.
|
||||
prompt_len = input_metadata.prompt_lens[i]
|
||||
prompt_len = sampling_metadata.prompt_lens[i]
|
||||
presence_penalties += [0] * (prompt_len - 1)
|
||||
frequency_penalties += [0] * (prompt_len - 1)
|
||||
repetition_penalties += [1] * (prompt_len - 1)
|
||||
@ -145,21 +147,21 @@ def _get_penalties(
|
||||
|
||||
|
||||
def _get_prompt_and_output_tokens(
|
||||
input_metadata: InputMetadata
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Tuple[List[List[int]], List[List[int]]]:
|
||||
prompt_tokens: List[List[int]] = []
|
||||
output_tokens: List[List[int]] = []
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||
seq_ids, sampling_params = seq_group
|
||||
if (i < input_metadata.num_prompts
|
||||
if (i < sampling_metadata.num_prompts
|
||||
and sampling_params.prompt_logprobs is not None):
|
||||
# NOTE: prompt token positions do not need output tokens to
|
||||
# compute penalties.
|
||||
prompt_len = input_metadata.prompt_lens[i]
|
||||
prompt_len = sampling_metadata.prompt_lens[i]
|
||||
prompt_tokens.extend([] for _ in range(prompt_len - 1))
|
||||
output_tokens.extend([] for _ in range(prompt_len - 1))
|
||||
for seq_id in seq_ids:
|
||||
seq_data = input_metadata.seq_data[seq_id]
|
||||
seq_data = sampling_metadata.seq_data[seq_id]
|
||||
prompt_tokens.append(seq_data.prompt_token_ids)
|
||||
output_tokens.append(seq_data.output_token_ids)
|
||||
return prompt_tokens, output_tokens
|
||||
@ -191,17 +193,19 @@ def _get_bin_counts_and_mask(
|
||||
return bin_counts, mask
|
||||
|
||||
|
||||
def _apply_logits_processors(logits: torch.Tensor,
|
||||
input_metadata: InputMetadata) -> torch.Tensor:
|
||||
def _apply_logits_processors(
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
logits_row_idx = 0
|
||||
found_logits_processors = False
|
||||
for seq_ids, sampling_params in input_metadata.seq_groups:
|
||||
for seq_ids, sampling_params in sampling_metadata.seq_groups:
|
||||
logits_processors = sampling_params.logits_processors
|
||||
if logits_processors:
|
||||
found_logits_processors = True
|
||||
for seq_id in seq_ids:
|
||||
logits_row = logits[logits_row_idx]
|
||||
token_ids = input_metadata.seq_data[seq_id].output_token_ids
|
||||
token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
|
||||
for logits_processor in logits_processors:
|
||||
logits_row = logits_processor(token_ids, logits_row)
|
||||
logits[logits_row_idx] = logits_row
|
||||
@ -215,7 +219,7 @@ def _apply_logits_processors(logits: torch.Tensor,
|
||||
|
||||
def _apply_penalties(
|
||||
logits: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
presence_penalties: List[float],
|
||||
frequency_penalties: List[float],
|
||||
repetition_penalties: List[float],
|
||||
@ -234,7 +238,7 @@ def _apply_penalties(
|
||||
return logits
|
||||
|
||||
prompt_tokens, output_tokens = (
|
||||
_get_prompt_and_output_tokens(input_metadata))
|
||||
_get_prompt_and_output_tokens(sampling_metadata))
|
||||
assert len(prompt_tokens) == logits.shape[0]
|
||||
assert len(output_tokens) == logits.shape[0]
|
||||
|
||||
@ -265,10 +269,10 @@ def _apply_penalties(
|
||||
return logits
|
||||
|
||||
|
||||
def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
|
||||
def _get_temperatures(sampling_metadata: SamplingMetadata) -> List[float]:
|
||||
# Collect the temperatures for the logits.
|
||||
temperatures: List[float] = []
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||
seq_ids, sampling_params = seq_group
|
||||
temperature = sampling_params.temperature
|
||||
if temperature < _SAMPLING_EPS:
|
||||
@ -276,22 +280,22 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
|
||||
# (i.e., greedy sampling or beam search).
|
||||
# Set the temperature to 1 to avoid division by zero.
|
||||
temperature = 1.0
|
||||
if (i < input_metadata.num_prompts
|
||||
if (i < sampling_metadata.num_prompts
|
||||
and sampling_params.prompt_logprobs is not None):
|
||||
prompt_len = input_metadata.prompt_lens[i]
|
||||
prompt_len = sampling_metadata.prompt_lens[i]
|
||||
temperatures += [temperature] * (prompt_len - 1)
|
||||
temperatures += [temperature] * len(seq_ids)
|
||||
return temperatures
|
||||
|
||||
|
||||
def _get_top_p_top_k_min_p(
|
||||
input_metadata: InputMetadata,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
vocab_size: int,
|
||||
) -> Tuple[List[float], List[int], List[float]]:
|
||||
top_ps: List[float] = []
|
||||
top_ks: List[int] = []
|
||||
min_ps: List[float] = []
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||
seq_ids, sampling_params = seq_group
|
||||
top_p = sampling_params.top_p
|
||||
min_p = sampling_params.min_p
|
||||
@ -299,9 +303,9 @@ def _get_top_p_top_k_min_p(
|
||||
top_k = min(sampling_params.top_k, vocab_size)
|
||||
# k=-1 means no truncation.
|
||||
top_k = vocab_size if top_k == -1 else top_k
|
||||
if (i < input_metadata.num_prompts
|
||||
if (i < sampling_metadata.num_prompts
|
||||
and sampling_params.prompt_logprobs is not None):
|
||||
prompt_len = input_metadata.prompt_lens[i]
|
||||
prompt_len = sampling_metadata.prompt_lens[i]
|
||||
top_ps += [top_p] * (prompt_len - 1)
|
||||
top_ks += [top_k] * (prompt_len - 1)
|
||||
min_ps += [min_p] * (prompt_len - 1)
|
||||
@ -471,11 +475,11 @@ def _beam_search_sample(
|
||||
def _sample(
|
||||
probs: torch.Tensor,
|
||||
logprobs: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
||||
categorized_sample_indices = input_metadata.categorized_sample_indices
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
||||
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||
_, sampling_params = seq_group
|
||||
sampling_type = sampling_params.sampling_type
|
||||
categorized_seq_group_ids[sampling_type].append(i)
|
||||
@ -483,8 +487,8 @@ def _sample(
|
||||
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
|
||||
for sampling_type in SamplingType:
|
||||
seq_group_ids = categorized_seq_group_ids[sampling_type]
|
||||
seq_groups = [input_metadata.seq_groups[i] for i in seq_group_ids]
|
||||
is_prompts = [i < input_metadata.num_prompts for i in seq_group_ids]
|
||||
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
|
||||
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
|
||||
sample_indices = categorized_sample_indices[sampling_type]
|
||||
num_tokens = len(sample_indices)
|
||||
if num_tokens == 0:
|
||||
@ -499,21 +503,22 @@ def _sample(
|
||||
elif sampling_type == SamplingType.BEAM:
|
||||
category_logprobs = logprobs[sample_indices]
|
||||
sample_results = _beam_search_sample(seq_groups, is_prompts,
|
||||
input_metadata.seq_data,
|
||||
sampling_metadata.seq_data,
|
||||
category_logprobs)
|
||||
else:
|
||||
raise ValueError(f"Unsupported sampling type: {sampling_type}")
|
||||
sample_results_dict.update(zip(seq_group_ids, sample_results))
|
||||
|
||||
sample_results = [
|
||||
sample_results_dict[i] for i in range(len(input_metadata.seq_groups))
|
||||
sample_results_dict[i]
|
||||
for i in range(len(sampling_metadata.seq_groups))
|
||||
]
|
||||
return sample_results
|
||||
|
||||
|
||||
def _get_logprobs(
|
||||
logprobs: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
sample_results: List[Tuple[List[int], List[int]]],
|
||||
) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[
|
||||
int, float]]]]:
|
||||
@ -523,16 +528,16 @@ def _get_logprobs(
|
||||
largest_num_logprobs = 0
|
||||
sample_idx = 0
|
||||
for i, (seq_group, sample_result) in enumerate(
|
||||
zip(input_metadata.seq_groups, sample_results)):
|
||||
zip(sampling_metadata.seq_groups, sample_results)):
|
||||
seq_ids, sampling_params = seq_group
|
||||
next_token_ids, parent_ids = sample_result
|
||||
num_parent_seqs = len(seq_ids)
|
||||
if (i < input_metadata.num_prompts
|
||||
if (i < sampling_metadata.num_prompts
|
||||
and sampling_params.prompt_logprobs is not None):
|
||||
largest_num_logprobs = max(largest_num_logprobs,
|
||||
sampling_params.prompt_logprobs)
|
||||
prompt_len = input_metadata.prompt_lens[i]
|
||||
prompt_tokens = input_metadata.seq_data[
|
||||
prompt_len = sampling_metadata.prompt_lens[i]
|
||||
prompt_tokens = sampling_metadata.seq_data[
|
||||
seq_ids[0]].prompt_token_ids
|
||||
batched_logprobs_query_seq_indices.extend(
|
||||
sample_idx + j for j in range(prompt_len - 1))
|
||||
@ -570,16 +575,16 @@ def _get_logprobs(
|
||||
sample_idx = 0
|
||||
query_result_idx = 0
|
||||
for i, (seq_group, sample_result) in enumerate(
|
||||
zip(input_metadata.seq_groups, sample_results)):
|
||||
zip(sampling_metadata.seq_groups, sample_results)):
|
||||
seq_ids, sampling_params = seq_group
|
||||
next_token_ids, parent_ids = sample_result
|
||||
|
||||
# Prompt logprobs
|
||||
if (i < input_metadata.num_prompts
|
||||
if (i < sampling_metadata.num_prompts
|
||||
and sampling_params.prompt_logprobs is not None):
|
||||
num_logprobs = sampling_params.prompt_logprobs
|
||||
prompt_len = input_metadata.prompt_lens[i]
|
||||
prompt_tokens = input_metadata.seq_data[
|
||||
prompt_len = sampling_metadata.prompt_lens[i]
|
||||
prompt_tokens = sampling_metadata.seq_data[
|
||||
seq_ids[0]].prompt_token_ids
|
||||
group_prompt_logprobs: PromptLogprobs = [None]
|
||||
for token_id in prompt_tokens[1:]:
|
||||
@ -625,13 +630,13 @@ def _get_logprobs(
|
||||
|
||||
def _build_sampler_output(
|
||||
sample_results: List[Tuple[List[int], List[int]]],
|
||||
input_metadata: InputMetadata,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
prompt_logprobs: List[Optional[PromptLogprobs]],
|
||||
sample_logprobs: List[SampleLogprobs],
|
||||
) -> SamplerOutput:
|
||||
sampler_output = []
|
||||
for (seq_group, sample_result, group_prompt_logprobs,
|
||||
group_sample_logprobs) in zip(input_metadata.seq_groups,
|
||||
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
|
||||
sample_results, prompt_logprobs,
|
||||
sample_logprobs):
|
||||
seq_ids, _ = seq_group
|
||||
|
||||
@ -39,6 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
@ -296,11 +297,18 @@ class AquilaForCausalLM(nn.Module):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> SamplerOutput:
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
return hidden_states
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
input_metadata)
|
||||
sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -38,6 +38,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
@ -311,11 +312,18 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> SamplerOutput:
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
return hidden_states
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
input_metadata)
|
||||
sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
@ -288,11 +289,18 @@ class BloomForCausalLM(nn.Module):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> SamplerOutput:
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
return hidden_states
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||
input_metadata)
|
||||
sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -22,6 +22,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
@ -350,11 +351,18 @@ class ChatGLMForCausalLM(nn.Module):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> SamplerOutput:
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
return hidden_states
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||
input_metadata)
|
||||
sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -41,6 +41,7 @@ from vllm.model_executor.parallel_utils.communication_op import (
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
@ -389,7 +390,7 @@ class FalconForCausalLM(nn.Module):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> SamplerOutput:
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(
|
||||
input_ids,
|
||||
positions,
|
||||
@ -397,9 +398,15 @@ class FalconForCausalLM(nn.Module):
|
||||
input_metadata,
|
||||
cache_events,
|
||||
)
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
@ -232,11 +233,18 @@ class GPT2LMHeadModel(nn.Module):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> SamplerOutput:
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
return hidden_states
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||
input_metadata)
|
||||
sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -36,6 +36,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
@ -251,11 +252,18 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> SamplerOutput:
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
return hidden_states
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||
input_metadata)
|
||||
sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
@ -238,11 +239,18 @@ class GPTJForCausalLM(nn.Module):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> SamplerOutput:
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
return hidden_states
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
input_metadata, self.lm_head.bias)
|
||||
sampling_metadata, self.lm_head.bias)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
@ -251,11 +252,18 @@ class GPTNeoXForCausalLM(nn.Module):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> SamplerOutput:
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
return hidden_states
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
next_tokens = self.sampler(self.embed_out.weight, hidden_states,
|
||||
input_metadata)
|
||||
sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -19,6 +19,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
@ -250,11 +251,18 @@ class InternLMForCausalLM(nn.Module):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> SamplerOutput:
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
return hidden_states
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
input_metadata)
|
||||
sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -41,6 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
@ -289,11 +290,18 @@ class LlamaForCausalLM(nn.Module):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> SamplerOutput:
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
return hidden_states
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
input_metadata)
|
||||
sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -41,6 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
@ -285,11 +286,18 @@ class MistralForCausalLM(nn.Module):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> SamplerOutput:
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
return hidden_states
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
input_metadata)
|
||||
sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -18,6 +18,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
@ -256,11 +257,18 @@ class MPTForCausalLM(nn.Module):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> SamplerOutput:
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
return hidden_states
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||
input_metadata)
|
||||
sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -36,6 +36,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
@ -308,11 +309,18 @@ class OPTForCausalLM(nn.Module):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> SamplerOutput:
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
return hidden_states
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||
input_metadata)
|
||||
sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -54,6 +54,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
@ -210,28 +211,6 @@ class PhiLayer(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PhiCausalLMHead(nn.Module):
|
||||
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
self.ln = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_epsilon)
|
||||
self.linear = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
bias=True)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
):
|
||||
hidden_states = self.ln(hidden_states)
|
||||
next_tokens = self.sampler(self.linear.weight, hidden_states,
|
||||
input_metadata, self.linear.bias)
|
||||
return next_tokens
|
||||
|
||||
|
||||
class PhiModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
@ -253,7 +232,7 @@ class PhiModel(nn.Module):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> SamplerOutput:
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embd(input_ids)
|
||||
for i in range(self.config.num_hidden_layers):
|
||||
cache_event = None if cache_events is None else cache_events[i]
|
||||
@ -268,6 +247,17 @@ class PhiModel(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PhiCausalLMHead(nn.Module):
|
||||
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
self.ln = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_epsilon)
|
||||
self.linear = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
bias=True)
|
||||
|
||||
|
||||
class PhiForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
@ -279,6 +269,7 @@ class PhiForCausalLM(nn.Module):
|
||||
|
||||
self.transformer = PhiModel(config, linear_method)
|
||||
self.lm_head = PhiCausalLMHead(config)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -287,11 +278,21 @@ class PhiForCausalLM(nn.Module):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> SamplerOutput:
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
lm_logits = self.lm_head(hidden_states, input_metadata)
|
||||
return lm_logits
|
||||
hidden_states = self.lm_head.ln(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
head = self.lm_head.linear
|
||||
next_tokens = self.sampler(head.weight, hidden_states,
|
||||
sampling_metadata, head.bias)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
|
||||
@ -23,6 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
@ -246,11 +247,18 @@ class QWenLMHeadModel(nn.Module):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> SamplerOutput:
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
return hidden_states
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
input_metadata)
|
||||
sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
@ -41,6 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
@ -284,11 +285,18 @@ class YiForCausalLM(nn.Module):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> SamplerOutput:
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
return hidden_states
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
input_metadata)
|
||||
sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
|
||||
43
vllm/model_executor/sampling_metadata.py
Normal file
43
vllm/model_executor/sampling_metadata.py
Normal file
@ -0,0 +1,43 @@
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
|
||||
class SamplingMetadata:
|
||||
"""Metadata for input sequences. Used in sampler.
|
||||
|
||||
Args:
|
||||
seq_groups: List of (seq_ids, sampling_params).
|
||||
seq_data: Seq_id -> SequenceData.
|
||||
prompt_lens: Lengths of prompts.
|
||||
selected_token_indices: Token indices selected for sampling.
|
||||
categorized_sample_indices: SamplingType -> token indicies to sample.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
seq_groups: List[Tuple[List[int], SamplingParams]],
|
||||
seq_data: Dict[int, SequenceData],
|
||||
prompt_lens: List[int],
|
||||
selected_token_indices: torch.Tensor,
|
||||
categorized_sample_indices: Dict[SamplingType, torch.Tensor],
|
||||
) -> None:
|
||||
self.seq_groups = seq_groups
|
||||
self.seq_data = seq_data
|
||||
self.prompt_lens = prompt_lens
|
||||
self.selected_token_indices = selected_token_indices
|
||||
self.categorized_sample_indices = categorized_sample_indices
|
||||
|
||||
self.num_prompts = len(prompt_lens)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
"SamplingMetadata("
|
||||
f"seq_groups={self.seq_groups}, "
|
||||
f"seq_data={self.seq_data}, "
|
||||
f"prompt_lens={self.prompt_lens}, "
|
||||
f"selected_token_indices={self.selected_token_indices}, "
|
||||
f"categorized_sample_indices={self.categorized_sample_indices})")
|
||||
334
vllm/worker/model_runner.py
Normal file
334
vllm/worker/model_runner.py
Normal file
@ -0,0 +1,334 @@
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_PAD_SLOT_ID = -1
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
):
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
|
||||
self.sliding_window = model_config.get_sliding_window()
|
||||
self.model = None
|
||||
self.block_size = None # Set after initial profiling.
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.model = get_model(self.model_config)
|
||||
|
||||
def set_block_size(self, block_size: int) -> None:
|
||||
self.block_size = block_size
|
||||
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[List[int]] = []
|
||||
input_positions: List[List[int]] = []
|
||||
slot_mapping: List[List[int]] = []
|
||||
|
||||
prompt_lens: List[int] = []
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert seq_group_metadata.is_prompt
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
assert len(seq_ids) == 1
|
||||
seq_id = seq_ids[0]
|
||||
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
prompt_tokens = seq_data.get_token_ids()
|
||||
prompt_len = len(prompt_tokens)
|
||||
prompt_lens.append(prompt_len)
|
||||
|
||||
input_tokens.append(prompt_tokens)
|
||||
# NOTE(woosuk): Here we assume that the first token in the prompt
|
||||
# is always the first token in the sequence.
|
||||
input_positions.append(list(range(prompt_len)))
|
||||
|
||||
if seq_group_metadata.block_tables is None:
|
||||
# During memory profiling, the block tables are not initialized
|
||||
# yet. In this case, we just use a dummy slot mapping.
|
||||
slot_mapping.append([_PAD_SLOT_ID] * prompt_len)
|
||||
continue
|
||||
|
||||
# Compute the slot mapping.
|
||||
slot_mapping.append([])
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
|
||||
# where start_idx is max(0, prompt_len - sliding_window).
|
||||
# For example, if the prompt len is 10, sliding window is 8, and
|
||||
# block size is 4, the first two tokens are masked and the slot
|
||||
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
|
||||
start_idx = 0
|
||||
if self.sliding_window is not None:
|
||||
start_idx = max(0, prompt_len - self.sliding_window)
|
||||
for i in range(prompt_len):
|
||||
if i < start_idx:
|
||||
slot_mapping[-1].append(_PAD_SLOT_ID)
|
||||
continue
|
||||
|
||||
block_number = block_table[i // self.block_size]
|
||||
block_offset = i % self.block_size
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping[-1].append(slot)
|
||||
|
||||
max_prompt_len = max(prompt_lens)
|
||||
input_tokens = _make_tensor_with_pad(input_tokens,
|
||||
max_prompt_len,
|
||||
pad=0,
|
||||
dtype=torch.long)
|
||||
input_positions = _make_tensor_with_pad(input_positions,
|
||||
max_prompt_len,
|
||||
pad=0,
|
||||
dtype=torch.long)
|
||||
slot_mapping = _make_tensor_with_pad(slot_mapping,
|
||||
max_prompt_len,
|
||||
pad=_PAD_SLOT_ID,
|
||||
dtype=torch.long)
|
||||
|
||||
input_metadata = InputMetadata(
|
||||
prompt_lens=prompt_lens,
|
||||
slot_mapping=slot_mapping,
|
||||
max_context_len=None,
|
||||
context_lens=None,
|
||||
block_tables=None,
|
||||
)
|
||||
return input_tokens, input_positions, input_metadata
|
||||
|
||||
def _prepare_decode(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[List[int]] = []
|
||||
input_positions: List[List[int]] = []
|
||||
slot_mapping: List[List[int]] = []
|
||||
context_lens: List[int] = []
|
||||
block_tables: List[List[int]] = []
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert not seq_group_metadata.is_prompt
|
||||
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
for seq_id in seq_ids:
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
generation_token = seq_data.get_last_token_id()
|
||||
input_tokens.append([generation_token])
|
||||
|
||||
context_len = seq_data.get_len()
|
||||
if self.sliding_window is not None:
|
||||
context_len = min(context_len, self.sliding_window)
|
||||
context_lens.append(context_len)
|
||||
|
||||
position = context_len - 1
|
||||
input_positions.append([position])
|
||||
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
block_number = block_table[position // self.block_size]
|
||||
block_offset = position % self.block_size
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append([slot])
|
||||
|
||||
if self.sliding_window is not None:
|
||||
sliding_window_blocks = (self.sliding_window //
|
||||
self.block_size)
|
||||
block_table = block_table[-sliding_window_blocks:]
|
||||
block_tables.append(block_table)
|
||||
|
||||
input_tokens = _make_tensor_with_pad(input_tokens,
|
||||
max_len=1,
|
||||
pad=0,
|
||||
dtype=torch.long)
|
||||
input_positions = _make_tensor_with_pad(input_positions,
|
||||
max_len=1,
|
||||
pad=0,
|
||||
dtype=torch.long)
|
||||
slot_mapping = _make_tensor_with_pad(slot_mapping,
|
||||
max_len=1,
|
||||
pad=_PAD_SLOT_ID,
|
||||
dtype=torch.long)
|
||||
max_context_len = max(context_lens)
|
||||
context_lens = torch.tensor(context_lens,
|
||||
dtype=torch.int,
|
||||
device="cuda")
|
||||
max_block_table_len = max([len(t) for t in block_tables])
|
||||
block_tables = _make_tensor_with_pad(block_tables,
|
||||
max_len=max_block_table_len,
|
||||
pad=0,
|
||||
dtype=torch.int)
|
||||
|
||||
input_metadata = InputMetadata(
|
||||
prompt_lens=[],
|
||||
slot_mapping=slot_mapping,
|
||||
max_context_len=max_context_len,
|
||||
context_lens=context_lens,
|
||||
block_tables=block_tables,
|
||||
)
|
||||
return input_tokens, input_positions, input_metadata
|
||||
|
||||
def _prepare_sample(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
prompt_lens: List[int],
|
||||
) -> SamplingMetadata:
|
||||
seq_groups: List[Tuple[List[int], SamplingParams]] = []
|
||||
selected_token_indices: List[int] = []
|
||||
selected_token_start_idx = 0
|
||||
categorized_sample_indices = {t: [] for t in SamplingType}
|
||||
categorized_sample_indices_start_idx = 0
|
||||
|
||||
max_prompt_len = max(prompt_lens) if prompt_lens else 1
|
||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
sampling_params = seq_group_metadata.sampling_params
|
||||
seq_groups.append((seq_ids, sampling_params))
|
||||
|
||||
if seq_group_metadata.is_prompt:
|
||||
assert len(seq_ids) == 1
|
||||
prompt_len = prompt_lens[i]
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
# NOTE: prompt token positions do not need sample, skip
|
||||
categorized_sample_indices_start_idx += prompt_len - 1
|
||||
|
||||
categorized_sample_indices[
|
||||
sampling_params.sampling_type].append(
|
||||
categorized_sample_indices_start_idx)
|
||||
categorized_sample_indices_start_idx += 1
|
||||
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
selected_token_indices.extend(
|
||||
range(selected_token_start_idx,
|
||||
selected_token_start_idx + prompt_len - 1))
|
||||
selected_token_indices.append(selected_token_start_idx +
|
||||
prompt_len - 1)
|
||||
selected_token_start_idx += max_prompt_len
|
||||
else:
|
||||
num_seqs = len(seq_ids)
|
||||
selected_token_indices.extend(
|
||||
range(selected_token_start_idx,
|
||||
selected_token_start_idx + num_seqs))
|
||||
selected_token_start_idx += num_seqs
|
||||
|
||||
categorized_sample_indices[
|
||||
sampling_params.sampling_type].extend(
|
||||
range(categorized_sample_indices_start_idx,
|
||||
categorized_sample_indices_start_idx + num_seqs))
|
||||
categorized_sample_indices_start_idx += num_seqs
|
||||
|
||||
selected_token_indices = torch.tensor(selected_token_indices,
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
categorized_sample_indices = {
|
||||
t: torch.tensor(seq_ids, dtype=torch.int, device="cuda")
|
||||
for t, seq_ids in categorized_sample_indices.items()
|
||||
}
|
||||
|
||||
seq_data: Dict[int, SequenceData] = {}
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
seq_data.update(seq_group_metadata.seq_data)
|
||||
|
||||
sampling_metadata = SamplingMetadata(
|
||||
seq_groups=seq_groups,
|
||||
seq_data=seq_data,
|
||||
prompt_lens=prompt_lens,
|
||||
selected_token_indices=selected_token_indices,
|
||||
categorized_sample_indices=categorized_sample_indices,
|
||||
)
|
||||
return sampling_metadata
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
cache_events: Optional[List[torch.cuda.Event]] = None,
|
||||
) -> SamplerOutput:
|
||||
# NOTE: We assume that all sequences in the group are all prompts or
|
||||
# all decodes.
|
||||
# Prepare input tensors.
|
||||
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||
if is_prompt:
|
||||
inputs = self._prepare_prompt(seq_group_metadata_list)
|
||||
input_tokens, input_positions, input_metadata = inputs
|
||||
else:
|
||||
inputs = self._prepare_decode(seq_group_metadata_list)
|
||||
input_tokens, input_positions, input_metadata = inputs
|
||||
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
|
||||
input_metadata.prompt_lens)
|
||||
|
||||
# Execute the model.
|
||||
hidden_states = self.model(
|
||||
input_ids=input_tokens,
|
||||
positions=input_positions,
|
||||
kv_caches=kv_caches,
|
||||
input_metadata=input_metadata,
|
||||
cache_events=cache_events,
|
||||
)
|
||||
|
||||
# Sample the next token.
|
||||
output = self.model.sample(
|
||||
hidden_states=hidden_states,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
return output
|
||||
|
||||
@torch.inference_mode()
|
||||
def profile_run(self) -> None:
|
||||
# Enable top-k sampling to reflect the accurate memory usage.
|
||||
vocab_size = self.model_config.get_vocab_size()
|
||||
sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
|
||||
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
max_num_seqs = self.scheduler_config.max_num_seqs
|
||||
|
||||
# Profile memory usage with max_num_sequences sequences and the total
|
||||
# number of tokens equal to max_num_batched_tokens.
|
||||
seqs: List[SequenceGroupMetadata] = []
|
||||
for group_id in range(max_num_seqs):
|
||||
seq_len = (max_num_batched_tokens // max_num_seqs +
|
||||
(group_id < max_num_batched_tokens % max_num_seqs))
|
||||
seq_data = SequenceData([0] * seq_len)
|
||||
seq = SequenceGroupMetadata(
|
||||
request_id=str(group_id),
|
||||
is_prompt=True,
|
||||
seq_data={group_id: seq_data},
|
||||
sampling_params=sampling_params,
|
||||
block_tables=None,
|
||||
)
|
||||
seqs.append(seq)
|
||||
|
||||
# Run the model with the dummy inputs.
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
kv_caches = [(None, None)] * num_layers
|
||||
self.execute_model(seqs, kv_caches)
|
||||
return
|
||||
|
||||
|
||||
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
|
||||
assert len(x) <= max_len
|
||||
return x + [pad] * (max_len - len(x))
|
||||
|
||||
|
||||
def _make_tensor_with_pad(
|
||||
x: List[List[int]],
|
||||
max_len: int,
|
||||
pad: int,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
|
||||
return torch.tensor(padded_x, dtype=dtype, device="cuda")
|
||||
@ -1,18 +1,18 @@
|
||||
"""A GPU worker class."""
|
||||
import os
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from vllm.model_executor import get_model, InputMetadata, set_random_seed
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
initialize_model_parallel)
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.model_runner import ModelRunner
|
||||
from vllm.utils import get_gpu_memory
|
||||
|
||||
|
||||
@ -38,11 +38,11 @@ class Worker:
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
|
||||
self.model_runner = ModelRunner(model_config, parallel_config,
|
||||
scheduler_config)
|
||||
# Uninitialized cache engine. Will be initialized by
|
||||
# self.init_cache_engine().
|
||||
self.cache_config = None
|
||||
self.block_size = None
|
||||
self.sliding_window = None
|
||||
self.cache_engine = None
|
||||
self.cache_events = None
|
||||
self.gpu_cache = None
|
||||
@ -69,7 +69,7 @@ class Worker:
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
def load_model(self):
|
||||
self.model = get_model(self.model_config)
|
||||
self.model_runner.load_model()
|
||||
|
||||
@torch.inference_mode()
|
||||
def profile_num_available_blocks(
|
||||
@ -83,40 +83,9 @@ class Worker:
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# Profile memory usage with max_num_sequences sequences and the total
|
||||
# number of tokens equal to max_num_batched_tokens.
|
||||
|
||||
# Enable top-k sampling to reflect the accurate memory usage.
|
||||
vocab_size = self.model.config.vocab_size
|
||||
sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
|
||||
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
max_num_seqs = self.scheduler_config.max_num_seqs
|
||||
seqs = []
|
||||
for group_id in range(max_num_seqs):
|
||||
seq_len = (max_num_batched_tokens // max_num_seqs +
|
||||
(group_id < max_num_batched_tokens % max_num_seqs))
|
||||
seq_data = SequenceData([0] * seq_len)
|
||||
seq = SequenceGroupMetadata(
|
||||
request_id=str(group_id),
|
||||
is_prompt=True,
|
||||
seq_data={group_id: seq_data},
|
||||
sampling_params=sampling_params,
|
||||
block_tables=None,
|
||||
)
|
||||
seqs.append(seq)
|
||||
|
||||
input_tokens, input_positions, input_metadata = self._prepare_inputs(
|
||||
seqs)
|
||||
|
||||
# Execute the model.
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
self.model(
|
||||
input_ids=input_tokens,
|
||||
positions=input_positions,
|
||||
kv_caches=[(None, None)] * num_layers,
|
||||
input_metadata=input_metadata,
|
||||
cache_events=None,
|
||||
)
|
||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||
# of the model.
|
||||
self.model_runner.profile_run()
|
||||
|
||||
# Calculate the number of blocks that can be allocated with the
|
||||
# profiled peak memory.
|
||||
@ -140,197 +109,11 @@ class Worker:
|
||||
|
||||
def init_cache_engine(self, cache_config: CacheConfig) -> None:
|
||||
self.cache_config = cache_config
|
||||
self.block_size = cache_config.block_size
|
||||
self.sliding_window = cache_config.sliding_window
|
||||
|
||||
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
|
||||
self.parallel_config)
|
||||
self.cache_events = self.cache_engine.events
|
||||
self.gpu_cache = self.cache_engine.gpu_cache
|
||||
|
||||
def _prepare_inputs(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
|
||||
seq_groups: List[Tuple[List[int], SamplingParams]] = []
|
||||
input_tokens: List[List[int]] = []
|
||||
input_positions: List[List[int]] = []
|
||||
slot_mapping: List[List[int]] = []
|
||||
selected_token_indices: List[int] = []
|
||||
selected_token_start_idx = 0
|
||||
categorized_sample_indices = {t: [] for t in SamplingType}
|
||||
categorized_sample_indices_start_idx = 0
|
||||
|
||||
# Add prompt tokens.
|
||||
prompt_lens: List[int] = []
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
if not seq_group_metadata.is_prompt:
|
||||
continue
|
||||
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
sampling_params = seq_group_metadata.sampling_params
|
||||
seq_groups.append((seq_ids, sampling_params))
|
||||
|
||||
# Use any sequence in the group.
|
||||
seq_id = seq_ids[0]
|
||||
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
prompt_tokens = seq_data.get_token_ids()
|
||||
prompt_len = len(prompt_tokens)
|
||||
prompt_lens.append(prompt_len)
|
||||
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
# NOTE: prompt token positions do not need sample, skip
|
||||
categorized_sample_indices_start_idx += prompt_len - 1
|
||||
|
||||
categorized_sample_indices[sampling_params.sampling_type].append(
|
||||
categorized_sample_indices_start_idx)
|
||||
categorized_sample_indices_start_idx += 1
|
||||
|
||||
input_tokens.append(prompt_tokens)
|
||||
# NOTE(woosuk): Here we assume that the first token in the prompt
|
||||
# is always the first token in the sequence.
|
||||
input_positions.append(list(range(prompt_len)))
|
||||
|
||||
if seq_group_metadata.block_tables is None:
|
||||
# During memory profiling, the block tables are not initialized
|
||||
# yet. In this case, we just use a dummy slot mapping.
|
||||
slot_mapping.append([0] * prompt_len)
|
||||
continue
|
||||
|
||||
# Compute the slot mapping.
|
||||
slot_mapping.append([])
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
for i in range(prompt_len):
|
||||
block_number = block_table[i // self.block_size]
|
||||
block_offset = i % self.block_size
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping[-1].append(slot)
|
||||
|
||||
# Add generation tokens.
|
||||
max_context_len = 0
|
||||
max_num_blocks_per_seq = 0
|
||||
context_lens: List[int] = []
|
||||
generation_block_tables: List[List[int]] = []
|
||||
max_seq_len = max(prompt_lens) if prompt_lens else 1
|
||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||
if seq_group_metadata.is_prompt:
|
||||
# We need to do this in this loop as we need to know max_seq_len
|
||||
assert len(
|
||||
seq_ids) == 1, "Prompt input should have only one seq."
|
||||
sampling_params = seq_group_metadata.sampling_params
|
||||
assert len(prompt_lens) == len(seq_group_metadata_list)
|
||||
prompt_len = prompt_lens[i]
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
selected_token_indices.extend(
|
||||
range(selected_token_start_idx,
|
||||
selected_token_start_idx + prompt_len - 1))
|
||||
selected_token_indices.append(selected_token_start_idx +
|
||||
prompt_len - 1)
|
||||
selected_token_start_idx += max_seq_len
|
||||
continue
|
||||
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
sampling_params = seq_group_metadata.sampling_params
|
||||
seq_groups.append((seq_ids, sampling_params))
|
||||
|
||||
num_seqs = len(seq_ids)
|
||||
selected_token_indices.extend(
|
||||
range(selected_token_start_idx,
|
||||
selected_token_start_idx + num_seqs))
|
||||
selected_token_start_idx += num_seqs
|
||||
|
||||
categorized_sample_indices[sampling_params.sampling_type].extend(
|
||||
range(categorized_sample_indices_start_idx,
|
||||
categorized_sample_indices_start_idx + num_seqs))
|
||||
categorized_sample_indices_start_idx += num_seqs
|
||||
|
||||
for seq_id in seq_ids:
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
generation_token = seq_data.get_last_token_id()
|
||||
input_tokens.append([generation_token])
|
||||
|
||||
context_len = seq_data.get_len()
|
||||
position = context_len - 1
|
||||
if self.sliding_window is not None:
|
||||
context_len = min(context_len, self.sliding_window)
|
||||
input_positions.append([position])
|
||||
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
|
||||
max_context_len = max(max_context_len, context_len)
|
||||
max_num_blocks_per_seq = max(max_num_blocks_per_seq,
|
||||
len(block_table))
|
||||
context_lens.append(context_len)
|
||||
|
||||
block_number = block_table[position // self.block_size]
|
||||
block_offset = position % self.block_size
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append([slot])
|
||||
|
||||
if self.sliding_window is not None:
|
||||
sliding_window_blocks = (self.sliding_window //
|
||||
self.block_size)
|
||||
block_table = block_table[-sliding_window_blocks:]
|
||||
generation_block_tables.append(block_table)
|
||||
|
||||
padded_input_tokens = [
|
||||
_pad_to_max(tokens, max_seq_len, pad=0) for tokens in input_tokens
|
||||
]
|
||||
padded_input_positions = [
|
||||
_pad_to_max(positions, max_seq_len, pad=0)
|
||||
for positions in input_positions
|
||||
]
|
||||
padded_slot_mapping = [
|
||||
_pad_to_max(mapping, max_seq_len, pad=-1)
|
||||
for mapping in slot_mapping
|
||||
]
|
||||
padded_block_tables = [
|
||||
_pad_to_max(block_table, max_num_blocks_per_seq, pad=0)
|
||||
for block_table in generation_block_tables
|
||||
]
|
||||
|
||||
# Convert to tensors.
|
||||
tokens_tensor = torch.tensor(padded_input_tokens,
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
positions_tensor = torch.tensor(padded_input_positions,
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
slot_mapping_tensor = torch.tensor(padded_slot_mapping,
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
context_lens_tensor = torch.tensor(context_lens,
|
||||
dtype=torch.int,
|
||||
device="cuda")
|
||||
selected_token_indices = torch.tensor(selected_token_indices,
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
categorized_sample_indices = {
|
||||
t: torch.tensor(seq_ids, dtype=torch.int, device="cuda")
|
||||
for t, seq_ids in categorized_sample_indices.items()
|
||||
}
|
||||
block_tables_tensor = torch.tensor(padded_block_tables,
|
||||
dtype=torch.int,
|
||||
device="cuda")
|
||||
|
||||
seq_data: Dict[int, SequenceData] = {}
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
seq_data.update(seq_group_metadata.seq_data)
|
||||
|
||||
input_metadata = InputMetadata(
|
||||
seq_groups=seq_groups,
|
||||
seq_data=seq_data,
|
||||
prompt_lens=prompt_lens,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
context_lens=context_lens_tensor,
|
||||
max_context_len=max_context_len,
|
||||
block_tables=block_tables_tensor,
|
||||
selected_token_indices=selected_token_indices,
|
||||
categorized_sample_indices=categorized_sample_indices,
|
||||
sliding_window=self.sliding_window,
|
||||
)
|
||||
return tokens_tensor, positions_tensor, input_metadata
|
||||
self.model_runner.set_block_size(self.cache_engine.block_size)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
@ -361,18 +144,8 @@ class Worker:
|
||||
event.wait()
|
||||
return {}
|
||||
|
||||
# Prepare input tensors.
|
||||
input_tokens, input_positions, input_metadata = self._prepare_inputs(
|
||||
seq_group_metadata_list)
|
||||
|
||||
# Execute the model.
|
||||
output = self.model(
|
||||
input_ids=input_tokens,
|
||||
positions=input_positions,
|
||||
kv_caches=self.gpu_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_events=cache_events,
|
||||
)
|
||||
output = self.model_runner.execute_model(seq_group_metadata_list,
|
||||
self.gpu_cache, cache_events)
|
||||
return output
|
||||
|
||||
|
||||
@ -407,14 +180,6 @@ def _init_distributed_environment(
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
|
||||
def _pad_to_alignment(x: List[int], multiple_of: int, pad: int) -> List[int]:
|
||||
return x + [pad] * ((-len(x)) % multiple_of)
|
||||
|
||||
|
||||
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
|
||||
return x + [pad] * (max_len - len(x))
|
||||
|
||||
|
||||
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
||||
# Check if the GPU supports the dtype.
|
||||
if torch_dtype == torch.bfloat16:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user