diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index 1b6bdabc7a539..1d84b6082fbd4 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -1,7 +1,7 @@ # default base image ARG REMOTE_VLLM="0" ARG COMMON_WORKDIR=/app -ARG BASE_IMAGE=rocm/vllm-dev:base +ARG BASE_IMAGE=rocm/vllm-dev:base_triton_test_dockerfile_update_20251219_tuned_20251219 FROM ${BASE_IMAGE} AS base @@ -130,6 +130,7 @@ RUN --mount=type=bind,from=export_vllm,src=/,target=/install \ && uv pip install --system *.whl ARG COMMON_WORKDIR +ARG BASE_IMAGE # Copy over the benchmark scripts as well COPY --from=export_vllm /benchmarks ${COMMON_WORKDIR}/vllm/benchmarks @@ -144,4 +145,9 @@ ENV SAFETENSORS_FAST_GPU=1 # Performance environment variable. ENV HIP_FORCE_DEV_KERNARG=1 +# Workaround for ROCm profiler limits +RUN echo "ROCTRACER_MAX_EVENTS=10000000" > /app/libkineto.conf +ENV KINETO_CONFIG="/app/libkineto.conf" +RUN echo "VLLM_BASE_IMAGE=${BASE_IMAGE}" >> /app/versions.txt + CMD ["/bin/bash"] diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index a3e5d02abdc1d..86a0669ed81f7 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -117,7 +117,7 @@ class PassConfig: """Fuse the custom SiluMul + quant ops.""" fuse_attn_quant: bool = Field(default=None) """Fuse the custom attention + quant ops.""" - eliminate_noops: bool = Field(default=None) + eliminate_noops: bool = Field(default=True) """Eliminate no-op ops.""" enable_sp: bool = Field(default=None) """Enable sequence parallelism.""" @@ -1000,6 +1000,14 @@ class CompilationConfig: op in self.splitting_ops for op in self._attention_ops ) + def add_missing_attention_splitting_ops(self): + if self.splitting_ops is None: + self.splitting_ops = list(self._attention_ops) + return + for op in self._attention_ops: + if op not in self.splitting_ops: + self.splitting_ops.append(op) + def is_attention_compiled_piecewise(self) -> bool: if not self.splitting_ops_contain_attention(): return False diff --git a/vllm/envs.py b/vllm/envs.py index d0f2798096263..47cd5ebc6a85c 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -19,7 +19,7 @@ if TYPE_CHECKING: VLLM_NCCL_SO_PATH: str | None = None LD_LIBRARY_PATH: str | None = None VLLM_ROCM_SLEEP_MEM_CHUNK_SIZE: int = 256 - VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False + VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = True VLLM_FLASH_ATTN_VERSION: int | None = None LOCAL_RANK: int = 0 CUDA_VISIBLE_DEVICES: str | None = None @@ -110,13 +110,13 @@ if TYPE_CHECKING: VLLM_SKIP_P2P_CHECK: bool = False VLLM_DISABLED_KERNELS: list[str] = [] VLLM_DISABLE_PYNCCL: bool = False - VLLM_ROCM_USE_AITER: bool = False + VLLM_ROCM_USE_AITER: bool = True VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False VLLM_ROCM_USE_AITER_LINEAR: bool = True VLLM_ROCM_USE_AITER_MOE: bool = True VLLM_ROCM_USE_AITER_RMSNORM: bool = True VLLM_ROCM_USE_AITER_MLA: bool = True - VLLM_ROCM_USE_AITER_MHA: bool = True + VLLM_ROCM_USE_AITER_MHA: bool = False VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = False VLLM_ROCM_USE_AITER_FP8BMM: bool = True @@ -294,6 +294,15 @@ def use_aot_compile() -> bool: ) +def use_aiter() -> bool: + from vllm.platforms.rocm import on_mi3xx + + return on_mi3xx() and os.environ.get("VLLM_ROCM_USE_AITER", "1").lower() in ( + "1", + "true", + ) + + def env_with_choices( env_name: str, default: str | None, @@ -558,7 +567,7 @@ environment_variables: dict[str, Callable[[], Any]] = { # Use separate prefill and decode kernels for V1 attention instead of # the unified triton kernel. "VLLM_V1_USE_PREFILL_DECODE_ATTENTION": lambda: ( - os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "False").lower() + os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "True").lower() in ("true", "1") ), # Force vllm to use a specific flash-attention version (2 or 3), only valid @@ -926,9 +935,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ), # Disable aiter ops unless specifically enabled. # Acts as a parent switch to enable the rest of the other operations. - "VLLM_ROCM_USE_AITER": lambda: ( - os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in ("true", "1") - ), + "VLLM_ROCM_USE_AITER": use_aiter, # Whether to use aiter paged attention. # By default is disabled. "VLLM_ROCM_USE_AITER_PAGED_ATTN": lambda: ( @@ -957,7 +964,7 @@ environment_variables: dict[str, Callable[[], Any]] = { # Whether to use aiter mha ops. # By default is enabled. "VLLM_ROCM_USE_AITER_MHA": lambda: ( - os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in ("true", "1") + os.getenv("VLLM_ROCM_USE_AITER_MHA", "False").lower() in ("true", "1") ), # Whether to use aiter fp4 gemm asm. # By default is disabled. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ca525326006ff..6ad17836470c4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4804,15 +4804,15 @@ class GPUModelRunner( f"{min_cg_support})" ) if min_cg_support == AttentionCGSupport.NEVER: - # if not supported any full cudagraphs, just raise it. - msg += ( - "; please try cudagraph_mode=PIECEWISE, and " - "make sure compilation mode is VLLM_COMPILE" + msg += "; setting cudagraph_mode=PIECEWISE" + cudagraph_mode = self.compilation_config.cudagraph_mode = ( + CUDAGraphMode.PIECEWISE ) - raise ValueError(msg) + if not self.compilation_config.splitting_ops_contain_attention(): + msg += "; adding attention ops to splitting ops" - # attempt to resolve the full cudagraph related mode - if self.compilation_config.splitting_ops_contain_attention(): + self.compilation_config.add_missing_attention_splitting_ops() + elif self.compilation_config.splitting_ops_contain_attention(): msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE" cudagraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.FULL_AND_PIECEWISE