[Platform] Move model arch check to platform (#11503)

Signed-off-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao 2024-12-27 16:45:25 +08:00 committed by GitHub
parent 2339d59f92
commit 6c6f7fe8a8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 51 additions and 37 deletions

View File

@ -187,31 +187,6 @@ _VLLM_MODELS = {
**_SPECULATIVE_DECODING_MODELS,
}
# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS: List[str] = []
# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
"Triton flash attention. For half-precision SWA support, "
"please use CK flash attention by setting "
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
"Qwen2ForCausalLM":
_ROCM_SWA_REASON,
"MistralForCausalLM":
_ROCM_SWA_REASON,
"MixtralForCausalLM":
_ROCM_SWA_REASON,
"PaliGemmaForConditionalGeneration":
("ROCm flash attention does not yet "
"fully support 32-bit precision on PaliGemma"),
"Phi3VForCausalLM":
("ROCm Triton flash attention may run into compilation errors due to "
"excessive use of shared memory. If this happens, disable Triton FA "
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
}
@dataclass(frozen=True)
class _ModelInfo:
@ -297,17 +272,7 @@ def _try_load_model_cls(
model_arch: str,
model: _BaseRegisteredModel,
) -> Optional[Type[nn.Module]]:
if current_platform.is_rocm():
if model_arch in _ROCM_UNSUPPORTED_MODELS:
raise ValueError(f"Model architecture '{model_arch}' is not "
"supported by ROCm for now.")
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
logger.warning(
"Model architecture '%s' is partially "
"supported by ROCm: %s", model_arch, msg)
current_platform.verify_model_arch(model_arch)
try:
return model.load_model_cls()
except Exception:

View File

@ -199,6 +199,18 @@ class Platform:
"""
pass
@classmethod
def verify_model_arch(cls, model_arch: str) -> None:
"""
Verify whether the current platform supports the specified model
architecture.
- This will raise an Error or Warning based on the model support on
the current platform.
- By default all models are considered supported.
"""
pass
@classmethod
def verify_quantization(cls, quant: str) -> None:
"""

View File

@ -1,6 +1,6 @@
import os
from functools import lru_cache
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Dict, List, Optional
import torch
@ -33,6 +33,31 @@ if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
" `spawn` instead.")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS: List[str] = []
# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
"Triton flash attention. For half-precision SWA support, "
"please use CK flash attention by setting "
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
"Qwen2ForCausalLM":
_ROCM_SWA_REASON,
"MistralForCausalLM":
_ROCM_SWA_REASON,
"MixtralForCausalLM":
_ROCM_SWA_REASON,
"PaliGemmaForConditionalGeneration":
("ROCm flash attention does not yet "
"fully support 32-bit precision on PaliGemma"),
"Phi3VForCausalLM":
("ROCm Triton flash attention may run into compilation errors due to "
"excessive use of shared memory. If this happens, disable Triton FA "
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
}
class RocmPlatform(Platform):
_enum = PlatformEnum.ROCM
@ -102,6 +127,18 @@ class RocmPlatform(Platform):
else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"
@classmethod
def verify_model_arch(cls, model_arch: str) -> None:
if model_arch in _ROCM_UNSUPPORTED_MODELS:
raise ValueError(f"Model architecture '{model_arch}' is not "
"supported by ROCm for now.")
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
logger.warning(
"Model architecture '%s' is partially "
"supported by ROCm: %s", model_arch, msg)
@classmethod
def verify_quantization(cls, quant: str) -> None:
super().verify_quantization(quant)