diff --git a/vllm/config.py b/vllm/config.py index 508e09174cc8..1e9d119ebf8e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4722,7 +4722,6 @@ class VllmConfig: # calculate the default `batch_size_capture_list` if not envs.VLLM_USE_V1: batch_size_capture_list = [] - max_batchsize_to_capture = 0 if self.scheduler_config is not None and \ self.model_config is not None and \ not self.model_config.enforce_eager: diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 913cb0895bb9..91f7bdb731b1 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -96,7 +96,8 @@ class CpuPlatform(Platform): from vllm.utils import GiB_bytes model_config = vllm_config.model_config - model_config.disable_cascade_attn = True + if model_config is not None: + model_config.disable_cascade_attn = True cache_config = vllm_config.cache_config @@ -123,7 +124,7 @@ class CpuPlatform(Platform): "CPU backend doesn't support fp8_e4m3 KV cache type, " "cast to fp8_e5m2.") - if (cache_config.cache_dtype != "auto" + if (cache_config.cache_dtype != "auto" and model_config is not None and model_config.dtype == torch.half): logger.warning("FP8 KV cache on the CPU backend only does not" " support fp16 for now, cast to bf16.") @@ -229,7 +230,7 @@ class CpuPlatform(Platform): os.environ["LOCAL_WORLD_SIZE"] = str( vllm_config.parallel_config.tensor_parallel_size) - if vllm_config.model_config and vllm_config.model_config.use_mla: + if model_config is not None and model_config.use_mla: logger.info( "MLA is enabled on a non-GPU platform; forcing chunked " "prefill and prefix caching to be disabled.") diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index b53d7e71a03e..35a2b48c7d01 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -166,17 +166,19 @@ class CudaPlatformBase(Platform): logger.info( "Forcing kv cache block size to 64 for FlashMLA backend.") + compilation_config = vllm_config.compilation_config if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" and parallel_config.data_parallel_size > 1 - and vllm_config.compilation_config.use_cudagraph): + and compilation_config.use_cudagraph): logger.info( "Data Parallel: Forcing enforce eager to be True since DP " "with DeepEP high-throughput kernels are not CUDA Graph " "compatible. The DeepEP low-latency kernels are CUDA Graph " "compatible. Set the all_to_all backend to deepep_low_latency " "to use those kernels instead.") - vllm_config.compilation_config.use_cudagraph = False - vllm_config.model_config.enforce_eager = True + compilation_config.use_cudagraph = False + if model_config is not None: + model_config.enforce_eager = True # TODO (varun): Turning this ON gives incorrect results for the # Deepseek-V2-lite model. vllm_config.compilation_config.use_inductor = False diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 10a7f7c60ee2..5ec3be908e7d 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -116,11 +116,13 @@ class TpuPlatform(Platform): assert vllm_config.speculative_config is None, \ "TPU does not support speculative decoding" - if vllm_config.model_config.dtype in (torch.float16, torch.float32): + model_config = vllm_config.model_config + if model_config is not None and model_config.dtype in (torch.float16, + torch.float32): logger.warning( "The TPU backend currently does not support %s. " - "Using bfloat16 instead.", vllm_config.model_config.dtype) - vllm_config.model_config.dtype = torch.bfloat16 + "Using bfloat16 instead.", model_config.dtype) + model_config.dtype = torch.bfloat16 from vllm.v1.attention.backends.pallas import PallasAttentionBackend cache_config.block_size = PallasAttentionBackend.get_page_size( @@ -146,7 +148,7 @@ class TpuPlatform(Platform): "Forcing --disable_chunked_mm_input.") scheduler_config.disable_chunked_mm_input = True - if vllm_config.model_config and vllm_config.model_config.use_mla: + if model_config and model_config.use_mla: logger.info( "MLA is enabled on a non-GPU platform; forcing chunked " "prefill and prefix caching to be disabled.") diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 3196f3059e19..c4530c1dfaa3 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -85,14 +85,14 @@ class XPUPlatform(Platform): @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: cache_config = vllm_config.cache_config + model_config = vllm_config.model_config # in V1(or with ipex chunked prefill) block_size is 64 if cache_config and cache_config.block_size is None: cache_config.block_size = 64 # FIXME: Temporarily forcing eager mode # remove after t.compile support stabilizes. - - if (envs.VLLM_USE_V1 and vllm_config.model_config is not None + if (envs.VLLM_USE_V1 and model_config is not None and not vllm_config.model_config.enforce_eager): from vllm.config import CompilationLevel vllm_config.compilation_config.level = CompilationLevel.NO_COMPILATION # noqa: E501 @@ -100,8 +100,7 @@ class XPUPlatform(Platform): # Instances created using VllmConfig() typically have model_config as # None by default. The modification involves adding a check to prevent # potential null exceptions check and update model config. - if vllm_config.model_config is not None: - model_config = vllm_config.model_config + if model_config is not None: if model_config.dtype == torch.bfloat16: bf16_supported = cls.device_support_bf16() if not bf16_supported: @@ -139,7 +138,7 @@ class XPUPlatform(Platform): parallel_config.distributed_executor_backend) parallel_config.distributed_executor_backend = "ray" - if vllm_config.model_config and vllm_config.model_config.use_mla: + if model_config and model_config.use_mla: logger.info( "MLA is enabled on a non-GPU platform; forcing chunked " "prefill and prefix caching to be disabled.")