【发布时间】:2022-06-21 03:34:00
【问题描述】:
当我在 PyTorch(1.6 版)的批量多类分类中实现 target 时,我遇到了以下问题。
我得到了一个变量D
最初的想法是生成一个torch.Size([16])的target张量,每个值都是唯一的,对应D中的行,从0到16为[0,1,2,...,15],用于in-批量多类分类。
这可以使用target = torch.LongTensor(torch.arange(16))来完成
但是D 中可能有重复的、非唯一的行,所以我希望D 中的相同的唯一行在target 中具有其唯一索引。比如D的row0、row1、row8相同的token_ids或者vector,其他的行都不同,那么target应该是[0,0,2,3,4,5,6,0,8,9,10,11,12,13,14,15]或者[0,0,1,2,3,4,5,0,6,7,8,9,10,11,12,13],前者还有索引0-15(但是没有 1 和 7),后者的索引在 0-13 之间。
我该如何实现?
【问题讨论】:
-
我没有了解网络的全部内容,但是如果您正在寻找一个唯一的 16 长度数组,其值从 0 到 15,那么您可以使用 randInt 来填充数组吗?
-
目标取决于 D 的行(在我的例子中,它是创建目标向量的源)。
标签: machine-learning pytorch tensor torch