mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 02:15:39 +08:00
[Distributed] [ROCM] Fix custom allreduce enable checks (#16010)
Signed-off-by: ilmarkov <imarkov@redhat.com> Co-authored-by: ilmarkov <imarkov@redhat.com>
This commit is contained in:
parent
2386803f2a
commit
ef608c37a7
@ -1619,13 +1619,12 @@ class ParallelConfig:
|
|||||||
if self.use_ray:
|
if self.use_ray:
|
||||||
from vllm.executor import ray_utils
|
from vllm.executor import ray_utils
|
||||||
ray_utils.assert_ray_available()
|
ray_utils.assert_ray_available()
|
||||||
device_capability = current_platform.get_device_capability()
|
|
||||||
if (current_platform.is_rocm() and device_capability is not None
|
if not current_platform.use_custom_allreduce():
|
||||||
and device_capability < (9, 4)):
|
|
||||||
self.disable_custom_all_reduce = True
|
self.disable_custom_all_reduce = True
|
||||||
logger.info(
|
logger.info(
|
||||||
"Disabled the custom all-reduce kernel because it is not "
|
"Disabled the custom all-reduce kernel because it is not "
|
||||||
"supported on AMD GPUs older than MI300X.")
|
"supported on current platform.")
|
||||||
if self.ray_workers_use_nsight and not self.use_ray:
|
if self.ray_workers_use_nsight and not self.use_ray:
|
||||||
raise ValueError("Unable to use nsight profiling unless workers "
|
raise ValueError("Unable to use nsight profiling unless workers "
|
||||||
"run with Ray.")
|
"run with Ray.")
|
||||||
|
|||||||
@ -308,6 +308,10 @@ class CudaPlatformBase(Platform):
|
|||||||
def supports_v1(cls, model_config: ModelConfig) -> bool:
|
def supports_v1(cls, model_config: ModelConfig) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def use_custom_allreduce(cls) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
# NVML utils
|
# NVML utils
|
||||||
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
||||||
|
|||||||
@ -379,6 +379,13 @@ class Platform:
|
|||||||
"""
|
"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def use_custom_allreduce(cls) -> bool:
|
||||||
|
"""
|
||||||
|
Returns if custom allreduce is supported on the current platform
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class UnspecifiedPlatform(Platform):
|
class UnspecifiedPlatform(Platform):
|
||||||
_enum = PlatformEnum.UNSPECIFIED
|
_enum = PlatformEnum.UNSPECIFIED
|
||||||
|
|||||||
@ -302,3 +302,10 @@ class RocmPlatform(Platform):
|
|||||||
def supports_v1(cls, model_config: ModelConfig) -> bool:
|
def supports_v1(cls, model_config: ModelConfig) -> bool:
|
||||||
# V1 support on AMD gpus is experimental
|
# V1 support on AMD gpus is experimental
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def use_custom_allreduce(cls) -> bool:
|
||||||
|
# We only enable custom allreduce for MI300 series
|
||||||
|
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
|
||||||
|
supported_archs = ['gfx94']
|
||||||
|
return any(gfx in gcn_arch for gfx in supported_archs)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user