diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 35f2fd0ba9e22..c8c373b8adc9b 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -138,9 +138,11 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase): super().__init__(cpu_group) self.handle_cache = Cache() - # This is the DeepEP default. Stick to it till we can establish - # reasonable defaults based on profiling. - self.num_sms = 20 + # Use all SMs for all2all communication + # This will need to be adjusted for dual-batch overlap + device = self.dp_group.device + props = torch.cuda.get_device_properties(device) + self.num_sms = props.multi_processor_count def get_handle(self, kwargs): raise NotImplementedError