diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 66674e3ebe91f..0133f26a0b1bc 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -93,6 +93,8 @@ class RocmPlatform(Platform): elif vllm_config.speculative_config: parallel_config.worker_cls = \ "vllm.spec_decode.spec_decode_worker.create_spec_worker" + parallel_config.sd_worker_cls = \ + "vllm.worker.worker.Worker" else: parallel_config.worker_cls = "vllm.worker.worker.Worker"