【发布时间】:2021-08-29 12:35:03
【问题描述】:
我正在使用 PyTorch 处理 BERT 的文本分类问题。这是我正在使用的 PyTorch 数据集格式,但是当我尝试从数据集访问输入时出现错误。
PyTorch 数据集
数据集返回一个字典,其中包含:ids、mask、token_type_ids、targets
class JigsawDataset:
def __init__(self, df, train_transforms = None):
self.comment_text = df["comment_text"].values
self.target = df["toxic"].values
self.tokenizer = config.BERT_TOKENIZER
self.max_len = config.MAX_LEN
self.langs = df["lang"].values
self.train_transforms = train_transforms
def __len__(self):
return len(self.comment_text)
def __getitem__(self, item):
comment_text = str(self.comment_text[item])
comment_text = " ".join(comment_text.split())
lang = self.langs[item]
if self.train_transforms:
comment_text, _ = self.train_transforms(data=(comment_text, lang))['data']
inputs = self.tokenizer.encode_plus(
comment_text,
None,
add_special_tokens=True,
max_length=self.max_len,
pad_to_max_length=True,
truncation=True,
)
ids = inputs["input_ids"]
mask = inputs["attention_mask"]
token_type_ids = inputs["token_type_ids"]
data_loader_dict = {}
data_loader_dict["ids"] = torch.tensor(ids, dtype=torch.long)
data_loader_dict["mask"] = torch.tensor(mask, dtype=torch.long)
data_loader_dict["token_type_ids"] = torch.tensor(token_type_ids, dtype=torch.long)
data_loader_dict["targets"] = torch.tensor(self.target[item], dtype=torch.float)
return data_loader_dict
产生错误的相关代码
在这种情况下,我尝试仅加载 1 个样本并将其设置为 PyTorch 数据集的格式
df = pd.read_csv("dataset.csv")
df = df.head(1) # Trying with only 1 Sample
dataset = JigsawDataset(df)
ids = dataset["ids"] # Error occurs at this line
mask = dataset["mask"]
token_type_ids = ["token_type_ids"]
错误
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-78-4608dd623cac> in <module>
3 dataset = JigsawDataset(df)
4
----> 5 ids = dataset["ids"] # Error occurs at this line
6 mask = dataset["mask"]
7 token_type_ids = ["token_type_ids"]
<ipython-input-40-121d8aa71516> in __getitem__(self, item)
13
14 def __getitem__(self, item):
---> 15 comment_text = str(self.comment_text[item])
16 comment_text = " ".join(comment_text.split())
17 lang = self.langs[item]
IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices
如何解决这个问题?
【问题讨论】:
-
请添加Minimal Reproducible Example,以便我们为您提供更多帮助。
标签: python pandas numpy pytorch huggingface-transformers