diff --git a/vllm/config.py b/vllm/config.py index 1adf830ffcc12..cd92d361d33c5 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -161,6 +161,12 @@ class ModelConfig: "must be divisible by pipeline parallel size " f"({pipeline_parallel_size}).") + def get_sliding_window(self) -> Optional[int]: + return getattr(self.hf_config, "sliding_window", None) + + def get_vocab_size(self) -> int: + return self.hf_config.vocab_size + def get_hidden_size(self) -> int: return self.hf_config.hidden_size diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 746b0e64ece7b..8dec696e7fb6c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -201,9 +201,10 @@ class EngineArgs: self.dtype, self.seed, self.revision, self.tokenizer_revision, self.max_model_len, self.quantization) - cache_config = CacheConfig( - self.block_size, self.gpu_memory_utilization, self.swap_space, - getattr(model_config.hf_config, 'sliding_window', None)) + cache_config = CacheConfig(self.block_size, + self.gpu_memory_utilization, + self.swap_space, + model_config.get_sliding_window()) parallel_config = ParallelConfig(self.pipeline_parallel_size, self.tensor_parallel_size, self.worker_use_ray, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index db1ee606a6dcc..b7dd60dfc70e2 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -88,8 +88,6 @@ class LLMEngine: self.model_config = model_config self.cache_config = cache_config - assert self.cache_config.sliding_window == getattr( - self.model_config.hf_config, "sliding_window", None) self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.log_stats = log_stats diff --git a/vllm/model_executor/__init__.py b/vllm/model_executor/__init__.py index 36fc30f9c1e3c..0d5b2004ad7cb 100644 --- a/vllm/model_executor/__init__.py +++ b/vllm/model_executor/__init__.py @@ -1,9 +1,11 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.model_loader import get_model +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed __all__ = [ "InputMetadata", "get_model", + "SamplingMetadata", "set_random_seed", ] diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index b3b5852e48769..e4ddf08cd9a03 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -1,91 +1,42 @@ -from typing import Dict, List, Optional, Tuple +from typing import List, Optional import torch -from xformers.ops import AttentionBias - -from vllm.sampling_params import SamplingParams, SamplingType -from vllm.sequence import SequenceData class InputMetadata: - """Metadata for input sequences. Used for PagedAttention. + """Metadata for input sequences. Used in PagedAttention. Args: - seq_groups: List of (seq_ids, sampling_params). - seq_data: Seq_id -> SequenceData. prompt_lens: Lengths of prompts. slot_mapping: The address to write the new KV to of each token. - context_lens: the length of attention context for each generation token. max_context_len: The maximum context length. + context_lens: the length of attention context for each sequence. block_tables: The block tables. (Seq id -> list of physical block) """ def __init__( self, - seq_groups: List[Tuple[List[int], SamplingParams]], - seq_data: Dict[int, SequenceData], prompt_lens: List[int], slot_mapping: torch.Tensor, - context_lens: torch.Tensor, - max_context_len: int, - block_tables: torch.Tensor, - selected_token_indices: torch.Tensor, - categorized_sample_indices: Dict[SamplingType, torch.Tensor], - sliding_window: Optional[int] = None, + max_context_len: Optional[int], + context_lens: Optional[torch.Tensor], + block_tables: Optional[torch.Tensor], ) -> None: - self.seq_groups = seq_groups - self.seq_data = seq_data self.prompt_lens = prompt_lens + self.max_context_len = max_context_len self.slot_mapping = slot_mapping self.context_lens = context_lens - self.max_context_len = max_context_len self.block_tables = block_tables - self.selected_token_indices = selected_token_indices - self.categorized_sample_indices = categorized_sample_indices - - self.max_prompt_len = max(prompt_lens) if prompt_lens else 0 - self.to_cache = None - if sliding_window is not None: - # We need to keep the positions of sliding windows within - # the key / value tables, this is helpful to know which - # elements we need to cache. - to_cache, start_idx = [], 0 - for prompt_len in self.prompt_lens: - to_cache.extend( - range( - start_idx + max(0, prompt_len - sliding_window), - start_idx + prompt_len, - )) - start_idx += self.max_prompt_len - to_cache.extend(range(start_idx, slot_mapping.shape[0])) - self.to_cache = torch.tensor(to_cache, - dtype=torch.int32, - device=self.slot_mapping.device) - - self.num_prompts = len(prompt_lens) - self.num_prompt_tokens = self.num_prompts * self.max_prompt_len - self.num_generation_tokens = context_lens.shape[0] - if block_tables.numel() > 0: - self.max_num_blocks_per_seq = block_tables.shape[1] - else: - self.max_num_blocks_per_seq = 0 - assert block_tables.shape[0] == self.num_generation_tokens + self.is_prompt = len(prompt_lens) > 0 # Set during the execution of the first attention op. - self.attn_bias: Optional[AttentionBias] = None + # FIXME(woosuk): This is a hack. + self.attn_bias = None def __repr__(self) -> str: - # Print only useful metadata. - return ( - f'InputMetadata(' - f'num_prompt_tokens={self.num_prompt_tokens}, ' - f'num_prompts={self.num_prompts}, ' - f'prompt_lens={self.prompt_lens}, ' - f'num_generation_tokens={self.num_generation_tokens}, ' - f'context_lens={self.context_lens}, ' - f'max_context_len={self.max_context_len}), ' - f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, ' - f'block_tables={self.block_tables}, ' - f'selected_token_indices={self.selected_token_indices}, ' - f'categorized_sample_indices={self.categorized_sample_indices}, ' - f'slot_mapping={self.slot_mapping})') + return ("InputMetadata(" + f"prompt_lens={self.prompt_lens}, " + f"max_context_len={self.max_context_len}, " + f"slot_mapping={self.slot_mapping}, " + f"context_lens={self.context_lens}, " + f"block_tables={self.block_tables})") diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 55b48fc5c7cca..b84af362efca6 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -101,23 +101,15 @@ class PagedAttention(nn.Module): # vectors will not be cached. This happens during the initial memory # profiling run. if key_cache is not None and value_cache is not None: - key_to_cache = key - value_to_cache = value - if input_metadata.to_cache is not None: - key_to_cache = key_to_cache[input_metadata.to_cache] - value_to_cache = value_to_cache[input_metadata.to_cache] - slot_mapping = slot_mapping[input_metadata.to_cache] - cache_ops.reshape_and_cache( - key_to_cache, - value_to_cache, + key, + value, key_cache, value_cache, slot_mapping, ) - is_prompt = len(input_metadata.prompt_lens) > 0 - if is_prompt: + if input_metadata.is_prompt: # Prompt run. if self.num_kv_heads != self.num_heads: # As of Nov 2023, xformers only supports MHA. For MQA/GQA, diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index b545587fd2044..13da9aa38af03 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -4,9 +4,9 @@ from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn -from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.parallel_utils.communication_op import ( tensor_model_parallel_all_gather) +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput, SequenceData, SequenceGroupOutput, SequenceOutput) @@ -37,29 +37,30 @@ class Sampler(nn.Module): self, embedding: torch.Tensor, hidden_states: torch.Tensor, - input_metadata: InputMetadata, + sampling_metadata: SamplingMetadata, embedding_bias: Optional[torch.Tensor] = None, ) -> SamplerOutput: # Get the hidden states that we use for sampling. - hidden_states = _prune_hidden_states(hidden_states, input_metadata) + hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) # Get the logits for the next tokens. logits = _get_logits(hidden_states, embedding, embedding_bias, self.vocab_size) # Apply logits processors (if any). - logits = _apply_logits_processors(logits, input_metadata) + logits = _apply_logits_processors(logits, sampling_metadata) # Apply presence and frequency penalties. presence_penalties, frequency_penalties, repetition_penalties = ( - _get_penalties(input_metadata)) + _get_penalties(sampling_metadata)) assert len(presence_penalties) == logits.shape[0] assert len(frequency_penalties) == logits.shape[0] assert len(repetition_penalties) == logits.shape[0] - logits = _apply_penalties(logits, input_metadata, presence_penalties, - frequency_penalties, repetition_penalties) + logits = _apply_penalties(logits, sampling_metadata, + presence_penalties, frequency_penalties, + repetition_penalties) # Apply temperature scaling. - temperatures = _get_temperatures(input_metadata) + temperatures = _get_temperatures(sampling_metadata) assert len(temperatures) == logits.shape[0] if any(t != 1.0 for t in temperatures): t = torch.tensor(temperatures, @@ -70,7 +71,7 @@ class Sampler(nn.Module): # Apply top-p and top-k truncation. top_ps, top_ks, min_ps = _get_top_p_top_k_min_p( - input_metadata, self.vocab_size) + sampling_metadata, self.vocab_size) assert len(top_ps) == len(top_ks) == logits.shape[0] do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps) do_top_k = any(k != self.vocab_size for k in top_ks) @@ -89,11 +90,11 @@ class Sampler(nn.Module): logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) # Sample the next tokens. - sample_results = _sample(probs, logprobs, input_metadata) + sample_results = _sample(probs, logprobs, sampling_metadata) # Get the logprobs query results. prompt_logprobs, sample_logprobs = _get_logprobs( - logprobs, input_metadata, sample_results) - return _build_sampler_output(sample_results, input_metadata, + logprobs, sampling_metadata, sample_results) + return _build_sampler_output(sample_results, sampling_metadata, prompt_logprobs, sample_logprobs) @@ -112,29 +113,30 @@ def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor, def _prune_hidden_states( hidden_states: torch.Tensor, - input_metadata: InputMetadata, + sampling_metadata: SamplingMetadata, ) -> torch.Tensor: hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - return hidden_states.index_select(0, input_metadata.selected_token_indices) + return hidden_states.index_select(0, + sampling_metadata.selected_token_indices) def _get_penalties( - input_metadata: InputMetadata + sampling_metadata: SamplingMetadata ) -> Tuple[List[float], List[float], List[float]]: # Collect the presence and frequency penalties. presence_penalties: List[float] = [] frequency_penalties: List[float] = [] repetition_penalties: List[float] = [] - for i, seq_group in enumerate(input_metadata.seq_groups): + for i, seq_group in enumerate(sampling_metadata.seq_groups): seq_ids, sampling_params = seq_group p = sampling_params.presence_penalty f = sampling_params.frequency_penalty r = sampling_params.repetition_penalty - if (i < input_metadata.num_prompts + if (i < sampling_metadata.num_prompts and sampling_params.prompt_logprobs is not None): # NOTE: We do not apply presence and frequency penalties for the # prompt token positions where we don't sample new tokens. - prompt_len = input_metadata.prompt_lens[i] + prompt_len = sampling_metadata.prompt_lens[i] presence_penalties += [0] * (prompt_len - 1) frequency_penalties += [0] * (prompt_len - 1) repetition_penalties += [1] * (prompt_len - 1) @@ -145,21 +147,21 @@ def _get_penalties( def _get_prompt_and_output_tokens( - input_metadata: InputMetadata + sampling_metadata: SamplingMetadata, ) -> Tuple[List[List[int]], List[List[int]]]: prompt_tokens: List[List[int]] = [] output_tokens: List[List[int]] = [] - for i, seq_group in enumerate(input_metadata.seq_groups): + for i, seq_group in enumerate(sampling_metadata.seq_groups): seq_ids, sampling_params = seq_group - if (i < input_metadata.num_prompts + if (i < sampling_metadata.num_prompts and sampling_params.prompt_logprobs is not None): # NOTE: prompt token positions do not need output tokens to # compute penalties. - prompt_len = input_metadata.prompt_lens[i] + prompt_len = sampling_metadata.prompt_lens[i] prompt_tokens.extend([] for _ in range(prompt_len - 1)) output_tokens.extend([] for _ in range(prompt_len - 1)) for seq_id in seq_ids: - seq_data = input_metadata.seq_data[seq_id] + seq_data = sampling_metadata.seq_data[seq_id] prompt_tokens.append(seq_data.prompt_token_ids) output_tokens.append(seq_data.output_token_ids) return prompt_tokens, output_tokens @@ -191,17 +193,19 @@ def _get_bin_counts_and_mask( return bin_counts, mask -def _apply_logits_processors(logits: torch.Tensor, - input_metadata: InputMetadata) -> torch.Tensor: +def _apply_logits_processors( + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> torch.Tensor: logits_row_idx = 0 found_logits_processors = False - for seq_ids, sampling_params in input_metadata.seq_groups: + for seq_ids, sampling_params in sampling_metadata.seq_groups: logits_processors = sampling_params.logits_processors if logits_processors: found_logits_processors = True for seq_id in seq_ids: logits_row = logits[logits_row_idx] - token_ids = input_metadata.seq_data[seq_id].output_token_ids + token_ids = sampling_metadata.seq_data[seq_id].output_token_ids for logits_processor in logits_processors: logits_row = logits_processor(token_ids, logits_row) logits[logits_row_idx] = logits_row @@ -215,7 +219,7 @@ def _apply_logits_processors(logits: torch.Tensor, def _apply_penalties( logits: torch.Tensor, - input_metadata: InputMetadata, + sampling_metadata: SamplingMetadata, presence_penalties: List[float], frequency_penalties: List[float], repetition_penalties: List[float], @@ -234,7 +238,7 @@ def _apply_penalties( return logits prompt_tokens, output_tokens = ( - _get_prompt_and_output_tokens(input_metadata)) + _get_prompt_and_output_tokens(sampling_metadata)) assert len(prompt_tokens) == logits.shape[0] assert len(output_tokens) == logits.shape[0] @@ -265,10 +269,10 @@ def _apply_penalties( return logits -def _get_temperatures(input_metadata: InputMetadata) -> List[float]: +def _get_temperatures(sampling_metadata: SamplingMetadata) -> List[float]: # Collect the temperatures for the logits. temperatures: List[float] = [] - for i, seq_group in enumerate(input_metadata.seq_groups): + for i, seq_group in enumerate(sampling_metadata.seq_groups): seq_ids, sampling_params = seq_group temperature = sampling_params.temperature if temperature < _SAMPLING_EPS: @@ -276,22 +280,22 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]: # (i.e., greedy sampling or beam search). # Set the temperature to 1 to avoid division by zero. temperature = 1.0 - if (i < input_metadata.num_prompts + if (i < sampling_metadata.num_prompts and sampling_params.prompt_logprobs is not None): - prompt_len = input_metadata.prompt_lens[i] + prompt_len = sampling_metadata.prompt_lens[i] temperatures += [temperature] * (prompt_len - 1) temperatures += [temperature] * len(seq_ids) return temperatures def _get_top_p_top_k_min_p( - input_metadata: InputMetadata, + sampling_metadata: SamplingMetadata, vocab_size: int, ) -> Tuple[List[float], List[int], List[float]]: top_ps: List[float] = [] top_ks: List[int] = [] min_ps: List[float] = [] - for i, seq_group in enumerate(input_metadata.seq_groups): + for i, seq_group in enumerate(sampling_metadata.seq_groups): seq_ids, sampling_params = seq_group top_p = sampling_params.top_p min_p = sampling_params.min_p @@ -299,9 +303,9 @@ def _get_top_p_top_k_min_p( top_k = min(sampling_params.top_k, vocab_size) # k=-1 means no truncation. top_k = vocab_size if top_k == -1 else top_k - if (i < input_metadata.num_prompts + if (i < sampling_metadata.num_prompts and sampling_params.prompt_logprobs is not None): - prompt_len = input_metadata.prompt_lens[i] + prompt_len = sampling_metadata.prompt_lens[i] top_ps += [top_p] * (prompt_len - 1) top_ks += [top_k] * (prompt_len - 1) min_ps += [min_p] * (prompt_len - 1) @@ -471,11 +475,11 @@ def _beam_search_sample( def _sample( probs: torch.Tensor, logprobs: torch.Tensor, - input_metadata: InputMetadata, + sampling_metadata: SamplingMetadata, ) -> List[Tuple[List[int], List[int]]]: categorized_seq_group_ids = {t: [] for t in SamplingType} - categorized_sample_indices = input_metadata.categorized_sample_indices - for i, seq_group in enumerate(input_metadata.seq_groups): + categorized_sample_indices = sampling_metadata.categorized_sample_indices + for i, seq_group in enumerate(sampling_metadata.seq_groups): _, sampling_params = seq_group sampling_type = sampling_params.sampling_type categorized_seq_group_ids[sampling_type].append(i) @@ -483,8 +487,8 @@ def _sample( sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} for sampling_type in SamplingType: seq_group_ids = categorized_seq_group_ids[sampling_type] - seq_groups = [input_metadata.seq_groups[i] for i in seq_group_ids] - is_prompts = [i < input_metadata.num_prompts for i in seq_group_ids] + seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids] + is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids] sample_indices = categorized_sample_indices[sampling_type] num_tokens = len(sample_indices) if num_tokens == 0: @@ -499,21 +503,22 @@ def _sample( elif sampling_type == SamplingType.BEAM: category_logprobs = logprobs[sample_indices] sample_results = _beam_search_sample(seq_groups, is_prompts, - input_metadata.seq_data, + sampling_metadata.seq_data, category_logprobs) else: raise ValueError(f"Unsupported sampling type: {sampling_type}") sample_results_dict.update(zip(seq_group_ids, sample_results)) sample_results = [ - sample_results_dict[i] for i in range(len(input_metadata.seq_groups)) + sample_results_dict[i] + for i in range(len(sampling_metadata.seq_groups)) ] return sample_results def _get_logprobs( logprobs: torch.Tensor, - input_metadata: InputMetadata, + sampling_metadata: SamplingMetadata, sample_results: List[Tuple[List[int], List[int]]], ) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[ int, float]]]]: @@ -523,16 +528,16 @@ def _get_logprobs( largest_num_logprobs = 0 sample_idx = 0 for i, (seq_group, sample_result) in enumerate( - zip(input_metadata.seq_groups, sample_results)): + zip(sampling_metadata.seq_groups, sample_results)): seq_ids, sampling_params = seq_group next_token_ids, parent_ids = sample_result num_parent_seqs = len(seq_ids) - if (i < input_metadata.num_prompts + if (i < sampling_metadata.num_prompts and sampling_params.prompt_logprobs is not None): largest_num_logprobs = max(largest_num_logprobs, sampling_params.prompt_logprobs) - prompt_len = input_metadata.prompt_lens[i] - prompt_tokens = input_metadata.seq_data[ + prompt_len = sampling_metadata.prompt_lens[i] + prompt_tokens = sampling_metadata.seq_data[ seq_ids[0]].prompt_token_ids batched_logprobs_query_seq_indices.extend( sample_idx + j for j in range(prompt_len - 1)) @@ -570,16 +575,16 @@ def _get_logprobs( sample_idx = 0 query_result_idx = 0 for i, (seq_group, sample_result) in enumerate( - zip(input_metadata.seq_groups, sample_results)): + zip(sampling_metadata.seq_groups, sample_results)): seq_ids, sampling_params = seq_group next_token_ids, parent_ids = sample_result # Prompt logprobs - if (i < input_metadata.num_prompts + if (i < sampling_metadata.num_prompts and sampling_params.prompt_logprobs is not None): num_logprobs = sampling_params.prompt_logprobs - prompt_len = input_metadata.prompt_lens[i] - prompt_tokens = input_metadata.seq_data[ + prompt_len = sampling_metadata.prompt_lens[i] + prompt_tokens = sampling_metadata.seq_data[ seq_ids[0]].prompt_token_ids group_prompt_logprobs: PromptLogprobs = [None] for token_id in prompt_tokens[1:]: @@ -625,13 +630,13 @@ def _get_logprobs( def _build_sampler_output( sample_results: List[Tuple[List[int], List[int]]], - input_metadata: InputMetadata, + sampling_metadata: SamplingMetadata, prompt_logprobs: List[Optional[PromptLogprobs]], sample_logprobs: List[SampleLogprobs], ) -> SamplerOutput: sampler_output = [] for (seq_group, sample_result, group_prompt_logprobs, - group_sample_logprobs) in zip(input_metadata.seq_groups, + group_sample_logprobs) in zip(sampling_metadata.seq_groups, sample_results, prompt_logprobs, sample_logprobs): seq_ids, _ = seq_group diff --git a/vllm/model_executor/models/aquila.py b/vllm/model_executor/models/aquila.py index ba2af445b1364..f8c4d643294b1 100644 --- a/vllm/model_executor/models/aquila.py +++ b/vllm/model_executor/models/aquila.py @@ -39,6 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -296,11 +297,18 @@ class AquilaForCausalLM(nn.Module): kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], - ) -> SamplerOutput: + ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, input_metadata, cache_events) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput: next_tokens = self.sampler(self.lm_head.weight, hidden_states, - input_metadata) + sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index e1de6bbefbbc6..d4a32e8e21a6d 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -38,6 +38,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -311,11 +312,18 @@ class BaiChuanBaseForCausalLM(nn.Module): kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], - ) -> SamplerOutput: + ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, input_metadata, cache_events) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput: next_tokens = self.sampler(self.lm_head.weight, hidden_states, - input_metadata) + sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 1703d1cdb3670..9da0490104b6a 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -288,11 +289,18 @@ class BloomForCausalLM(nn.Module): kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], - ) -> SamplerOutput: + ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, input_metadata, cache_events) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput: next_tokens = self.sampler(self.lm_head_weight, hidden_states, - input_metadata) + sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 5c08a1a823685..60ec4d9b4018a 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -22,6 +22,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -350,11 +351,18 @@ class ChatGLMForCausalLM(nn.Module): kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], - ) -> SamplerOutput: + ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, input_metadata, cache_events) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput: next_tokens = self.sampler(self.lm_head_weight, hidden_states, - input_metadata) + sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index b7af514661a68..8890d29b1267b 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -41,6 +41,7 @@ from vllm.model_executor.parallel_utils.communication_op import ( tensor_model_parallel_all_reduce) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -389,7 +390,7 @@ class FalconForCausalLM(nn.Module): kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], - ) -> SamplerOutput: + ) -> torch.Tensor: hidden_states = self.transformer( input_ids, positions, @@ -397,9 +398,15 @@ class FalconForCausalLM(nn.Module): input_metadata, cache_events, ) - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - input_metadata) + return hidden_states + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput: + next_tokens = self.sampler(self.lm_head.weight, hidden_states, + sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 5dce59f77eea2..5fe678ecc9d5d 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -232,11 +233,18 @@ class GPT2LMHeadModel(nn.Module): kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], - ) -> SamplerOutput: + ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, input_metadata, cache_events) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput: next_tokens = self.sampler(self.lm_head_weight, hidden_states, - input_metadata) + sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 9b69fc90b13aa..2007c264f0cb9 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -36,6 +36,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -251,11 +252,18 @@ class GPTBigCodeForCausalLM(nn.Module): kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], - ) -> SamplerOutput: + ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, input_metadata, cache_events) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput: next_tokens = self.sampler(self.lm_head_weight, hidden_states, - input_metadata) + sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 7db6edd110f27..1ad344fd6cc0d 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -238,11 +239,18 @@ class GPTJForCausalLM(nn.Module): kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], - ) -> SamplerOutput: + ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, input_metadata, cache_events) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput: next_tokens = self.sampler(self.lm_head.weight, hidden_states, - input_metadata, self.lm_head.bias) + sampling_metadata, self.lm_head.bias) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 1d21d06e21a62..df5c86bf103ad 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -251,11 +252,18 @@ class GPTNeoXForCausalLM(nn.Module): kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], - ) -> SamplerOutput: + ) -> torch.Tensor: hidden_states = self.gpt_neox(input_ids, positions, kv_caches, input_metadata, cache_events) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput: next_tokens = self.sampler(self.embed_out.weight, hidden_states, - input_metadata) + sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/internlm.py b/vllm/model_executor/models/internlm.py index 8b20462c18c15..ebb96c75736cd 100644 --- a/vllm/model_executor/models/internlm.py +++ b/vllm/model_executor/models/internlm.py @@ -19,6 +19,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -250,11 +251,18 @@ class InternLMForCausalLM(nn.Module): kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], - ) -> SamplerOutput: + ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, input_metadata, cache_events) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput: next_tokens = self.sampler(self.lm_head.weight, hidden_states, - input_metadata) + sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index cd39d9059eeca..40d13ab061d72 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -41,6 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -289,11 +290,18 @@ class LlamaForCausalLM(nn.Module): kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], - ) -> SamplerOutput: + ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, input_metadata, cache_events) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput: next_tokens = self.sampler(self.lm_head.weight, hidden_states, - input_metadata) + sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py index 8470020006f3c..6e67d8b82ef1d 100644 --- a/vllm/model_executor/models/mistral.py +++ b/vllm/model_executor/models/mistral.py @@ -41,6 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -285,11 +286,18 @@ class MistralForCausalLM(nn.Module): kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], - ) -> SamplerOutput: + ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, input_metadata, cache_events) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput: next_tokens = self.sampler(self.lm_head.weight, hidden_states, - input_metadata) + sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index caa169b57d481..c7be7a922915f 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -18,6 +18,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -256,11 +257,18 @@ class MPTForCausalLM(nn.Module): kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], - ) -> SamplerOutput: + ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, input_metadata, cache_events) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput: next_tokens = self.sampler(self.lm_head_weight, hidden_states, - input_metadata) + sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 2bbcd0030f6fc..1c698c20f35db 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -36,6 +36,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -308,11 +309,18 @@ class OPTForCausalLM(nn.Module): kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], - ) -> SamplerOutput: + ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, input_metadata, cache_events) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput: next_tokens = self.sampler(self.lm_head_weight, hidden_states, - input_metadata) + sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/phi_1_5.py b/vllm/model_executor/models/phi_1_5.py index bd4afbf2ca973..ac441e476bb82 100644 --- a/vllm/model_executor/models/phi_1_5.py +++ b/vllm/model_executor/models/phi_1_5.py @@ -54,6 +54,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -210,28 +211,6 @@ class PhiLayer(nn.Module): return hidden_states -class PhiCausalLMHead(nn.Module): - - def __init__(self, config: PretrainedConfig): - super().__init__() - self.ln = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_epsilon) - self.linear = ParallelLMHead(config.vocab_size, - config.hidden_size, - bias=True) - self.sampler = Sampler(config.vocab_size) - - def forward( - self, - hidden_states: torch.Tensor, - input_metadata: InputMetadata, - ): - hidden_states = self.ln(hidden_states) - next_tokens = self.sampler(self.linear.weight, hidden_states, - input_metadata, self.linear.bias) - return next_tokens - - class PhiModel(nn.Module): def __init__(self, @@ -253,7 +232,7 @@ class PhiModel(nn.Module): kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], - ) -> SamplerOutput: + ) -> torch.Tensor: hidden_states = self.embd(input_ids) for i in range(self.config.num_hidden_layers): cache_event = None if cache_events is None else cache_events[i] @@ -268,6 +247,17 @@ class PhiModel(nn.Module): return hidden_states +class PhiCausalLMHead(nn.Module): + + def __init__(self, config: PretrainedConfig): + super().__init__() + self.ln = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + self.linear = ParallelLMHead(config.vocab_size, + config.hidden_size, + bias=True) + + class PhiForCausalLM(nn.Module): def __init__(self, @@ -279,6 +269,7 @@ class PhiForCausalLM(nn.Module): self.transformer = PhiModel(config, linear_method) self.lm_head = PhiCausalLMHead(config) + self.sampler = Sampler(config.vocab_size) def forward( self, @@ -287,11 +278,21 @@ class PhiForCausalLM(nn.Module): kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], - ) -> SamplerOutput: + ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, input_metadata, cache_events) - lm_logits = self.lm_head(hidden_states, input_metadata) - return lm_logits + hidden_states = self.lm_head.ln(hidden_states) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput: + head = self.lm_head.linear + next_tokens = self.sampler(head.weight, hidden_states, + sampling_metadata, head.bias) + return next_tokens def load_weights(self, model_name_or_path: str, diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index e6c089b3f8289..33bae61f6016d 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -23,6 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -246,11 +247,18 @@ class QWenLMHeadModel(nn.Module): kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], - ) -> SamplerOutput: + ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, input_metadata, cache_events) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput: next_tokens = self.sampler(self.lm_head.weight, hidden_states, - input_metadata) + sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/yi.py b/vllm/model_executor/models/yi.py index af83241412727..889cc3f0b5fcb 100644 --- a/vllm/model_executor/models/yi.py +++ b/vllm/model_executor/models/yi.py @@ -41,6 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -284,11 +285,18 @@ class YiForCausalLM(nn.Module): kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], - ) -> SamplerOutput: + ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, input_metadata, cache_events) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput: next_tokens = self.sampler(self.lm_head.weight, hidden_states, - input_metadata) + sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py new file mode 100644 index 0000000000000..deb779f537c69 --- /dev/null +++ b/vllm/model_executor/sampling_metadata.py @@ -0,0 +1,43 @@ +from typing import Dict, List, Tuple + +import torch + +from vllm.sampling_params import SamplingParams, SamplingType +from vllm.sequence import SequenceData + + +class SamplingMetadata: + """Metadata for input sequences. Used in sampler. + + Args: + seq_groups: List of (seq_ids, sampling_params). + seq_data: Seq_id -> SequenceData. + prompt_lens: Lengths of prompts. + selected_token_indices: Token indices selected for sampling. + categorized_sample_indices: SamplingType -> token indicies to sample. + """ + + def __init__( + self, + seq_groups: List[Tuple[List[int], SamplingParams]], + seq_data: Dict[int, SequenceData], + prompt_lens: List[int], + selected_token_indices: torch.Tensor, + categorized_sample_indices: Dict[SamplingType, torch.Tensor], + ) -> None: + self.seq_groups = seq_groups + self.seq_data = seq_data + self.prompt_lens = prompt_lens + self.selected_token_indices = selected_token_indices + self.categorized_sample_indices = categorized_sample_indices + + self.num_prompts = len(prompt_lens) + + def __repr__(self) -> str: + return ( + "SamplingMetadata(" + f"seq_groups={self.seq_groups}, " + f"seq_data={self.seq_data}, " + f"prompt_lens={self.prompt_lens}, " + f"selected_token_indices={self.selected_token_indices}, " + f"categorized_sample_indices={self.categorized_sample_indices})") diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py new file mode 100644 index 0000000000000..e0e381b369e4b --- /dev/null +++ b/vllm/worker/model_runner.py @@ -0,0 +1,334 @@ +from typing import Dict, List, Optional, Tuple + +import torch + +from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig +from vllm.logger import init_logger +from vllm.model_executor import get_model, InputMetadata, SamplingMetadata +from vllm.sampling_params import SamplingParams, SamplingType +from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata + +logger = init_logger(__name__) + +_PAD_SLOT_ID = -1 + + +class ModelRunner: + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + ): + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + + self.sliding_window = model_config.get_sliding_window() + self.model = None + self.block_size = None # Set after initial profiling. + + def load_model(self) -> None: + self.model = get_model(self.model_config) + + def set_block_size(self, block_size: int) -> None: + self.block_size = block_size + + def _prepare_prompt( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[List[int]] = [] + input_positions: List[List[int]] = [] + slot_mapping: List[List[int]] = [] + + prompt_lens: List[int] = [] + for seq_group_metadata in seq_group_metadata_list: + assert seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + + seq_data = seq_group_metadata.seq_data[seq_id] + prompt_tokens = seq_data.get_token_ids() + prompt_len = len(prompt_tokens) + prompt_lens.append(prompt_len) + + input_tokens.append(prompt_tokens) + # NOTE(woosuk): Here we assume that the first token in the prompt + # is always the first token in the sequence. + input_positions.append(list(range(prompt_len))) + + if seq_group_metadata.block_tables is None: + # During memory profiling, the block tables are not initialized + # yet. In this case, we just use a dummy slot mapping. + slot_mapping.append([_PAD_SLOT_ID] * prompt_len) + continue + + # Compute the slot mapping. + slot_mapping.append([]) + block_table = seq_group_metadata.block_tables[seq_id] + # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, + # where start_idx is max(0, prompt_len - sliding_window). + # For example, if the prompt len is 10, sliding window is 8, and + # block size is 4, the first two tokens are masked and the slot + # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + start_idx = 0 + if self.sliding_window is not None: + start_idx = max(0, prompt_len - self.sliding_window) + for i in range(prompt_len): + if i < start_idx: + slot_mapping[-1].append(_PAD_SLOT_ID) + continue + + block_number = block_table[i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping[-1].append(slot) + + max_prompt_len = max(prompt_lens) + input_tokens = _make_tensor_with_pad(input_tokens, + max_prompt_len, + pad=0, + dtype=torch.long) + input_positions = _make_tensor_with_pad(input_positions, + max_prompt_len, + pad=0, + dtype=torch.long) + slot_mapping = _make_tensor_with_pad(slot_mapping, + max_prompt_len, + pad=_PAD_SLOT_ID, + dtype=torch.long) + + input_metadata = InputMetadata( + prompt_lens=prompt_lens, + slot_mapping=slot_mapping, + max_context_len=None, + context_lens=None, + block_tables=None, + ) + return input_tokens, input_positions, input_metadata + + def _prepare_decode( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[List[int]] = [] + input_positions: List[List[int]] = [] + slot_mapping: List[List[int]] = [] + context_lens: List[int] = [] + block_tables: List[List[int]] = [] + + for seq_group_metadata in seq_group_metadata_list: + assert not seq_group_metadata.is_prompt + + seq_ids = list(seq_group_metadata.seq_data.keys()) + for seq_id in seq_ids: + seq_data = seq_group_metadata.seq_data[seq_id] + generation_token = seq_data.get_last_token_id() + input_tokens.append([generation_token]) + + context_len = seq_data.get_len() + if self.sliding_window is not None: + context_len = min(context_len, self.sliding_window) + context_lens.append(context_len) + + position = context_len - 1 + input_positions.append([position]) + + block_table = seq_group_metadata.block_tables[seq_id] + block_number = block_table[position // self.block_size] + block_offset = position % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append([slot]) + + if self.sliding_window is not None: + sliding_window_blocks = (self.sliding_window // + self.block_size) + block_table = block_table[-sliding_window_blocks:] + block_tables.append(block_table) + + input_tokens = _make_tensor_with_pad(input_tokens, + max_len=1, + pad=0, + dtype=torch.long) + input_positions = _make_tensor_with_pad(input_positions, + max_len=1, + pad=0, + dtype=torch.long) + slot_mapping = _make_tensor_with_pad(slot_mapping, + max_len=1, + pad=_PAD_SLOT_ID, + dtype=torch.long) + max_context_len = max(context_lens) + context_lens = torch.tensor(context_lens, + dtype=torch.int, + device="cuda") + max_block_table_len = max([len(t) for t in block_tables]) + block_tables = _make_tensor_with_pad(block_tables, + max_len=max_block_table_len, + pad=0, + dtype=torch.int) + + input_metadata = InputMetadata( + prompt_lens=[], + slot_mapping=slot_mapping, + max_context_len=max_context_len, + context_lens=context_lens, + block_tables=block_tables, + ) + return input_tokens, input_positions, input_metadata + + def _prepare_sample( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + prompt_lens: List[int], + ) -> SamplingMetadata: + seq_groups: List[Tuple[List[int], SamplingParams]] = [] + selected_token_indices: List[int] = [] + selected_token_start_idx = 0 + categorized_sample_indices = {t: [] for t in SamplingType} + categorized_sample_indices_start_idx = 0 + + max_prompt_len = max(prompt_lens) if prompt_lens else 1 + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + seq_ids = list(seq_group_metadata.seq_data.keys()) + sampling_params = seq_group_metadata.sampling_params + seq_groups.append((seq_ids, sampling_params)) + + if seq_group_metadata.is_prompt: + assert len(seq_ids) == 1 + prompt_len = prompt_lens[i] + if sampling_params.prompt_logprobs is not None: + # NOTE: prompt token positions do not need sample, skip + categorized_sample_indices_start_idx += prompt_len - 1 + + categorized_sample_indices[ + sampling_params.sampling_type].append( + categorized_sample_indices_start_idx) + categorized_sample_indices_start_idx += 1 + + if sampling_params.prompt_logprobs is not None: + selected_token_indices.extend( + range(selected_token_start_idx, + selected_token_start_idx + prompt_len - 1)) + selected_token_indices.append(selected_token_start_idx + + prompt_len - 1) + selected_token_start_idx += max_prompt_len + else: + num_seqs = len(seq_ids) + selected_token_indices.extend( + range(selected_token_start_idx, + selected_token_start_idx + num_seqs)) + selected_token_start_idx += num_seqs + + categorized_sample_indices[ + sampling_params.sampling_type].extend( + range(categorized_sample_indices_start_idx, + categorized_sample_indices_start_idx + num_seqs)) + categorized_sample_indices_start_idx += num_seqs + + selected_token_indices = torch.tensor(selected_token_indices, + dtype=torch.long, + device="cuda") + categorized_sample_indices = { + t: torch.tensor(seq_ids, dtype=torch.int, device="cuda") + for t, seq_ids in categorized_sample_indices.items() + } + + seq_data: Dict[int, SequenceData] = {} + for seq_group_metadata in seq_group_metadata_list: + seq_data.update(seq_group_metadata.seq_data) + + sampling_metadata = SamplingMetadata( + seq_groups=seq_groups, + seq_data=seq_data, + prompt_lens=prompt_lens, + selected_token_indices=selected_token_indices, + categorized_sample_indices=categorized_sample_indices, + ) + return sampling_metadata + + @torch.inference_mode() + def execute_model( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + cache_events: Optional[List[torch.cuda.Event]] = None, + ) -> SamplerOutput: + # NOTE: We assume that all sequences in the group are all prompts or + # all decodes. + # Prepare input tensors. + is_prompt = seq_group_metadata_list[0].is_prompt + if is_prompt: + inputs = self._prepare_prompt(seq_group_metadata_list) + input_tokens, input_positions, input_metadata = inputs + else: + inputs = self._prepare_decode(seq_group_metadata_list) + input_tokens, input_positions, input_metadata = inputs + sampling_metadata = self._prepare_sample(seq_group_metadata_list, + input_metadata.prompt_lens) + + # Execute the model. + hidden_states = self.model( + input_ids=input_tokens, + positions=input_positions, + kv_caches=kv_caches, + input_metadata=input_metadata, + cache_events=cache_events, + ) + + # Sample the next token. + output = self.model.sample( + hidden_states=hidden_states, + sampling_metadata=sampling_metadata, + ) + return output + + @torch.inference_mode() + def profile_run(self) -> None: + # Enable top-k sampling to reflect the accurate memory usage. + vocab_size = self.model_config.get_vocab_size() + sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1) + max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens + max_num_seqs = self.scheduler_config.max_num_seqs + + # Profile memory usage with max_num_sequences sequences and the total + # number of tokens equal to max_num_batched_tokens. + seqs: List[SequenceGroupMetadata] = [] + for group_id in range(max_num_seqs): + seq_len = (max_num_batched_tokens // max_num_seqs + + (group_id < max_num_batched_tokens % max_num_seqs)) + seq_data = SequenceData([0] * seq_len) + seq = SequenceGroupMetadata( + request_id=str(group_id), + is_prompt=True, + seq_data={group_id: seq_data}, + sampling_params=sampling_params, + block_tables=None, + ) + seqs.append(seq) + + # Run the model with the dummy inputs. + num_layers = self.model_config.get_num_layers(self.parallel_config) + kv_caches = [(None, None)] * num_layers + self.execute_model(seqs, kv_caches) + return + + +def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]: + assert len(x) <= max_len + return x + [pad] * (max_len - len(x)) + + +def _make_tensor_with_pad( + x: List[List[int]], + max_len: int, + pad: int, + dtype: torch.dtype, +) -> torch.Tensor: + padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x] + return torch.tensor(padded_x, dtype=dtype, device="cuda") diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 702767ebd8d09..6f5e16f0011f6 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -1,18 +1,18 @@ """A GPU worker class.""" import os -from typing import Dict, List, Tuple, Optional +from typing import Dict, List, Optional, Tuple import torch import torch.distributed from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig) -from vllm.model_executor import get_model, InputMetadata, set_random_seed +from vllm.model_executor import set_random_seed from vllm.model_executor.parallel_utils.parallel_state import ( initialize_model_parallel) -from vllm.sampling_params import SamplingParams, SamplingType -from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata +from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.worker.cache_engine import CacheEngine +from vllm.worker.model_runner import ModelRunner from vllm.utils import get_gpu_memory @@ -38,11 +38,11 @@ class Worker: self.rank = rank self.distributed_init_method = distributed_init_method + self.model_runner = ModelRunner(model_config, parallel_config, + scheduler_config) # Uninitialized cache engine. Will be initialized by # self.init_cache_engine(). self.cache_config = None - self.block_size = None - self.sliding_window = None self.cache_engine = None self.cache_events = None self.gpu_cache = None @@ -69,7 +69,7 @@ class Worker: set_random_seed(self.model_config.seed) def load_model(self): - self.model = get_model(self.model_config) + self.model_runner.load_model() @torch.inference_mode() def profile_num_available_blocks( @@ -83,40 +83,9 @@ class Worker: torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() - # Profile memory usage with max_num_sequences sequences and the total - # number of tokens equal to max_num_batched_tokens. - - # Enable top-k sampling to reflect the accurate memory usage. - vocab_size = self.model.config.vocab_size - sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1) - max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens - max_num_seqs = self.scheduler_config.max_num_seqs - seqs = [] - for group_id in range(max_num_seqs): - seq_len = (max_num_batched_tokens // max_num_seqs + - (group_id < max_num_batched_tokens % max_num_seqs)) - seq_data = SequenceData([0] * seq_len) - seq = SequenceGroupMetadata( - request_id=str(group_id), - is_prompt=True, - seq_data={group_id: seq_data}, - sampling_params=sampling_params, - block_tables=None, - ) - seqs.append(seq) - - input_tokens, input_positions, input_metadata = self._prepare_inputs( - seqs) - - # Execute the model. - num_layers = self.model_config.get_num_layers(self.parallel_config) - self.model( - input_ids=input_tokens, - positions=input_positions, - kv_caches=[(None, None)] * num_layers, - input_metadata=input_metadata, - cache_events=None, - ) + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + self.model_runner.profile_run() # Calculate the number of blocks that can be allocated with the # profiled peak memory. @@ -140,197 +109,11 @@ class Worker: def init_cache_engine(self, cache_config: CacheConfig) -> None: self.cache_config = cache_config - self.block_size = cache_config.block_size - self.sliding_window = cache_config.sliding_window - self.cache_engine = CacheEngine(self.cache_config, self.model_config, self.parallel_config) self.cache_events = self.cache_engine.events self.gpu_cache = self.cache_engine.gpu_cache - - def _prepare_inputs( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]: - seq_groups: List[Tuple[List[int], SamplingParams]] = [] - input_tokens: List[List[int]] = [] - input_positions: List[List[int]] = [] - slot_mapping: List[List[int]] = [] - selected_token_indices: List[int] = [] - selected_token_start_idx = 0 - categorized_sample_indices = {t: [] for t in SamplingType} - categorized_sample_indices_start_idx = 0 - - # Add prompt tokens. - prompt_lens: List[int] = [] - for seq_group_metadata in seq_group_metadata_list: - if not seq_group_metadata.is_prompt: - continue - - seq_ids = list(seq_group_metadata.seq_data.keys()) - sampling_params = seq_group_metadata.sampling_params - seq_groups.append((seq_ids, sampling_params)) - - # Use any sequence in the group. - seq_id = seq_ids[0] - - seq_data = seq_group_metadata.seq_data[seq_id] - prompt_tokens = seq_data.get_token_ids() - prompt_len = len(prompt_tokens) - prompt_lens.append(prompt_len) - - if sampling_params.prompt_logprobs is not None: - # NOTE: prompt token positions do not need sample, skip - categorized_sample_indices_start_idx += prompt_len - 1 - - categorized_sample_indices[sampling_params.sampling_type].append( - categorized_sample_indices_start_idx) - categorized_sample_indices_start_idx += 1 - - input_tokens.append(prompt_tokens) - # NOTE(woosuk): Here we assume that the first token in the prompt - # is always the first token in the sequence. - input_positions.append(list(range(prompt_len))) - - if seq_group_metadata.block_tables is None: - # During memory profiling, the block tables are not initialized - # yet. In this case, we just use a dummy slot mapping. - slot_mapping.append([0] * prompt_len) - continue - - # Compute the slot mapping. - slot_mapping.append([]) - block_table = seq_group_metadata.block_tables[seq_id] - for i in range(prompt_len): - block_number = block_table[i // self.block_size] - block_offset = i % self.block_size - slot = block_number * self.block_size + block_offset - slot_mapping[-1].append(slot) - - # Add generation tokens. - max_context_len = 0 - max_num_blocks_per_seq = 0 - context_lens: List[int] = [] - generation_block_tables: List[List[int]] = [] - max_seq_len = max(prompt_lens) if prompt_lens else 1 - for i, seq_group_metadata in enumerate(seq_group_metadata_list): - if seq_group_metadata.is_prompt: - # We need to do this in this loop as we need to know max_seq_len - assert len( - seq_ids) == 1, "Prompt input should have only one seq." - sampling_params = seq_group_metadata.sampling_params - assert len(prompt_lens) == len(seq_group_metadata_list) - prompt_len = prompt_lens[i] - if sampling_params.prompt_logprobs is not None: - selected_token_indices.extend( - range(selected_token_start_idx, - selected_token_start_idx + prompt_len - 1)) - selected_token_indices.append(selected_token_start_idx + - prompt_len - 1) - selected_token_start_idx += max_seq_len - continue - - seq_ids = list(seq_group_metadata.seq_data.keys()) - sampling_params = seq_group_metadata.sampling_params - seq_groups.append((seq_ids, sampling_params)) - - num_seqs = len(seq_ids) - selected_token_indices.extend( - range(selected_token_start_idx, - selected_token_start_idx + num_seqs)) - selected_token_start_idx += num_seqs - - categorized_sample_indices[sampling_params.sampling_type].extend( - range(categorized_sample_indices_start_idx, - categorized_sample_indices_start_idx + num_seqs)) - categorized_sample_indices_start_idx += num_seqs - - for seq_id in seq_ids: - seq_data = seq_group_metadata.seq_data[seq_id] - generation_token = seq_data.get_last_token_id() - input_tokens.append([generation_token]) - - context_len = seq_data.get_len() - position = context_len - 1 - if self.sliding_window is not None: - context_len = min(context_len, self.sliding_window) - input_positions.append([position]) - - block_table = seq_group_metadata.block_tables[seq_id] - - max_context_len = max(max_context_len, context_len) - max_num_blocks_per_seq = max(max_num_blocks_per_seq, - len(block_table)) - context_lens.append(context_len) - - block_number = block_table[position // self.block_size] - block_offset = position % self.block_size - slot = block_number * self.block_size + block_offset - slot_mapping.append([slot]) - - if self.sliding_window is not None: - sliding_window_blocks = (self.sliding_window // - self.block_size) - block_table = block_table[-sliding_window_blocks:] - generation_block_tables.append(block_table) - - padded_input_tokens = [ - _pad_to_max(tokens, max_seq_len, pad=0) for tokens in input_tokens - ] - padded_input_positions = [ - _pad_to_max(positions, max_seq_len, pad=0) - for positions in input_positions - ] - padded_slot_mapping = [ - _pad_to_max(mapping, max_seq_len, pad=-1) - for mapping in slot_mapping - ] - padded_block_tables = [ - _pad_to_max(block_table, max_num_blocks_per_seq, pad=0) - for block_table in generation_block_tables - ] - - # Convert to tensors. - tokens_tensor = torch.tensor(padded_input_tokens, - dtype=torch.long, - device="cuda") - positions_tensor = torch.tensor(padded_input_positions, - dtype=torch.long, - device="cuda") - slot_mapping_tensor = torch.tensor(padded_slot_mapping, - dtype=torch.long, - device="cuda") - context_lens_tensor = torch.tensor(context_lens, - dtype=torch.int, - device="cuda") - selected_token_indices = torch.tensor(selected_token_indices, - dtype=torch.long, - device="cuda") - categorized_sample_indices = { - t: torch.tensor(seq_ids, dtype=torch.int, device="cuda") - for t, seq_ids in categorized_sample_indices.items() - } - block_tables_tensor = torch.tensor(padded_block_tables, - dtype=torch.int, - device="cuda") - - seq_data: Dict[int, SequenceData] = {} - for seq_group_metadata in seq_group_metadata_list: - seq_data.update(seq_group_metadata.seq_data) - - input_metadata = InputMetadata( - seq_groups=seq_groups, - seq_data=seq_data, - prompt_lens=prompt_lens, - slot_mapping=slot_mapping_tensor, - context_lens=context_lens_tensor, - max_context_len=max_context_len, - block_tables=block_tables_tensor, - selected_token_indices=selected_token_indices, - categorized_sample_indices=categorized_sample_indices, - sliding_window=self.sliding_window, - ) - return tokens_tensor, positions_tensor, input_metadata + self.model_runner.set_block_size(self.cache_engine.block_size) @torch.inference_mode() def execute_model( @@ -361,18 +144,8 @@ class Worker: event.wait() return {} - # Prepare input tensors. - input_tokens, input_positions, input_metadata = self._prepare_inputs( - seq_group_metadata_list) - - # Execute the model. - output = self.model( - input_ids=input_tokens, - positions=input_positions, - kv_caches=self.gpu_cache, - input_metadata=input_metadata, - cache_events=cache_events, - ) + output = self.model_runner.execute_model(seq_group_metadata_list, + self.gpu_cache, cache_events) return output @@ -407,14 +180,6 @@ def _init_distributed_environment( parallel_config.pipeline_parallel_size) -def _pad_to_alignment(x: List[int], multiple_of: int, pad: int) -> List[int]: - return x + [pad] * ((-len(x)) % multiple_of) - - -def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]: - return x + [pad] * (max_len - len(x)) - - def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): # Check if the GPU supports the dtype. if torch_dtype == torch.bfloat16: