mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-29 08:07:05 +08:00
[TPU] Use mark_dynamic only for dummy run (#7634)
This commit is contained in:
parent
ce143353c6
commit
0c2fa50b84
@ -144,7 +144,11 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
)
|
)
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
xm.wait_device_ops()
|
xm.wait_device_ops()
|
||||||
self.model = CompiledModelWrapper(model)
|
model = ModelWrapper(model)
|
||||||
|
self.model = torch.compile(model,
|
||||||
|
backend="openxla",
|
||||||
|
fullgraph=True,
|
||||||
|
dynamic=False)
|
||||||
|
|
||||||
def _dummy_run(
|
def _dummy_run(
|
||||||
self,
|
self,
|
||||||
@ -206,9 +210,31 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
)
|
)
|
||||||
t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
|
t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
|
||||||
p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
|
p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
|
||||||
|
|
||||||
# Dummy run.
|
|
||||||
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
|
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
|
||||||
|
|
||||||
|
# NOTE(woosuk): There are two stages of compilation: torch.compile and
|
||||||
|
# XLA compilation. Using `mark_dynamic` can reduce the torch.compile
|
||||||
|
# overhead by reusing the FX graph for different shapes.
|
||||||
|
# However, the XLA graph will still require static shapes and needs to
|
||||||
|
# be re-compiled for every different shapes. This overhead is inevitable
|
||||||
|
# in the first run, but can be skipped afterwards as we cache the XLA
|
||||||
|
# graphs in the disk (VLLM_XLA_CACHE_PATH).
|
||||||
|
if is_prompt:
|
||||||
|
# Prefll
|
||||||
|
torch._dynamo.mark_dynamic(token_ids, 1)
|
||||||
|
torch._dynamo.mark_dynamic(position_ids, 1)
|
||||||
|
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1)
|
||||||
|
else:
|
||||||
|
# Decode
|
||||||
|
torch._dynamo.mark_dynamic(token_ids, 0)
|
||||||
|
torch._dynamo.mark_dynamic(position_ids, 0)
|
||||||
|
torch._dynamo.mark_dynamic(input_lens, 0)
|
||||||
|
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
|
||||||
|
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
|
||||||
|
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
|
||||||
|
torch._dynamo.mark_dynamic(t, 0)
|
||||||
|
torch._dynamo.mark_dynamic(p, 0)
|
||||||
|
# Dummy run.
|
||||||
self.model(token_ids, position_ids, attn_metadata, input_lens, t, p,
|
self.model(token_ids, position_ids, attn_metadata, input_lens, t, p,
|
||||||
num_samples, kv_caches)
|
num_samples, kv_caches)
|
||||||
|
|
||||||
@ -682,52 +708,6 @@ class ModelWrapper(nn.Module):
|
|||||||
return next_token_ids
|
return next_token_ids
|
||||||
|
|
||||||
|
|
||||||
class CompiledModelWrapper:
|
|
||||||
|
|
||||||
def __init__(self, model: nn.Module):
|
|
||||||
model = ModelWrapper(model)
|
|
||||||
self.model = torch.compile(model,
|
|
||||||
backend="openxla",
|
|
||||||
fullgraph=True,
|
|
||||||
dynamic=False)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
token_ids: torch.Tensor,
|
|
||||||
position_ids: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
input_lens: torch.Tensor,
|
|
||||||
t: torch.Tensor,
|
|
||||||
p: torch.Tensor,
|
|
||||||
num_samples: int,
|
|
||||||
kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
# NOTE(woosuk): There are two stages of compilation: torch.compile and
|
|
||||||
# XLA compilation. Using `mark_dynamic` can reduce the torch.compile
|
|
||||||
# overhead by reusing the FX graph for different shapes.
|
|
||||||
# However, the XLA graph will still require static shapes and needs to
|
|
||||||
# be re-compiled for every different shapes. This overhead is inevitable
|
|
||||||
# in the first run, but can be skipped afterwards as we cache the XLA
|
|
||||||
# graphs in the disk (VLLM_XLA_CACHE_PATH).
|
|
||||||
if attn_metadata.num_prefills > 0:
|
|
||||||
# Prefll
|
|
||||||
torch._dynamo.mark_dynamic(token_ids, 1)
|
|
||||||
torch._dynamo.mark_dynamic(position_ids, 1)
|
|
||||||
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1)
|
|
||||||
else:
|
|
||||||
# Decode
|
|
||||||
torch._dynamo.mark_dynamic(token_ids, 0)
|
|
||||||
torch._dynamo.mark_dynamic(position_ids, 0)
|
|
||||||
torch._dynamo.mark_dynamic(input_lens, 0)
|
|
||||||
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
|
|
||||||
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
|
|
||||||
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
|
|
||||||
torch._dynamo.mark_dynamic(t, 0)
|
|
||||||
torch._dynamo.mark_dynamic(p, 0)
|
|
||||||
return self.model(token_ids, position_ids, attn_metadata, input_lens,
|
|
||||||
t, p, num_samples, kv_caches)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_padded_prefill_len(x: int) -> int:
|
def _get_padded_prefill_len(x: int) -> int:
|
||||||
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
|
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
|
||||||
# length to be a multiple of 16. We pad the prompt length to the nearest
|
# length to be a multiple of 16. We pad the prompt length to the nearest
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user