mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 18:57:05 +08:00
support MLA
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
37c9babaa0
commit
df8f889f37
@ -30,10 +30,15 @@ sampling_params = SamplingParams(**param_kwargs)
|
||||
|
||||
def main():
|
||||
# Create an LLM.
|
||||
llm = LLM(model="facebook/opt-125m",
|
||||
llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite",
|
||||
enforce_eager=False,
|
||||
compilation_config=2,
|
||||
enable_microbatching=True,)
|
||||
enable_microbatching=True,
|
||||
trust_remote_code=True,
|
||||
tensor_parallel_size=4,
|
||||
max_model_len=1024,
|
||||
#load_format="dummy",
|
||||
)
|
||||
# Generate texts from the prompts.
|
||||
# The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
|
||||
@ -740,7 +740,7 @@ class PiecewiseBackend:
|
||||
# manage the memory during cuda graph capture
|
||||
return output
|
||||
|
||||
if self.is_debugging_mode:
|
||||
if self.is_debugging_mode or envs.VLLM_CUDAGRAPH_SANITIZER:
|
||||
# check if the input addresses are the same
|
||||
new_input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
|
||||
@ -117,6 +117,7 @@ if TYPE_CHECKING:
|
||||
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557
|
||||
VLLM_ALL2ALL_BACKEND: str = "naive"
|
||||
VLLM_CUDAGRAPH_SANITIZER: bool = False
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -811,6 +812,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# all2all backend for vllm's expert parallel communication
|
||||
"VLLM_ALL2ALL_BACKEND":
|
||||
lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"),
|
||||
|
||||
# check that the cudagraphs input addresses are correct before replaying
|
||||
"VLLM_CUDAGRAPH_SANITIZER":
|
||||
lambda: os.getenv("VLLM_CUDAGRAPH_SANITIZER", "0") == "1",
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
||||
@ -210,6 +210,8 @@ from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
from vllm.v1.attention.backends.utils import slice_query_start_locs
|
||||
|
||||
try:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
is_vllm_fa = True
|
||||
@ -432,14 +434,6 @@ class MLACommonMetadataBuilder(Generic[M]):
|
||||
input_batch.swap_states(prefills[i - 1], decode_idx)
|
||||
modified_batch = True
|
||||
|
||||
# Save for next `build` call
|
||||
# TODO(lucas): this is a bit of a hack, we should probably have a
|
||||
# better way of doing this
|
||||
self._num_decodes = num_decodes
|
||||
self._num_prefills = num_prefills
|
||||
self._num_decode_tokens = num_decode_tokens
|
||||
self._num_prefill_tokens = num_prefill_tokens
|
||||
|
||||
return modified_batch
|
||||
|
||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||
@ -448,37 +442,74 @@ class MLACommonMetadataBuilder(Generic[M]):
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens,
|
||||
)
|
||||
|
||||
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata) -> M:
|
||||
assert self._num_decodes + self._num_prefills == num_reqs
|
||||
|
||||
def _split_decodes_and_prefills(self, max_query_len: int, num_reqs: int, num_tokens: int, query_start_loc: torch.Tensor):
|
||||
"""
|
||||
return
|
||||
- num_decodes: number of decode requests
|
||||
- num_prefills: number of prefill requests
|
||||
- num_decode_tokens: number of decode tokens
|
||||
- num_prefill_tokens: number of prefill tokens
|
||||
"""
|
||||
if max_query_len == 1:
|
||||
# Pure decode
|
||||
return num_reqs, 0, num_tokens, 0
|
||||
else:
|
||||
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
first_prefill = (query_lens > 1).int().argmax(dim=-1).item()
|
||||
assert torch.all(query_lens[first_prefill:] > 1)
|
||||
num_decodes = first_prefill
|
||||
num_prefills = num_reqs - num_decodes
|
||||
num_decode_tokens = first_prefill
|
||||
num_prefill_tokens = num_tokens - query_start_loc[first_prefill]
|
||||
return (
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens
|
||||
)
|
||||
|
||||
def build_slice(self, req_slice: slice,
|
||||
token_slice: slice,
|
||||
max_query_len: int,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> M:
|
||||
num_reqs = req_slice.stop - req_slice.start
|
||||
num_tokens = token_slice.stop - token_slice.start
|
||||
|
||||
# Note(simon): be careful about the CPU <> GPU memory movement in this
|
||||
# function. We should avoid GPU -> CPU sync as much as possible because
|
||||
# it blocks on all previous kernels.
|
||||
device = self.runner.device
|
||||
block_table = self.block_table
|
||||
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
|
||||
slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].to(
|
||||
block_table_tensor = block_table.get_device_tensor()[req_slice]
|
||||
slot_mapping = block_table.slot_mapping_cpu[token_slice].to(
|
||||
device, non_blocking=True).long()
|
||||
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
query_start_loc = slice_query_start_locs(
|
||||
common_attn_metadata.query_start_loc, req_slice)
|
||||
seq_lens = common_attn_metadata.seq_lens[req_slice]
|
||||
|
||||
num_computed_tokens = self.runner.input_batch.\
|
||||
num_computed_tokens_cpu_tensor[req_slice]
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||||
self._split_decodes_and_prefills(
|
||||
max_query_len, num_reqs, num_tokens, query_start_loc)
|
||||
|
||||
assert num_decodes + num_prefills == num_reqs
|
||||
assert num_decode_tokens + num_prefill_tokens == num_tokens
|
||||
|
||||
prefill_metadata = None
|
||||
if self._num_prefills > 0:
|
||||
reqs_start = self._num_decodes # prefill_start
|
||||
if num_prefills > 0:
|
||||
reqs_start = num_decodes # prefill_start
|
||||
|
||||
context_lens_cpu = self.runner.input_batch.\
|
||||
num_computed_tokens_cpu_tensor[reqs_start:num_reqs]
|
||||
context_lens_cpu = num_computed_tokens[reqs_start:num_reqs]
|
||||
max_context_len_cpu = context_lens_cpu.max().item()
|
||||
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
|
||||
prefill_query_start_loc = query_start_loc[
|
||||
reqs_start:] - query_start_loc[reqs_start]
|
||||
|
||||
chunked_context_metadata = None
|
||||
if self.chunked_prefill_enabled and self._num_prefills > 0 \
|
||||
if self.chunked_prefill_enabled and num_prefills > 0 \
|
||||
and max_context_len_cpu > 0:
|
||||
# NOTE: it is recommend you read the `Chunked Prefill` section
|
||||
# in the comment at the top of the file before trying to
|
||||
@ -509,14 +540,14 @@ class MLACommonMetadataBuilder(Generic[M]):
|
||||
# of `to_list`.
|
||||
chunk_starts = \
|
||||
torch.arange(num_chunks, dtype=torch.int32) \
|
||||
.unsqueeze(1).expand(-1, self._num_prefills) \
|
||||
.unsqueeze(1).expand(-1, num_prefills) \
|
||||
* max_context_chunk
|
||||
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
|
||||
chunk_starts + max_context_chunk)
|
||||
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
|
||||
|
||||
cu_seq_lens_cpu = torch.zeros(num_chunks,
|
||||
self._num_prefills + 1,
|
||||
num_prefills + 1,
|
||||
dtype=torch.int32,
|
||||
pin_memory=True)
|
||||
torch.cumsum(chunk_seq_lens,
|
||||
@ -544,25 +575,36 @@ class MLACommonMetadataBuilder(Generic[M]):
|
||||
)
|
||||
|
||||
decode_metadata = None
|
||||
if self._num_decodes > 0:
|
||||
if num_decodes > 0:
|
||||
decode_metadata = self._build_decode(
|
||||
block_table_tensor=block_table_tensor[:self._num_decodes, ...],
|
||||
seq_lens=seq_lens[:self._num_decodes],
|
||||
block_table_tensor=block_table_tensor[:num_decodes, ...],
|
||||
seq_lens=seq_lens[:num_decodes],
|
||||
)
|
||||
|
||||
return self.metadata_cls(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
num_actual_tokens=num_tokens,
|
||||
query_start_loc=query_start_loc,
|
||||
slot_mapping=slot_mapping,
|
||||
head_dim=self.runner.model_config.get_head_size(),
|
||||
# MLACommonMetadata Chunk prefill specific
|
||||
num_decodes=self._num_decodes,
|
||||
num_decode_tokens=self._num_decode_tokens,
|
||||
num_prefills=self._num_prefills,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
prefill=prefill_metadata,
|
||||
decode=decode_metadata,
|
||||
)
|
||||
|
||||
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata):
|
||||
return self.build_slice(
|
||||
req_slice=slice(0, num_reqs),
|
||||
token_slice=slice(0, num_actual_tokens),
|
||||
max_query_len=max_query_len,
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user