mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-05 00:22:19 +08:00
support MLA
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
37c9babaa0
commit
df8f889f37
@ -30,10 +30,15 @@ sampling_params = SamplingParams(**param_kwargs)
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Create an LLM.
|
# Create an LLM.
|
||||||
llm = LLM(model="facebook/opt-125m",
|
llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite",
|
||||||
enforce_eager=False,
|
enforce_eager=False,
|
||||||
compilation_config=2,
|
compilation_config=2,
|
||||||
enable_microbatching=True,)
|
enable_microbatching=True,
|
||||||
|
trust_remote_code=True,
|
||||||
|
tensor_parallel_size=4,
|
||||||
|
max_model_len=1024,
|
||||||
|
#load_format="dummy",
|
||||||
|
)
|
||||||
# Generate texts from the prompts.
|
# Generate texts from the prompts.
|
||||||
# The output is a list of RequestOutput objects
|
# The output is a list of RequestOutput objects
|
||||||
# that contain the prompt, generated text, and other information.
|
# that contain the prompt, generated text, and other information.
|
||||||
|
|||||||
@ -740,7 +740,7 @@ class PiecewiseBackend:
|
|||||||
# manage the memory during cuda graph capture
|
# manage the memory during cuda graph capture
|
||||||
return output
|
return output
|
||||||
|
|
||||||
if self.is_debugging_mode:
|
if self.is_debugging_mode or envs.VLLM_CUDAGRAPH_SANITIZER:
|
||||||
# check if the input addresses are the same
|
# check if the input addresses are the same
|
||||||
new_input_addresses = [
|
new_input_addresses = [
|
||||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||||
|
|||||||
@ -117,6 +117,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
|
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
|
||||||
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557
|
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557
|
||||||
VLLM_ALL2ALL_BACKEND: str = "naive"
|
VLLM_ALL2ALL_BACKEND: str = "naive"
|
||||||
|
VLLM_CUDAGRAPH_SANITIZER: bool = False
|
||||||
|
|
||||||
|
|
||||||
def get_default_cache_root():
|
def get_default_cache_root():
|
||||||
@ -811,6 +812,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
# all2all backend for vllm's expert parallel communication
|
# all2all backend for vllm's expert parallel communication
|
||||||
"VLLM_ALL2ALL_BACKEND":
|
"VLLM_ALL2ALL_BACKEND":
|
||||||
lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"),
|
lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"),
|
||||||
|
|
||||||
|
# check that the cudagraphs input addresses are correct before replaying
|
||||||
|
"VLLM_CUDAGRAPH_SANITIZER":
|
||||||
|
lambda: os.getenv("VLLM_CUDAGRAPH_SANITIZER", "0") == "1",
|
||||||
}
|
}
|
||||||
|
|
||||||
# end-env-vars-definition
|
# end-env-vars-definition
|
||||||
|
|||||||
@ -210,6 +210,8 @@ from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
|||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
from vllm.v1.worker.block_table import BlockTable
|
from vllm.v1.worker.block_table import BlockTable
|
||||||
|
|
||||||
|
from vllm.v1.attention.backends.utils import slice_query_start_locs
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||||
is_vllm_fa = True
|
is_vllm_fa = True
|
||||||
@ -432,14 +434,6 @@ class MLACommonMetadataBuilder(Generic[M]):
|
|||||||
input_batch.swap_states(prefills[i - 1], decode_idx)
|
input_batch.swap_states(prefills[i - 1], decode_idx)
|
||||||
modified_batch = True
|
modified_batch = True
|
||||||
|
|
||||||
# Save for next `build` call
|
|
||||||
# TODO(lucas): this is a bit of a hack, we should probably have a
|
|
||||||
# better way of doing this
|
|
||||||
self._num_decodes = num_decodes
|
|
||||||
self._num_prefills = num_prefills
|
|
||||||
self._num_decode_tokens = num_decode_tokens
|
|
||||||
self._num_prefill_tokens = num_prefill_tokens
|
|
||||||
|
|
||||||
return modified_batch
|
return modified_batch
|
||||||
|
|
||||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||||
@ -448,37 +442,74 @@ class MLACommonMetadataBuilder(Generic[M]):
|
|||||||
block_table=block_table_tensor,
|
block_table=block_table_tensor,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
)
|
)
|
||||||
|
|
||||||
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
def _split_decodes_and_prefills(self, max_query_len: int, num_reqs: int, num_tokens: int, query_start_loc: torch.Tensor):
|
||||||
common_prefix_len: int,
|
"""
|
||||||
common_attn_metadata: CommonAttentionMetadata) -> M:
|
return
|
||||||
assert self._num_decodes + self._num_prefills == num_reqs
|
- num_decodes: number of decode requests
|
||||||
|
- num_prefills: number of prefill requests
|
||||||
|
- num_decode_tokens: number of decode tokens
|
||||||
|
- num_prefill_tokens: number of prefill tokens
|
||||||
|
"""
|
||||||
|
if max_query_len == 1:
|
||||||
|
# Pure decode
|
||||||
|
return num_reqs, 0, num_tokens, 0
|
||||||
|
else:
|
||||||
|
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||||
|
first_prefill = (query_lens > 1).int().argmax(dim=-1).item()
|
||||||
|
assert torch.all(query_lens[first_prefill:] > 1)
|
||||||
|
num_decodes = first_prefill
|
||||||
|
num_prefills = num_reqs - num_decodes
|
||||||
|
num_decode_tokens = first_prefill
|
||||||
|
num_prefill_tokens = num_tokens - query_start_loc[first_prefill]
|
||||||
|
return (
|
||||||
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
def build_slice(self, req_slice: slice,
|
||||||
|
token_slice: slice,
|
||||||
|
max_query_len: int,
|
||||||
|
common_prefix_len: int,
|
||||||
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
) -> M:
|
||||||
|
num_reqs = req_slice.stop - req_slice.start
|
||||||
|
num_tokens = token_slice.stop - token_slice.start
|
||||||
|
|
||||||
# Note(simon): be careful about the CPU <> GPU memory movement in this
|
# Note(simon): be careful about the CPU <> GPU memory movement in this
|
||||||
# function. We should avoid GPU -> CPU sync as much as possible because
|
# function. We should avoid GPU -> CPU sync as much as possible because
|
||||||
# it blocks on all previous kernels.
|
# it blocks on all previous kernels.
|
||||||
device = self.runner.device
|
device = self.runner.device
|
||||||
block_table = self.block_table
|
block_table = self.block_table
|
||||||
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
|
block_table_tensor = block_table.get_device_tensor()[req_slice]
|
||||||
slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].to(
|
slot_mapping = block_table.slot_mapping_cpu[token_slice].to(
|
||||||
device, non_blocking=True).long()
|
device, non_blocking=True).long()
|
||||||
|
|
||||||
query_start_loc = common_attn_metadata.query_start_loc
|
query_start_loc = slice_query_start_locs(
|
||||||
seq_lens = common_attn_metadata.seq_lens
|
common_attn_metadata.query_start_loc, req_slice)
|
||||||
|
seq_lens = common_attn_metadata.seq_lens[req_slice]
|
||||||
|
|
||||||
|
num_computed_tokens = self.runner.input_batch.\
|
||||||
|
num_computed_tokens_cpu_tensor[req_slice]
|
||||||
|
|
||||||
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||||||
|
self._split_decodes_and_prefills(
|
||||||
|
max_query_len, num_reqs, num_tokens, query_start_loc)
|
||||||
|
|
||||||
|
assert num_decodes + num_prefills == num_reqs
|
||||||
|
assert num_decode_tokens + num_prefill_tokens == num_tokens
|
||||||
|
|
||||||
prefill_metadata = None
|
prefill_metadata = None
|
||||||
if self._num_prefills > 0:
|
if num_prefills > 0:
|
||||||
reqs_start = self._num_decodes # prefill_start
|
reqs_start = num_decodes # prefill_start
|
||||||
|
|
||||||
context_lens_cpu = self.runner.input_batch.\
|
context_lens_cpu = num_computed_tokens[reqs_start:num_reqs]
|
||||||
num_computed_tokens_cpu_tensor[reqs_start:num_reqs]
|
|
||||||
max_context_len_cpu = context_lens_cpu.max().item()
|
max_context_len_cpu = context_lens_cpu.max().item()
|
||||||
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
|
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
|
||||||
prefill_query_start_loc = query_start_loc[
|
prefill_query_start_loc = query_start_loc[
|
||||||
reqs_start:] - query_start_loc[reqs_start]
|
reqs_start:] - query_start_loc[reqs_start]
|
||||||
|
|
||||||
chunked_context_metadata = None
|
chunked_context_metadata = None
|
||||||
if self.chunked_prefill_enabled and self._num_prefills > 0 \
|
if self.chunked_prefill_enabled and num_prefills > 0 \
|
||||||
and max_context_len_cpu > 0:
|
and max_context_len_cpu > 0:
|
||||||
# NOTE: it is recommend you read the `Chunked Prefill` section
|
# NOTE: it is recommend you read the `Chunked Prefill` section
|
||||||
# in the comment at the top of the file before trying to
|
# in the comment at the top of the file before trying to
|
||||||
@ -509,14 +540,14 @@ class MLACommonMetadataBuilder(Generic[M]):
|
|||||||
# of `to_list`.
|
# of `to_list`.
|
||||||
chunk_starts = \
|
chunk_starts = \
|
||||||
torch.arange(num_chunks, dtype=torch.int32) \
|
torch.arange(num_chunks, dtype=torch.int32) \
|
||||||
.unsqueeze(1).expand(-1, self._num_prefills) \
|
.unsqueeze(1).expand(-1, num_prefills) \
|
||||||
* max_context_chunk
|
* max_context_chunk
|
||||||
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
|
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
|
||||||
chunk_starts + max_context_chunk)
|
chunk_starts + max_context_chunk)
|
||||||
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
|
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
|
||||||
|
|
||||||
cu_seq_lens_cpu = torch.zeros(num_chunks,
|
cu_seq_lens_cpu = torch.zeros(num_chunks,
|
||||||
self._num_prefills + 1,
|
num_prefills + 1,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
pin_memory=True)
|
pin_memory=True)
|
||||||
torch.cumsum(chunk_seq_lens,
|
torch.cumsum(chunk_seq_lens,
|
||||||
@ -544,25 +575,36 @@ class MLACommonMetadataBuilder(Generic[M]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
decode_metadata = None
|
decode_metadata = None
|
||||||
if self._num_decodes > 0:
|
if num_decodes > 0:
|
||||||
decode_metadata = self._build_decode(
|
decode_metadata = self._build_decode(
|
||||||
block_table_tensor=block_table_tensor[:self._num_decodes, ...],
|
block_table_tensor=block_table_tensor[:num_decodes, ...],
|
||||||
seq_lens=seq_lens[:self._num_decodes],
|
seq_lens=seq_lens[:num_decodes],
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.metadata_cls(
|
return self.metadata_cls(
|
||||||
num_actual_tokens=num_actual_tokens,
|
num_actual_tokens=num_tokens,
|
||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
head_dim=self.runner.model_config.get_head_size(),
|
head_dim=self.runner.model_config.get_head_size(),
|
||||||
# MLACommonMetadata Chunk prefill specific
|
# MLACommonMetadata Chunk prefill specific
|
||||||
num_decodes=self._num_decodes,
|
num_decodes=num_decodes,
|
||||||
num_decode_tokens=self._num_decode_tokens,
|
num_decode_tokens=num_decode_tokens,
|
||||||
num_prefills=self._num_prefills,
|
num_prefills=num_prefills,
|
||||||
prefill=prefill_metadata,
|
prefill=prefill_metadata,
|
||||||
decode=decode_metadata,
|
decode=decode_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
||||||
|
common_prefix_len: int,
|
||||||
|
common_attn_metadata: CommonAttentionMetadata):
|
||||||
|
return self.build_slice(
|
||||||
|
req_slice=slice(0, num_reqs),
|
||||||
|
token_slice=slice(0, num_actual_tokens),
|
||||||
|
max_query_len=max_query_len,
|
||||||
|
common_prefix_len=common_prefix_len,
|
||||||
|
common_attn_metadata=common_attn_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user