mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 19:44:28 +08:00
[5/N][torch.compile] torch.jit.script --> torch.compile (#10406)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
4186be8111
commit
7851b45196
@ -368,7 +368,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
|
|||||||
# Note that we always sample with replacement.
|
# Note that we always sample with replacement.
|
||||||
# probs will be modified in place, but this is fine, as we pass
|
# probs will be modified in place, but this is fine, as we pass
|
||||||
# in a copy already.
|
# in a copy already.
|
||||||
@torch.jit.script
|
@torch.compile(dynamic=True)
|
||||||
def _multinomial(
|
def _multinomial(
|
||||||
probs: torch.Tensor,
|
probs: torch.Tensor,
|
||||||
num_samples: int,
|
num_samples: int,
|
||||||
|
|||||||
@ -133,13 +133,13 @@ class VocabParallelEmbeddingShardIndices:
|
|||||||
assert self.num_added_elements <= self.num_added_elements_padded
|
assert self.num_added_elements <= self.num_added_elements_padded
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.compile(dynamic=True)
|
||||||
def get_masked_input_and_mask(
|
def get_masked_input_and_mask(
|
||||||
input_: torch.Tensor, org_vocab_start_index: int,
|
input_: torch.Tensor, org_vocab_start_index: int,
|
||||||
org_vocab_end_index: int, num_org_vocab_padding: int,
|
org_vocab_end_index: int, num_org_vocab_padding: int,
|
||||||
added_vocab_start_index: int,
|
added_vocab_start_index: int,
|
||||||
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# torch.jit.script will fuse all of the pointwise ops below
|
# torch.compile will fuse all of the pointwise ops below
|
||||||
# into a single kernel, making it very fast
|
# into a single kernel, making it very fast
|
||||||
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ <
|
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ <
|
||||||
org_vocab_end_index)
|
org_vocab_end_index)
|
||||||
|
|||||||
@ -54,12 +54,12 @@ class HeadMajorColumnParallelLinear(MergedColumnParallelLinear):
|
|||||||
return load_column_parallel_weight(param, loaded_weight)
|
return load_column_parallel_weight(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.compile(dynamic=True)
|
||||||
def quick_gelu(x):
|
def quick_gelu(x):
|
||||||
return x * torch.sigmoid(1.702 * x)
|
return x * torch.sigmoid(1.702 * x)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.compile(dynamic=True)
|
||||||
def gegelu(input, limit: Optional[float] = None):
|
def gegelu(input, limit: Optional[float] = None):
|
||||||
a_gelu, a_linear = input[..., ::2], input[..., 1::2]
|
a_gelu, a_linear = input[..., ::2], input[..., 1::2]
|
||||||
if limit is not None:
|
if limit is not None:
|
||||||
|
|||||||
@ -1769,7 +1769,7 @@ class CUDAGraphRunner(nn.Module):
|
|||||||
# Run the model a few times without capturing the graph.
|
# Run the model a few times without capturing the graph.
|
||||||
# This is to make sure that the captured graph does not include the
|
# This is to make sure that the captured graph does not include the
|
||||||
# kernel launches for initial benchmarking (e.g., Triton autotune).
|
# kernel launches for initial benchmarking (e.g., Triton autotune).
|
||||||
# Note one iteration is not enough for torch.jit.script
|
# Note one iteration is not enough for torch.compile
|
||||||
for _ in range(_NUM_WARMUP_ITERS):
|
for _ in range(_NUM_WARMUP_ITERS):
|
||||||
self.model(
|
self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user