[mypy][5/N] Support all typing on model executor (#4427)

This commit is contained in:
SangBin Cho 2024-04-29 11:01:26 +09:00 committed by GitHub
parent 03dd7d52bf
commit df29793dc7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 61 additions and 34 deletions

View File

@ -43,8 +43,8 @@ jobs:
mypy vllm/worker --config-file pyproject.toml mypy vllm/worker --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml mypy vllm/lora --config-file pyproject.toml
mypy vllm/model_executor --config-file pyproject.toml
# TODO(sang): Fix nested dir # TODO(sang): Fix nested dir
mypy vllm/model_executor/*.py --config-file pyproject.toml
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml

View File

@ -105,7 +105,7 @@ mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/engine --config-file pyproject.toml mypy vllm/engine --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml mypy vllm/worker --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/model_executor/*.py --config-file pyproject.toml mypy vllm/model_executor --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml mypy vllm/lora --config-file pyproject.toml

View File

@ -61,6 +61,7 @@ def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict:
return schema return schema
if isinstance(schema, BaseModel): if isinstance(schema, BaseModel):
return schema.model_json_schema() return schema.model_json_schema()
raise AssertionError(f"Unsupported schema type {schema}")
@lru_cache @lru_cache

View File

@ -128,7 +128,8 @@ class LinearBase(torch.nn.Module):
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype self.params_dtype = params_dtype
if quant_config is None: if quant_config is None:
self.quant_method = UnquantizedLinearMethod() self.quant_method: Optional[
QuantizeMethodBase] = UnquantizedLinearMethod()
else: else:
self.quant_method = quant_config.get_quant_method(self) self.quant_method = quant_config.get_quant_method(self)
@ -160,6 +161,8 @@ class ReplicatedLinear(LinearBase):
super().__init__(input_size, output_size, skip_bias_add, params_dtype, super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config) quant_config)
# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self, self.input_size, self.quant_method.create_weights(self, self.input_size,
[self.output_size], self.input_size, [self.output_size], self.input_size,
self.output_size, self.params_dtype) self.output_size, self.params_dtype)
@ -173,6 +176,7 @@ class ReplicatedLinear(LinearBase):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, x, bias) output = self.quant_method.apply(self, x, bias)
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
return output, output_bias return output, output_bias
@ -221,6 +225,8 @@ class ColumnParallelLinear(LinearBase):
self.output_size_per_partition = divide(output_size, tp_size) self.output_size_per_partition = divide(output_size, tp_size)
if output_sizes is None: if output_sizes is None:
output_sizes = [output_size] output_sizes = [output_size]
# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self, self.quant_method.create_weights(self,
self.input_size, self.input_size,
[x // tp_size for x in output_sizes], [x // tp_size for x in output_sizes],
@ -255,6 +261,7 @@ class ColumnParallelLinear(LinearBase):
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
# Matrix multiply. # Matrix multiply.
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias) output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
@ -579,6 +586,8 @@ class RowParallelLinear(LinearBase):
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size) self.input_size_per_partition = divide(input_size, self.tp_size)
# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self, self.quant_method.create_weights(self,
self.input_size_per_partition, self.input_size_per_partition,
[self.output_size], [self.output_size],
@ -624,6 +633,7 @@ class RowParallelLinear(LinearBase):
input_parallel = splitted_input[tp_rank].contiguous() input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply. # Matrix multiply.
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_parallel) output_parallel = self.quant_method.apply(self, input_parallel)
if self.reduce_results and self.tp_size > 1: if self.reduce_results and self.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel) output_ = tensor_model_parallel_all_reduce(output_parallel)

View File

@ -1,4 +1,4 @@
from typing import Type from typing import Dict, Type
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq import AWQConfig
@ -9,7 +9,7 @@ from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
QUANTIZATION_METHODS = { QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"aqlm": AQLMConfig, "aqlm": AQLMConfig,
"awq": AWQConfig, "awq": AWQConfig,
"fp8": Fp8Config, "fp8": Fp8Config,

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List from typing import Any, Dict, List, Optional
import torch import torch
from torch import nn from torch import nn
@ -76,8 +76,16 @@ class QuantizationConfig(ABC):
"quantization config.") "quantization config.")
@abstractmethod @abstractmethod
def get_quant_method(self, layer: torch.nn.Module) -> QuantizeMethodBase: def get_quant_method(
"""Get the quantize method to use for the quantized layer.""" self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]:
"""Get the quantize method to use for the quantized layer.
Args:
layer: The layer for the quant method.
Returns:
The quantize method. None if the given layer doesn't support quant
method.
"""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod

View File

@ -52,11 +52,10 @@ class SqueezeLLMConfig(QuantizationConfig):
return cls(weight_bits) return cls(weight_bits)
def get_quant_method( def get_quant_method(
self, self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]:
layer: torch.nn.Module) -> Optional["SqueezeLLMLinearMethod"]:
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return SqueezeLLMLinearMethod(self) return SqueezeLLMLinearMethod(self)
return return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
return [] return []

View File

@ -431,8 +431,8 @@ class Phi3SuScaledRotaryEmbedding(nn.Module):
torch.full_like(positions, k)).long() torch.full_like(positions, k)).long()
idx = (torch.add(positions, long_prompt_offset) idx = (torch.add(positions, long_prompt_offset)
if long_prompt_offset is not None else positions) if long_prompt_offset is not None else positions)
self.long_short_cos_sin_cache = self.long_short_cos_sin_cache.to( self.long_short_cos_sin_cache: torch.Tensor = (
idx.device) self.long_short_cos_sin_cache.to(idx.device))
idx = torch.add(idx, offsets) if offsets is not None else idx idx = torch.add(idx, offsets) if offsets is not None else idx
cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx) cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)

View File

@ -13,6 +13,9 @@ from vllm.sampling_params import SamplingType
from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs, from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
SamplerOutput, SequenceGroupOutput, SequenceOutput) SamplerOutput, SequenceGroupOutput, SequenceOutput)
# (num_token_ids, num_parent_ids) per sequence group.
SampleResultType = List[Tuple[List[int], List[int]]]
class Sampler(nn.Module): class Sampler(nn.Module):
"""Samples the next tokens from the model's outputs. """Samples the next tokens from the model's outputs.
@ -155,7 +158,7 @@ def _apply_min_tokens_penalty(
have not been generated yet have not been generated yet
""" """
# list of indices in logits that will be set to -inf # list of indices in logits that will be set to -inf
logits_to_penalize = [] logits_to_penalize: List[Tuple[int, int]] = []
logits_applied = 0 logits_applied = 0
for seq_group in sampling_metadata.seq_groups: for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids seq_ids = seq_group.seq_ids
@ -269,7 +272,7 @@ def _apply_min_p(
def _greedy_sample( def _greedy_sample(
selected_seq_groups: List[SequenceGroupToSample], selected_seq_groups: List[SequenceGroupToSample],
samples: torch.Tensor, samples: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]: ) -> SampleResultType:
"""Run greedy sampling on a given samples. """Run greedy sampling on a given samples.
Args: Args:
@ -284,7 +287,7 @@ def _greedy_sample(
""" """
samples = samples.tolist() samples = samples.tolist()
sample_idx = 0 sample_idx = 0
results = [] results: SampleResultType = []
for seq_group in selected_seq_groups: for seq_group in selected_seq_groups:
if not seq_group.do_sample: if not seq_group.do_sample:
results.append(([], [])) results.append(([], []))
@ -304,7 +307,7 @@ def _greedy_sample(
def _random_sample( def _random_sample(
selected_seq_groups: List[SequenceGroupToSample], selected_seq_groups: List[SequenceGroupToSample],
random_samples: torch.Tensor, random_samples: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]: ) -> SampleResultType:
"""Run random sampling on a given samples. """Run random sampling on a given samples.
Args: Args:
@ -320,7 +323,7 @@ def _random_sample(
# Find the maximum best_of value of the prompt phase requests. # Find the maximum best_of value of the prompt phase requests.
random_samples = random_samples.cpu() random_samples = random_samples.cpu()
sample_idx = 0 sample_idx = 0
results = [] results: SampleResultType = []
for seq_group in selected_seq_groups: for seq_group in selected_seq_groups:
if not seq_group.do_sample: if not seq_group.do_sample:
results.append(([], [])) results.append(([], []))
@ -348,7 +351,7 @@ def _random_sample(
def _beam_search_sample( def _beam_search_sample(
selected_seq_groups: List[SequenceGroupToSample], selected_seq_groups: List[SequenceGroupToSample],
logprobs: torch.Tensor, logprobs: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]: ) -> SampleResultType:
"""Run beam sampling on a given samples. """Run beam sampling on a given samples.
Args: Args:
@ -370,7 +373,7 @@ def _beam_search_sample(
# NOTE: Beam search is not vectorized, so its speed can be slower than # NOTE: Beam search is not vectorized, so its speed can be slower than
# other sampling methods. # other sampling methods.
sample_idx = 0 sample_idx = 0
results = [] results: SampleResultType = []
for seq_group in selected_seq_groups: for seq_group in selected_seq_groups:
if not seq_group.do_sample: if not seq_group.do_sample:
results.append(([], [])) results.append(([], []))
@ -391,16 +394,16 @@ def _beam_search_sample(
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
else: else:
# Generation phase. # Generation phase.
cumulative_logprobs = [ cumulative_logprobs: List[int] = [
seq_group.seq_data[seq_id].cumulative_logprob seq_group.seq_data[seq_id].cumulative_logprob
for seq_id in seq_ids for seq_id in seq_ids
] ]
cumulative_logprobs = torch.tensor( cumulative_logprobs_tensor = torch.tensor(
cumulative_logprobs, cumulative_logprobs,
dtype=torch.float, dtype=torch.float,
device=seq_group_logprobs.device) device=seq_group_logprobs.device)
seq_group_logprobs = (seq_group_logprobs + seq_group_logprobs = (seq_group_logprobs +
cumulative_logprobs.unsqueeze(dim=1)) cumulative_logprobs_tensor.unsqueeze(dim=1))
_, topk_ids = torch.topk(seq_group_logprobs.flatten(), _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
2 * beam_width) 2 * beam_width)
topk_ids = topk_ids.tolist() topk_ids = topk_ids.tolist()
@ -452,8 +455,10 @@ def _sample_with_torch(
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
include_gpu_probs_tensor: bool, include_gpu_probs_tensor: bool,
modify_greedy_probs: bool, modify_greedy_probs: bool,
) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]: ) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
categorized_seq_group_ids = {t: [] for t in SamplingType} categorized_seq_group_ids: Dict[SamplingType,
List[int]] = {t: []
for t in SamplingType}
categorized_sample_indices = sampling_metadata.categorized_sample_indices categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups): for i, seq_group in enumerate(sampling_metadata.seq_groups):
sampling_params = seq_group.sampling_params sampling_params = seq_group.sampling_params
@ -555,8 +560,10 @@ def _sample_with_triton_kernel(
logprobs: torch.Tensor, logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
sampling_tensors: SamplingTensors, sampling_tensors: SamplingTensors,
) -> List[Tuple[List[int], List[int]]]: ) -> SampleResultType:
categorized_seq_group_ids = {t: [] for t in SamplingType} categorized_seq_group_ids: Dict[SamplingType,
List[int]] = {t: []
for t in SamplingType}
categorized_sample_indices = sampling_metadata.categorized_sample_indices categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups): for i, seq_group in enumerate(sampling_metadata.seq_groups):
sampling_params = seq_group.sampling_params sampling_params = seq_group.sampling_params
@ -632,7 +639,7 @@ def _sample(
probs: torch.Tensor, logprobs: torch.Tensor, probs: torch.Tensor, logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors, sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool, modify_greedy_probs: bool include_gpu_probs_tensor: bool, modify_greedy_probs: bool
) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]: ) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
""" """
Args: Args:
probs: (num_query_tokens_in_batch, num_vocab) probs: (num_query_tokens_in_batch, num_vocab)
@ -680,7 +687,7 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
def _get_logprobs( def _get_logprobs(
logprobs: torch.Tensor, logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
sample_results: List[Tuple[List[int], List[int]]], sample_results: SampleResultType,
) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]: ) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
"""Return sample lobprobs and prompt logprobs. """Return sample lobprobs and prompt logprobs.
@ -751,8 +758,8 @@ def _get_logprobs(
assert len(next_token_ids) == len(query_indices) assert len(next_token_ids) == len(query_indices)
if len(query_indices) == 0: if len(query_indices) == 0:
empty_sampled_logprob = [] empty_sampled_logprob: SampleLogprobs = []
empty_prompt_logprob = None empty_prompt_logprob: Optional[PromptLogprobs] = None
return [empty_prompt_logprob], [empty_sampled_logprob] return [empty_prompt_logprob], [empty_sampled_logprob]
query_indices_gpu = torch.tensor(query_indices, device=logprobs.device) query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
@ -965,7 +972,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
def _build_sampler_output( def _build_sampler_output(
sample_results: List[Tuple[List[int], List[int]]], sample_results: SampleResultType,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
prompt_logprobs: List[Optional[PromptLogprobs]], prompt_logprobs: List[Optional[PromptLogprobs]],
sample_logprobs: List[SampleLogprobs], sample_logprobs: List[SampleLogprobs],
@ -1009,7 +1016,7 @@ def _build_sampler_output(
) )
def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[str]: def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
"""Get a list of next prompt tokens to compute logprob from a """Get a list of next prompt tokens to compute logprob from a
given sequence group. given sequence group.

View File

@ -64,7 +64,7 @@ class TensorizerConfig:
"s3_secret_access_key": self.s3_secret_access_key, "s3_secret_access_key": self.s3_secret_access_key,
"s3_endpoint": self.s3_endpoint, "s3_endpoint": self.s3_endpoint,
} }
return TensorizerArgs(**tensorizer_args) return TensorizerArgs(**tensorizer_args) # type: ignore
def verify_with_parallel_config( def verify_with_parallel_config(
self, self,
@ -270,8 +270,10 @@ class TensorizerAgent:
self.model = self._init_model() self.model = self._init_model()
def _init_model(self): def _init_model(self):
assert self.tensorizer_config.hf_config is not None
model_args = self.tensorizer_config.hf_config model_args = self.tensorizer_config.hf_config
model_args.torch_dtype = self.tensorizer_config.dtype model_args.torch_dtype = self.tensorizer_config.dtype
assert self.tensorizer_config.model_class is not None
with no_init_or_tensor(): with no_init_or_tensor():
return self.tensorizer_config.model_class( return self.tensorizer_config.model_class(
config=model_args, config=model_args,