diff --git a/examples/offline_inference/tpu.py b/examples/offline_inference/tpu.py index bd0e984627d11..4a8f17ba1d0d7 100644 --- a/examples/offline_inference/tpu.py +++ b/examples/offline_inference/tpu.py @@ -21,7 +21,9 @@ sampling_params = SamplingParams(temperature=0.7, # Set `enforce_eager=True` to avoid ahead-of-time compilation. # In real workloads, `enforace_eager` should be `False`. -llm = LLM(model="google/gemma-2b", enforce_eager=True) +llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", + max_num_batched_tokens=64, + max_num_seqs=4) outputs = llm.generate(prompts, sampling_params) for output, answer in zip(outputs, answers): prompt = output.prompt diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 1897d859b71c9..00869467be341 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -401,6 +401,7 @@ class TPUModelRunner: self.query_start_loc_np[0] = 0 np.cumsum(num_scheduled_tokens_per_req, out=self.query_start_loc_np[1:num_reqs + 1]) + self.query_start_loc_np[num_reqs + 1:] = 1 self.seq_lens_np[:num_reqs] = ( self.input_batch.num_computed_tokens_cpu[:num_reqs] + @@ -441,7 +442,10 @@ class TPUModelRunner: # partial request, we do so for simplicity. We will ignore the sampled # token from the partial request. # TODO: Support prompt logprobs. - logits_indices = query_start_loc[1:] - 1 + padded_num_reqs = _get_padded_num_reqs_with_upper_limit( + num_reqs, self.max_num_reqs) + logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1 + logits_indices = logits_indices.to(self.device) return attn_metadata, logits_indices def _execute_encoder(self, scheduler_output: "SchedulerOutput"): @@ -551,7 +555,6 @@ class TPUModelRunner: # Prepare inputs attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) - total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if self.is_multimodal_model: # NOTE(woosuk): To unify token ids and soft tokens (vision @@ -579,12 +582,10 @@ class TPUModelRunner: kv_caches=self.kv_caches, inputs_embeds=inputs_embeds, ) - hidden_states = hidden_states[:total_num_scheduled_tokens] num_reqs = self.input_batch.num_reqs - logits_indices = logits_indices[:num_reqs] - hidden_states = hidden_states[logits_indices] - logits = self.model.compute_logits(hidden_states, None) - selected_token_ids = torch.argmax(logits, dim=-1, keepdim=True) + selected_token_ids = self.model.compute_logits(hidden_states, + logits_indices, None) + selected_token_ids = selected_token_ids.cpu()[:num_reqs] # Then, let's update the cache state. request_seq_lens: list[tuple[int, CachedRequestState, int]] = [] @@ -726,12 +727,31 @@ class TPUModelRunner: with set_forward_context(attn_metadata, self.vllm_config, 0): assert self.model is not None - self.model( + hidden_states = self.model( input_ids=input_ids, positions=position_ids, kv_caches=kv_caches, inputs_embeds=inputs_embeds, ) + num_reqs = _get_padded_num_reqs_with_upper_limit( + 64, self.max_num_reqs) + # NOTE(chengjiyao): In total, the compute_logits function utilizes a + # compilation cache size of token_bucket_num multiplied by + # req_bucket_num. This is acceptable, given the graph's relatively + # small size. + while True: + logits_indices = torch.zeros( + num_reqs, + dtype=torch.int32, + device=self.device, + ) + torch._dynamo.mark_dynamic(hidden_states, 0) + torch._dynamo.mark_dynamic(logits_indices, 0) + self.model.compute_logits(hidden_states, logits_indices, None) + if num_reqs >= self.max_num_reqs: + break + num_reqs = _get_padded_num_reqs_with_upper_limit( + num_reqs + 1, self.max_num_reqs) def capture_model(self) -> None: """Compile the model.""" @@ -823,13 +843,17 @@ class ModelWrapperV1(nn.Module): return hidden_states + @torch.compile(backend="openxla", fullgraph=True, dynamic=False) def compute_logits( self, hidden_states: torch.Tensor, + logits_indices: torch.Tensor, sampling_metadata, ) -> Optional[torch.Tensor]: + hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(hidden_states, sampling_metadata) - return logits + selected_token_ids = torch.argmax(logits, dim=-1, keepdim=True) + return selected_token_ids def get_multimodal_embeddings(self, *args, **kwargs): return self.model.get_multimodal_embeddings(*args, **kwargs) @@ -846,3 +870,8 @@ def _get_padded_token_len(x: int) -> int: if x <= 16: return 16 return 1 << (x - 1).bit_length() + + +def _get_padded_num_reqs_with_upper_limit(x, upper_limit) -> int: + res = 64 if x <= 64 else 1 << (x - 1).bit_length() + return min(res, upper_limit)