[TPU] Fix the test_sampler (#17820)

This commit is contained in:
Jevin Jiang 2025-05-08 02:51:33 -07:00 committed by GitHub
parent ca04b97c93
commit a463555dee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 2 additions and 2 deletions

View File

@ -26,7 +26,7 @@ def test_sampler_different(model_name: str):
enforce_eager=False, enforce_eager=False,
max_num_seqs=1, max_num_seqs=1,
max_model_len=512, max_model_len=512,
max_num_batched_tokens=512) max_num_batched_tokens=256)
prompts = [ prompts = [
"Write a short story about a robot that dreams for the first time." "Write a short story about a robot that dreams for the first time."
] ]

View File

@ -95,7 +95,7 @@ class PallasMetadata:
block_tables: torch.Tensor block_tables: torch.Tensor
context_lens: torch.Tensor context_lens: torch.Tensor
query_start_loc: torch.Tensor query_start_loc: torch.Tensor
num_seqs: int num_seqs: torch.Tensor
class PallasAttentionBackendImpl(AttentionImpl): class PallasAttentionBackendImpl(AttentionImpl):