mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-02 16:58:03 +08:00
Update some more deprecated type hinting (#17998)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
acee8f48aa
commit
9d7ea9dbbf
@ -79,7 +79,9 @@ exclude = [
|
|||||||
"vllm/engine/**/*.py" = ["UP006", "UP035"]
|
"vllm/engine/**/*.py" = ["UP006", "UP035"]
|
||||||
"vllm/executor/**/*.py" = ["UP006", "UP035"]
|
"vllm/executor/**/*.py" = ["UP006", "UP035"]
|
||||||
"vllm/lora/**/*.py" = ["UP006", "UP035"]
|
"vllm/lora/**/*.py" = ["UP006", "UP035"]
|
||||||
"vllm/model_executor/**/*.py" = ["UP006", "UP035"]
|
"vllm/model_executor/layers/**/*.py" = ["UP006", "UP035"]
|
||||||
|
"vllm/model_executor/model_loader/**/*.py" = ["UP006", "UP035"]
|
||||||
|
"vllm/model_executor/models/**/*.py" = ["UP006", "UP035"]
|
||||||
"vllm/platforms/**/*.py" = ["UP006", "UP035"]
|
"vllm/platforms/**/*.py" = ["UP006", "UP035"]
|
||||||
"vllm/plugins/**/*.py" = ["UP006", "UP035"]
|
"vllm/plugins/**/*.py" = ["UP006", "UP035"]
|
||||||
"vllm/profiler/**/*.py" = ["UP006", "UP035"]
|
"vllm/profiler/**/*.py" = ["UP006", "UP035"]
|
||||||
|
|||||||
@ -1,7 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import Dict, Type
|
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
@ -138,7 +136,7 @@ class CustomOp(nn.Module):
|
|||||||
# Examples:
|
# Examples:
|
||||||
# - MyOp.enabled()
|
# - MyOp.enabled()
|
||||||
# - op_registry["my_op"].enabled()
|
# - op_registry["my_op"].enabled()
|
||||||
op_registry: Dict[str, Type['CustomOp']] = {}
|
op_registry: dict[str, type['CustomOp']] = {}
|
||||||
|
|
||||||
# Decorator to register custom ops.
|
# Decorator to register custom ops.
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
import os
|
import os
|
||||||
from typing import Any, List
|
from typing import Any
|
||||||
|
|
||||||
import llguidance
|
import llguidance
|
||||||
import llguidance.hf
|
import llguidance.hf
|
||||||
@ -62,7 +62,7 @@ class GuidanceLogitsProcessor:
|
|||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
input_ids: List[int],
|
input_ids: list[int],
|
||||||
scores: torch.Tensor,
|
scores: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# we initialize the guidance model here
|
# we initialize the guidance model here
|
||||||
|
|||||||
@ -1,16 +1,16 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, TypedDict, Union
|
from typing import Optional, TypedDict, Union
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
# These classes are deprecated, see SamplingParams
|
# These classes are deprecated, see SamplingParams
|
||||||
class LLMGuidedOptions(TypedDict, total=False):
|
class LLMGuidedOptions(TypedDict, total=False):
|
||||||
guided_json: Union[Dict, BaseModel, str]
|
guided_json: Union[dict, BaseModel, str]
|
||||||
guided_regex: str
|
guided_regex: str
|
||||||
guided_choice: List[str]
|
guided_choice: list[str]
|
||||||
guided_grammar: str
|
guided_grammar: str
|
||||||
guided_decoding_backend: str
|
guided_decoding_backend: str
|
||||||
guided_whitespace_pattern: str
|
guided_whitespace_pattern: str
|
||||||
@ -20,9 +20,9 @@ class LLMGuidedOptions(TypedDict, total=False):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class GuidedDecodingRequest:
|
class GuidedDecodingRequest:
|
||||||
"""One of the fields will be used to retrieve the logit processor."""
|
"""One of the fields will be used to retrieve the logit processor."""
|
||||||
guided_json: Optional[Union[Dict, BaseModel, str]] = None
|
guided_json: Optional[Union[dict, BaseModel, str]] = None
|
||||||
guided_regex: Optional[str] = None
|
guided_regex: Optional[str] = None
|
||||||
guided_choice: Optional[List[str]] = None
|
guided_choice: Optional[list[str]] = None
|
||||||
guided_grammar: Optional[str] = None
|
guided_grammar: Optional[str] = None
|
||||||
guided_decoding_backend: Optional[str] = None
|
guided_decoding_backend: Optional[str] = None
|
||||||
guided_whitespace_pattern: Optional[str] = None
|
guided_whitespace_pattern: Optional[str] = None
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import os
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from json import dumps as json_dumps
|
from json import dumps as json_dumps
|
||||||
from re import escape as regex_escape
|
from re import escape as regex_escape
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
@ -111,7 +111,7 @@ def get_local_outlines_guided_decoding_logits_processor(
|
|||||||
|
|
||||||
def _get_guide_and_mode(
|
def _get_guide_and_mode(
|
||||||
guided_params: GuidedDecodingParams
|
guided_params: GuidedDecodingParams
|
||||||
) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]:
|
) -> Union[tuple[str, GuidedDecodingMode], tuple[None, None]]:
|
||||||
if guided_params.json:
|
if guided_params.json:
|
||||||
if isinstance(guided_params.json, dict):
|
if isinstance(guided_params.json, dict):
|
||||||
# turn dict into hashable string
|
# turn dict into hashable string
|
||||||
|
|||||||
@ -19,7 +19,7 @@ import copy
|
|||||||
import json
|
import json
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Callable, DefaultDict, Dict, List, Optional, Union
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -53,10 +53,10 @@ class BaseLogitsProcessor:
|
|||||||
self._guide: Guide = guide
|
self._guide: Guide = guide
|
||||||
self._reasoner: Optional[ReasoningParser] = reasoner
|
self._reasoner: Optional[ReasoningParser] = reasoner
|
||||||
# CFGState is used for the FSM state for CFGGuide
|
# CFGState is used for the FSM state for CFGGuide
|
||||||
self._fsm_state: DefaultDict[int, Union[int,
|
self._fsm_state: defaultdict[int, Union[int,
|
||||||
CFGState]] = defaultdict(int)
|
CFGState]] = defaultdict(int)
|
||||||
|
|
||||||
def __call__(self, input_ids: List[int],
|
def __call__(self, input_ids: list[int],
|
||||||
scores: torch.Tensor) -> torch.Tensor:
|
scores: torch.Tensor) -> torch.Tensor:
|
||||||
"""Use the FSM to bias the logits before sampling the next token."""
|
"""Use the FSM to bias the logits before sampling the next token."""
|
||||||
|
|
||||||
@ -160,7 +160,7 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
|
|||||||
|
|
||||||
class JSONLogitsProcessor(RegexLogitsProcessor):
|
class JSONLogitsProcessor(RegexLogitsProcessor):
|
||||||
|
|
||||||
def __init__(self, schema: Union[str, Dict, BaseModel],
|
def __init__(self, schema: Union[str, dict, BaseModel],
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
whitespace_pattern: Union[str, None],
|
whitespace_pattern: Union[str, None],
|
||||||
reasoner: Optional[ReasoningParser]):
|
reasoner: Optional[ReasoningParser]):
|
||||||
@ -181,7 +181,7 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
|
|||||||
"""
|
"""
|
||||||
if isinstance(schema, type(BaseModel)):
|
if isinstance(schema, type(BaseModel)):
|
||||||
schema_str = json.dumps(schema.model_json_schema())
|
schema_str = json.dumps(schema.model_json_schema())
|
||||||
elif isinstance(schema, Dict):
|
elif isinstance(schema, dict):
|
||||||
schema_str = json.dumps(schema)
|
schema_str = json.dumps(schema)
|
||||||
elif isinstance(schema, str):
|
elif isinstance(schema, str):
|
||||||
schema_str = schema
|
schema_str = schema
|
||||||
@ -252,11 +252,11 @@ def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
|
|||||||
return string
|
return string
|
||||||
|
|
||||||
def change_decoder(
|
def change_decoder(
|
||||||
decoder: Callable[[List[int]],
|
decoder: Callable[[list[int]],
|
||||||
str]) -> Callable[[List[int]], List[str]]:
|
str]) -> Callable[[list[int]], list[str]]:
|
||||||
"""Sync vLLM's decoder with the outlines by returning list."""
|
"""Sync vLLM's decoder with the outlines by returning list."""
|
||||||
|
|
||||||
def new_decoder(inp_tokens: List[int]) -> List[str]:
|
def new_decoder(inp_tokens: list[int]) -> list[str]:
|
||||||
if (isinstance(inp_tokens, list) and len(inp_tokens) == 1
|
if (isinstance(inp_tokens, list) and len(inp_tokens) == 1
|
||||||
and isinstance(inp_tokens[0], list)):
|
and isinstance(inp_tokens[0], list)):
|
||||||
inp_tokens = inp_tokens[0]
|
inp_tokens = inp_tokens[0]
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, Any, List
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -273,7 +273,7 @@ class GrammarConfig:
|
|||||||
return re.sub(r'(["\\])', r'\\\1', s)
|
return re.sub(r'(["\\])', r'\\\1', s)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def choice_as_grammar(choice: List[str] | None) -> str:
|
def choice_as_grammar(choice: list[str] | None) -> str:
|
||||||
if choice is None:
|
if choice is None:
|
||||||
raise ValueError("Choice is not set")
|
raise ValueError("Choice is not set")
|
||||||
escaped_choices = (GrammarConfig.escape_ebnf_string(c) for c in choice)
|
escaped_choices = (GrammarConfig.escape_ebnf_string(c) for c in choice)
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -23,9 +23,9 @@ class PoolingMetadata:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
seq_groups: List[Tuple[List[int], PoolingParams]],
|
seq_groups: list[tuple[list[int], PoolingParams]],
|
||||||
seq_data: Dict[int, Any], # Specific data related to sequences
|
seq_data: dict[int, Any], # Specific data related to sequences
|
||||||
prompt_lens: List[int],
|
prompt_lens: list[int],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.seq_groups = seq_groups
|
self.seq_groups = seq_groups
|
||||||
self.seq_data = seq_data
|
self.seq_data = seq_data
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from array import array
|
from array import array
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -25,10 +25,10 @@ class SequenceGroupToSample:
|
|||||||
# |-- query_len ---|
|
# |-- query_len ---|
|
||||||
|
|
||||||
# Sequence ids for the sequence group in a previous step.
|
# Sequence ids for the sequence group in a previous step.
|
||||||
seq_ids: List[int]
|
seq_ids: list[int]
|
||||||
sampling_params: SamplingParams
|
sampling_params: SamplingParams
|
||||||
# seq_id -> sequence data.
|
# seq_id -> sequence data.
|
||||||
seq_data: Dict[int, SequenceData]
|
seq_data: dict[int, SequenceData]
|
||||||
# The length of the sequence (all tokens seen in the past + new token to
|
# The length of the sequence (all tokens seen in the past + new token to
|
||||||
# compute attention) of the sequence group. None if it is in a decode
|
# compute attention) of the sequence group. None if it is in a decode
|
||||||
# stage.
|
# stage.
|
||||||
@ -44,9 +44,9 @@ class SequenceGroupToSample:
|
|||||||
is_prompt: bool
|
is_prompt: bool
|
||||||
# Query token indices from logits. to compute prompt logprob. Empty if
|
# Query token indices from logits. to compute prompt logprob. Empty if
|
||||||
# prompt logprob is not required.
|
# prompt logprob is not required.
|
||||||
prompt_logprob_indices: List[int]
|
prompt_logprob_indices: list[int]
|
||||||
# Sample token indices from logits. Empty if sampling is not required.
|
# Sample token indices from logits. Empty if sampling is not required.
|
||||||
sample_indices: List[int]
|
sample_indices: list[int]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def do_sample(self):
|
def do_sample(self):
|
||||||
@ -78,7 +78,7 @@ class SamplingMetadataCache:
|
|||||||
"""Used to cache SamplingMetadata objects between scheduler iterations"""
|
"""Used to cache SamplingMetadata objects between scheduler iterations"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._seq_group_to_sample_cache: Dict[int, PyObjectCache] = {}
|
self._seq_group_to_sample_cache: dict[int, PyObjectCache] = {}
|
||||||
|
|
||||||
def get_cached_seq_group_to_sample(self, num_seqs):
|
def get_cached_seq_group_to_sample(self, num_seqs):
|
||||||
if num_seqs not in self._seq_group_to_sample_cache:
|
if num_seqs not in self._seq_group_to_sample_cache:
|
||||||
@ -130,9 +130,9 @@ class SamplingMetadata:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
seq_groups: List[SequenceGroupToSample],
|
seq_groups: list[SequenceGroupToSample],
|
||||||
selected_token_indices: torch.Tensor,
|
selected_token_indices: torch.Tensor,
|
||||||
categorized_sample_indices: Dict[SamplingType, torch.Tensor],
|
categorized_sample_indices: dict[SamplingType, torch.Tensor],
|
||||||
num_prompts: int,
|
num_prompts: int,
|
||||||
skip_sampler_cpu_output: bool = False,
|
skip_sampler_cpu_output: bool = False,
|
||||||
reuse_sampling_tensors: bool = False,
|
reuse_sampling_tensors: bool = False,
|
||||||
@ -146,12 +146,12 @@ class SamplingMetadata:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def prepare(
|
def prepare(
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: list[SequenceGroupMetadata],
|
||||||
seq_lens: List[int],
|
seq_lens: list[int],
|
||||||
query_lens: List[int],
|
query_lens: list[int],
|
||||||
device: str,
|
device: str,
|
||||||
pin_memory: bool,
|
pin_memory: bool,
|
||||||
generators: Optional[Dict[str, torch.Generator]] = None,
|
generators: Optional[dict[str, torch.Generator]] = None,
|
||||||
cache: Optional[SamplingMetadataCache] = None,
|
cache: Optional[SamplingMetadataCache] = None,
|
||||||
) -> "SamplingMetadata":
|
) -> "SamplingMetadata":
|
||||||
(
|
(
|
||||||
@ -195,16 +195,16 @@ class SamplingMetadata:
|
|||||||
|
|
||||||
|
|
||||||
def _prepare_seq_groups(
|
def _prepare_seq_groups(
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: list[SequenceGroupMetadata],
|
||||||
seq_lens: List[int],
|
seq_lens: list[int],
|
||||||
query_lens: List[int],
|
query_lens: list[int],
|
||||||
device: str,
|
device: str,
|
||||||
generators: Optional[Dict[str, torch.Generator]] = None,
|
generators: Optional[dict[str, torch.Generator]] = None,
|
||||||
cache: Optional[SamplingMetadataCache] = None,
|
cache: Optional[SamplingMetadataCache] = None,
|
||||||
) -> Tuple[
|
) -> tuple[
|
||||||
List[SequenceGroupToSample],
|
list[SequenceGroupToSample],
|
||||||
List[int],
|
list[int],
|
||||||
Dict[SamplingType, List[int]],
|
dict[SamplingType, list[int]],
|
||||||
int,
|
int,
|
||||||
]:
|
]:
|
||||||
"""Prepare sequence groups and indices for sampling.
|
"""Prepare sequence groups and indices for sampling.
|
||||||
@ -227,17 +227,17 @@ def _prepare_seq_groups(
|
|||||||
num_prompts: Total number of prompts from `seq_group_metadata_list`.
|
num_prompts: Total number of prompts from `seq_group_metadata_list`.
|
||||||
"""
|
"""
|
||||||
# Batched sequence groups for the current model forward stsep.
|
# Batched sequence groups for the current model forward stsep.
|
||||||
seq_groups: List[SequenceGroupToSample] = []
|
seq_groups: list[SequenceGroupToSample] = []
|
||||||
# A list of token indices to sample/compute logprob. It is used to
|
# A list of token indices to sample/compute logprob. It is used to
|
||||||
# prune the outcome logits from the model for the performance.
|
# prune the outcome logits from the model for the performance.
|
||||||
selected_token_indices: List[int] = []
|
selected_token_indices: list[int] = []
|
||||||
# Used for selected_token_indices.
|
# Used for selected_token_indices.
|
||||||
model_output_idx = 0
|
model_output_idx = 0
|
||||||
|
|
||||||
# Sampling type -> (
|
# Sampling type -> (
|
||||||
# indices to sample/prompt logprob within pruned output logits,
|
# indices to sample/prompt logprob within pruned output logits,
|
||||||
# indices to sample within pruned logits)
|
# indices to sample within pruned logits)
|
||||||
categorized_sample_indices: Dict[SamplingType, List[int]] = {
|
categorized_sample_indices: dict[SamplingType, list[int]] = {
|
||||||
t: []
|
t: []
|
||||||
for t in SamplingType
|
for t in SamplingType
|
||||||
}
|
}
|
||||||
@ -265,9 +265,9 @@ def _prepare_seq_groups(
|
|||||||
# If the current seq group is in decode stage, it is None.
|
# If the current seq group is in decode stage, it is None.
|
||||||
seq_len: Optional[int] = None
|
seq_len: Optional[int] = None
|
||||||
query_len: Optional[int] = None
|
query_len: Optional[int] = None
|
||||||
prompt_logprob_indices: List[int] = (sample_obj.prompt_logprob_indices
|
prompt_logprob_indices: list[int] = (sample_obj.prompt_logprob_indices
|
||||||
if cache is not None else [])
|
if cache is not None else [])
|
||||||
sample_indices: List[int] = (sample_obj.sample_indices
|
sample_indices: list[int] = (sample_obj.sample_indices
|
||||||
if cache is not None else [])
|
if cache is not None else [])
|
||||||
do_sample = seq_group_metadata.do_sample
|
do_sample = seq_group_metadata.do_sample
|
||||||
|
|
||||||
@ -389,16 +389,16 @@ class SamplingTensors:
|
|||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
) -> Tuple["SamplingTensors", bool, bool, bool]:
|
) -> tuple["SamplingTensors", bool, bool, bool]:
|
||||||
prompt_tokens: List[array] = []
|
prompt_tokens: list[array] = []
|
||||||
output_tokens: List[array] = []
|
output_tokens: list[array] = []
|
||||||
top_ks: List[int] = []
|
top_ks: list[int] = []
|
||||||
temperatures: List[float] = []
|
temperatures: list[float] = []
|
||||||
top_ps: List[float] = []
|
top_ps: list[float] = []
|
||||||
min_ps: List[float] = []
|
min_ps: list[float] = []
|
||||||
presence_penalties: List[float] = []
|
presence_penalties: list[float] = []
|
||||||
frequency_penalties: List[float] = []
|
frequency_penalties: list[float] = []
|
||||||
repetition_penalties: List[float] = []
|
repetition_penalties: list[float] = []
|
||||||
do_penalties = False
|
do_penalties = False
|
||||||
do_top_p_top_k = False
|
do_top_p_top_k = False
|
||||||
do_min_p = False
|
do_min_p = False
|
||||||
@ -496,15 +496,15 @@ class SamplingTensors:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_lists(
|
def from_lists(
|
||||||
cls,
|
cls,
|
||||||
temperatures: List[float],
|
temperatures: list[float],
|
||||||
top_ps: List[float],
|
top_ps: list[float],
|
||||||
top_ks: List[int],
|
top_ks: list[int],
|
||||||
min_ps: List[float],
|
min_ps: list[float],
|
||||||
presence_penalties: List[float],
|
presence_penalties: list[float],
|
||||||
frequency_penalties: List[float],
|
frequency_penalties: list[float],
|
||||||
repetition_penalties: List[float],
|
repetition_penalties: list[float],
|
||||||
prompt_tokens: List[array],
|
prompt_tokens: list[array],
|
||||||
output_tokens: List[array],
|
output_tokens: list[array],
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
"""Utils for model executor."""
|
"""Utils for model executor."""
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -12,7 +12,7 @@ def set_random_seed(seed: int) -> None:
|
|||||||
|
|
||||||
def set_weight_attrs(
|
def set_weight_attrs(
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
weight_attrs: Optional[Dict[str, Any]],
|
weight_attrs: Optional[dict[str, Any]],
|
||||||
):
|
):
|
||||||
"""Set attributes on a weight tensor.
|
"""Set attributes on a weight tensor.
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user