【发布时间】:2020-08-25 12:00:53
【问题描述】:
我知道这似乎是一个常见问题,但我无法找到解决方案。我正在运行一个多标签分类模型并且遇到张量大小问题。
我的完整代码如下所示:
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
import torch
# Instantiating tokenizer and model
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-cased')
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-cased')
# Instantiating quantized model
quantized_model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
# Forming data tensors
input_ids = torch.tensor(tokenizer.encode(x_train[0], add_special_tokens=True)).unsqueeze(0)
labels = torch.tensor(Y[0]).unsqueeze(0)
# Train model
outputs = quantized_model(input_ids, labels=labels)
loss, logits = outputs[:2]
产生错误:
ValueError: Expected input batch_size (1) to match target batch_size (11)
Input_ids 看起来像:
tensor([[ 101, 789, 160, 1766, 1616, 1110, 170, 1205, 7727, 1113,
170, 2463, 1128, 1336, 1309, 1138, 112, 119, 11882, 11545,
119, 108, 15710, 108, 3645, 108, 3994, 102]])
有形状:
torch.Size([1, 28])
标签看起来像:
tensor([[0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1]])
有形状:
torch.Size([1, 11])
input_ids 的大小会随着要编码的字符串大小的变化而变化。
我还注意到,当输入 5 个 Y 值以生成 5 个标签时,会产生错误:
ValueError: Expected input batch_size (1) to match target batch_size (55).
带标签形状:
torch.Size([1, 5, 11])
(请注意,我没有输入 5 个 input_id,这可能是输入大小保持不变的原因)
我已经尝试了几种不同的方法来让这些工作,但我目前不知所措。我真的很感激一些指导。谢谢!
【问题讨论】:
-
你为什么要解压第一个暗淡?那应该是
batch size -
我可能应该说我从一个拥抱脸变压器示例中提取了大部分内容作为示例,其中包含以下内容:
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1和labels = torch.tensor([1]).unsqueeze(0) # Batch size 1所以我保留了这些内容 -
当我将标签 unsqueeze 值更改为 11 并将 input_id unsqueeze 值更改为 -1 时,我收到以下错误:
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 11)@umbreon29 -
@umbreon29 当我将 input_id unsqueeze 值更改为 -1 并将标签 unsqueeze 值更改为 0 时,我收到错误:
ValueError: Expected input batch_size (28) to match target batch_size (11) -
我真的很困惑如何让形状对齐
标签: python tensorflow pytorch classification tensor