【问题标题】:Python - PyTorch: IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indicesPython - PyTorch:IndexError:只有整数、切片(`:`)、省略号(`...`)、numpy.newaxis(`None`)和整数或布尔数组是有效的索引
【发布时间】:2021-08-29 12:35:03
【问题描述】:

我正在使用 PyTorch 处理 BERT 的文本分类问题。这是我正在使用的 PyTorch 数据集格式,但是当我尝试从数据集访问输入时出现错误。

PyTorch 数据集

数据集返回一个字典,其中包含:idsmasktoken_type_idstargets

class JigsawDataset:
    def __init__(self, df, train_transforms = None):
        self.comment_text = df["comment_text"].values
        self.target = df["toxic"].values
        self.tokenizer = config.BERT_TOKENIZER
        self.max_len = config.MAX_LEN
        self.langs = df["lang"].values
        self.train_transforms = train_transforms

    def __len__(self):
        return len(self.comment_text)

    def __getitem__(self, item):
        comment_text = str(self.comment_text[item])
        comment_text = " ".join(comment_text.split())
        lang = self.langs[item]
        
        if self.train_transforms:
            comment_text, _ = self.train_transforms(data=(comment_text, lang))['data']

        inputs = self.tokenizer.encode_plus(
            comment_text,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            pad_to_max_length=True,
            truncation=True,
        )

        ids = inputs["input_ids"]
        mask = inputs["attention_mask"]
        token_type_ids = inputs["token_type_ids"]

        data_loader_dict = {}
        data_loader_dict["ids"] = torch.tensor(ids, dtype=torch.long)
        data_loader_dict["mask"] = torch.tensor(mask, dtype=torch.long)
        data_loader_dict["token_type_ids"] = torch.tensor(token_type_ids, dtype=torch.long)
        data_loader_dict["targets"] = torch.tensor(self.target[item], dtype=torch.float)
        
        
        return data_loader_dict

产生错误的相关代码

在这种情况下,我尝试仅加载 1 个样本并将其设置为 PyTorch 数据集的格式

df = pd.read_csv("dataset.csv")
df = df.head(1) # Trying with only 1 Sample
dataset = JigsawDataset(df)

ids = dataset["ids"]    # Error occurs at this line
mask = dataset["mask"]
token_type_ids = ["token_type_ids"]

错误

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-78-4608dd623cac> in <module>
      3 dataset = JigsawDataset(df)
      4 
----> 5 ids = dataset["ids"]    # Error occurs at this line
      6 mask = dataset["mask"]
      7 token_type_ids = ["token_type_ids"]

<ipython-input-40-121d8aa71516> in __getitem__(self, item)
     13 
     14     def __getitem__(self, item):
---> 15         comment_text = str(self.comment_text[item])
     16         comment_text = " ".join(comment_text.split())
     17         lang = self.langs[item]

IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices

如何解决这个问题?

【问题讨论】:

标签: python pandas numpy pytorch huggingface-transformers


【解决方案1】:

根据the docs,panda 的 DataFrame 对象的 'values' 方法返回一个 numpy 数组。
在您的代码中,您将属性 'self.comment_text' 设置为 'df["comment_text"].values' 返回的 numpy 数组(代码框 1 中的第 3 行)。
Numpy 数组不接受字符串作为索引。
很难给你一个答案,我敢肯定不测试它就可以工作,但我会从将属性“self.comment_text”设置为数据框或它的副本开始,而不仅仅是它所拥有的值。

我会改变这个:

class JigsawDataset:
    def __init__(self, df, train_transforms = None):
        self.comment_text = df["comment_text"].values
        self.target = df["toxic"].values
        self.tokenizer = config.BERT_TOKENIZER
.
.
.

到这里:

class JigsawDataset:
    def __init__(self, df, train_transforms = None):
        self.comment_text = df["comment_text"]
        self.target = df["toxic"].values
        self.tokenizer = config.BERT_TOKENIZER
.
.
.

【讨论】:

  • 我不太确定我应该根据您的建议做出哪些改变。能否请您详细说明一下或演示一个小例子?
【解决方案2】:

我发现了问题。

代码不正确

ids = dataset["ids"]    
mask = dataset["mask"]
token_type_ids = ["token_type_ids"]

正确代码

ids = dataset[0]["ids"]    
mask = dataset[0]["mask"]
token_type_ids = [0]["token_type_ids"]

问题是“ids”、“mask”和“token_type_ids”是字典键。 JigsawDataset 为每个样本返回一个字典。因此,为了访问样本,我们需要在指定键之前指定索引 ([0])。

【讨论】:

    猜你喜欢
    • 2016-04-29
    • 1970-01-01
    • 2019-09-26
    • 2018-03-11
    • 1970-01-01
    • 2017-12-08
    • 2019-03-13
    • 1970-01-01
    • 2017-11-24
    相关资源
    最近更新 更多