mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 22:44:29 +08:00
[mypy][5/N] Support all typing on model executor (#4427)
This commit is contained in:
parent
03dd7d52bf
commit
df29793dc7
2
.github/workflows/mypy.yaml
vendored
2
.github/workflows/mypy.yaml
vendored
@ -43,8 +43,8 @@ jobs:
|
||||
mypy vllm/worker --config-file pyproject.toml
|
||||
mypy vllm/spec_decode --config-file pyproject.toml
|
||||
mypy vllm/lora --config-file pyproject.toml
|
||||
mypy vllm/model_executor --config-file pyproject.toml
|
||||
|
||||
# TODO(sang): Fix nested dir
|
||||
mypy vllm/model_executor/*.py --config-file pyproject.toml
|
||||
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
|
||||
|
||||
@ -105,7 +105,7 @@ mypy vllm/transformers_utils --config-file pyproject.toml
|
||||
mypy vllm/engine --config-file pyproject.toml
|
||||
mypy vllm/worker --config-file pyproject.toml
|
||||
mypy vllm/spec_decode --config-file pyproject.toml
|
||||
mypy vllm/model_executor/*.py --config-file pyproject.toml
|
||||
mypy vllm/model_executor --config-file pyproject.toml
|
||||
mypy vllm/lora --config-file pyproject.toml
|
||||
|
||||
|
||||
|
||||
@ -61,6 +61,7 @@ def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict:
|
||||
return schema
|
||||
if isinstance(schema, BaseModel):
|
||||
return schema.model_json_schema()
|
||||
raise AssertionError(f"Unsupported schema type {schema}")
|
||||
|
||||
|
||||
@lru_cache
|
||||
|
||||
@ -128,7 +128,8 @@ class LinearBase(torch.nn.Module):
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
if quant_config is None:
|
||||
self.quant_method = UnquantizedLinearMethod()
|
||||
self.quant_method: Optional[
|
||||
QuantizeMethodBase] = UnquantizedLinearMethod()
|
||||
else:
|
||||
self.quant_method = quant_config.get_quant_method(self)
|
||||
|
||||
@ -160,6 +161,8 @@ class ReplicatedLinear(LinearBase):
|
||||
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
|
||||
quant_config)
|
||||
|
||||
# All the linear layer supports quant method.
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(self, self.input_size,
|
||||
[self.output_size], self.input_size,
|
||||
self.output_size, self.params_dtype)
|
||||
@ -173,6 +176,7 @@ class ReplicatedLinear(LinearBase):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
assert self.quant_method is not None
|
||||
output = self.quant_method.apply(self, x, bias)
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
@ -221,6 +225,8 @@ class ColumnParallelLinear(LinearBase):
|
||||
self.output_size_per_partition = divide(output_size, tp_size)
|
||||
if output_sizes is None:
|
||||
output_sizes = [output_size]
|
||||
# All the linear layer supports quant method.
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(self,
|
||||
self.input_size,
|
||||
[x // tp_size for x in output_sizes],
|
||||
@ -255,6 +261,7 @@ class ColumnParallelLinear(LinearBase):
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
output_parallel = self.quant_method.apply(self, input_, bias)
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
@ -579,6 +586,8 @@ class RowParallelLinear(LinearBase):
|
||||
# Divide the weight matrix along the last dimension.
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||
# All the linear layer supports quant method.
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(self,
|
||||
self.input_size_per_partition,
|
||||
[self.output_size],
|
||||
@ -624,6 +633,7 @@ class RowParallelLinear(LinearBase):
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
output_parallel = self.quant_method.apply(self, input_parallel)
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Type
|
||||
from typing import Dict, Type
|
||||
|
||||
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
@ -9,7 +9,7 @@ from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
||||
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
|
||||
|
||||
QUANTIZATION_METHODS = {
|
||||
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||
"aqlm": AQLMConfig,
|
||||
"awq": AWQConfig,
|
||||
"fp8": Fp8Config,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -76,8 +76,16 @@ class QuantizationConfig(ABC):
|
||||
"quantization config.")
|
||||
|
||||
@abstractmethod
|
||||
def get_quant_method(self, layer: torch.nn.Module) -> QuantizeMethodBase:
|
||||
"""Get the quantize method to use for the quantized layer."""
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]:
|
||||
"""Get the quantize method to use for the quantized layer.
|
||||
|
||||
Args:
|
||||
layer: The layer for the quant method.
|
||||
Returns:
|
||||
The quantize method. None if the given layer doesn't support quant
|
||||
method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@ -52,11 +52,10 @@ class SqueezeLLMConfig(QuantizationConfig):
|
||||
return cls(weight_bits)
|
||||
|
||||
def get_quant_method(
|
||||
self,
|
||||
layer: torch.nn.Module) -> Optional["SqueezeLLMLinearMethod"]:
|
||||
self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return SqueezeLLMLinearMethod(self)
|
||||
return
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
@ -431,8 +431,8 @@ class Phi3SuScaledRotaryEmbedding(nn.Module):
|
||||
torch.full_like(positions, k)).long()
|
||||
idx = (torch.add(positions, long_prompt_offset)
|
||||
if long_prompt_offset is not None else positions)
|
||||
self.long_short_cos_sin_cache = self.long_short_cos_sin_cache.to(
|
||||
idx.device)
|
||||
self.long_short_cos_sin_cache: torch.Tensor = (
|
||||
self.long_short_cos_sin_cache.to(idx.device))
|
||||
idx = torch.add(idx, offsets) if offsets is not None else idx
|
||||
cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)
|
||||
|
||||
|
||||
@ -13,6 +13,9 @@ from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
|
||||
SamplerOutput, SequenceGroupOutput, SequenceOutput)
|
||||
|
||||
# (num_token_ids, num_parent_ids) per sequence group.
|
||||
SampleResultType = List[Tuple[List[int], List[int]]]
|
||||
|
||||
|
||||
class Sampler(nn.Module):
|
||||
"""Samples the next tokens from the model's outputs.
|
||||
@ -155,7 +158,7 @@ def _apply_min_tokens_penalty(
|
||||
have not been generated yet
|
||||
"""
|
||||
# list of indices in logits that will be set to -inf
|
||||
logits_to_penalize = []
|
||||
logits_to_penalize: List[Tuple[int, int]] = []
|
||||
logits_applied = 0
|
||||
for seq_group in sampling_metadata.seq_groups:
|
||||
seq_ids = seq_group.seq_ids
|
||||
@ -269,7 +272,7 @@ def _apply_min_p(
|
||||
def _greedy_sample(
|
||||
selected_seq_groups: List[SequenceGroupToSample],
|
||||
samples: torch.Tensor,
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
) -> SampleResultType:
|
||||
"""Run greedy sampling on a given samples.
|
||||
|
||||
Args:
|
||||
@ -284,7 +287,7 @@ def _greedy_sample(
|
||||
"""
|
||||
samples = samples.tolist()
|
||||
sample_idx = 0
|
||||
results = []
|
||||
results: SampleResultType = []
|
||||
for seq_group in selected_seq_groups:
|
||||
if not seq_group.do_sample:
|
||||
results.append(([], []))
|
||||
@ -304,7 +307,7 @@ def _greedy_sample(
|
||||
def _random_sample(
|
||||
selected_seq_groups: List[SequenceGroupToSample],
|
||||
random_samples: torch.Tensor,
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
) -> SampleResultType:
|
||||
"""Run random sampling on a given samples.
|
||||
|
||||
Args:
|
||||
@ -320,7 +323,7 @@ def _random_sample(
|
||||
# Find the maximum best_of value of the prompt phase requests.
|
||||
random_samples = random_samples.cpu()
|
||||
sample_idx = 0
|
||||
results = []
|
||||
results: SampleResultType = []
|
||||
for seq_group in selected_seq_groups:
|
||||
if not seq_group.do_sample:
|
||||
results.append(([], []))
|
||||
@ -348,7 +351,7 @@ def _random_sample(
|
||||
def _beam_search_sample(
|
||||
selected_seq_groups: List[SequenceGroupToSample],
|
||||
logprobs: torch.Tensor,
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
) -> SampleResultType:
|
||||
"""Run beam sampling on a given samples.
|
||||
|
||||
Args:
|
||||
@ -370,7 +373,7 @@ def _beam_search_sample(
|
||||
# NOTE: Beam search is not vectorized, so its speed can be slower than
|
||||
# other sampling methods.
|
||||
sample_idx = 0
|
||||
results = []
|
||||
results: SampleResultType = []
|
||||
for seq_group in selected_seq_groups:
|
||||
if not seq_group.do_sample:
|
||||
results.append(([], []))
|
||||
@ -391,16 +394,16 @@ def _beam_search_sample(
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
else:
|
||||
# Generation phase.
|
||||
cumulative_logprobs = [
|
||||
cumulative_logprobs: List[int] = [
|
||||
seq_group.seq_data[seq_id].cumulative_logprob
|
||||
for seq_id in seq_ids
|
||||
]
|
||||
cumulative_logprobs = torch.tensor(
|
||||
cumulative_logprobs_tensor = torch.tensor(
|
||||
cumulative_logprobs,
|
||||
dtype=torch.float,
|
||||
device=seq_group_logprobs.device)
|
||||
seq_group_logprobs = (seq_group_logprobs +
|
||||
cumulative_logprobs.unsqueeze(dim=1))
|
||||
cumulative_logprobs_tensor.unsqueeze(dim=1))
|
||||
_, topk_ids = torch.topk(seq_group_logprobs.flatten(),
|
||||
2 * beam_width)
|
||||
topk_ids = topk_ids.tolist()
|
||||
@ -452,8 +455,10 @@ def _sample_with_torch(
|
||||
sampling_metadata: SamplingMetadata,
|
||||
include_gpu_probs_tensor: bool,
|
||||
modify_greedy_probs: bool,
|
||||
) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
|
||||
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
||||
) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
|
||||
categorized_seq_group_ids: Dict[SamplingType,
|
||||
List[int]] = {t: []
|
||||
for t in SamplingType}
|
||||
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
||||
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||
sampling_params = seq_group.sampling_params
|
||||
@ -555,8 +560,10 @@ def _sample_with_triton_kernel(
|
||||
logprobs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
sampling_tensors: SamplingTensors,
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
||||
) -> SampleResultType:
|
||||
categorized_seq_group_ids: Dict[SamplingType,
|
||||
List[int]] = {t: []
|
||||
for t in SamplingType}
|
||||
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
||||
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||
sampling_params = seq_group.sampling_params
|
||||
@ -632,7 +639,7 @@ def _sample(
|
||||
probs: torch.Tensor, logprobs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
|
||||
include_gpu_probs_tensor: bool, modify_greedy_probs: bool
|
||||
) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
|
||||
) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
probs: (num_query_tokens_in_batch, num_vocab)
|
||||
@ -680,7 +687,7 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
|
||||
def _get_logprobs(
|
||||
logprobs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
sample_results: List[Tuple[List[int], List[int]]],
|
||||
sample_results: SampleResultType,
|
||||
) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
|
||||
"""Return sample lobprobs and prompt logprobs.
|
||||
|
||||
@ -751,8 +758,8 @@ def _get_logprobs(
|
||||
assert len(next_token_ids) == len(query_indices)
|
||||
|
||||
if len(query_indices) == 0:
|
||||
empty_sampled_logprob = []
|
||||
empty_prompt_logprob = None
|
||||
empty_sampled_logprob: SampleLogprobs = []
|
||||
empty_prompt_logprob: Optional[PromptLogprobs] = None
|
||||
return [empty_prompt_logprob], [empty_sampled_logprob]
|
||||
|
||||
query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
|
||||
@ -965,7 +972,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
|
||||
|
||||
|
||||
def _build_sampler_output(
|
||||
sample_results: List[Tuple[List[int], List[int]]],
|
||||
sample_results: SampleResultType,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
prompt_logprobs: List[Optional[PromptLogprobs]],
|
||||
sample_logprobs: List[SampleLogprobs],
|
||||
@ -1009,7 +1016,7 @@ def _build_sampler_output(
|
||||
)
|
||||
|
||||
|
||||
def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[str]:
|
||||
def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
|
||||
"""Get a list of next prompt tokens to compute logprob from a
|
||||
given sequence group.
|
||||
|
||||
|
||||
@ -64,7 +64,7 @@ class TensorizerConfig:
|
||||
"s3_secret_access_key": self.s3_secret_access_key,
|
||||
"s3_endpoint": self.s3_endpoint,
|
||||
}
|
||||
return TensorizerArgs(**tensorizer_args)
|
||||
return TensorizerArgs(**tensorizer_args) # type: ignore
|
||||
|
||||
def verify_with_parallel_config(
|
||||
self,
|
||||
@ -270,8 +270,10 @@ class TensorizerAgent:
|
||||
self.model = self._init_model()
|
||||
|
||||
def _init_model(self):
|
||||
assert self.tensorizer_config.hf_config is not None
|
||||
model_args = self.tensorizer_config.hf_config
|
||||
model_args.torch_dtype = self.tensorizer_config.dtype
|
||||
assert self.tensorizer_config.model_class is not None
|
||||
with no_init_or_tensor():
|
||||
return self.tensorizer_config.model_class(
|
||||
config=model_args,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user