support medusa

Signed-off-by: Xingyu Liu <charlotteliu12x@gmail.com>
This commit is contained in:
Xingyu Liu 2025-12-09 20:47:53 -08:00
parent f72949b288
commit aab35fc31c
2 changed files with 12 additions and 0 deletions

View File

@ -401,6 +401,9 @@ class SpeculativeConfig:
model_type="eagle",
)
self.draft_model_config.hf_config = eagle_config
self.draft_model_config.model_arch_config = (
self.draft_model_config.get_model_arch_config()
)
if self.num_speculative_tokens is not None and hasattr(
self.draft_model_config.hf_config, "num_lookahead_tokens"

View File

@ -276,6 +276,14 @@ class TerratorchModelArchConfigConvertor(ModelArchConfigConvertorBase):
return 0
class MedusaModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_head_size(self) -> int:
return 0
def get_total_num_kv_heads(self) -> int:
return 0
class Zamba2ModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_head_size(self) -> int:
return getattr(self.hf_text_config, "attention_head_dim", 0)
@ -367,6 +375,7 @@ MODEL_ARCH_CONFIG_CONVERTORS = {
"mamba": MambaModelArchConfigConvertor,
"falcon_mamba": MambaModelArchConfigConvertor,
"timm_wrapper": TerratorchModelArchConfigConvertor,
"medusa": MedusaModelArchConfigConvertor,
"zamba2": Zamba2ModelArchConfigConvertor,
"mpt": MPTModelArchConfigConvertor,
"dbrx": DbrxModelArchConfigConvertor,