mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-04 03:26:31 +08:00
[TPU] Use mark_dynamic only for dummy run (#7634)
This commit is contained in:
parent
ce143353c6
commit
0c2fa50b84
@ -144,7 +144,11 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
||||
)
|
||||
model = model.eval()
|
||||
xm.wait_device_ops()
|
||||
self.model = CompiledModelWrapper(model)
|
||||
model = ModelWrapper(model)
|
||||
self.model = torch.compile(model,
|
||||
backend="openxla",
|
||||
fullgraph=True,
|
||||
dynamic=False)
|
||||
|
||||
def _dummy_run(
|
||||
self,
|
||||
@ -206,9 +210,31 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
||||
)
|
||||
t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
|
||||
p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
|
||||
|
||||
# Dummy run.
|
||||
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
|
||||
|
||||
# NOTE(woosuk): There are two stages of compilation: torch.compile and
|
||||
# XLA compilation. Using `mark_dynamic` can reduce the torch.compile
|
||||
# overhead by reusing the FX graph for different shapes.
|
||||
# However, the XLA graph will still require static shapes and needs to
|
||||
# be re-compiled for every different shapes. This overhead is inevitable
|
||||
# in the first run, but can be skipped afterwards as we cache the XLA
|
||||
# graphs in the disk (VLLM_XLA_CACHE_PATH).
|
||||
if is_prompt:
|
||||
# Prefll
|
||||
torch._dynamo.mark_dynamic(token_ids, 1)
|
||||
torch._dynamo.mark_dynamic(position_ids, 1)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1)
|
||||
else:
|
||||
# Decode
|
||||
torch._dynamo.mark_dynamic(token_ids, 0)
|
||||
torch._dynamo.mark_dynamic(position_ids, 0)
|
||||
torch._dynamo.mark_dynamic(input_lens, 0)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
|
||||
torch._dynamo.mark_dynamic(t, 0)
|
||||
torch._dynamo.mark_dynamic(p, 0)
|
||||
# Dummy run.
|
||||
self.model(token_ids, position_ids, attn_metadata, input_lens, t, p,
|
||||
num_samples, kv_caches)
|
||||
|
||||
@ -682,52 +708,6 @@ class ModelWrapper(nn.Module):
|
||||
return next_token_ids
|
||||
|
||||
|
||||
class CompiledModelWrapper:
|
||||
|
||||
def __init__(self, model: nn.Module):
|
||||
model = ModelWrapper(model)
|
||||
self.model = torch.compile(model,
|
||||
backend="openxla",
|
||||
fullgraph=True,
|
||||
dynamic=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
token_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
input_lens: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
p: torch.Tensor,
|
||||
num_samples: int,
|
||||
kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
|
||||
) -> torch.Tensor:
|
||||
# NOTE(woosuk): There are two stages of compilation: torch.compile and
|
||||
# XLA compilation. Using `mark_dynamic` can reduce the torch.compile
|
||||
# overhead by reusing the FX graph for different shapes.
|
||||
# However, the XLA graph will still require static shapes and needs to
|
||||
# be re-compiled for every different shapes. This overhead is inevitable
|
||||
# in the first run, but can be skipped afterwards as we cache the XLA
|
||||
# graphs in the disk (VLLM_XLA_CACHE_PATH).
|
||||
if attn_metadata.num_prefills > 0:
|
||||
# Prefll
|
||||
torch._dynamo.mark_dynamic(token_ids, 1)
|
||||
torch._dynamo.mark_dynamic(position_ids, 1)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1)
|
||||
else:
|
||||
# Decode
|
||||
torch._dynamo.mark_dynamic(token_ids, 0)
|
||||
torch._dynamo.mark_dynamic(position_ids, 0)
|
||||
torch._dynamo.mark_dynamic(input_lens, 0)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
|
||||
torch._dynamo.mark_dynamic(t, 0)
|
||||
torch._dynamo.mark_dynamic(p, 0)
|
||||
return self.model(token_ids, position_ids, attn_metadata, input_lens,
|
||||
t, p, num_samples, kv_caches)
|
||||
|
||||
|
||||
def _get_padded_prefill_len(x: int) -> int:
|
||||
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
|
||||
# length to be a multiple of 16. We pad the prompt length to the nearest
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user