mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-09 19:57:10 +08:00
add support for cutlass mla full cudagraphs
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
6d76bd034a
commit
090f485aa1
@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import ClassVar, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -12,11 +12,22 @@ from vllm.attention.backends.abstract import (AttentionType,
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||||
MLACommonImpl,
|
MLACommonImpl,
|
||||||
MLACommonMetadata)
|
MLACommonMetadata,
|
||||||
|
MLACommonMetadataBuilder)
|
||||||
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
||||||
|
# enable full CUDA Graph support for decode-only capture
|
||||||
|
full_cudagraph_supported: ClassVar[bool] = True # Decode-only
|
||||||
|
|
||||||
|
def can_run_in_cudagraph(
|
||||||
|
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
||||||
|
return common_attn_metadata.max_query_len == 1
|
||||||
|
|
||||||
|
|
||||||
class CutlassMLABackend(MLACommonBackend):
|
class CutlassMLABackend(MLACommonBackend):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -27,6 +38,10 @@ class CutlassMLABackend(MLACommonBackend):
|
|||||||
def get_impl_cls() -> type["CutlassMLAImpl"]:
|
def get_impl_cls() -> type["CutlassMLAImpl"]:
|
||||||
return CutlassMLAImpl
|
return CutlassMLAImpl
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]:
|
||||||
|
return CutlassMLAMetadataBuilder
|
||||||
|
|
||||||
|
|
||||||
class SM100Workspace:
|
class SM100Workspace:
|
||||||
|
|
||||||
|
|||||||
@ -1920,21 +1920,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
elif num_scheduled_tokens in self.cudagraphs \
|
elif num_scheduled_tokens in self.cudagraphs \
|
||||||
and not skip_cuda_graphs:
|
and not skip_cuda_graphs:
|
||||||
cudagraph_metadata = self.cudagraphs[num_scheduled_tokens]
|
cudagraph_metadata = self.cudagraphs[num_scheduled_tokens]
|
||||||
# if is_global_first_rank():
|
if is_global_first_rank():
|
||||||
# logger.info(f"UBATCH REPLAY {num_scheduled_tokens}")
|
logger.info(f"UBATCH REPLAY {num_scheduled_tokens}")
|
||||||
cudagraph_metadata.cudagraph.replay()
|
cudagraph_metadata.cudagraph.replay()
|
||||||
return cudagraph_metadata.outputs
|
return cudagraph_metadata.outputs
|
||||||
else:
|
else:
|
||||||
# if is_global_first_rank():
|
if is_global_first_rank():
|
||||||
# logger.info(f"RUNNING NORMALLY {num_scheduled_tokens}")
|
logger.info(f"RUNNING NORMALLY {num_scheduled_tokens}")
|
||||||
return self._run_ubatches(ubatch_metadata, self.model)
|
return self._run_ubatches(ubatch_metadata, self.model)
|
||||||
# run normal batch
|
# run normal batch
|
||||||
else:
|
else:
|
||||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||||
self.model_inputs(slice(0, num_scheduled_tokens),
|
self.model_inputs(slice(0, num_scheduled_tokens),
|
||||||
scheduler_output, is_dummy_run)
|
scheduler_output, is_dummy_run)
|
||||||
# if is_global_first_rank():
|
if is_global_first_rank():
|
||||||
# logger.info(f"RUNNING FULL BATCH {num_scheduled_tokens}")
|
logger.info(f"RUNNING FULL BATCH {num_scheduled_tokens}")
|
||||||
skip_cuda_graphs = self.parallel_config.enable_microbatching
|
skip_cuda_graphs = self.parallel_config.enable_microbatching
|
||||||
with set_forward_context(attn_metadata,
|
with set_forward_context(attn_metadata,
|
||||||
vllm_config=self.vllm_config,
|
vllm_config=self.vllm_config,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user