mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 12:37:22 +08:00
Refactor Worker & InputMetadata (#1843)
This commit is contained in:
parent
c782195662
commit
27feead2f8
@ -161,6 +161,12 @@ class ModelConfig:
|
|||||||
"must be divisible by pipeline parallel size "
|
"must be divisible by pipeline parallel size "
|
||||||
f"({pipeline_parallel_size}).")
|
f"({pipeline_parallel_size}).")
|
||||||
|
|
||||||
|
def get_sliding_window(self) -> Optional[int]:
|
||||||
|
return getattr(self.hf_config, "sliding_window", None)
|
||||||
|
|
||||||
|
def get_vocab_size(self) -> int:
|
||||||
|
return self.hf_config.vocab_size
|
||||||
|
|
||||||
def get_hidden_size(self) -> int:
|
def get_hidden_size(self) -> int:
|
||||||
return self.hf_config.hidden_size
|
return self.hf_config.hidden_size
|
||||||
|
|
||||||
|
|||||||
@ -201,9 +201,10 @@ class EngineArgs:
|
|||||||
self.dtype, self.seed, self.revision,
|
self.dtype, self.seed, self.revision,
|
||||||
self.tokenizer_revision, self.max_model_len,
|
self.tokenizer_revision, self.max_model_len,
|
||||||
self.quantization)
|
self.quantization)
|
||||||
cache_config = CacheConfig(
|
cache_config = CacheConfig(self.block_size,
|
||||||
self.block_size, self.gpu_memory_utilization, self.swap_space,
|
self.gpu_memory_utilization,
|
||||||
getattr(model_config.hf_config, 'sliding_window', None))
|
self.swap_space,
|
||||||
|
model_config.get_sliding_window())
|
||||||
parallel_config = ParallelConfig(self.pipeline_parallel_size,
|
parallel_config = ParallelConfig(self.pipeline_parallel_size,
|
||||||
self.tensor_parallel_size,
|
self.tensor_parallel_size,
|
||||||
self.worker_use_ray,
|
self.worker_use_ray,
|
||||||
|
|||||||
@ -88,8 +88,6 @@ class LLMEngine:
|
|||||||
|
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.cache_config = cache_config
|
self.cache_config = cache_config
|
||||||
assert self.cache_config.sliding_window == getattr(
|
|
||||||
self.model_config.hf_config, "sliding_window", None)
|
|
||||||
self.parallel_config = parallel_config
|
self.parallel_config = parallel_config
|
||||||
self.scheduler_config = scheduler_config
|
self.scheduler_config = scheduler_config
|
||||||
self.log_stats = log_stats
|
self.log_stats = log_stats
|
||||||
|
|||||||
@ -1,9 +1,11 @@
|
|||||||
from vllm.model_executor.input_metadata import InputMetadata
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.model_executor.utils import set_random_seed
|
from vllm.model_executor.utils import set_random_seed
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"InputMetadata",
|
"InputMetadata",
|
||||||
"get_model",
|
"get_model",
|
||||||
|
"SamplingMetadata",
|
||||||
"set_random_seed",
|
"set_random_seed",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,91 +1,42 @@
|
|||||||
from typing import Dict, List, Optional, Tuple
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from xformers.ops import AttentionBias
|
|
||||||
|
|
||||||
from vllm.sampling_params import SamplingParams, SamplingType
|
|
||||||
from vllm.sequence import SequenceData
|
|
||||||
|
|
||||||
|
|
||||||
class InputMetadata:
|
class InputMetadata:
|
||||||
"""Metadata for input sequences. Used for PagedAttention.
|
"""Metadata for input sequences. Used in PagedAttention.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
seq_groups: List of (seq_ids, sampling_params).
|
|
||||||
seq_data: Seq_id -> SequenceData.
|
|
||||||
prompt_lens: Lengths of prompts.
|
prompt_lens: Lengths of prompts.
|
||||||
slot_mapping: The address to write the new KV to of each token.
|
slot_mapping: The address to write the new KV to of each token.
|
||||||
context_lens: the length of attention context for each generation token.
|
|
||||||
max_context_len: The maximum context length.
|
max_context_len: The maximum context length.
|
||||||
|
context_lens: the length of attention context for each sequence.
|
||||||
block_tables: The block tables. (Seq id -> list of physical block)
|
block_tables: The block tables. (Seq id -> list of physical block)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
seq_groups: List[Tuple[List[int], SamplingParams]],
|
|
||||||
seq_data: Dict[int, SequenceData],
|
|
||||||
prompt_lens: List[int],
|
prompt_lens: List[int],
|
||||||
slot_mapping: torch.Tensor,
|
slot_mapping: torch.Tensor,
|
||||||
context_lens: torch.Tensor,
|
max_context_len: Optional[int],
|
||||||
max_context_len: int,
|
context_lens: Optional[torch.Tensor],
|
||||||
block_tables: torch.Tensor,
|
block_tables: Optional[torch.Tensor],
|
||||||
selected_token_indices: torch.Tensor,
|
|
||||||
categorized_sample_indices: Dict[SamplingType, torch.Tensor],
|
|
||||||
sliding_window: Optional[int] = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.seq_groups = seq_groups
|
|
||||||
self.seq_data = seq_data
|
|
||||||
self.prompt_lens = prompt_lens
|
self.prompt_lens = prompt_lens
|
||||||
|
self.max_context_len = max_context_len
|
||||||
self.slot_mapping = slot_mapping
|
self.slot_mapping = slot_mapping
|
||||||
self.context_lens = context_lens
|
self.context_lens = context_lens
|
||||||
self.max_context_len = max_context_len
|
|
||||||
self.block_tables = block_tables
|
self.block_tables = block_tables
|
||||||
self.selected_token_indices = selected_token_indices
|
|
||||||
self.categorized_sample_indices = categorized_sample_indices
|
|
||||||
|
|
||||||
self.max_prompt_len = max(prompt_lens) if prompt_lens else 0
|
|
||||||
self.to_cache = None
|
|
||||||
if sliding_window is not None:
|
|
||||||
# We need to keep the positions of sliding windows within
|
|
||||||
# the key / value tables, this is helpful to know which
|
|
||||||
# elements we need to cache.
|
|
||||||
to_cache, start_idx = [], 0
|
|
||||||
for prompt_len in self.prompt_lens:
|
|
||||||
to_cache.extend(
|
|
||||||
range(
|
|
||||||
start_idx + max(0, prompt_len - sliding_window),
|
|
||||||
start_idx + prompt_len,
|
|
||||||
))
|
|
||||||
start_idx += self.max_prompt_len
|
|
||||||
to_cache.extend(range(start_idx, slot_mapping.shape[0]))
|
|
||||||
self.to_cache = torch.tensor(to_cache,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=self.slot_mapping.device)
|
|
||||||
|
|
||||||
self.num_prompts = len(prompt_lens)
|
|
||||||
self.num_prompt_tokens = self.num_prompts * self.max_prompt_len
|
|
||||||
self.num_generation_tokens = context_lens.shape[0]
|
|
||||||
if block_tables.numel() > 0:
|
|
||||||
self.max_num_blocks_per_seq = block_tables.shape[1]
|
|
||||||
else:
|
|
||||||
self.max_num_blocks_per_seq = 0
|
|
||||||
assert block_tables.shape[0] == self.num_generation_tokens
|
|
||||||
|
|
||||||
|
self.is_prompt = len(prompt_lens) > 0
|
||||||
# Set during the execution of the first attention op.
|
# Set during the execution of the first attention op.
|
||||||
self.attn_bias: Optional[AttentionBias] = None
|
# FIXME(woosuk): This is a hack.
|
||||||
|
self.attn_bias = None
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
# Print only useful metadata.
|
return ("InputMetadata("
|
||||||
return (
|
f"prompt_lens={self.prompt_lens}, "
|
||||||
f'InputMetadata('
|
f"max_context_len={self.max_context_len}, "
|
||||||
f'num_prompt_tokens={self.num_prompt_tokens}, '
|
f"slot_mapping={self.slot_mapping}, "
|
||||||
f'num_prompts={self.num_prompts}, '
|
f"context_lens={self.context_lens}, "
|
||||||
f'prompt_lens={self.prompt_lens}, '
|
f"block_tables={self.block_tables})")
|
||||||
f'num_generation_tokens={self.num_generation_tokens}, '
|
|
||||||
f'context_lens={self.context_lens}, '
|
|
||||||
f'max_context_len={self.max_context_len}), '
|
|
||||||
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
|
|
||||||
f'block_tables={self.block_tables}, '
|
|
||||||
f'selected_token_indices={self.selected_token_indices}, '
|
|
||||||
f'categorized_sample_indices={self.categorized_sample_indices}, '
|
|
||||||
f'slot_mapping={self.slot_mapping})')
|
|
||||||
|
|||||||
@ -101,23 +101,15 @@ class PagedAttention(nn.Module):
|
|||||||
# vectors will not be cached. This happens during the initial memory
|
# vectors will not be cached. This happens during the initial memory
|
||||||
# profiling run.
|
# profiling run.
|
||||||
if key_cache is not None and value_cache is not None:
|
if key_cache is not None and value_cache is not None:
|
||||||
key_to_cache = key
|
|
||||||
value_to_cache = value
|
|
||||||
if input_metadata.to_cache is not None:
|
|
||||||
key_to_cache = key_to_cache[input_metadata.to_cache]
|
|
||||||
value_to_cache = value_to_cache[input_metadata.to_cache]
|
|
||||||
slot_mapping = slot_mapping[input_metadata.to_cache]
|
|
||||||
|
|
||||||
cache_ops.reshape_and_cache(
|
cache_ops.reshape_and_cache(
|
||||||
key_to_cache,
|
key,
|
||||||
value_to_cache,
|
value,
|
||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
slot_mapping,
|
slot_mapping,
|
||||||
)
|
)
|
||||||
|
|
||||||
is_prompt = len(input_metadata.prompt_lens) > 0
|
if input_metadata.is_prompt:
|
||||||
if is_prompt:
|
|
||||||
# Prompt run.
|
# Prompt run.
|
||||||
if self.num_kv_heads != self.num_heads:
|
if self.num_kv_heads != self.num_heads:
|
||||||
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
|
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
|
||||||
|
|||||||
@ -4,9 +4,9 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.model_executor.input_metadata import InputMetadata
|
|
||||||
from vllm.model_executor.parallel_utils.communication_op import (
|
from vllm.model_executor.parallel_utils.communication_op import (
|
||||||
tensor_model_parallel_all_gather)
|
tensor_model_parallel_all_gather)
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sampling_params import SamplingParams, SamplingType
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput,
|
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput,
|
||||||
SequenceData, SequenceGroupOutput, SequenceOutput)
|
SequenceData, SequenceGroupOutput, SequenceOutput)
|
||||||
@ -37,29 +37,30 @@ class Sampler(nn.Module):
|
|||||||
self,
|
self,
|
||||||
embedding: torch.Tensor,
|
embedding: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
input_metadata: InputMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
embedding_bias: Optional[torch.Tensor] = None,
|
embedding_bias: Optional[torch.Tensor] = None,
|
||||||
) -> SamplerOutput:
|
) -> SamplerOutput:
|
||||||
# Get the hidden states that we use for sampling.
|
# Get the hidden states that we use for sampling.
|
||||||
hidden_states = _prune_hidden_states(hidden_states, input_metadata)
|
hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
|
||||||
|
|
||||||
# Get the logits for the next tokens.
|
# Get the logits for the next tokens.
|
||||||
logits = _get_logits(hidden_states, embedding, embedding_bias,
|
logits = _get_logits(hidden_states, embedding, embedding_bias,
|
||||||
self.vocab_size)
|
self.vocab_size)
|
||||||
|
|
||||||
# Apply logits processors (if any).
|
# Apply logits processors (if any).
|
||||||
logits = _apply_logits_processors(logits, input_metadata)
|
logits = _apply_logits_processors(logits, sampling_metadata)
|
||||||
# Apply presence and frequency penalties.
|
# Apply presence and frequency penalties.
|
||||||
presence_penalties, frequency_penalties, repetition_penalties = (
|
presence_penalties, frequency_penalties, repetition_penalties = (
|
||||||
_get_penalties(input_metadata))
|
_get_penalties(sampling_metadata))
|
||||||
assert len(presence_penalties) == logits.shape[0]
|
assert len(presence_penalties) == logits.shape[0]
|
||||||
assert len(frequency_penalties) == logits.shape[0]
|
assert len(frequency_penalties) == logits.shape[0]
|
||||||
assert len(repetition_penalties) == logits.shape[0]
|
assert len(repetition_penalties) == logits.shape[0]
|
||||||
logits = _apply_penalties(logits, input_metadata, presence_penalties,
|
logits = _apply_penalties(logits, sampling_metadata,
|
||||||
frequency_penalties, repetition_penalties)
|
presence_penalties, frequency_penalties,
|
||||||
|
repetition_penalties)
|
||||||
|
|
||||||
# Apply temperature scaling.
|
# Apply temperature scaling.
|
||||||
temperatures = _get_temperatures(input_metadata)
|
temperatures = _get_temperatures(sampling_metadata)
|
||||||
assert len(temperatures) == logits.shape[0]
|
assert len(temperatures) == logits.shape[0]
|
||||||
if any(t != 1.0 for t in temperatures):
|
if any(t != 1.0 for t in temperatures):
|
||||||
t = torch.tensor(temperatures,
|
t = torch.tensor(temperatures,
|
||||||
@ -70,7 +71,7 @@ class Sampler(nn.Module):
|
|||||||
|
|
||||||
# Apply top-p and top-k truncation.
|
# Apply top-p and top-k truncation.
|
||||||
top_ps, top_ks, min_ps = _get_top_p_top_k_min_p(
|
top_ps, top_ks, min_ps = _get_top_p_top_k_min_p(
|
||||||
input_metadata, self.vocab_size)
|
sampling_metadata, self.vocab_size)
|
||||||
assert len(top_ps) == len(top_ks) == logits.shape[0]
|
assert len(top_ps) == len(top_ks) == logits.shape[0]
|
||||||
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
|
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
|
||||||
do_top_k = any(k != self.vocab_size for k in top_ks)
|
do_top_k = any(k != self.vocab_size for k in top_ks)
|
||||||
@ -89,11 +90,11 @@ class Sampler(nn.Module):
|
|||||||
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
||||||
|
|
||||||
# Sample the next tokens.
|
# Sample the next tokens.
|
||||||
sample_results = _sample(probs, logprobs, input_metadata)
|
sample_results = _sample(probs, logprobs, sampling_metadata)
|
||||||
# Get the logprobs query results.
|
# Get the logprobs query results.
|
||||||
prompt_logprobs, sample_logprobs = _get_logprobs(
|
prompt_logprobs, sample_logprobs = _get_logprobs(
|
||||||
logprobs, input_metadata, sample_results)
|
logprobs, sampling_metadata, sample_results)
|
||||||
return _build_sampler_output(sample_results, input_metadata,
|
return _build_sampler_output(sample_results, sampling_metadata,
|
||||||
prompt_logprobs, sample_logprobs)
|
prompt_logprobs, sample_logprobs)
|
||||||
|
|
||||||
|
|
||||||
@ -112,29 +113,30 @@ def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
|
|||||||
|
|
||||||
def _prune_hidden_states(
|
def _prune_hidden_states(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
input_metadata: InputMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||||
return hidden_states.index_select(0, input_metadata.selected_token_indices)
|
return hidden_states.index_select(0,
|
||||||
|
sampling_metadata.selected_token_indices)
|
||||||
|
|
||||||
|
|
||||||
def _get_penalties(
|
def _get_penalties(
|
||||||
input_metadata: InputMetadata
|
sampling_metadata: SamplingMetadata
|
||||||
) -> Tuple[List[float], List[float], List[float]]:
|
) -> Tuple[List[float], List[float], List[float]]:
|
||||||
# Collect the presence and frequency penalties.
|
# Collect the presence and frequency penalties.
|
||||||
presence_penalties: List[float] = []
|
presence_penalties: List[float] = []
|
||||||
frequency_penalties: List[float] = []
|
frequency_penalties: List[float] = []
|
||||||
repetition_penalties: List[float] = []
|
repetition_penalties: List[float] = []
|
||||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||||
seq_ids, sampling_params = seq_group
|
seq_ids, sampling_params = seq_group
|
||||||
p = sampling_params.presence_penalty
|
p = sampling_params.presence_penalty
|
||||||
f = sampling_params.frequency_penalty
|
f = sampling_params.frequency_penalty
|
||||||
r = sampling_params.repetition_penalty
|
r = sampling_params.repetition_penalty
|
||||||
if (i < input_metadata.num_prompts
|
if (i < sampling_metadata.num_prompts
|
||||||
and sampling_params.prompt_logprobs is not None):
|
and sampling_params.prompt_logprobs is not None):
|
||||||
# NOTE: We do not apply presence and frequency penalties for the
|
# NOTE: We do not apply presence and frequency penalties for the
|
||||||
# prompt token positions where we don't sample new tokens.
|
# prompt token positions where we don't sample new tokens.
|
||||||
prompt_len = input_metadata.prompt_lens[i]
|
prompt_len = sampling_metadata.prompt_lens[i]
|
||||||
presence_penalties += [0] * (prompt_len - 1)
|
presence_penalties += [0] * (prompt_len - 1)
|
||||||
frequency_penalties += [0] * (prompt_len - 1)
|
frequency_penalties += [0] * (prompt_len - 1)
|
||||||
repetition_penalties += [1] * (prompt_len - 1)
|
repetition_penalties += [1] * (prompt_len - 1)
|
||||||
@ -145,21 +147,21 @@ def _get_penalties(
|
|||||||
|
|
||||||
|
|
||||||
def _get_prompt_and_output_tokens(
|
def _get_prompt_and_output_tokens(
|
||||||
input_metadata: InputMetadata
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Tuple[List[List[int]], List[List[int]]]:
|
) -> Tuple[List[List[int]], List[List[int]]]:
|
||||||
prompt_tokens: List[List[int]] = []
|
prompt_tokens: List[List[int]] = []
|
||||||
output_tokens: List[List[int]] = []
|
output_tokens: List[List[int]] = []
|
||||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||||
seq_ids, sampling_params = seq_group
|
seq_ids, sampling_params = seq_group
|
||||||
if (i < input_metadata.num_prompts
|
if (i < sampling_metadata.num_prompts
|
||||||
and sampling_params.prompt_logprobs is not None):
|
and sampling_params.prompt_logprobs is not None):
|
||||||
# NOTE: prompt token positions do not need output tokens to
|
# NOTE: prompt token positions do not need output tokens to
|
||||||
# compute penalties.
|
# compute penalties.
|
||||||
prompt_len = input_metadata.prompt_lens[i]
|
prompt_len = sampling_metadata.prompt_lens[i]
|
||||||
prompt_tokens.extend([] for _ in range(prompt_len - 1))
|
prompt_tokens.extend([] for _ in range(prompt_len - 1))
|
||||||
output_tokens.extend([] for _ in range(prompt_len - 1))
|
output_tokens.extend([] for _ in range(prompt_len - 1))
|
||||||
for seq_id in seq_ids:
|
for seq_id in seq_ids:
|
||||||
seq_data = input_metadata.seq_data[seq_id]
|
seq_data = sampling_metadata.seq_data[seq_id]
|
||||||
prompt_tokens.append(seq_data.prompt_token_ids)
|
prompt_tokens.append(seq_data.prompt_token_ids)
|
||||||
output_tokens.append(seq_data.output_token_ids)
|
output_tokens.append(seq_data.output_token_ids)
|
||||||
return prompt_tokens, output_tokens
|
return prompt_tokens, output_tokens
|
||||||
@ -191,17 +193,19 @@ def _get_bin_counts_and_mask(
|
|||||||
return bin_counts, mask
|
return bin_counts, mask
|
||||||
|
|
||||||
|
|
||||||
def _apply_logits_processors(logits: torch.Tensor,
|
def _apply_logits_processors(
|
||||||
input_metadata: InputMetadata) -> torch.Tensor:
|
logits: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
logits_row_idx = 0
|
logits_row_idx = 0
|
||||||
found_logits_processors = False
|
found_logits_processors = False
|
||||||
for seq_ids, sampling_params in input_metadata.seq_groups:
|
for seq_ids, sampling_params in sampling_metadata.seq_groups:
|
||||||
logits_processors = sampling_params.logits_processors
|
logits_processors = sampling_params.logits_processors
|
||||||
if logits_processors:
|
if logits_processors:
|
||||||
found_logits_processors = True
|
found_logits_processors = True
|
||||||
for seq_id in seq_ids:
|
for seq_id in seq_ids:
|
||||||
logits_row = logits[logits_row_idx]
|
logits_row = logits[logits_row_idx]
|
||||||
token_ids = input_metadata.seq_data[seq_id].output_token_ids
|
token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
|
||||||
for logits_processor in logits_processors:
|
for logits_processor in logits_processors:
|
||||||
logits_row = logits_processor(token_ids, logits_row)
|
logits_row = logits_processor(token_ids, logits_row)
|
||||||
logits[logits_row_idx] = logits_row
|
logits[logits_row_idx] = logits_row
|
||||||
@ -215,7 +219,7 @@ def _apply_logits_processors(logits: torch.Tensor,
|
|||||||
|
|
||||||
def _apply_penalties(
|
def _apply_penalties(
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
input_metadata: InputMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
presence_penalties: List[float],
|
presence_penalties: List[float],
|
||||||
frequency_penalties: List[float],
|
frequency_penalties: List[float],
|
||||||
repetition_penalties: List[float],
|
repetition_penalties: List[float],
|
||||||
@ -234,7 +238,7 @@ def _apply_penalties(
|
|||||||
return logits
|
return logits
|
||||||
|
|
||||||
prompt_tokens, output_tokens = (
|
prompt_tokens, output_tokens = (
|
||||||
_get_prompt_and_output_tokens(input_metadata))
|
_get_prompt_and_output_tokens(sampling_metadata))
|
||||||
assert len(prompt_tokens) == logits.shape[0]
|
assert len(prompt_tokens) == logits.shape[0]
|
||||||
assert len(output_tokens) == logits.shape[0]
|
assert len(output_tokens) == logits.shape[0]
|
||||||
|
|
||||||
@ -265,10 +269,10 @@ def _apply_penalties(
|
|||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
|
def _get_temperatures(sampling_metadata: SamplingMetadata) -> List[float]:
|
||||||
# Collect the temperatures for the logits.
|
# Collect the temperatures for the logits.
|
||||||
temperatures: List[float] = []
|
temperatures: List[float] = []
|
||||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||||
seq_ids, sampling_params = seq_group
|
seq_ids, sampling_params = seq_group
|
||||||
temperature = sampling_params.temperature
|
temperature = sampling_params.temperature
|
||||||
if temperature < _SAMPLING_EPS:
|
if temperature < _SAMPLING_EPS:
|
||||||
@ -276,22 +280,22 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
|
|||||||
# (i.e., greedy sampling or beam search).
|
# (i.e., greedy sampling or beam search).
|
||||||
# Set the temperature to 1 to avoid division by zero.
|
# Set the temperature to 1 to avoid division by zero.
|
||||||
temperature = 1.0
|
temperature = 1.0
|
||||||
if (i < input_metadata.num_prompts
|
if (i < sampling_metadata.num_prompts
|
||||||
and sampling_params.prompt_logprobs is not None):
|
and sampling_params.prompt_logprobs is not None):
|
||||||
prompt_len = input_metadata.prompt_lens[i]
|
prompt_len = sampling_metadata.prompt_lens[i]
|
||||||
temperatures += [temperature] * (prompt_len - 1)
|
temperatures += [temperature] * (prompt_len - 1)
|
||||||
temperatures += [temperature] * len(seq_ids)
|
temperatures += [temperature] * len(seq_ids)
|
||||||
return temperatures
|
return temperatures
|
||||||
|
|
||||||
|
|
||||||
def _get_top_p_top_k_min_p(
|
def _get_top_p_top_k_min_p(
|
||||||
input_metadata: InputMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
) -> Tuple[List[float], List[int], List[float]]:
|
) -> Tuple[List[float], List[int], List[float]]:
|
||||||
top_ps: List[float] = []
|
top_ps: List[float] = []
|
||||||
top_ks: List[int] = []
|
top_ks: List[int] = []
|
||||||
min_ps: List[float] = []
|
min_ps: List[float] = []
|
||||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||||
seq_ids, sampling_params = seq_group
|
seq_ids, sampling_params = seq_group
|
||||||
top_p = sampling_params.top_p
|
top_p = sampling_params.top_p
|
||||||
min_p = sampling_params.min_p
|
min_p = sampling_params.min_p
|
||||||
@ -299,9 +303,9 @@ def _get_top_p_top_k_min_p(
|
|||||||
top_k = min(sampling_params.top_k, vocab_size)
|
top_k = min(sampling_params.top_k, vocab_size)
|
||||||
# k=-1 means no truncation.
|
# k=-1 means no truncation.
|
||||||
top_k = vocab_size if top_k == -1 else top_k
|
top_k = vocab_size if top_k == -1 else top_k
|
||||||
if (i < input_metadata.num_prompts
|
if (i < sampling_metadata.num_prompts
|
||||||
and sampling_params.prompt_logprobs is not None):
|
and sampling_params.prompt_logprobs is not None):
|
||||||
prompt_len = input_metadata.prompt_lens[i]
|
prompt_len = sampling_metadata.prompt_lens[i]
|
||||||
top_ps += [top_p] * (prompt_len - 1)
|
top_ps += [top_p] * (prompt_len - 1)
|
||||||
top_ks += [top_k] * (prompt_len - 1)
|
top_ks += [top_k] * (prompt_len - 1)
|
||||||
min_ps += [min_p] * (prompt_len - 1)
|
min_ps += [min_p] * (prompt_len - 1)
|
||||||
@ -471,11 +475,11 @@ def _beam_search_sample(
|
|||||||
def _sample(
|
def _sample(
|
||||||
probs: torch.Tensor,
|
probs: torch.Tensor,
|
||||||
logprobs: torch.Tensor,
|
logprobs: torch.Tensor,
|
||||||
input_metadata: InputMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> List[Tuple[List[int], List[int]]]:
|
) -> List[Tuple[List[int], List[int]]]:
|
||||||
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
||||||
categorized_sample_indices = input_metadata.categorized_sample_indices
|
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
||||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||||
_, sampling_params = seq_group
|
_, sampling_params = seq_group
|
||||||
sampling_type = sampling_params.sampling_type
|
sampling_type = sampling_params.sampling_type
|
||||||
categorized_seq_group_ids[sampling_type].append(i)
|
categorized_seq_group_ids[sampling_type].append(i)
|
||||||
@ -483,8 +487,8 @@ def _sample(
|
|||||||
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
|
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
|
||||||
for sampling_type in SamplingType:
|
for sampling_type in SamplingType:
|
||||||
seq_group_ids = categorized_seq_group_ids[sampling_type]
|
seq_group_ids = categorized_seq_group_ids[sampling_type]
|
||||||
seq_groups = [input_metadata.seq_groups[i] for i in seq_group_ids]
|
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
|
||||||
is_prompts = [i < input_metadata.num_prompts for i in seq_group_ids]
|
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
|
||||||
sample_indices = categorized_sample_indices[sampling_type]
|
sample_indices = categorized_sample_indices[sampling_type]
|
||||||
num_tokens = len(sample_indices)
|
num_tokens = len(sample_indices)
|
||||||
if num_tokens == 0:
|
if num_tokens == 0:
|
||||||
@ -499,21 +503,22 @@ def _sample(
|
|||||||
elif sampling_type == SamplingType.BEAM:
|
elif sampling_type == SamplingType.BEAM:
|
||||||
category_logprobs = logprobs[sample_indices]
|
category_logprobs = logprobs[sample_indices]
|
||||||
sample_results = _beam_search_sample(seq_groups, is_prompts,
|
sample_results = _beam_search_sample(seq_groups, is_prompts,
|
||||||
input_metadata.seq_data,
|
sampling_metadata.seq_data,
|
||||||
category_logprobs)
|
category_logprobs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported sampling type: {sampling_type}")
|
raise ValueError(f"Unsupported sampling type: {sampling_type}")
|
||||||
sample_results_dict.update(zip(seq_group_ids, sample_results))
|
sample_results_dict.update(zip(seq_group_ids, sample_results))
|
||||||
|
|
||||||
sample_results = [
|
sample_results = [
|
||||||
sample_results_dict[i] for i in range(len(input_metadata.seq_groups))
|
sample_results_dict[i]
|
||||||
|
for i in range(len(sampling_metadata.seq_groups))
|
||||||
]
|
]
|
||||||
return sample_results
|
return sample_results
|
||||||
|
|
||||||
|
|
||||||
def _get_logprobs(
|
def _get_logprobs(
|
||||||
logprobs: torch.Tensor,
|
logprobs: torch.Tensor,
|
||||||
input_metadata: InputMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
sample_results: List[Tuple[List[int], List[int]]],
|
sample_results: List[Tuple[List[int], List[int]]],
|
||||||
) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[
|
) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[
|
||||||
int, float]]]]:
|
int, float]]]]:
|
||||||
@ -523,16 +528,16 @@ def _get_logprobs(
|
|||||||
largest_num_logprobs = 0
|
largest_num_logprobs = 0
|
||||||
sample_idx = 0
|
sample_idx = 0
|
||||||
for i, (seq_group, sample_result) in enumerate(
|
for i, (seq_group, sample_result) in enumerate(
|
||||||
zip(input_metadata.seq_groups, sample_results)):
|
zip(sampling_metadata.seq_groups, sample_results)):
|
||||||
seq_ids, sampling_params = seq_group
|
seq_ids, sampling_params = seq_group
|
||||||
next_token_ids, parent_ids = sample_result
|
next_token_ids, parent_ids = sample_result
|
||||||
num_parent_seqs = len(seq_ids)
|
num_parent_seqs = len(seq_ids)
|
||||||
if (i < input_metadata.num_prompts
|
if (i < sampling_metadata.num_prompts
|
||||||
and sampling_params.prompt_logprobs is not None):
|
and sampling_params.prompt_logprobs is not None):
|
||||||
largest_num_logprobs = max(largest_num_logprobs,
|
largest_num_logprobs = max(largest_num_logprobs,
|
||||||
sampling_params.prompt_logprobs)
|
sampling_params.prompt_logprobs)
|
||||||
prompt_len = input_metadata.prompt_lens[i]
|
prompt_len = sampling_metadata.prompt_lens[i]
|
||||||
prompt_tokens = input_metadata.seq_data[
|
prompt_tokens = sampling_metadata.seq_data[
|
||||||
seq_ids[0]].prompt_token_ids
|
seq_ids[0]].prompt_token_ids
|
||||||
batched_logprobs_query_seq_indices.extend(
|
batched_logprobs_query_seq_indices.extend(
|
||||||
sample_idx + j for j in range(prompt_len - 1))
|
sample_idx + j for j in range(prompt_len - 1))
|
||||||
@ -570,16 +575,16 @@ def _get_logprobs(
|
|||||||
sample_idx = 0
|
sample_idx = 0
|
||||||
query_result_idx = 0
|
query_result_idx = 0
|
||||||
for i, (seq_group, sample_result) in enumerate(
|
for i, (seq_group, sample_result) in enumerate(
|
||||||
zip(input_metadata.seq_groups, sample_results)):
|
zip(sampling_metadata.seq_groups, sample_results)):
|
||||||
seq_ids, sampling_params = seq_group
|
seq_ids, sampling_params = seq_group
|
||||||
next_token_ids, parent_ids = sample_result
|
next_token_ids, parent_ids = sample_result
|
||||||
|
|
||||||
# Prompt logprobs
|
# Prompt logprobs
|
||||||
if (i < input_metadata.num_prompts
|
if (i < sampling_metadata.num_prompts
|
||||||
and sampling_params.prompt_logprobs is not None):
|
and sampling_params.prompt_logprobs is not None):
|
||||||
num_logprobs = sampling_params.prompt_logprobs
|
num_logprobs = sampling_params.prompt_logprobs
|
||||||
prompt_len = input_metadata.prompt_lens[i]
|
prompt_len = sampling_metadata.prompt_lens[i]
|
||||||
prompt_tokens = input_metadata.seq_data[
|
prompt_tokens = sampling_metadata.seq_data[
|
||||||
seq_ids[0]].prompt_token_ids
|
seq_ids[0]].prompt_token_ids
|
||||||
group_prompt_logprobs: PromptLogprobs = [None]
|
group_prompt_logprobs: PromptLogprobs = [None]
|
||||||
for token_id in prompt_tokens[1:]:
|
for token_id in prompt_tokens[1:]:
|
||||||
@ -625,13 +630,13 @@ def _get_logprobs(
|
|||||||
|
|
||||||
def _build_sampler_output(
|
def _build_sampler_output(
|
||||||
sample_results: List[Tuple[List[int], List[int]]],
|
sample_results: List[Tuple[List[int], List[int]]],
|
||||||
input_metadata: InputMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
prompt_logprobs: List[Optional[PromptLogprobs]],
|
prompt_logprobs: List[Optional[PromptLogprobs]],
|
||||||
sample_logprobs: List[SampleLogprobs],
|
sample_logprobs: List[SampleLogprobs],
|
||||||
) -> SamplerOutput:
|
) -> SamplerOutput:
|
||||||
sampler_output = []
|
sampler_output = []
|
||||||
for (seq_group, sample_result, group_prompt_logprobs,
|
for (seq_group, sample_result, group_prompt_logprobs,
|
||||||
group_sample_logprobs) in zip(input_metadata.seq_groups,
|
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
|
||||||
sample_results, prompt_logprobs,
|
sample_results, prompt_logprobs,
|
||||||
sample_logprobs):
|
sample_logprobs):
|
||||||
seq_ids, _ = seq_group
|
seq_ids, _ = seq_group
|
||||||
|
|||||||
@ -39,6 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
VocabParallelEmbedding, ParallelLMHead)
|
VocabParallelEmbedding, ParallelLMHead)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||||
hf_model_weights_iterator)
|
hf_model_weights_iterator)
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
@ -296,11 +297,18 @@ class AquilaForCausalLM(nn.Module):
|
|||||||
kv_caches: List[KVCache],
|
kv_caches: List[KVCache],
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
) -> SamplerOutput:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||||
input_metadata, cache_events)
|
input_metadata, cache_events)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> SamplerOutput:
|
||||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||||
input_metadata)
|
sampling_metadata)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -38,6 +38,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
VocabParallelEmbedding, ParallelLMHead)
|
VocabParallelEmbedding, ParallelLMHead)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||||
hf_model_weights_iterator)
|
hf_model_weights_iterator)
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
@ -311,11 +312,18 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
|||||||
kv_caches: List[KVCache],
|
kv_caches: List[KVCache],
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
) -> SamplerOutput:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||||
input_metadata, cache_events)
|
input_metadata, cache_events)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> SamplerOutput:
|
||||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||||
input_metadata)
|
sampling_metadata)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
VocabParallelEmbedding)
|
VocabParallelEmbedding)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||||
hf_model_weights_iterator)
|
hf_model_weights_iterator)
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
@ -288,11 +289,18 @@ class BloomForCausalLM(nn.Module):
|
|||||||
kv_caches: List[KVCache],
|
kv_caches: List[KVCache],
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
) -> SamplerOutput:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||||
input_metadata, cache_events)
|
input_metadata, cache_events)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> SamplerOutput:
|
||||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||||
input_metadata)
|
sampling_metadata)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -22,6 +22,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
VocabParallelEmbedding, ParallelLMHead)
|
VocabParallelEmbedding, ParallelLMHead)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||||
hf_model_weights_iterator)
|
hf_model_weights_iterator)
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
@ -350,11 +351,18 @@ class ChatGLMForCausalLM(nn.Module):
|
|||||||
kv_caches: List[KVCache],
|
kv_caches: List[KVCache],
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
) -> SamplerOutput:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||||
input_metadata, cache_events)
|
input_metadata, cache_events)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> SamplerOutput:
|
||||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||||
input_metadata)
|
sampling_metadata)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -41,6 +41,7 @@ from vllm.model_executor.parallel_utils.communication_op import (
|
|||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||||
hf_model_weights_iterator)
|
hf_model_weights_iterator)
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
@ -389,7 +390,7 @@ class FalconForCausalLM(nn.Module):
|
|||||||
kv_caches: List[KVCache],
|
kv_caches: List[KVCache],
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
) -> SamplerOutput:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.transformer(
|
hidden_states = self.transformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
positions,
|
positions,
|
||||||
@ -397,9 +398,15 @@ class FalconForCausalLM(nn.Module):
|
|||||||
input_metadata,
|
input_metadata,
|
||||||
cache_events,
|
cache_events,
|
||||||
)
|
)
|
||||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
return hidden_states
|
||||||
input_metadata)
|
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> SamplerOutput:
|
||||||
|
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
VocabParallelEmbedding)
|
VocabParallelEmbedding)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||||
hf_model_weights_iterator)
|
hf_model_weights_iterator)
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
@ -232,11 +233,18 @@ class GPT2LMHeadModel(nn.Module):
|
|||||||
kv_caches: List[KVCache],
|
kv_caches: List[KVCache],
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
) -> SamplerOutput:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||||
input_metadata, cache_events)
|
input_metadata, cache_events)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> SamplerOutput:
|
||||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||||
input_metadata)
|
sampling_metadata)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -36,6 +36,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
VocabParallelEmbedding)
|
VocabParallelEmbedding)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||||
hf_model_weights_iterator)
|
hf_model_weights_iterator)
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
@ -251,11 +252,18 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|||||||
kv_caches: List[KVCache],
|
kv_caches: List[KVCache],
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
) -> SamplerOutput:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||||
input_metadata, cache_events)
|
input_metadata, cache_events)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> SamplerOutput:
|
||||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||||
input_metadata)
|
sampling_metadata)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
VocabParallelEmbedding, ParallelLMHead)
|
VocabParallelEmbedding, ParallelLMHead)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||||
hf_model_weights_iterator)
|
hf_model_weights_iterator)
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
@ -238,11 +239,18 @@ class GPTJForCausalLM(nn.Module):
|
|||||||
kv_caches: List[KVCache],
|
kv_caches: List[KVCache],
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
) -> SamplerOutput:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||||
input_metadata, cache_events)
|
input_metadata, cache_events)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> SamplerOutput:
|
||||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||||
input_metadata, self.lm_head.bias)
|
sampling_metadata, self.lm_head.bias)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
VocabParallelEmbedding, ParallelLMHead)
|
VocabParallelEmbedding, ParallelLMHead)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||||
hf_model_weights_iterator)
|
hf_model_weights_iterator)
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
@ -251,11 +252,18 @@ class GPTNeoXForCausalLM(nn.Module):
|
|||||||
kv_caches: List[KVCache],
|
kv_caches: List[KVCache],
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
) -> SamplerOutput:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
|
hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
|
||||||
input_metadata, cache_events)
|
input_metadata, cache_events)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> SamplerOutput:
|
||||||
next_tokens = self.sampler(self.embed_out.weight, hidden_states,
|
next_tokens = self.sampler(self.embed_out.weight, hidden_states,
|
||||||
input_metadata)
|
sampling_metadata)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -19,6 +19,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
VocabParallelEmbedding, ParallelLMHead)
|
VocabParallelEmbedding, ParallelLMHead)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||||
hf_model_weights_iterator)
|
hf_model_weights_iterator)
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
@ -250,11 +251,18 @@ class InternLMForCausalLM(nn.Module):
|
|||||||
kv_caches: List[KVCache],
|
kv_caches: List[KVCache],
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
) -> SamplerOutput:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||||
input_metadata, cache_events)
|
input_metadata, cache_events)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> SamplerOutput:
|
||||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||||
input_metadata)
|
sampling_metadata)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -41,6 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
VocabParallelEmbedding, ParallelLMHead)
|
VocabParallelEmbedding, ParallelLMHead)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||||
hf_model_weights_iterator)
|
hf_model_weights_iterator)
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
@ -289,11 +290,18 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
kv_caches: List[KVCache],
|
kv_caches: List[KVCache],
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
) -> SamplerOutput:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||||
input_metadata, cache_events)
|
input_metadata, cache_events)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> SamplerOutput:
|
||||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||||
input_metadata)
|
sampling_metadata)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -41,6 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
VocabParallelEmbedding, ParallelLMHead)
|
VocabParallelEmbedding, ParallelLMHead)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||||
hf_model_weights_iterator)
|
hf_model_weights_iterator)
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
@ -285,11 +286,18 @@ class MistralForCausalLM(nn.Module):
|
|||||||
kv_caches: List[KVCache],
|
kv_caches: List[KVCache],
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
) -> SamplerOutput:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||||
input_metadata, cache_events)
|
input_metadata, cache_events)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> SamplerOutput:
|
||||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||||
input_metadata)
|
sampling_metadata)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -18,6 +18,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
VocabParallelEmbedding)
|
VocabParallelEmbedding)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||||
hf_model_weights_iterator)
|
hf_model_weights_iterator)
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
@ -256,11 +257,18 @@ class MPTForCausalLM(nn.Module):
|
|||||||
kv_caches: List[KVCache],
|
kv_caches: List[KVCache],
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
) -> SamplerOutput:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||||
input_metadata, cache_events)
|
input_metadata, cache_events)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> SamplerOutput:
|
||||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||||
input_metadata)
|
sampling_metadata)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -36,6 +36,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
VocabParallelEmbedding)
|
VocabParallelEmbedding)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||||
hf_model_weights_iterator)
|
hf_model_weights_iterator)
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
@ -308,11 +309,18 @@ class OPTForCausalLM(nn.Module):
|
|||||||
kv_caches: List[KVCache],
|
kv_caches: List[KVCache],
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
) -> SamplerOutput:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||||
input_metadata, cache_events)
|
input_metadata, cache_events)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> SamplerOutput:
|
||||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||||
input_metadata)
|
sampling_metadata)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -54,6 +54,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
VocabParallelEmbedding, ParallelLMHead)
|
VocabParallelEmbedding, ParallelLMHead)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||||
hf_model_weights_iterator)
|
hf_model_weights_iterator)
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
@ -210,28 +211,6 @@ class PhiLayer(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class PhiCausalLMHead(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, config: PretrainedConfig):
|
|
||||||
super().__init__()
|
|
||||||
self.ln = nn.LayerNorm(config.hidden_size,
|
|
||||||
eps=config.layer_norm_epsilon)
|
|
||||||
self.linear = ParallelLMHead(config.vocab_size,
|
|
||||||
config.hidden_size,
|
|
||||||
bias=True)
|
|
||||||
self.sampler = Sampler(config.vocab_size)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
input_metadata: InputMetadata,
|
|
||||||
):
|
|
||||||
hidden_states = self.ln(hidden_states)
|
|
||||||
next_tokens = self.sampler(self.linear.weight, hidden_states,
|
|
||||||
input_metadata, self.linear.bias)
|
|
||||||
return next_tokens
|
|
||||||
|
|
||||||
|
|
||||||
class PhiModel(nn.Module):
|
class PhiModel(nn.Module):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -253,7 +232,7 @@ class PhiModel(nn.Module):
|
|||||||
kv_caches: List[KVCache],
|
kv_caches: List[KVCache],
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
) -> SamplerOutput:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.embd(input_ids)
|
hidden_states = self.embd(input_ids)
|
||||||
for i in range(self.config.num_hidden_layers):
|
for i in range(self.config.num_hidden_layers):
|
||||||
cache_event = None if cache_events is None else cache_events[i]
|
cache_event = None if cache_events is None else cache_events[i]
|
||||||
@ -268,6 +247,17 @@ class PhiModel(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class PhiCausalLMHead(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: PretrainedConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.ln = nn.LayerNorm(config.hidden_size,
|
||||||
|
eps=config.layer_norm_epsilon)
|
||||||
|
self.linear = ParallelLMHead(config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
bias=True)
|
||||||
|
|
||||||
|
|
||||||
class PhiForCausalLM(nn.Module):
|
class PhiForCausalLM(nn.Module):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -279,6 +269,7 @@ class PhiForCausalLM(nn.Module):
|
|||||||
|
|
||||||
self.transformer = PhiModel(config, linear_method)
|
self.transformer = PhiModel(config, linear_method)
|
||||||
self.lm_head = PhiCausalLMHead(config)
|
self.lm_head = PhiCausalLMHead(config)
|
||||||
|
self.sampler = Sampler(config.vocab_size)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -287,11 +278,21 @@ class PhiForCausalLM(nn.Module):
|
|||||||
kv_caches: List[KVCache],
|
kv_caches: List[KVCache],
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
) -> SamplerOutput:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||||
input_metadata, cache_events)
|
input_metadata, cache_events)
|
||||||
lm_logits = self.lm_head(hidden_states, input_metadata)
|
hidden_states = self.lm_head.ln(hidden_states)
|
||||||
return lm_logits
|
return hidden_states
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> SamplerOutput:
|
||||||
|
head = self.lm_head.linear
|
||||||
|
next_tokens = self.sampler(head.weight, hidden_states,
|
||||||
|
sampling_metadata, head.bias)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
|
|||||||
@ -23,6 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
VocabParallelEmbedding, ParallelLMHead)
|
VocabParallelEmbedding, ParallelLMHead)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||||
hf_model_weights_iterator)
|
hf_model_weights_iterator)
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
@ -246,11 +247,18 @@ class QWenLMHeadModel(nn.Module):
|
|||||||
kv_caches: List[KVCache],
|
kv_caches: List[KVCache],
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
) -> SamplerOutput:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||||
input_metadata, cache_events)
|
input_metadata, cache_events)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> SamplerOutput:
|
||||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||||
input_metadata)
|
sampling_metadata)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
@ -41,6 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
VocabParallelEmbedding, ParallelLMHead)
|
VocabParallelEmbedding, ParallelLMHead)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||||
hf_model_weights_iterator)
|
hf_model_weights_iterator)
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
@ -284,11 +285,18 @@ class YiForCausalLM(nn.Module):
|
|||||||
kv_caches: List[KVCache],
|
kv_caches: List[KVCache],
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
) -> SamplerOutput:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||||
input_metadata, cache_events)
|
input_metadata, cache_events)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> SamplerOutput:
|
||||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||||
input_metadata)
|
sampling_metadata)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
|
|||||||
43
vllm/model_executor/sampling_metadata.py
Normal file
43
vllm/model_executor/sampling_metadata.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
|
from vllm.sequence import SequenceData
|
||||||
|
|
||||||
|
|
||||||
|
class SamplingMetadata:
|
||||||
|
"""Metadata for input sequences. Used in sampler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq_groups: List of (seq_ids, sampling_params).
|
||||||
|
seq_data: Seq_id -> SequenceData.
|
||||||
|
prompt_lens: Lengths of prompts.
|
||||||
|
selected_token_indices: Token indices selected for sampling.
|
||||||
|
categorized_sample_indices: SamplingType -> token indicies to sample.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
seq_groups: List[Tuple[List[int], SamplingParams]],
|
||||||
|
seq_data: Dict[int, SequenceData],
|
||||||
|
prompt_lens: List[int],
|
||||||
|
selected_token_indices: torch.Tensor,
|
||||||
|
categorized_sample_indices: Dict[SamplingType, torch.Tensor],
|
||||||
|
) -> None:
|
||||||
|
self.seq_groups = seq_groups
|
||||||
|
self.seq_data = seq_data
|
||||||
|
self.prompt_lens = prompt_lens
|
||||||
|
self.selected_token_indices = selected_token_indices
|
||||||
|
self.categorized_sample_indices = categorized_sample_indices
|
||||||
|
|
||||||
|
self.num_prompts = len(prompt_lens)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
"SamplingMetadata("
|
||||||
|
f"seq_groups={self.seq_groups}, "
|
||||||
|
f"seq_data={self.seq_data}, "
|
||||||
|
f"prompt_lens={self.prompt_lens}, "
|
||||||
|
f"selected_token_indices={self.selected_token_indices}, "
|
||||||
|
f"categorized_sample_indices={self.categorized_sample_indices})")
|
||||||
334
vllm/worker/model_runner.py
Normal file
334
vllm/worker/model_runner.py
Normal file
@ -0,0 +1,334 @@
|
|||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
|
||||||
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
|
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
_PAD_SLOT_ID = -1
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRunner:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
parallel_config: ParallelConfig,
|
||||||
|
scheduler_config: SchedulerConfig,
|
||||||
|
):
|
||||||
|
self.model_config = model_config
|
||||||
|
self.parallel_config = parallel_config
|
||||||
|
self.scheduler_config = scheduler_config
|
||||||
|
|
||||||
|
self.sliding_window = model_config.get_sliding_window()
|
||||||
|
self.model = None
|
||||||
|
self.block_size = None # Set after initial profiling.
|
||||||
|
|
||||||
|
def load_model(self) -> None:
|
||||||
|
self.model = get_model(self.model_config)
|
||||||
|
|
||||||
|
def set_block_size(self, block_size: int) -> None:
|
||||||
|
self.block_size = block_size
|
||||||
|
|
||||||
|
def _prepare_prompt(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
|
||||||
|
assert len(seq_group_metadata_list) > 0
|
||||||
|
input_tokens: List[List[int]] = []
|
||||||
|
input_positions: List[List[int]] = []
|
||||||
|
slot_mapping: List[List[int]] = []
|
||||||
|
|
||||||
|
prompt_lens: List[int] = []
|
||||||
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
|
assert seq_group_metadata.is_prompt
|
||||||
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||||
|
assert len(seq_ids) == 1
|
||||||
|
seq_id = seq_ids[0]
|
||||||
|
|
||||||
|
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||||
|
prompt_tokens = seq_data.get_token_ids()
|
||||||
|
prompt_len = len(prompt_tokens)
|
||||||
|
prompt_lens.append(prompt_len)
|
||||||
|
|
||||||
|
input_tokens.append(prompt_tokens)
|
||||||
|
# NOTE(woosuk): Here we assume that the first token in the prompt
|
||||||
|
# is always the first token in the sequence.
|
||||||
|
input_positions.append(list(range(prompt_len)))
|
||||||
|
|
||||||
|
if seq_group_metadata.block_tables is None:
|
||||||
|
# During memory profiling, the block tables are not initialized
|
||||||
|
# yet. In this case, we just use a dummy slot mapping.
|
||||||
|
slot_mapping.append([_PAD_SLOT_ID] * prompt_len)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Compute the slot mapping.
|
||||||
|
slot_mapping.append([])
|
||||||
|
block_table = seq_group_metadata.block_tables[seq_id]
|
||||||
|
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
|
||||||
|
# where start_idx is max(0, prompt_len - sliding_window).
|
||||||
|
# For example, if the prompt len is 10, sliding window is 8, and
|
||||||
|
# block size is 4, the first two tokens are masked and the slot
|
||||||
|
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
|
||||||
|
start_idx = 0
|
||||||
|
if self.sliding_window is not None:
|
||||||
|
start_idx = max(0, prompt_len - self.sliding_window)
|
||||||
|
for i in range(prompt_len):
|
||||||
|
if i < start_idx:
|
||||||
|
slot_mapping[-1].append(_PAD_SLOT_ID)
|
||||||
|
continue
|
||||||
|
|
||||||
|
block_number = block_table[i // self.block_size]
|
||||||
|
block_offset = i % self.block_size
|
||||||
|
slot = block_number * self.block_size + block_offset
|
||||||
|
slot_mapping[-1].append(slot)
|
||||||
|
|
||||||
|
max_prompt_len = max(prompt_lens)
|
||||||
|
input_tokens = _make_tensor_with_pad(input_tokens,
|
||||||
|
max_prompt_len,
|
||||||
|
pad=0,
|
||||||
|
dtype=torch.long)
|
||||||
|
input_positions = _make_tensor_with_pad(input_positions,
|
||||||
|
max_prompt_len,
|
||||||
|
pad=0,
|
||||||
|
dtype=torch.long)
|
||||||
|
slot_mapping = _make_tensor_with_pad(slot_mapping,
|
||||||
|
max_prompt_len,
|
||||||
|
pad=_PAD_SLOT_ID,
|
||||||
|
dtype=torch.long)
|
||||||
|
|
||||||
|
input_metadata = InputMetadata(
|
||||||
|
prompt_lens=prompt_lens,
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
max_context_len=None,
|
||||||
|
context_lens=None,
|
||||||
|
block_tables=None,
|
||||||
|
)
|
||||||
|
return input_tokens, input_positions, input_metadata
|
||||||
|
|
||||||
|
def _prepare_decode(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
|
||||||
|
assert len(seq_group_metadata_list) > 0
|
||||||
|
input_tokens: List[List[int]] = []
|
||||||
|
input_positions: List[List[int]] = []
|
||||||
|
slot_mapping: List[List[int]] = []
|
||||||
|
context_lens: List[int] = []
|
||||||
|
block_tables: List[List[int]] = []
|
||||||
|
|
||||||
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
|
assert not seq_group_metadata.is_prompt
|
||||||
|
|
||||||
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||||
|
for seq_id in seq_ids:
|
||||||
|
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||||
|
generation_token = seq_data.get_last_token_id()
|
||||||
|
input_tokens.append([generation_token])
|
||||||
|
|
||||||
|
context_len = seq_data.get_len()
|
||||||
|
if self.sliding_window is not None:
|
||||||
|
context_len = min(context_len, self.sliding_window)
|
||||||
|
context_lens.append(context_len)
|
||||||
|
|
||||||
|
position = context_len - 1
|
||||||
|
input_positions.append([position])
|
||||||
|
|
||||||
|
block_table = seq_group_metadata.block_tables[seq_id]
|
||||||
|
block_number = block_table[position // self.block_size]
|
||||||
|
block_offset = position % self.block_size
|
||||||
|
slot = block_number * self.block_size + block_offset
|
||||||
|
slot_mapping.append([slot])
|
||||||
|
|
||||||
|
if self.sliding_window is not None:
|
||||||
|
sliding_window_blocks = (self.sliding_window //
|
||||||
|
self.block_size)
|
||||||
|
block_table = block_table[-sliding_window_blocks:]
|
||||||
|
block_tables.append(block_table)
|
||||||
|
|
||||||
|
input_tokens = _make_tensor_with_pad(input_tokens,
|
||||||
|
max_len=1,
|
||||||
|
pad=0,
|
||||||
|
dtype=torch.long)
|
||||||
|
input_positions = _make_tensor_with_pad(input_positions,
|
||||||
|
max_len=1,
|
||||||
|
pad=0,
|
||||||
|
dtype=torch.long)
|
||||||
|
slot_mapping = _make_tensor_with_pad(slot_mapping,
|
||||||
|
max_len=1,
|
||||||
|
pad=_PAD_SLOT_ID,
|
||||||
|
dtype=torch.long)
|
||||||
|
max_context_len = max(context_lens)
|
||||||
|
context_lens = torch.tensor(context_lens,
|
||||||
|
dtype=torch.int,
|
||||||
|
device="cuda")
|
||||||
|
max_block_table_len = max([len(t) for t in block_tables])
|
||||||
|
block_tables = _make_tensor_with_pad(block_tables,
|
||||||
|
max_len=max_block_table_len,
|
||||||
|
pad=0,
|
||||||
|
dtype=torch.int)
|
||||||
|
|
||||||
|
input_metadata = InputMetadata(
|
||||||
|
prompt_lens=[],
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
max_context_len=max_context_len,
|
||||||
|
context_lens=context_lens,
|
||||||
|
block_tables=block_tables,
|
||||||
|
)
|
||||||
|
return input_tokens, input_positions, input_metadata
|
||||||
|
|
||||||
|
def _prepare_sample(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
prompt_lens: List[int],
|
||||||
|
) -> SamplingMetadata:
|
||||||
|
seq_groups: List[Tuple[List[int], SamplingParams]] = []
|
||||||
|
selected_token_indices: List[int] = []
|
||||||
|
selected_token_start_idx = 0
|
||||||
|
categorized_sample_indices = {t: [] for t in SamplingType}
|
||||||
|
categorized_sample_indices_start_idx = 0
|
||||||
|
|
||||||
|
max_prompt_len = max(prompt_lens) if prompt_lens else 1
|
||||||
|
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||||
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||||
|
sampling_params = seq_group_metadata.sampling_params
|
||||||
|
seq_groups.append((seq_ids, sampling_params))
|
||||||
|
|
||||||
|
if seq_group_metadata.is_prompt:
|
||||||
|
assert len(seq_ids) == 1
|
||||||
|
prompt_len = prompt_lens[i]
|
||||||
|
if sampling_params.prompt_logprobs is not None:
|
||||||
|
# NOTE: prompt token positions do not need sample, skip
|
||||||
|
categorized_sample_indices_start_idx += prompt_len - 1
|
||||||
|
|
||||||
|
categorized_sample_indices[
|
||||||
|
sampling_params.sampling_type].append(
|
||||||
|
categorized_sample_indices_start_idx)
|
||||||
|
categorized_sample_indices_start_idx += 1
|
||||||
|
|
||||||
|
if sampling_params.prompt_logprobs is not None:
|
||||||
|
selected_token_indices.extend(
|
||||||
|
range(selected_token_start_idx,
|
||||||
|
selected_token_start_idx + prompt_len - 1))
|
||||||
|
selected_token_indices.append(selected_token_start_idx +
|
||||||
|
prompt_len - 1)
|
||||||
|
selected_token_start_idx += max_prompt_len
|
||||||
|
else:
|
||||||
|
num_seqs = len(seq_ids)
|
||||||
|
selected_token_indices.extend(
|
||||||
|
range(selected_token_start_idx,
|
||||||
|
selected_token_start_idx + num_seqs))
|
||||||
|
selected_token_start_idx += num_seqs
|
||||||
|
|
||||||
|
categorized_sample_indices[
|
||||||
|
sampling_params.sampling_type].extend(
|
||||||
|
range(categorized_sample_indices_start_idx,
|
||||||
|
categorized_sample_indices_start_idx + num_seqs))
|
||||||
|
categorized_sample_indices_start_idx += num_seqs
|
||||||
|
|
||||||
|
selected_token_indices = torch.tensor(selected_token_indices,
|
||||||
|
dtype=torch.long,
|
||||||
|
device="cuda")
|
||||||
|
categorized_sample_indices = {
|
||||||
|
t: torch.tensor(seq_ids, dtype=torch.int, device="cuda")
|
||||||
|
for t, seq_ids in categorized_sample_indices.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
seq_data: Dict[int, SequenceData] = {}
|
||||||
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
|
seq_data.update(seq_group_metadata.seq_data)
|
||||||
|
|
||||||
|
sampling_metadata = SamplingMetadata(
|
||||||
|
seq_groups=seq_groups,
|
||||||
|
seq_data=seq_data,
|
||||||
|
prompt_lens=prompt_lens,
|
||||||
|
selected_token_indices=selected_token_indices,
|
||||||
|
categorized_sample_indices=categorized_sample_indices,
|
||||||
|
)
|
||||||
|
return sampling_metadata
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
cache_events: Optional[List[torch.cuda.Event]] = None,
|
||||||
|
) -> SamplerOutput:
|
||||||
|
# NOTE: We assume that all sequences in the group are all prompts or
|
||||||
|
# all decodes.
|
||||||
|
# Prepare input tensors.
|
||||||
|
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||||
|
if is_prompt:
|
||||||
|
inputs = self._prepare_prompt(seq_group_metadata_list)
|
||||||
|
input_tokens, input_positions, input_metadata = inputs
|
||||||
|
else:
|
||||||
|
inputs = self._prepare_decode(seq_group_metadata_list)
|
||||||
|
input_tokens, input_positions, input_metadata = inputs
|
||||||
|
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
|
||||||
|
input_metadata.prompt_lens)
|
||||||
|
|
||||||
|
# Execute the model.
|
||||||
|
hidden_states = self.model(
|
||||||
|
input_ids=input_tokens,
|
||||||
|
positions=input_positions,
|
||||||
|
kv_caches=kv_caches,
|
||||||
|
input_metadata=input_metadata,
|
||||||
|
cache_events=cache_events,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sample the next token.
|
||||||
|
output = self.model.sample(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
sampling_metadata=sampling_metadata,
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def profile_run(self) -> None:
|
||||||
|
# Enable top-k sampling to reflect the accurate memory usage.
|
||||||
|
vocab_size = self.model_config.get_vocab_size()
|
||||||
|
sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
|
||||||
|
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
||||||
|
max_num_seqs = self.scheduler_config.max_num_seqs
|
||||||
|
|
||||||
|
# Profile memory usage with max_num_sequences sequences and the total
|
||||||
|
# number of tokens equal to max_num_batched_tokens.
|
||||||
|
seqs: List[SequenceGroupMetadata] = []
|
||||||
|
for group_id in range(max_num_seqs):
|
||||||
|
seq_len = (max_num_batched_tokens // max_num_seqs +
|
||||||
|
(group_id < max_num_batched_tokens % max_num_seqs))
|
||||||
|
seq_data = SequenceData([0] * seq_len)
|
||||||
|
seq = SequenceGroupMetadata(
|
||||||
|
request_id=str(group_id),
|
||||||
|
is_prompt=True,
|
||||||
|
seq_data={group_id: seq_data},
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
block_tables=None,
|
||||||
|
)
|
||||||
|
seqs.append(seq)
|
||||||
|
|
||||||
|
# Run the model with the dummy inputs.
|
||||||
|
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||||
|
kv_caches = [(None, None)] * num_layers
|
||||||
|
self.execute_model(seqs, kv_caches)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
|
||||||
|
assert len(x) <= max_len
|
||||||
|
return x + [pad] * (max_len - len(x))
|
||||||
|
|
||||||
|
|
||||||
|
def _make_tensor_with_pad(
|
||||||
|
x: List[List[int]],
|
||||||
|
max_len: int,
|
||||||
|
pad: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
|
||||||
|
return torch.tensor(padded_x, dtype=dtype, device="cuda")
|
||||||
@ -1,18 +1,18 @@
|
|||||||
"""A GPU worker class."""
|
"""A GPU worker class."""
|
||||||
import os
|
import os
|
||||||
from typing import Dict, List, Tuple, Optional
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||||
SchedulerConfig)
|
SchedulerConfig)
|
||||||
from vllm.model_executor import get_model, InputMetadata, set_random_seed
|
from vllm.model_executor import set_random_seed
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
initialize_model_parallel)
|
initialize_model_parallel)
|
||||||
from vllm.sampling_params import SamplingParams, SamplingType
|
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
|
||||||
from vllm.worker.cache_engine import CacheEngine
|
from vllm.worker.cache_engine import CacheEngine
|
||||||
|
from vllm.worker.model_runner import ModelRunner
|
||||||
from vllm.utils import get_gpu_memory
|
from vllm.utils import get_gpu_memory
|
||||||
|
|
||||||
|
|
||||||
@ -38,11 +38,11 @@ class Worker:
|
|||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.distributed_init_method = distributed_init_method
|
self.distributed_init_method = distributed_init_method
|
||||||
|
|
||||||
|
self.model_runner = ModelRunner(model_config, parallel_config,
|
||||||
|
scheduler_config)
|
||||||
# Uninitialized cache engine. Will be initialized by
|
# Uninitialized cache engine. Will be initialized by
|
||||||
# self.init_cache_engine().
|
# self.init_cache_engine().
|
||||||
self.cache_config = None
|
self.cache_config = None
|
||||||
self.block_size = None
|
|
||||||
self.sliding_window = None
|
|
||||||
self.cache_engine = None
|
self.cache_engine = None
|
||||||
self.cache_events = None
|
self.cache_events = None
|
||||||
self.gpu_cache = None
|
self.gpu_cache = None
|
||||||
@ -69,7 +69,7 @@ class Worker:
|
|||||||
set_random_seed(self.model_config.seed)
|
set_random_seed(self.model_config.seed)
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
self.model = get_model(self.model_config)
|
self.model_runner.load_model()
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def profile_num_available_blocks(
|
def profile_num_available_blocks(
|
||||||
@ -83,40 +83,9 @@ class Worker:
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.reset_peak_memory_stats()
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
# Profile memory usage with max_num_sequences sequences and the total
|
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||||
# number of tokens equal to max_num_batched_tokens.
|
# of the model.
|
||||||
|
self.model_runner.profile_run()
|
||||||
# Enable top-k sampling to reflect the accurate memory usage.
|
|
||||||
vocab_size = self.model.config.vocab_size
|
|
||||||
sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
|
|
||||||
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
|
||||||
max_num_seqs = self.scheduler_config.max_num_seqs
|
|
||||||
seqs = []
|
|
||||||
for group_id in range(max_num_seqs):
|
|
||||||
seq_len = (max_num_batched_tokens // max_num_seqs +
|
|
||||||
(group_id < max_num_batched_tokens % max_num_seqs))
|
|
||||||
seq_data = SequenceData([0] * seq_len)
|
|
||||||
seq = SequenceGroupMetadata(
|
|
||||||
request_id=str(group_id),
|
|
||||||
is_prompt=True,
|
|
||||||
seq_data={group_id: seq_data},
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
block_tables=None,
|
|
||||||
)
|
|
||||||
seqs.append(seq)
|
|
||||||
|
|
||||||
input_tokens, input_positions, input_metadata = self._prepare_inputs(
|
|
||||||
seqs)
|
|
||||||
|
|
||||||
# Execute the model.
|
|
||||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
|
||||||
self.model(
|
|
||||||
input_ids=input_tokens,
|
|
||||||
positions=input_positions,
|
|
||||||
kv_caches=[(None, None)] * num_layers,
|
|
||||||
input_metadata=input_metadata,
|
|
||||||
cache_events=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate the number of blocks that can be allocated with the
|
# Calculate the number of blocks that can be allocated with the
|
||||||
# profiled peak memory.
|
# profiled peak memory.
|
||||||
@ -140,197 +109,11 @@ class Worker:
|
|||||||
|
|
||||||
def init_cache_engine(self, cache_config: CacheConfig) -> None:
|
def init_cache_engine(self, cache_config: CacheConfig) -> None:
|
||||||
self.cache_config = cache_config
|
self.cache_config = cache_config
|
||||||
self.block_size = cache_config.block_size
|
|
||||||
self.sliding_window = cache_config.sliding_window
|
|
||||||
|
|
||||||
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
|
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
|
||||||
self.parallel_config)
|
self.parallel_config)
|
||||||
self.cache_events = self.cache_engine.events
|
self.cache_events = self.cache_engine.events
|
||||||
self.gpu_cache = self.cache_engine.gpu_cache
|
self.gpu_cache = self.cache_engine.gpu_cache
|
||||||
|
self.model_runner.set_block_size(self.cache_engine.block_size)
|
||||||
def _prepare_inputs(
|
|
||||||
self,
|
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
|
|
||||||
seq_groups: List[Tuple[List[int], SamplingParams]] = []
|
|
||||||
input_tokens: List[List[int]] = []
|
|
||||||
input_positions: List[List[int]] = []
|
|
||||||
slot_mapping: List[List[int]] = []
|
|
||||||
selected_token_indices: List[int] = []
|
|
||||||
selected_token_start_idx = 0
|
|
||||||
categorized_sample_indices = {t: [] for t in SamplingType}
|
|
||||||
categorized_sample_indices_start_idx = 0
|
|
||||||
|
|
||||||
# Add prompt tokens.
|
|
||||||
prompt_lens: List[int] = []
|
|
||||||
for seq_group_metadata in seq_group_metadata_list:
|
|
||||||
if not seq_group_metadata.is_prompt:
|
|
||||||
continue
|
|
||||||
|
|
||||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
|
||||||
sampling_params = seq_group_metadata.sampling_params
|
|
||||||
seq_groups.append((seq_ids, sampling_params))
|
|
||||||
|
|
||||||
# Use any sequence in the group.
|
|
||||||
seq_id = seq_ids[0]
|
|
||||||
|
|
||||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
|
||||||
prompt_tokens = seq_data.get_token_ids()
|
|
||||||
prompt_len = len(prompt_tokens)
|
|
||||||
prompt_lens.append(prompt_len)
|
|
||||||
|
|
||||||
if sampling_params.prompt_logprobs is not None:
|
|
||||||
# NOTE: prompt token positions do not need sample, skip
|
|
||||||
categorized_sample_indices_start_idx += prompt_len - 1
|
|
||||||
|
|
||||||
categorized_sample_indices[sampling_params.sampling_type].append(
|
|
||||||
categorized_sample_indices_start_idx)
|
|
||||||
categorized_sample_indices_start_idx += 1
|
|
||||||
|
|
||||||
input_tokens.append(prompt_tokens)
|
|
||||||
# NOTE(woosuk): Here we assume that the first token in the prompt
|
|
||||||
# is always the first token in the sequence.
|
|
||||||
input_positions.append(list(range(prompt_len)))
|
|
||||||
|
|
||||||
if seq_group_metadata.block_tables is None:
|
|
||||||
# During memory profiling, the block tables are not initialized
|
|
||||||
# yet. In this case, we just use a dummy slot mapping.
|
|
||||||
slot_mapping.append([0] * prompt_len)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Compute the slot mapping.
|
|
||||||
slot_mapping.append([])
|
|
||||||
block_table = seq_group_metadata.block_tables[seq_id]
|
|
||||||
for i in range(prompt_len):
|
|
||||||
block_number = block_table[i // self.block_size]
|
|
||||||
block_offset = i % self.block_size
|
|
||||||
slot = block_number * self.block_size + block_offset
|
|
||||||
slot_mapping[-1].append(slot)
|
|
||||||
|
|
||||||
# Add generation tokens.
|
|
||||||
max_context_len = 0
|
|
||||||
max_num_blocks_per_seq = 0
|
|
||||||
context_lens: List[int] = []
|
|
||||||
generation_block_tables: List[List[int]] = []
|
|
||||||
max_seq_len = max(prompt_lens) if prompt_lens else 1
|
|
||||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
|
||||||
if seq_group_metadata.is_prompt:
|
|
||||||
# We need to do this in this loop as we need to know max_seq_len
|
|
||||||
assert len(
|
|
||||||
seq_ids) == 1, "Prompt input should have only one seq."
|
|
||||||
sampling_params = seq_group_metadata.sampling_params
|
|
||||||
assert len(prompt_lens) == len(seq_group_metadata_list)
|
|
||||||
prompt_len = prompt_lens[i]
|
|
||||||
if sampling_params.prompt_logprobs is not None:
|
|
||||||
selected_token_indices.extend(
|
|
||||||
range(selected_token_start_idx,
|
|
||||||
selected_token_start_idx + prompt_len - 1))
|
|
||||||
selected_token_indices.append(selected_token_start_idx +
|
|
||||||
prompt_len - 1)
|
|
||||||
selected_token_start_idx += max_seq_len
|
|
||||||
continue
|
|
||||||
|
|
||||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
|
||||||
sampling_params = seq_group_metadata.sampling_params
|
|
||||||
seq_groups.append((seq_ids, sampling_params))
|
|
||||||
|
|
||||||
num_seqs = len(seq_ids)
|
|
||||||
selected_token_indices.extend(
|
|
||||||
range(selected_token_start_idx,
|
|
||||||
selected_token_start_idx + num_seqs))
|
|
||||||
selected_token_start_idx += num_seqs
|
|
||||||
|
|
||||||
categorized_sample_indices[sampling_params.sampling_type].extend(
|
|
||||||
range(categorized_sample_indices_start_idx,
|
|
||||||
categorized_sample_indices_start_idx + num_seqs))
|
|
||||||
categorized_sample_indices_start_idx += num_seqs
|
|
||||||
|
|
||||||
for seq_id in seq_ids:
|
|
||||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
|
||||||
generation_token = seq_data.get_last_token_id()
|
|
||||||
input_tokens.append([generation_token])
|
|
||||||
|
|
||||||
context_len = seq_data.get_len()
|
|
||||||
position = context_len - 1
|
|
||||||
if self.sliding_window is not None:
|
|
||||||
context_len = min(context_len, self.sliding_window)
|
|
||||||
input_positions.append([position])
|
|
||||||
|
|
||||||
block_table = seq_group_metadata.block_tables[seq_id]
|
|
||||||
|
|
||||||
max_context_len = max(max_context_len, context_len)
|
|
||||||
max_num_blocks_per_seq = max(max_num_blocks_per_seq,
|
|
||||||
len(block_table))
|
|
||||||
context_lens.append(context_len)
|
|
||||||
|
|
||||||
block_number = block_table[position // self.block_size]
|
|
||||||
block_offset = position % self.block_size
|
|
||||||
slot = block_number * self.block_size + block_offset
|
|
||||||
slot_mapping.append([slot])
|
|
||||||
|
|
||||||
if self.sliding_window is not None:
|
|
||||||
sliding_window_blocks = (self.sliding_window //
|
|
||||||
self.block_size)
|
|
||||||
block_table = block_table[-sliding_window_blocks:]
|
|
||||||
generation_block_tables.append(block_table)
|
|
||||||
|
|
||||||
padded_input_tokens = [
|
|
||||||
_pad_to_max(tokens, max_seq_len, pad=0) for tokens in input_tokens
|
|
||||||
]
|
|
||||||
padded_input_positions = [
|
|
||||||
_pad_to_max(positions, max_seq_len, pad=0)
|
|
||||||
for positions in input_positions
|
|
||||||
]
|
|
||||||
padded_slot_mapping = [
|
|
||||||
_pad_to_max(mapping, max_seq_len, pad=-1)
|
|
||||||
for mapping in slot_mapping
|
|
||||||
]
|
|
||||||
padded_block_tables = [
|
|
||||||
_pad_to_max(block_table, max_num_blocks_per_seq, pad=0)
|
|
||||||
for block_table in generation_block_tables
|
|
||||||
]
|
|
||||||
|
|
||||||
# Convert to tensors.
|
|
||||||
tokens_tensor = torch.tensor(padded_input_tokens,
|
|
||||||
dtype=torch.long,
|
|
||||||
device="cuda")
|
|
||||||
positions_tensor = torch.tensor(padded_input_positions,
|
|
||||||
dtype=torch.long,
|
|
||||||
device="cuda")
|
|
||||||
slot_mapping_tensor = torch.tensor(padded_slot_mapping,
|
|
||||||
dtype=torch.long,
|
|
||||||
device="cuda")
|
|
||||||
context_lens_tensor = torch.tensor(context_lens,
|
|
||||||
dtype=torch.int,
|
|
||||||
device="cuda")
|
|
||||||
selected_token_indices = torch.tensor(selected_token_indices,
|
|
||||||
dtype=torch.long,
|
|
||||||
device="cuda")
|
|
||||||
categorized_sample_indices = {
|
|
||||||
t: torch.tensor(seq_ids, dtype=torch.int, device="cuda")
|
|
||||||
for t, seq_ids in categorized_sample_indices.items()
|
|
||||||
}
|
|
||||||
block_tables_tensor = torch.tensor(padded_block_tables,
|
|
||||||
dtype=torch.int,
|
|
||||||
device="cuda")
|
|
||||||
|
|
||||||
seq_data: Dict[int, SequenceData] = {}
|
|
||||||
for seq_group_metadata in seq_group_metadata_list:
|
|
||||||
seq_data.update(seq_group_metadata.seq_data)
|
|
||||||
|
|
||||||
input_metadata = InputMetadata(
|
|
||||||
seq_groups=seq_groups,
|
|
||||||
seq_data=seq_data,
|
|
||||||
prompt_lens=prompt_lens,
|
|
||||||
slot_mapping=slot_mapping_tensor,
|
|
||||||
context_lens=context_lens_tensor,
|
|
||||||
max_context_len=max_context_len,
|
|
||||||
block_tables=block_tables_tensor,
|
|
||||||
selected_token_indices=selected_token_indices,
|
|
||||||
categorized_sample_indices=categorized_sample_indices,
|
|
||||||
sliding_window=self.sliding_window,
|
|
||||||
)
|
|
||||||
return tokens_tensor, positions_tensor, input_metadata
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
@ -361,18 +144,8 @@ class Worker:
|
|||||||
event.wait()
|
event.wait()
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
# Prepare input tensors.
|
output = self.model_runner.execute_model(seq_group_metadata_list,
|
||||||
input_tokens, input_positions, input_metadata = self._prepare_inputs(
|
self.gpu_cache, cache_events)
|
||||||
seq_group_metadata_list)
|
|
||||||
|
|
||||||
# Execute the model.
|
|
||||||
output = self.model(
|
|
||||||
input_ids=input_tokens,
|
|
||||||
positions=input_positions,
|
|
||||||
kv_caches=self.gpu_cache,
|
|
||||||
input_metadata=input_metadata,
|
|
||||||
cache_events=cache_events,
|
|
||||||
)
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@ -407,14 +180,6 @@ def _init_distributed_environment(
|
|||||||
parallel_config.pipeline_parallel_size)
|
parallel_config.pipeline_parallel_size)
|
||||||
|
|
||||||
|
|
||||||
def _pad_to_alignment(x: List[int], multiple_of: int, pad: int) -> List[int]:
|
|
||||||
return x + [pad] * ((-len(x)) % multiple_of)
|
|
||||||
|
|
||||||
|
|
||||||
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
|
|
||||||
return x + [pad] * (max_len - len(x))
|
|
||||||
|
|
||||||
|
|
||||||
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
||||||
# Check if the GPU supports the dtype.
|
# Check if the GPU supports the dtype.
|
||||||
if torch_dtype == torch.bfloat16:
|
if torch_dtype == torch.bfloat16:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user