diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index bb44c5ad84cc1..9f3c34b15e2a8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3460,6 +3460,10 @@ class GPUModelRunner( scope="local", ) prepare_communication_buffer_for_model(self.model) + if (drafter := getattr(self, "drafter", None)) and ( + drafter_model := getattr(drafter, "model", None) + ): + prepare_communication_buffer_for_model(drafter_model) mm_config = self.model_config.multimodal_config self.is_multimodal_pruning_enabled = ( supports_multimodal_pruning(self.get_model())