diff --git a/examples/offline_inference/basic/basic.py b/examples/offline_inference/basic/basic.py index bbab7d97ae11d..15aa8dd17e64b 100644 --- a/examples/offline_inference/basic/basic.py +++ b/examples/offline_inference/basic/basic.py @@ -30,10 +30,15 @@ sampling_params = SamplingParams(**param_kwargs) def main(): # Create an LLM. - llm = LLM(model="facebook/opt-125m", + llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite", enforce_eager=False, compilation_config=2, - enable_microbatching=True,) + enable_microbatching=True, + trust_remote_code=True, + tensor_parallel_size=4, + max_model_len=1024, + #load_format="dummy", + ) # Generate texts from the prompts. # The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 0c1381a565c16..ae2372b2a6e6b 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -740,7 +740,7 @@ class PiecewiseBackend: # manage the memory during cuda graph capture return output - if self.is_debugging_mode: + if self.is_debugging_mode or envs.VLLM_CUDAGRAPH_SANITIZER: # check if the input addresses are the same new_input_addresses = [ x.data_ptr() for x in args if isinstance(x, torch.Tensor) diff --git a/vllm/envs.py b/vllm/envs.py index dc23c8ea5314d..9d226b298cefb 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -117,6 +117,7 @@ if TYPE_CHECKING: VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost" VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557 VLLM_ALL2ALL_BACKEND: str = "naive" + VLLM_CUDAGRAPH_SANITIZER: bool = False def get_default_cache_root(): @@ -811,6 +812,10 @@ environment_variables: dict[str, Callable[[], Any]] = { # all2all backend for vllm's expert parallel communication "VLLM_ALL2ALL_BACKEND": lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"), + + # check that the cudagraphs input addresses are correct before replaying + "VLLM_CUDAGRAPH_SANITIZER": + lambda: os.getenv("VLLM_CUDAGRAPH_SANITIZER", "0") == "1", } # end-env-vars-definition diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 83e1811165777..4f4a6354a51f8 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -210,6 +210,8 @@ from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable +from vllm.v1.attention.backends.utils import slice_query_start_locs + try: from vllm.vllm_flash_attn import flash_attn_varlen_func is_vllm_fa = True @@ -432,14 +434,6 @@ class MLACommonMetadataBuilder(Generic[M]): input_batch.swap_states(prefills[i - 1], decode_idx) modified_batch = True - # Save for next `build` call - # TODO(lucas): this is a bit of a hack, we should probably have a - # better way of doing this - self._num_decodes = num_decodes - self._num_prefills = num_prefills - self._num_decode_tokens = num_decode_tokens - self._num_prefill_tokens = num_prefill_tokens - return modified_batch def _build_decode(self, block_table_tensor: torch.Tensor, @@ -448,37 +442,74 @@ class MLACommonMetadataBuilder(Generic[M]): block_table=block_table_tensor, seq_lens=seq_lens, ) - - def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata) -> M: - assert self._num_decodes + self._num_prefills == num_reqs + + def _split_decodes_and_prefills(self, max_query_len: int, num_reqs: int, num_tokens: int, query_start_loc: torch.Tensor): + """ + return + - num_decodes: number of decode requests + - num_prefills: number of prefill requests + - num_decode_tokens: number of decode tokens + - num_prefill_tokens: number of prefill tokens + """ + if max_query_len == 1: + # Pure decode + return num_reqs, 0, num_tokens, 0 + else: + query_lens = query_start_loc[1:] - query_start_loc[:-1] + first_prefill = (query_lens > 1).int().argmax(dim=-1).item() + assert torch.all(query_lens[first_prefill:] > 1) + num_decodes = first_prefill + num_prefills = num_reqs - num_decodes + num_decode_tokens = first_prefill + num_prefill_tokens = num_tokens - query_start_loc[first_prefill] + return ( + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens + ) + + def build_slice(self, req_slice: slice, + token_slice: slice, + max_query_len: int, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + ) -> M: + num_reqs = req_slice.stop - req_slice.start + num_tokens = token_slice.stop - token_slice.start # Note(simon): be careful about the CPU <> GPU memory movement in this # function. We should avoid GPU -> CPU sync as much as possible because # it blocks on all previous kernels. device = self.runner.device block_table = self.block_table - block_table_tensor = block_table.get_device_tensor()[:num_reqs] - slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].to( + block_table_tensor = block_table.get_device_tensor()[req_slice] + slot_mapping = block_table.slot_mapping_cpu[token_slice].to( device, non_blocking=True).long() - query_start_loc = common_attn_metadata.query_start_loc - seq_lens = common_attn_metadata.seq_lens + query_start_loc = slice_query_start_locs( + common_attn_metadata.query_start_loc, req_slice) + seq_lens = common_attn_metadata.seq_lens[req_slice] + + num_computed_tokens = self.runner.input_batch.\ + num_computed_tokens_cpu_tensor[req_slice] + + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ + self._split_decodes_and_prefills( + max_query_len, num_reqs, num_tokens, query_start_loc) + + assert num_decodes + num_prefills == num_reqs + assert num_decode_tokens + num_prefill_tokens == num_tokens prefill_metadata = None - if self._num_prefills > 0: - reqs_start = self._num_decodes # prefill_start + if num_prefills > 0: + reqs_start = num_decodes # prefill_start - context_lens_cpu = self.runner.input_batch.\ - num_computed_tokens_cpu_tensor[reqs_start:num_reqs] + context_lens_cpu = num_computed_tokens[reqs_start:num_reqs] max_context_len_cpu = context_lens_cpu.max().item() num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() prefill_query_start_loc = query_start_loc[ reqs_start:] - query_start_loc[reqs_start] chunked_context_metadata = None - if self.chunked_prefill_enabled and self._num_prefills > 0 \ + if self.chunked_prefill_enabled and num_prefills > 0 \ and max_context_len_cpu > 0: # NOTE: it is recommend you read the `Chunked Prefill` section # in the comment at the top of the file before trying to @@ -509,14 +540,14 @@ class MLACommonMetadataBuilder(Generic[M]): # of `to_list`. chunk_starts = \ torch.arange(num_chunks, dtype=torch.int32) \ - .unsqueeze(1).expand(-1, self._num_prefills) \ + .unsqueeze(1).expand(-1, num_prefills) \ * max_context_chunk chunk_ends = torch.min(context_lens_cpu.unsqueeze(0), chunk_starts + max_context_chunk) chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) cu_seq_lens_cpu = torch.zeros(num_chunks, - self._num_prefills + 1, + num_prefills + 1, dtype=torch.int32, pin_memory=True) torch.cumsum(chunk_seq_lens, @@ -544,25 +575,36 @@ class MLACommonMetadataBuilder(Generic[M]): ) decode_metadata = None - if self._num_decodes > 0: + if num_decodes > 0: decode_metadata = self._build_decode( - block_table_tensor=block_table_tensor[:self._num_decodes, ...], - seq_lens=seq_lens[:self._num_decodes], + block_table_tensor=block_table_tensor[:num_decodes, ...], + seq_lens=seq_lens[:num_decodes], ) return self.metadata_cls( - num_actual_tokens=num_actual_tokens, + num_actual_tokens=num_tokens, query_start_loc=query_start_loc, slot_mapping=slot_mapping, head_dim=self.runner.model_config.get_head_size(), # MLACommonMetadata Chunk prefill specific - num_decodes=self._num_decodes, - num_decode_tokens=self._num_decode_tokens, - num_prefills=self._num_prefills, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, prefill=prefill_metadata, decode=decode_metadata, ) + def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata): + return self.build_slice( + req_slice=slice(0, num_reqs), + token_slice=slice(0, num_actual_tokens), + max_query_len=max_query_len, + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + ) + def use_cascade_attention(self, *args, **kwargs) -> bool: return False