diff --git a/README.md b/README.md index 372984e03871e..ffa890bd30e31 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,7 @@ vLLM is flexible and easy to use with: vLLM seamlessly supports many Huggingface models, including the following architectures: +- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.) - GPT-2 (`gpt2`, `gpt2-xl`, etc.) - GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.) - GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.) diff --git a/csrc/attention.cpp b/csrc/attention.cpp index bb2766c1d6b67..b0ee4c906b83a 100644 --- a/csrc/attention.cpp +++ b/csrc/attention.cpp @@ -1,4 +1,5 @@ #include +#include void single_query_cached_kv_attention( torch::Tensor& out, @@ -9,7 +10,8 @@ void single_query_cached_kv_attention( torch::Tensor& block_tables, torch::Tensor& context_lens, int block_size, - int max_context_len); + int max_context_len, + const c10::optional& alibi_slopes); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def( diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index c975785b68774..203b8644a6626 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -80,6 +80,7 @@ __global__ void single_query_cached_kv_attention_kernel( const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ context_lens, // [num_seqs] const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride) { constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE; @@ -91,6 +92,7 @@ __global__ void single_query_cached_kv_attention_kernel( const int head_idx = blockIdx.x; const int num_heads = gridDim.x; const int seq_idx = blockIdx.y; + const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; // A vector type to store a part of a key or a query. // The vector size is configured in such a way that the threads in a thread group @@ -167,12 +169,14 @@ __global__ void single_query_cached_kv_attention_kernel( // Compute dot product. // This includes a reduction across the threads in the same thread group. - const float qk = scale * Qk_dot::dot(q_vecs, k_vecs); - const bool mask = token_idx >= context_len; - + float qk = scale * Qk_dot::dot(q_vecs, k_vecs); + // Add the ALiBi bias if slopes are given. + qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len) : 0; + if (thread_group_offset == 0) { // Store the partial reductions to shared memory. // NOTE(woosuk): It is required to zero out the masked logits. + const bool mask = token_idx >= context_len; logits[token_idx] = mask ? 0.f : qk; // Update the max value. qk_max = mask ? qk_max : fmaxf(qk_max, qk); @@ -328,6 +332,7 @@ __global__ void single_query_cached_kv_attention_kernel( block_tables_ptr, \ context_lens_ptr, \ max_num_blocks_per_seq, \ + alibi_slopes_ptr, \ query_stride); // TODO(woosuk): Tune NUM_THREADS. @@ -343,7 +348,8 @@ void single_query_cached_kv_attention_launcher( float scale, torch::Tensor& block_tables, torch::Tensor& context_lens, - int max_context_len) { + int max_context_len, + const c10::optional& alibi_slopes) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -353,6 +359,11 @@ void single_query_cached_kv_attention_launcher( int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); assert(head_size % thread_group_size == 0); + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = alibi_slopes ? + reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + T* out_ptr = reinterpret_cast(out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); @@ -411,7 +422,8 @@ void single_query_cached_kv_attention_launcher( scale, \ block_tables, \ context_lens, \ - max_context_len); + max_context_len, \ + alibi_slopes); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. @@ -458,7 +470,8 @@ void single_query_cached_kv_attention( torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& context_lens, // [num_seqs] int block_size, - int max_context_len) { + int max_context_len, + const c10::optional& alibi_slopes) { if (query.dtype() == at::ScalarType::Float) { CALL_KERNEL_LAUNCHER_BLOCK_SIZE(float); } else if (query.dtype() == at::ScalarType::Half) { diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index b8f7f9e000b10..e29f27fcd70fd 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -14,6 +14,9 @@ Alongside each architecture, we include some popular models that use it. * - Architecture - Models - Example HuggingFace Models + * - :code:`BloomForCausalLM` + - BLOOM, BLOOMZ, BLOOMChat + - :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc. * - :code:`GPT2LMHeadModel` - GPT-2 - :code:`gpt2`, :code:`gpt2-xl`, etc. diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 96316110d460d..6b88cd1096896 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -216,6 +216,7 @@ def run_single_query_cached_kv_attention( context_lens, block_size, max_context_len, + None, # ALiBi slopes. ) ref_output = torch.empty_like(query) diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index 99c988730dbec..1b0bc7327f7a9 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -1,7 +1,7 @@ from typing import Dict, List, Tuple import torch -from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask +from xformers.ops import AttentionBias from vllm.sampling_params import SamplingParams from vllm.sequence import SequenceData @@ -38,7 +38,6 @@ class InputMetadata: self.max_context_len = max_context_len self.block_tables = block_tables - self.attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens) self.num_prompts = len(prompt_lens) self.num_prompt_tokens = sum(prompt_lens) self.num_generation_tokens = context_lens.shape[0] @@ -50,6 +49,9 @@ class InputMetadata: assert block_tables.shape[0] == self.num_generation_tokens assert context_lens.shape[0] == self.num_generation_tokens + # Set during the execution of the first attention op. + self.attn_bias: List[AttentionBias] = [] + def __repr__(self) -> str: # Print only useful metadata. return (f'InputMetadata(' diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 6528626b9b255..b350de8f59412 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -1,9 +1,11 @@ """Multi-head attention.""" -from typing import Optional +from typing import List, Optional import torch import torch.nn as nn from xformers import ops as xops +from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask, + LowerTriangularMaskWithTensorBias) from vllm import attention_ops from vllm import cache_ops @@ -53,13 +55,21 @@ class PagedAttention(nn.Module): raise ValueError(f"head_size ({self.head_size}) is not supported. " f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.") + def set_attn_bias(self, input_metadata: InputMetadata) -> None: + if input_metadata.attn_bias: + # Already set by a previous layer. + return + prompt_lens = input_metadata.prompt_lens + attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens) + input_metadata.attn_bias.append(attn_bias) + def multi_query_kv_attention( self, output: torch.Tensor, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attn_bias: xops.AttentionBias, + input_metadata: InputMetadata, ) -> torch.Tensor: """Normal attention for the prompt tokens. @@ -68,13 +78,14 @@ class PagedAttention(nn.Module): query: shape = [num_prompt_tokens, num_heads, head_size] key: shape = [num_prompt_tokens, num_heads, head_size] value: shape = [num_prompt_tokens, num_heads, head_size] + input_metadata: metadata for paged attention. """ # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize. out = xops.memory_efficient_attention_forward( query.unsqueeze(0), key.unsqueeze(0), value.unsqueeze(0), - attn_bias=attn_bias, + attn_bias=input_metadata.attn_bias[0], p=0.0, scale=self.scale, op=self.attn_op, @@ -112,6 +123,7 @@ class PagedAttention(nn.Module): input_metadata.context_lens, block_size, input_metadata.max_context_len, + None, # alibi_slopes ) def forward( @@ -154,12 +166,13 @@ class PagedAttention(nn.Module): # Compute the attention op for prompts. num_prompt_tokens = input_metadata.num_prompt_tokens if num_prompt_tokens > 0: + self.set_attn_bias(input_metadata) self.multi_query_kv_attention( output[:num_prompt_tokens], query[:num_prompt_tokens], key[:num_prompt_tokens], value[:num_prompt_tokens], - input_metadata.attn_bias, + input_metadata, ) # Wait until the cache op is done. @@ -219,7 +232,8 @@ class PagedAttentionWithRoPE(PagedAttention): cache = torch.cat((cos, sin), dim=-1) # FIXME(woosuk): This assumes that we configure the default dtype when - # initializing the model. Make it more robust. + # initializing the model. + # TODO(woosuk): Make it more robust. torch_dtype = torch.get_default_dtype() cache = cache.to(torch_dtype) # Embedding size: [max_position, rotary_dim] @@ -271,3 +285,112 @@ class PagedAttentionWithRoPE(PagedAttention): input_metadata, cache_event, ) + + +class PagedAttentionWithALiBi(PagedAttention): + """PagedAttention with ALiBi attention bias.""" + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + slopes: List[float], + ) -> None: + super().__init__(num_heads, head_size, scale) + assert len(slopes) == num_heads + + slopes = torch.tensor(slopes, dtype=torch.float32) + self.register_buffer("alibi_slopes", slopes, persistent=False) + + def set_attn_bias(self, input_metadata: InputMetadata) -> None: + if input_metadata.attn_bias: + # Already set by a previous layer. + return + # Generates ALiBi mask for each prompt. + for prompt_len in input_metadata.prompt_lens: + bias = torch.arange(prompt_len) + bias = bias[None, :] - bias[:, None] + bias = bias.to(self.alibi_slopes.device) + + # When using custom attention bias, xformers requires the bias to + # be sliced from a tensor whose length is a multiple of 8. + padded_len = (prompt_len + 7) // 8 * 8 + bias = torch.empty( + self.num_heads, + padded_len, + padded_len, + device=self.alibi_slopes.device, + )[:, :prompt_len, :prompt_len].copy_(bias) + bias.mul_(self.alibi_slopes[:, None, None]) + attn_bias = LowerTriangularMaskWithTensorBias(bias) + input_metadata.attn_bias.append(attn_bias) + + def multi_query_kv_attention( + self, + output: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + """Attention with ALiBi bias for the prompt tokens. + + Args: + output: shape = [num_prompt_tokens, num_heads, head_size] + query: shape = [num_prompt_tokens, num_heads, head_size] + key: shape = [num_prompt_tokens, num_heads, head_size] + value: shape = [num_prompt_tokens, num_heads, head_size] + input_metadata: metadata for paged attention. + """ + # FIXME(woosuk): Because xformers does not support dynamic sequence + # lengths with custom attention bias, we process each prompt one by + # one. This is inefficient, especially when we have many short prompts. + start = 0 + for i, prompt_len in enumerate(input_metadata.prompt_lens): + end = start + prompt_len + out = xops.memory_efficient_attention_forward( + query[None, start:end], + key[None, start:end], + value[None, start:end], + attn_bias=input_metadata.attn_bias[i], + p=0.0, + scale=self.scale, + op=self.attn_op, + ) + # TODO(woosuk): Unnecessary copy. Optimize. + output[start:end].copy_(out.squeeze(0)) + start += prompt_len + return output + + def single_query_cached_kv_attention( + self, + output: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + input_metadata: InputMetadata, + ) -> None: + """PagedAttention with ALiBi bias for the generation tokens. + + Args: + output: shape = [num_generation_tokens, num_heads, head_size] + query: shape = [num_generation_tokens, num_heads, head_size] + key_cache: shape = [num_blocks, num_heads, head_size/x, + block_size, x] + value_cache: shape = [num_blocks, num_heads, head_size, block_size] + input_metadata: metadata for paged attention. + """ + block_size = value_cache.shape[3] + attention_ops.single_query_cached_kv_attention( + output, + query, + key_cache, + value_cache, + self.scale, + input_metadata.block_tables, + input_metadata.context_lens, + block_size, + input_metadata.max_context_len, + self.alibi_slopes, + ) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 931fadc0027b1..f5e2793b33303 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -6,13 +6,12 @@ import torch.nn as nn from transformers import PretrainedConfig from vllm.config import ModelConfig -from vllm.model_executor.models import (GPT2LMHeadModel, GPTBigCodeForCausalLM, - GPTNeoXForCausalLM, LlamaForCausalLM, - OPTForCausalLM) +from vllm.model_executor.models import * # pylint: disable=wildcard-import from vllm.model_executor.weight_utils import initialize_dummy_weights # TODO(woosuk): Lazy-load the model classes. _MODEL_REGISTRY = { + "BloomForCausalLM": BloomForCausalLM, "GPT2LMHeadModel": GPT2LMHeadModel, "GPTBigCodeForCausalLM": GPTBigCodeForCausalLM, "GPTNeoXForCausalLM": GPTNeoXForCausalLM, diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 023974614a4a3..251052a296201 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -1,3 +1,4 @@ +from vllm.model_executor.models.bloom import BloomForCausalLM from vllm.model_executor.models.gpt2 import GPT2LMHeadModel from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM @@ -5,6 +6,7 @@ from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.opt import OPTForCausalLM __all__ = [ + "BloomForCausalLM", "GPT2LMHeadModel", "GPTBigCodeForCausalLM", "GPTNeoXForCausalLM", diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py new file mode 100644 index 0000000000000..12b17e4a4e1e6 --- /dev/null +++ b/vllm/model_executor/models/bloom.py @@ -0,0 +1,316 @@ +# coding=utf-8 +# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py +# Copyright 2023 The CacheFlow team. +# Copyright 2022 HuggingFace Inc. team and BigScience workshop. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only BLOOM model compatible with HuggingFace weights. + +The input of the model is flattened to a 1D tensor of tokens. The model uses +InputMetadata to extract the original 2D shape of the input. +""" +import math +from typing import Dict, List, Optional, Tuple + +import torch +from torch import nn +from transformers import BloomConfig + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.attention import PagedAttentionWithALiBi +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.weight_utils import (hf_model_weights_iterator, + load_tensor_parallel_weights) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.parallel_utils.tensor_parallel import ( + VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) +from vllm.sequence import SequenceOutputs + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: + closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) + base = torch.tensor( + 2**(-(2**-(math.log2(closest_power_of_2) - 3))), + dtype=torch.float32, + ) + powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != total_num_heads: + extra_base = torch.tensor( + 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + dtype=torch.float32, + ) + num_remaining_heads = min(closest_power_of_2, + total_num_heads - closest_power_of_2) + extra_powers = torch.arange(start=1, + end=1 + 2 * num_remaining_heads, + step=2, + dtype=torch.int32) + slopes = torch.cat( + [slopes, torch.pow(extra_base, extra_powers)], dim=0) + return slopes + + +class BloomAttention(nn.Module): + + def __init__(self, config: BloomConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.total_num_heads = config.n_head + self.head_dim = self.hidden_size // self.total_num_heads + assert self.head_dim * self.total_num_heads == self.hidden_size + + tp_world_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_world_size == 0 + self.num_heads = self.total_num_heads // tp_world_size + + self.query_key_value = ColumnParallelLinear( + self.hidden_size, + 3 * self.hidden_size, + bias=True, + gather_output=False, + perform_initialization=False, + ) + self.dense = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + input_is_parallel=True, + perform_initialization=False, + ) + + # Create the alibi slopes and slice them. + tp_rank = get_tensor_model_parallel_rank() + head_start = tp_rank * self.num_heads + head_end = (tp_rank + 1) * self.num_heads + alibi_slopes = _get_alibi_slopes(self.total_num_heads) + alibi_slopes = alibi_slopes[head_start:head_end].tolist() + + scaling = self.head_dim**-0.5 + self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim, + scaling, alibi_slopes) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + del position_ids # Unused. + qkv, _ = self.query_key_value(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + k_cache, v_cache = kv_cache + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, + cache_event) + output, _ = self.dense(attn_output) + return output + + +class BloomMLP(nn.Module): + + def __init__(self, config: BloomConfig): + super().__init__() + hidden_size = config.hidden_size + self.dense_h_to_4h = ColumnParallelLinear(hidden_size, + 4 * hidden_size, + gather_output=False, + perform_initialization=False) + self.act = get_act_fn("gelu") + self.dense_4h_to_h = RowParallelLinear(4 * hidden_size, + hidden_size, + input_is_parallel=True, + perform_initialization=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.dense_h_to_4h(x) + x = self.act(x) + x, _ = self.dense_4h_to_h(x) + return x + + +class BloomBlock(nn.Module): + + def __init__(self, config: BloomConfig): + super().__init__() + hidden_size = config.hidden_size + + self.input_layernorm = nn.LayerNorm(hidden_size, + eps=config.layer_norm_epsilon) + self.self_attention = BloomAttention(config) + self.post_attention_layernorm = nn.LayerNorm( + hidden_size, eps=config.layer_norm_epsilon) + self.mlp = BloomMLP(config) + self.apply_residual_connection_post_layernorm = ( + config.apply_residual_connection_post_layernorm) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + + # Layer norm post the self attention. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + # Self attention. + attention_output = self.self_attention( + position_ids=position_ids, + hidden_states=layernorm_output, + kv_cache=kv_cache, + input_metadata=input_metadata, + cache_event=cache_event, + ) + attention_output = attention_output + residual + layernorm_output = self.post_attention_layernorm(attention_output) + + # Get residual + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = attention_output + + # MLP. + output = self.mlp(layernorm_output) + residual + return output + + +class BloomModel(nn.Module): + + def __init__(self, config: BloomConfig): + super().__init__() + self.embed_dim = config.hidden_size + + # Embedding + LN Embedding + self.word_embeddings = VocabParallelEmbedding( + config.vocab_size, self.embed_dim, perform_initialization=False) + self.word_embeddings_layernorm = nn.LayerNorm( + self.embed_dim, eps=config.layer_norm_epsilon) + + # Transformer blocks + self.h = nn.ModuleList( + [BloomBlock(config) for _ in range(config.num_hidden_layers)]) + + # Final Layer Norm + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> torch.Tensor: + hidden_states = self.word_embeddings(input_ids) + hidden_states = self.word_embeddings_layernorm(hidden_states) + for i in range(len(self.h)): + if cache_events is None: + cache_event = None + else: + cache_event = cache_events[i] + layer = self.h[i] + hidden_states = layer( + position_ids, + hidden_states, + kv_caches[i], + input_metadata, + cache_event, + ) + hidden_states = self.ln_f(hidden_states) + return hidden_states + + +class BloomForCausalLM(nn.Module): + + def __init__(self, config: BloomConfig): + super().__init__() + self.config = config + self.transformer = BloomModel(config) + # TODO(zhuohan): create a new weight after implementing pipeline + # parallelism + self.lm_head_weight = self.transformer.word_embeddings.weight + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> Dict[int, SequenceOutputs]: + hidden_states = self.transformer(input_ids, positions, kv_caches, + input_metadata, cache_events) + next_tokens = self.sampler(self.lm_head_weight, hidden_states, + input_metadata) + return next_tokens + + _column_parallel_weights = [ + "word_embeddings.weight", "dense_h_to_4h.weight", "dense_h_to_4h.bias" + ] + _row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"] + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + use_np_cache: bool = False): + tp_rank = get_tensor_model_parallel_rank() + state_dict = self.state_dict() + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, use_np_cache): + if not name.startswith("transformer."): + name = "transformer." + name + + param = state_dict[name] + if "query_key_value" in name: + # NOTE(woosuk): BLOOM's fused QKV has the shape of + # [num_heads * 3 * head_size, hidden_size], while the + # required shape is [3 * num_heads * head_size, hidden_size]. + # Thus, we need weight conversion. + shard_size = param.shape[0] + start = shard_size * tp_rank + end = shard_size * (tp_rank + 1) + loaded_weight = loaded_weight[start:end] + + num_heads = self.config.num_attention_heads + hidden_size = self.config.hidden_size + head_size = hidden_size // num_heads + if "query_key_value.weight" in name: + loaded_weight = loaded_weight.view(-1, 3, head_size, + hidden_size) + loaded_weight = loaded_weight.transpose(0, 1) + loaded_weight = loaded_weight.reshape(-1, hidden_size) + elif "query_key_value.bias" in name: + loaded_weight = loaded_weight.view(-1, 3, head_size) + loaded_weight = loaded_weight.transpose(0, 1) + loaded_weight = loaded_weight.reshape(-1) + else: + raise ValueError(f"Unexpected weight name: {name}") + load_tensor_parallel_weights(param, loaded_weight, name, + self._column_parallel_weights, + self._row_parallel_weights, tp_rank) diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index e0001bf27d29e..de25029d9eec4 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -80,7 +80,6 @@ class GPTNeoXAttention(nn.Module): cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: qkv, _ = self.query_key_value(hidden_states) - q, k, v = qkv.chunk(chunks=3, dim=-1) k_cache, v_cache = kv_cache attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache,