support MLA

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-05-20 03:29:32 +00:00
parent 37c9babaa0
commit df8f889f37
4 changed files with 86 additions and 34 deletions

View File

@ -30,10 +30,15 @@ sampling_params = SamplingParams(**param_kwargs)
def main():
# Create an LLM.
llm = LLM(model="facebook/opt-125m",
llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite",
enforce_eager=False,
compilation_config=2,
enable_microbatching=True,)
enable_microbatching=True,
trust_remote_code=True,
tensor_parallel_size=4,
max_model_len=1024,
#load_format="dummy",
)
# Generate texts from the prompts.
# The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.

View File

@ -740,7 +740,7 @@ class PiecewiseBackend:
# manage the memory during cuda graph capture
return output
if self.is_debugging_mode:
if self.is_debugging_mode or envs.VLLM_CUDAGRAPH_SANITIZER:
# check if the input addresses are the same
new_input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)

View File

@ -117,6 +117,7 @@ if TYPE_CHECKING:
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557
VLLM_ALL2ALL_BACKEND: str = "naive"
VLLM_CUDAGRAPH_SANITIZER: bool = False
def get_default_cache_root():
@ -811,6 +812,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
# all2all backend for vllm's expert parallel communication
"VLLM_ALL2ALL_BACKEND":
lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"),
# check that the cudagraphs input addresses are correct before replaying
"VLLM_CUDAGRAPH_SANITIZER":
lambda: os.getenv("VLLM_CUDAGRAPH_SANITIZER", "0") == "1",
}
# end-env-vars-definition

View File

@ -210,6 +210,8 @@ from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
from vllm.v1.attention.backends.utils import slice_query_start_locs
try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
is_vllm_fa = True
@ -432,14 +434,6 @@ class MLACommonMetadataBuilder(Generic[M]):
input_batch.swap_states(prefills[i - 1], decode_idx)
modified_batch = True
# Save for next `build` call
# TODO(lucas): this is a bit of a hack, we should probably have a
# better way of doing this
self._num_decodes = num_decodes
self._num_prefills = num_prefills
self._num_decode_tokens = num_decode_tokens
self._num_prefill_tokens = num_prefill_tokens
return modified_batch
def _build_decode(self, block_table_tensor: torch.Tensor,
@ -448,37 +442,74 @@ class MLACommonMetadataBuilder(Generic[M]):
block_table=block_table_tensor,
seq_lens=seq_lens,
)
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata) -> M:
assert self._num_decodes + self._num_prefills == num_reqs
def _split_decodes_and_prefills(self, max_query_len: int, num_reqs: int, num_tokens: int, query_start_loc: torch.Tensor):
"""
return
- num_decodes: number of decode requests
- num_prefills: number of prefill requests
- num_decode_tokens: number of decode tokens
- num_prefill_tokens: number of prefill tokens
"""
if max_query_len == 1:
# Pure decode
return num_reqs, 0, num_tokens, 0
else:
query_lens = query_start_loc[1:] - query_start_loc[:-1]
first_prefill = (query_lens > 1).int().argmax(dim=-1).item()
assert torch.all(query_lens[first_prefill:] > 1)
num_decodes = first_prefill
num_prefills = num_reqs - num_decodes
num_decode_tokens = first_prefill
num_prefill_tokens = num_tokens - query_start_loc[first_prefill]
return (
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens
)
def build_slice(self, req_slice: slice,
token_slice: slice,
max_query_len: int,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
) -> M:
num_reqs = req_slice.stop - req_slice.start
num_tokens = token_slice.stop - token_slice.start
# Note(simon): be careful about the CPU <> GPU memory movement in this
# function. We should avoid GPU -> CPU sync as much as possible because
# it blocks on all previous kernels.
device = self.runner.device
block_table = self.block_table
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].to(
block_table_tensor = block_table.get_device_tensor()[req_slice]
slot_mapping = block_table.slot_mapping_cpu[token_slice].to(
device, non_blocking=True).long()
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
query_start_loc = slice_query_start_locs(
common_attn_metadata.query_start_loc, req_slice)
seq_lens = common_attn_metadata.seq_lens[req_slice]
num_computed_tokens = self.runner.input_batch.\
num_computed_tokens_cpu_tensor[req_slice]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
self._split_decodes_and_prefills(
max_query_len, num_reqs, num_tokens, query_start_loc)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_tokens
prefill_metadata = None
if self._num_prefills > 0:
reqs_start = self._num_decodes # prefill_start
if num_prefills > 0:
reqs_start = num_decodes # prefill_start
context_lens_cpu = self.runner.input_batch.\
num_computed_tokens_cpu_tensor[reqs_start:num_reqs]
context_lens_cpu = num_computed_tokens[reqs_start:num_reqs]
max_context_len_cpu = context_lens_cpu.max().item()
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
prefill_query_start_loc = query_start_loc[
reqs_start:] - query_start_loc[reqs_start]
chunked_context_metadata = None
if self.chunked_prefill_enabled and self._num_prefills > 0 \
if self.chunked_prefill_enabled and num_prefills > 0 \
and max_context_len_cpu > 0:
# NOTE: it is recommend you read the `Chunked Prefill` section
# in the comment at the top of the file before trying to
@ -509,14 +540,14 @@ class MLACommonMetadataBuilder(Generic[M]):
# of `to_list`.
chunk_starts = \
torch.arange(num_chunks, dtype=torch.int32) \
.unsqueeze(1).expand(-1, self._num_prefills) \
.unsqueeze(1).expand(-1, num_prefills) \
* max_context_chunk
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
chunk_starts + max_context_chunk)
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
cu_seq_lens_cpu = torch.zeros(num_chunks,
self._num_prefills + 1,
num_prefills + 1,
dtype=torch.int32,
pin_memory=True)
torch.cumsum(chunk_seq_lens,
@ -544,25 +575,36 @@ class MLACommonMetadataBuilder(Generic[M]):
)
decode_metadata = None
if self._num_decodes > 0:
if num_decodes > 0:
decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:self._num_decodes, ...],
seq_lens=seq_lens[:self._num_decodes],
block_table_tensor=block_table_tensor[:num_decodes, ...],
seq_lens=seq_lens[:num_decodes],
)
return self.metadata_cls(
num_actual_tokens=num_actual_tokens,
num_actual_tokens=num_tokens,
query_start_loc=query_start_loc,
slot_mapping=slot_mapping,
head_dim=self.runner.model_config.get_head_size(),
# MLACommonMetadata Chunk prefill specific
num_decodes=self._num_decodes,
num_decode_tokens=self._num_decode_tokens,
num_prefills=self._num_prefills,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
prefill=prefill_metadata,
decode=decode_metadata,
)
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
return self.build_slice(
req_slice=slice(0, num_reqs),
token_slice=slice(0, num_actual_tokens),
max_query_len=max_query_len,
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
)
def use_cascade_attention(self, *args, **kwargs) -> bool:
return False