From aab35fc31c6a803a589d39a284afb057b39f8c0a Mon Sep 17 00:00:00 2001 From: Xingyu Liu Date: Tue, 9 Dec 2025 20:47:53 -0800 Subject: [PATCH] support medusa Signed-off-by: Xingyu Liu --- vllm/config/speculative.py | 3 +++ vllm/transformers_utils/model_arch_config_convertor.py | 9 +++++++++ 2 files changed, 12 insertions(+) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index bf533bf14e55c..ad4057de834fb 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -401,6 +401,9 @@ class SpeculativeConfig: model_type="eagle", ) self.draft_model_config.hf_config = eagle_config + self.draft_model_config.model_arch_config = ( + self.draft_model_config.get_model_arch_config() + ) if self.num_speculative_tokens is not None and hasattr( self.draft_model_config.hf_config, "num_lookahead_tokens" diff --git a/vllm/transformers_utils/model_arch_config_convertor.py b/vllm/transformers_utils/model_arch_config_convertor.py index 40cf438f4a804..ed6ba0adb5e20 100644 --- a/vllm/transformers_utils/model_arch_config_convertor.py +++ b/vllm/transformers_utils/model_arch_config_convertor.py @@ -276,6 +276,14 @@ class TerratorchModelArchConfigConvertor(ModelArchConfigConvertorBase): return 0 +class MedusaModelArchConfigConvertor(ModelArchConfigConvertorBase): + def get_head_size(self) -> int: + return 0 + + def get_total_num_kv_heads(self) -> int: + return 0 + + class Zamba2ModelArchConfigConvertor(ModelArchConfigConvertorBase): def get_head_size(self) -> int: return getattr(self.hf_text_config, "attention_head_dim", 0) @@ -367,6 +375,7 @@ MODEL_ARCH_CONFIG_CONVERTORS = { "mamba": MambaModelArchConfigConvertor, "falcon_mamba": MambaModelArchConfigConvertor, "timm_wrapper": TerratorchModelArchConfigConvertor, + "medusa": MedusaModelArchConfigConvertor, "zamba2": Zamba2ModelArchConfigConvertor, "mpt": MPTModelArchConfigConvertor, "dbrx": DbrxModelArchConfigConvertor,