mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-02 02:44:35 +08:00
[mypy][5/N] Support all typing on model executor (#4427)
This commit is contained in:
parent
03dd7d52bf
commit
df29793dc7
2
.github/workflows/mypy.yaml
vendored
2
.github/workflows/mypy.yaml
vendored
@ -43,8 +43,8 @@ jobs:
|
|||||||
mypy vllm/worker --config-file pyproject.toml
|
mypy vllm/worker --config-file pyproject.toml
|
||||||
mypy vllm/spec_decode --config-file pyproject.toml
|
mypy vllm/spec_decode --config-file pyproject.toml
|
||||||
mypy vllm/lora --config-file pyproject.toml
|
mypy vllm/lora --config-file pyproject.toml
|
||||||
|
mypy vllm/model_executor --config-file pyproject.toml
|
||||||
|
|
||||||
# TODO(sang): Fix nested dir
|
# TODO(sang): Fix nested dir
|
||||||
mypy vllm/model_executor/*.py --config-file pyproject.toml
|
|
||||||
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
|
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
|
|
||||||
|
|||||||
@ -105,7 +105,7 @@ mypy vllm/transformers_utils --config-file pyproject.toml
|
|||||||
mypy vllm/engine --config-file pyproject.toml
|
mypy vllm/engine --config-file pyproject.toml
|
||||||
mypy vllm/worker --config-file pyproject.toml
|
mypy vllm/worker --config-file pyproject.toml
|
||||||
mypy vllm/spec_decode --config-file pyproject.toml
|
mypy vllm/spec_decode --config-file pyproject.toml
|
||||||
mypy vllm/model_executor/*.py --config-file pyproject.toml
|
mypy vllm/model_executor --config-file pyproject.toml
|
||||||
mypy vllm/lora --config-file pyproject.toml
|
mypy vllm/lora --config-file pyproject.toml
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -61,6 +61,7 @@ def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict:
|
|||||||
return schema
|
return schema
|
||||||
if isinstance(schema, BaseModel):
|
if isinstance(schema, BaseModel):
|
||||||
return schema.model_json_schema()
|
return schema.model_json_schema()
|
||||||
|
raise AssertionError(f"Unsupported schema type {schema}")
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
|
|||||||
@ -128,7 +128,8 @@ class LinearBase(torch.nn.Module):
|
|||||||
params_dtype = torch.get_default_dtype()
|
params_dtype = torch.get_default_dtype()
|
||||||
self.params_dtype = params_dtype
|
self.params_dtype = params_dtype
|
||||||
if quant_config is None:
|
if quant_config is None:
|
||||||
self.quant_method = UnquantizedLinearMethod()
|
self.quant_method: Optional[
|
||||||
|
QuantizeMethodBase] = UnquantizedLinearMethod()
|
||||||
else:
|
else:
|
||||||
self.quant_method = quant_config.get_quant_method(self)
|
self.quant_method = quant_config.get_quant_method(self)
|
||||||
|
|
||||||
@ -160,6 +161,8 @@ class ReplicatedLinear(LinearBase):
|
|||||||
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
|
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
|
||||||
quant_config)
|
quant_config)
|
||||||
|
|
||||||
|
# All the linear layer supports quant method.
|
||||||
|
assert self.quant_method is not None
|
||||||
self.quant_method.create_weights(self, self.input_size,
|
self.quant_method.create_weights(self, self.input_size,
|
||||||
[self.output_size], self.input_size,
|
[self.output_size], self.input_size,
|
||||||
self.output_size, self.params_dtype)
|
self.output_size, self.params_dtype)
|
||||||
@ -173,6 +176,7 @@ class ReplicatedLinear(LinearBase):
|
|||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
bias = self.bias if not self.skip_bias_add else None
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
|
assert self.quant_method is not None
|
||||||
output = self.quant_method.apply(self, x, bias)
|
output = self.quant_method.apply(self, x, bias)
|
||||||
output_bias = self.bias if self.skip_bias_add else None
|
output_bias = self.bias if self.skip_bias_add else None
|
||||||
return output, output_bias
|
return output, output_bias
|
||||||
@ -221,6 +225,8 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
self.output_size_per_partition = divide(output_size, tp_size)
|
self.output_size_per_partition = divide(output_size, tp_size)
|
||||||
if output_sizes is None:
|
if output_sizes is None:
|
||||||
output_sizes = [output_size]
|
output_sizes = [output_size]
|
||||||
|
# All the linear layer supports quant method.
|
||||||
|
assert self.quant_method is not None
|
||||||
self.quant_method.create_weights(self,
|
self.quant_method.create_weights(self,
|
||||||
self.input_size,
|
self.input_size,
|
||||||
[x // tp_size for x in output_sizes],
|
[x // tp_size for x in output_sizes],
|
||||||
@ -255,6 +261,7 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
bias = self.bias if not self.skip_bias_add else None
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
|
assert self.quant_method is not None
|
||||||
output_parallel = self.quant_method.apply(self, input_, bias)
|
output_parallel = self.quant_method.apply(self, input_, bias)
|
||||||
if self.gather_output:
|
if self.gather_output:
|
||||||
# All-gather across the partitions.
|
# All-gather across the partitions.
|
||||||
@ -579,6 +586,8 @@ class RowParallelLinear(LinearBase):
|
|||||||
# Divide the weight matrix along the last dimension.
|
# Divide the weight matrix along the last dimension.
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||||
|
# All the linear layer supports quant method.
|
||||||
|
assert self.quant_method is not None
|
||||||
self.quant_method.create_weights(self,
|
self.quant_method.create_weights(self,
|
||||||
self.input_size_per_partition,
|
self.input_size_per_partition,
|
||||||
[self.output_size],
|
[self.output_size],
|
||||||
@ -624,6 +633,7 @@ class RowParallelLinear(LinearBase):
|
|||||||
input_parallel = splitted_input[tp_rank].contiguous()
|
input_parallel = splitted_input[tp_rank].contiguous()
|
||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
|
assert self.quant_method is not None
|
||||||
output_parallel = self.quant_method.apply(self, input_parallel)
|
output_parallel = self.quant_method.apply(self, input_parallel)
|
||||||
if self.reduce_results and self.tp_size > 1:
|
if self.reduce_results and self.tp_size > 1:
|
||||||
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Type
|
from typing import Dict, Type
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
||||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||||
@ -9,7 +9,7 @@ from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
|||||||
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
||||||
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
|
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
|
||||||
|
|
||||||
QUANTIZATION_METHODS = {
|
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||||
"aqlm": AQLMConfig,
|
"aqlm": AQLMConfig,
|
||||||
"awq": AWQConfig,
|
"awq": AWQConfig,
|
||||||
"fp8": Fp8Config,
|
"fp8": Fp8Config,
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -76,8 +76,16 @@ class QuantizationConfig(ABC):
|
|||||||
"quantization config.")
|
"quantization config.")
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_quant_method(self, layer: torch.nn.Module) -> QuantizeMethodBase:
|
def get_quant_method(
|
||||||
"""Get the quantize method to use for the quantized layer."""
|
self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]:
|
||||||
|
"""Get the quantize method to use for the quantized layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer: The layer for the quant method.
|
||||||
|
Returns:
|
||||||
|
The quantize method. None if the given layer doesn't support quant
|
||||||
|
method.
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@ -52,11 +52,10 @@ class SqueezeLLMConfig(QuantizationConfig):
|
|||||||
return cls(weight_bits)
|
return cls(weight_bits)
|
||||||
|
|
||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self,
|
self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]:
|
||||||
layer: torch.nn.Module) -> Optional["SqueezeLLMLinearMethod"]:
|
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
return SqueezeLLMLinearMethod(self)
|
return SqueezeLLMLinearMethod(self)
|
||||||
return
|
return None
|
||||||
|
|
||||||
def get_scaled_act_names(self) -> List[str]:
|
def get_scaled_act_names(self) -> List[str]:
|
||||||
return []
|
return []
|
||||||
|
|||||||
@ -431,8 +431,8 @@ class Phi3SuScaledRotaryEmbedding(nn.Module):
|
|||||||
torch.full_like(positions, k)).long()
|
torch.full_like(positions, k)).long()
|
||||||
idx = (torch.add(positions, long_prompt_offset)
|
idx = (torch.add(positions, long_prompt_offset)
|
||||||
if long_prompt_offset is not None else positions)
|
if long_prompt_offset is not None else positions)
|
||||||
self.long_short_cos_sin_cache = self.long_short_cos_sin_cache.to(
|
self.long_short_cos_sin_cache: torch.Tensor = (
|
||||||
idx.device)
|
self.long_short_cos_sin_cache.to(idx.device))
|
||||||
idx = torch.add(idx, offsets) if offsets is not None else idx
|
idx = torch.add(idx, offsets) if offsets is not None else idx
|
||||||
cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)
|
cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)
|
||||||
|
|
||||||
|
|||||||
@ -13,6 +13,9 @@ from vllm.sampling_params import SamplingType
|
|||||||
from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
|
from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
|
||||||
SamplerOutput, SequenceGroupOutput, SequenceOutput)
|
SamplerOutput, SequenceGroupOutput, SequenceOutput)
|
||||||
|
|
||||||
|
# (num_token_ids, num_parent_ids) per sequence group.
|
||||||
|
SampleResultType = List[Tuple[List[int], List[int]]]
|
||||||
|
|
||||||
|
|
||||||
class Sampler(nn.Module):
|
class Sampler(nn.Module):
|
||||||
"""Samples the next tokens from the model's outputs.
|
"""Samples the next tokens from the model's outputs.
|
||||||
@ -155,7 +158,7 @@ def _apply_min_tokens_penalty(
|
|||||||
have not been generated yet
|
have not been generated yet
|
||||||
"""
|
"""
|
||||||
# list of indices in logits that will be set to -inf
|
# list of indices in logits that will be set to -inf
|
||||||
logits_to_penalize = []
|
logits_to_penalize: List[Tuple[int, int]] = []
|
||||||
logits_applied = 0
|
logits_applied = 0
|
||||||
for seq_group in sampling_metadata.seq_groups:
|
for seq_group in sampling_metadata.seq_groups:
|
||||||
seq_ids = seq_group.seq_ids
|
seq_ids = seq_group.seq_ids
|
||||||
@ -269,7 +272,7 @@ def _apply_min_p(
|
|||||||
def _greedy_sample(
|
def _greedy_sample(
|
||||||
selected_seq_groups: List[SequenceGroupToSample],
|
selected_seq_groups: List[SequenceGroupToSample],
|
||||||
samples: torch.Tensor,
|
samples: torch.Tensor,
|
||||||
) -> List[Tuple[List[int], List[int]]]:
|
) -> SampleResultType:
|
||||||
"""Run greedy sampling on a given samples.
|
"""Run greedy sampling on a given samples.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -284,7 +287,7 @@ def _greedy_sample(
|
|||||||
"""
|
"""
|
||||||
samples = samples.tolist()
|
samples = samples.tolist()
|
||||||
sample_idx = 0
|
sample_idx = 0
|
||||||
results = []
|
results: SampleResultType = []
|
||||||
for seq_group in selected_seq_groups:
|
for seq_group in selected_seq_groups:
|
||||||
if not seq_group.do_sample:
|
if not seq_group.do_sample:
|
||||||
results.append(([], []))
|
results.append(([], []))
|
||||||
@ -304,7 +307,7 @@ def _greedy_sample(
|
|||||||
def _random_sample(
|
def _random_sample(
|
||||||
selected_seq_groups: List[SequenceGroupToSample],
|
selected_seq_groups: List[SequenceGroupToSample],
|
||||||
random_samples: torch.Tensor,
|
random_samples: torch.Tensor,
|
||||||
) -> List[Tuple[List[int], List[int]]]:
|
) -> SampleResultType:
|
||||||
"""Run random sampling on a given samples.
|
"""Run random sampling on a given samples.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -320,7 +323,7 @@ def _random_sample(
|
|||||||
# Find the maximum best_of value of the prompt phase requests.
|
# Find the maximum best_of value of the prompt phase requests.
|
||||||
random_samples = random_samples.cpu()
|
random_samples = random_samples.cpu()
|
||||||
sample_idx = 0
|
sample_idx = 0
|
||||||
results = []
|
results: SampleResultType = []
|
||||||
for seq_group in selected_seq_groups:
|
for seq_group in selected_seq_groups:
|
||||||
if not seq_group.do_sample:
|
if not seq_group.do_sample:
|
||||||
results.append(([], []))
|
results.append(([], []))
|
||||||
@ -348,7 +351,7 @@ def _random_sample(
|
|||||||
def _beam_search_sample(
|
def _beam_search_sample(
|
||||||
selected_seq_groups: List[SequenceGroupToSample],
|
selected_seq_groups: List[SequenceGroupToSample],
|
||||||
logprobs: torch.Tensor,
|
logprobs: torch.Tensor,
|
||||||
) -> List[Tuple[List[int], List[int]]]:
|
) -> SampleResultType:
|
||||||
"""Run beam sampling on a given samples.
|
"""Run beam sampling on a given samples.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -370,7 +373,7 @@ def _beam_search_sample(
|
|||||||
# NOTE: Beam search is not vectorized, so its speed can be slower than
|
# NOTE: Beam search is not vectorized, so its speed can be slower than
|
||||||
# other sampling methods.
|
# other sampling methods.
|
||||||
sample_idx = 0
|
sample_idx = 0
|
||||||
results = []
|
results: SampleResultType = []
|
||||||
for seq_group in selected_seq_groups:
|
for seq_group in selected_seq_groups:
|
||||||
if not seq_group.do_sample:
|
if not seq_group.do_sample:
|
||||||
results.append(([], []))
|
results.append(([], []))
|
||||||
@ -391,16 +394,16 @@ def _beam_search_sample(
|
|||||||
next_token_ids = next_token_ids.tolist()
|
next_token_ids = next_token_ids.tolist()
|
||||||
else:
|
else:
|
||||||
# Generation phase.
|
# Generation phase.
|
||||||
cumulative_logprobs = [
|
cumulative_logprobs: List[int] = [
|
||||||
seq_group.seq_data[seq_id].cumulative_logprob
|
seq_group.seq_data[seq_id].cumulative_logprob
|
||||||
for seq_id in seq_ids
|
for seq_id in seq_ids
|
||||||
]
|
]
|
||||||
cumulative_logprobs = torch.tensor(
|
cumulative_logprobs_tensor = torch.tensor(
|
||||||
cumulative_logprobs,
|
cumulative_logprobs,
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
device=seq_group_logprobs.device)
|
device=seq_group_logprobs.device)
|
||||||
seq_group_logprobs = (seq_group_logprobs +
|
seq_group_logprobs = (seq_group_logprobs +
|
||||||
cumulative_logprobs.unsqueeze(dim=1))
|
cumulative_logprobs_tensor.unsqueeze(dim=1))
|
||||||
_, topk_ids = torch.topk(seq_group_logprobs.flatten(),
|
_, topk_ids = torch.topk(seq_group_logprobs.flatten(),
|
||||||
2 * beam_width)
|
2 * beam_width)
|
||||||
topk_ids = topk_ids.tolist()
|
topk_ids = topk_ids.tolist()
|
||||||
@ -452,8 +455,10 @@ def _sample_with_torch(
|
|||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
include_gpu_probs_tensor: bool,
|
include_gpu_probs_tensor: bool,
|
||||||
modify_greedy_probs: bool,
|
modify_greedy_probs: bool,
|
||||||
) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
|
) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
|
||||||
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
categorized_seq_group_ids: Dict[SamplingType,
|
||||||
|
List[int]] = {t: []
|
||||||
|
for t in SamplingType}
|
||||||
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
||||||
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||||
sampling_params = seq_group.sampling_params
|
sampling_params = seq_group.sampling_params
|
||||||
@ -555,8 +560,10 @@ def _sample_with_triton_kernel(
|
|||||||
logprobs: torch.Tensor,
|
logprobs: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
sampling_tensors: SamplingTensors,
|
sampling_tensors: SamplingTensors,
|
||||||
) -> List[Tuple[List[int], List[int]]]:
|
) -> SampleResultType:
|
||||||
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
categorized_seq_group_ids: Dict[SamplingType,
|
||||||
|
List[int]] = {t: []
|
||||||
|
for t in SamplingType}
|
||||||
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
||||||
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||||
sampling_params = seq_group.sampling_params
|
sampling_params = seq_group.sampling_params
|
||||||
@ -632,7 +639,7 @@ def _sample(
|
|||||||
probs: torch.Tensor, logprobs: torch.Tensor,
|
probs: torch.Tensor, logprobs: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
|
sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
|
||||||
include_gpu_probs_tensor: bool, modify_greedy_probs: bool
|
include_gpu_probs_tensor: bool, modify_greedy_probs: bool
|
||||||
) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
|
) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
probs: (num_query_tokens_in_batch, num_vocab)
|
probs: (num_query_tokens_in_batch, num_vocab)
|
||||||
@ -680,7 +687,7 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
|
|||||||
def _get_logprobs(
|
def _get_logprobs(
|
||||||
logprobs: torch.Tensor,
|
logprobs: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
sample_results: List[Tuple[List[int], List[int]]],
|
sample_results: SampleResultType,
|
||||||
) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
|
) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
|
||||||
"""Return sample lobprobs and prompt logprobs.
|
"""Return sample lobprobs and prompt logprobs.
|
||||||
|
|
||||||
@ -751,8 +758,8 @@ def _get_logprobs(
|
|||||||
assert len(next_token_ids) == len(query_indices)
|
assert len(next_token_ids) == len(query_indices)
|
||||||
|
|
||||||
if len(query_indices) == 0:
|
if len(query_indices) == 0:
|
||||||
empty_sampled_logprob = []
|
empty_sampled_logprob: SampleLogprobs = []
|
||||||
empty_prompt_logprob = None
|
empty_prompt_logprob: Optional[PromptLogprobs] = None
|
||||||
return [empty_prompt_logprob], [empty_sampled_logprob]
|
return [empty_prompt_logprob], [empty_sampled_logprob]
|
||||||
|
|
||||||
query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
|
query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
|
||||||
@ -965,7 +972,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
|
|||||||
|
|
||||||
|
|
||||||
def _build_sampler_output(
|
def _build_sampler_output(
|
||||||
sample_results: List[Tuple[List[int], List[int]]],
|
sample_results: SampleResultType,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
prompt_logprobs: List[Optional[PromptLogprobs]],
|
prompt_logprobs: List[Optional[PromptLogprobs]],
|
||||||
sample_logprobs: List[SampleLogprobs],
|
sample_logprobs: List[SampleLogprobs],
|
||||||
@ -1009,7 +1016,7 @@ def _build_sampler_output(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[str]:
|
def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
|
||||||
"""Get a list of next prompt tokens to compute logprob from a
|
"""Get a list of next prompt tokens to compute logprob from a
|
||||||
given sequence group.
|
given sequence group.
|
||||||
|
|
||||||
|
|||||||
@ -64,7 +64,7 @@ class TensorizerConfig:
|
|||||||
"s3_secret_access_key": self.s3_secret_access_key,
|
"s3_secret_access_key": self.s3_secret_access_key,
|
||||||
"s3_endpoint": self.s3_endpoint,
|
"s3_endpoint": self.s3_endpoint,
|
||||||
}
|
}
|
||||||
return TensorizerArgs(**tensorizer_args)
|
return TensorizerArgs(**tensorizer_args) # type: ignore
|
||||||
|
|
||||||
def verify_with_parallel_config(
|
def verify_with_parallel_config(
|
||||||
self,
|
self,
|
||||||
@ -270,8 +270,10 @@ class TensorizerAgent:
|
|||||||
self.model = self._init_model()
|
self.model = self._init_model()
|
||||||
|
|
||||||
def _init_model(self):
|
def _init_model(self):
|
||||||
|
assert self.tensorizer_config.hf_config is not None
|
||||||
model_args = self.tensorizer_config.hf_config
|
model_args = self.tensorizer_config.hf_config
|
||||||
model_args.torch_dtype = self.tensorizer_config.dtype
|
model_args.torch_dtype = self.tensorizer_config.dtype
|
||||||
|
assert self.tensorizer_config.model_class is not None
|
||||||
with no_init_or_tensor():
|
with no_init_or_tensor():
|
||||||
return self.tensorizer_config.model_class(
|
return self.tensorizer_config.model_class(
|
||||||
config=model_args,
|
config=model_args,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user