mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-02 12:37:13 +08:00
support medusa
Signed-off-by: Xingyu Liu <charlotteliu12x@gmail.com>
This commit is contained in:
parent
f72949b288
commit
aab35fc31c
@ -401,6 +401,9 @@ class SpeculativeConfig:
|
||||
model_type="eagle",
|
||||
)
|
||||
self.draft_model_config.hf_config = eagle_config
|
||||
self.draft_model_config.model_arch_config = (
|
||||
self.draft_model_config.get_model_arch_config()
|
||||
)
|
||||
|
||||
if self.num_speculative_tokens is not None and hasattr(
|
||||
self.draft_model_config.hf_config, "num_lookahead_tokens"
|
||||
|
||||
@ -276,6 +276,14 @@ class TerratorchModelArchConfigConvertor(ModelArchConfigConvertorBase):
|
||||
return 0
|
||||
|
||||
|
||||
class MedusaModelArchConfigConvertor(ModelArchConfigConvertorBase):
|
||||
def get_head_size(self) -> int:
|
||||
return 0
|
||||
|
||||
def get_total_num_kv_heads(self) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
class Zamba2ModelArchConfigConvertor(ModelArchConfigConvertorBase):
|
||||
def get_head_size(self) -> int:
|
||||
return getattr(self.hf_text_config, "attention_head_dim", 0)
|
||||
@ -367,6 +375,7 @@ MODEL_ARCH_CONFIG_CONVERTORS = {
|
||||
"mamba": MambaModelArchConfigConvertor,
|
||||
"falcon_mamba": MambaModelArchConfigConvertor,
|
||||
"timm_wrapper": TerratorchModelArchConfigConvertor,
|
||||
"medusa": MedusaModelArchConfigConvertor,
|
||||
"zamba2": Zamba2ModelArchConfigConvertor,
|
||||
"mpt": MPTModelArchConfigConvertor,
|
||||
"dbrx": DbrxModelArchConfigConvertor,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user