mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 16:14:37 +08:00
hfrunner.classify should return list[list[float]] not list[str] (#29671)
Signed-off-by: Chukwuma Nwaugha <nwaughac@gmail.com>
This commit is contained in:
parent
f4341f45d3
commit
ad7f714d62
@ -459,14 +459,17 @@ class HfRunner:
|
||||
embeddings.append(embedding)
|
||||
return embeddings
|
||||
|
||||
def classify(self, prompts: list[str]) -> list[str]:
|
||||
def classify(self, prompts: list[str]) -> list[list[float]]:
|
||||
# output is final logits
|
||||
all_inputs = self.get_inputs(prompts)
|
||||
outputs = []
|
||||
outputs: list[list[float]] = []
|
||||
problem_type = getattr(self.config, "problem_type", "")
|
||||
|
||||
for inputs in all_inputs:
|
||||
output = self.model(**self.wrap_device(inputs))
|
||||
|
||||
assert isinstance(output.logits, torch.Tensor)
|
||||
|
||||
if problem_type == "regression":
|
||||
logits = output.logits[0].tolist()
|
||||
elif problem_type == "multi_label_classification":
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user