mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-14 01:07:25 +08:00
[AMD][FP8][BugFix] Remove V1 check in arg_utils.py for FP8 since it is not necessary (#17215)
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
This commit is contained in:
parent
8de2901fea
commit
68af5f6c5c
@ -1368,23 +1368,6 @@ class EngineArgs:
|
|||||||
recommend_to_remove=False)
|
recommend_to_remove=False)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if current_platform.is_rocm():
|
|
||||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
|
||||||
load_config = self.create_load_config()
|
|
||||||
quantization_config = VllmConfig.get_quantization_config(
|
|
||||||
model_config, load_config)
|
|
||||||
if isinstance(quantization_config, Fp8Config):
|
|
||||||
_raise_or_fallback(feature_name="fp8 for ROCm",
|
|
||||||
recommend_to_remove=False)
|
|
||||||
return False
|
|
||||||
from vllm.model_executor.layers.quantization.quark.quark import (
|
|
||||||
QuarkConfig)
|
|
||||||
|
|
||||||
if isinstance(quantization_config, QuarkConfig
|
|
||||||
) and quantization_config.has_fp8_layer_weights():
|
|
||||||
_raise_or_fallback(feature_name="Quark fp8 for ROCm",
|
|
||||||
recommend_to_remove=False)
|
|
||||||
|
|
||||||
# No Fp8 KV cache so far.
|
# No Fp8 KV cache so far.
|
||||||
if self.kv_cache_dtype != "auto":
|
if self.kv_cache_dtype != "auto":
|
||||||
fp8_attention = self.kv_cache_dtype.startswith("fp8")
|
fp8_attention = self.kv_cache_dtype.startswith("fp8")
|
||||||
|
|||||||
@ -307,18 +307,6 @@ class QuarkConfig(QuantizationConfig):
|
|||||||
# If no matches, return None
|
# If no matches, return None
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def has_fp8_layer_weights(self):
|
|
||||||
layer_quant_config = self.quant_config.get("layer_quant_config")
|
|
||||||
to_dict = lambda obj: cast(Dict[str, Any], obj) or {}
|
|
||||||
return any([
|
|
||||||
'fp8' in cast(
|
|
||||||
str,
|
|
||||||
to_dict(
|
|
||||||
to_dict(to_dict(layer_quant_config).get(layer_name)).get(
|
|
||||||
"weight")).get("dtype"))
|
|
||||||
for layer_name in ["*v_proj", "*k_proj", "*q_proj"]
|
|
||||||
])
|
|
||||||
|
|
||||||
|
|
||||||
class QuarkLinearMethod(LinearMethodBase):
|
class QuarkLinearMethod(LinearMethodBase):
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user