Fix config for Falcon (#1164)

This commit is contained in:
Woosuk Kwon 2023-09-23 17:38:43 -07:00 committed by GitHub
parent f187877945
commit 9f6be8692e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 3 deletions

View File

@ -135,7 +135,8 @@ class ModelConfig:
# FIXME(woosuk): This may not be true for all models.
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
"""Returns the number of KV heads per GPU worker."""
# For GPTBigCode & Falcon:
# Note: for falcon, when new_decoder_architecture is True, the
# multi_query flag is ignored and we use n_head_kv for the number of
@ -147,11 +148,15 @@ class ModelConfig:
if not new_decoder_arch_falcon and getattr(self.hf_config,
"multi_query", False):
# Multi-query attention, only one KV head.
# Currently, tensor parallelism is not supported in this case.
return 1
# For Falcon:
if getattr(self.hf_config, "n_head_kv", None) is not None:
return (self.hf_config.n_head_kv //
parallel_config.tensor_parallel_size)
if getattr(self.hf_config, "num_kv_heads", None) is not None:
return (self.hf_config.num_kv_heads //
parallel_config.tensor_parallel_size)
# For LLaMA-2:
if getattr(self.hf_config, "num_key_value_heads", None) is not None:
return (self.hf_config.num_key_value_heads //

View File

@ -33,7 +33,7 @@ class CacheEngine:
self.head_size = model_config.get_head_size()
self.num_layers = model_config.get_num_layers(parallel_config)
self.num_heads = model_config.get_num_heads(parallel_config)
self.num_heads = model_config.get_num_kv_heads(parallel_config)
self.dtype = model_config.dtype
self.block_size = cache_config.block_size
@ -146,7 +146,7 @@ class CacheEngine:
parallel_config: ParallelConfig,
) -> int:
head_size = model_config.get_head_size()
num_heads = model_config.get_num_heads(parallel_config)
num_heads = model_config.get_num_kv_heads(parallel_config)
num_layers = model_config.get_num_layers(parallel_config)
key_cache_block = block_size * num_heads * head_size