mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 22:04:58 +08:00
Fix config for Falcon (#1164)
This commit is contained in:
parent
f187877945
commit
9f6be8692e
@ -135,7 +135,8 @@ class ModelConfig:
|
|||||||
# FIXME(woosuk): This may not be true for all models.
|
# FIXME(woosuk): This may not be true for all models.
|
||||||
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
|
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
|
||||||
|
|
||||||
def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
|
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
|
||||||
|
"""Returns the number of KV heads per GPU worker."""
|
||||||
# For GPTBigCode & Falcon:
|
# For GPTBigCode & Falcon:
|
||||||
# Note: for falcon, when new_decoder_architecture is True, the
|
# Note: for falcon, when new_decoder_architecture is True, the
|
||||||
# multi_query flag is ignored and we use n_head_kv for the number of
|
# multi_query flag is ignored and we use n_head_kv for the number of
|
||||||
@ -147,11 +148,15 @@ class ModelConfig:
|
|||||||
if not new_decoder_arch_falcon and getattr(self.hf_config,
|
if not new_decoder_arch_falcon and getattr(self.hf_config,
|
||||||
"multi_query", False):
|
"multi_query", False):
|
||||||
# Multi-query attention, only one KV head.
|
# Multi-query attention, only one KV head.
|
||||||
|
# Currently, tensor parallelism is not supported in this case.
|
||||||
return 1
|
return 1
|
||||||
# For Falcon:
|
# For Falcon:
|
||||||
if getattr(self.hf_config, "n_head_kv", None) is not None:
|
if getattr(self.hf_config, "n_head_kv", None) is not None:
|
||||||
return (self.hf_config.n_head_kv //
|
return (self.hf_config.n_head_kv //
|
||||||
parallel_config.tensor_parallel_size)
|
parallel_config.tensor_parallel_size)
|
||||||
|
if getattr(self.hf_config, "num_kv_heads", None) is not None:
|
||||||
|
return (self.hf_config.num_kv_heads //
|
||||||
|
parallel_config.tensor_parallel_size)
|
||||||
# For LLaMA-2:
|
# For LLaMA-2:
|
||||||
if getattr(self.hf_config, "num_key_value_heads", None) is not None:
|
if getattr(self.hf_config, "num_key_value_heads", None) is not None:
|
||||||
return (self.hf_config.num_key_value_heads //
|
return (self.hf_config.num_key_value_heads //
|
||||||
|
|||||||
@ -33,7 +33,7 @@ class CacheEngine:
|
|||||||
|
|
||||||
self.head_size = model_config.get_head_size()
|
self.head_size = model_config.get_head_size()
|
||||||
self.num_layers = model_config.get_num_layers(parallel_config)
|
self.num_layers = model_config.get_num_layers(parallel_config)
|
||||||
self.num_heads = model_config.get_num_heads(parallel_config)
|
self.num_heads = model_config.get_num_kv_heads(parallel_config)
|
||||||
self.dtype = model_config.dtype
|
self.dtype = model_config.dtype
|
||||||
|
|
||||||
self.block_size = cache_config.block_size
|
self.block_size = cache_config.block_size
|
||||||
@ -146,7 +146,7 @@ class CacheEngine:
|
|||||||
parallel_config: ParallelConfig,
|
parallel_config: ParallelConfig,
|
||||||
) -> int:
|
) -> int:
|
||||||
head_size = model_config.get_head_size()
|
head_size = model_config.get_head_size()
|
||||||
num_heads = model_config.get_num_heads(parallel_config)
|
num_heads = model_config.get_num_kv_heads(parallel_config)
|
||||||
num_layers = model_config.get_num_layers(parallel_config)
|
num_layers = model_config.get_num_layers(parallel_config)
|
||||||
|
|
||||||
key_cache_block = block_size * num_heads * head_size
|
key_cache_block = block_size * num_heads * head_size
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user