【问题标题】:Pytorch transform tensor to one hotPytorch 将张量转换为一个热
【发布时间】:2020-09-26 10:06:30
【问题描述】:

将填充有n 值的形状张量(batch_size,height,width)转换为形状张量(batch_size,n,height,width)的最简单方法是什么? 我在下面创建了解决方案,但看起来有更简单快捷的方法来做到这一点


def batch_tensor_to_onehot(tnsr, classes):
    tnsr = tnsr.unsqueeze(1)
    res = []
    for cls in range(classes):
        res.append((tnsr == cls).long())
    return torch.cat(res, dim=1)

【问题讨论】:

    标签: python pytorch tensor one-hot-encoding


    【解决方案1】:

    您也可以使用Tensor.scatter_,它可以避免.permute,但可以说比@Alpha 提出的直接方法更难理解。

    def batch_tensor_to_onehot(tnsr, classes):
        result = torch.zeros(tnsr.shape[0], classes, *tnsr.shape[1:], dtype=torch.long, device=tnsr.device)
        result.scatter_(1, tnsr.unsqueeze(1), 1) 
        return result
    

    基准测试结果

    我很好奇,决定对这三种方法进行基准测试。我发现建议的方法在批量大小、宽度或高度方面似乎没有显着的相对差异。主要是类的数量是区分因素。当然,与任何基准里程一样,里程可能会有所不同。

    使用随机索引并使用批量大小、高度、宽度 = 100 收集基准。每个实验重复 20 次,并报告平均值。 num_classes=100 实验在分析预热之前运行一次。

    CPU 结果表明,对于 num_classes 小于 30 左右,原始方法可能是最佳的,而对于 GPU,scatter_ 方法似乎是最快的。

    在 Ubuntu 18.04、NVIDIA 2060 Super、i7-9700K 上执行的测试

    用于基准测试的代码如下:

    import torch
    from tqdm import tqdm
    import time
    import matplotlib.pyplot as plt
    
    
    def batch_tensor_to_onehot_slavka(tnsr, classes):
        tnsr = tnsr.unsqueeze(1)
        res = []
        for cls in range(classes):
            res.append((tnsr == cls).long())
        return torch.cat(res, dim=1)
    
    
    def batch_tensor_to_onehot_alpha(tnsr, classes):
        result = torch.nn.functional.one_hot(tnsr, num_classes=classes)
        return result.permute(0, 3, 1, 2)
    
    
    def batch_tensor_to_onehot_jodag(tnsr, classes):
        result = torch.zeros(tnsr.shape[0], classes, *tnsr.shape[1:], dtype=torch.long, device=tnsr.device)
        result.scatter_(1, tnsr.unsqueeze(1), 1)
        return result
    
    
    def main():
        num_classes = [2, 10, 25, 50, 100]
        height = 100
        width = 100
        bs = [100] * 20
    
        for d in ['cpu', 'cuda']:
            times_slavka = []
            times_alpha = []
            times_jodag = []
            warmup = True
            for c in tqdm([num_classes[-1]] + num_classes, ncols=0):
                tslavka = 0
                talpha = 0
                tjodag = 0
    
                for b in bs:
                    tnsr = torch.randint(c, (b, height, width)).to(device=d)
    
                    t0 = time.time()
                    y = batch_tensor_to_onehot_slavka(tnsr, c)
                    torch.cuda.synchronize()
                    tslavka += time.time() - t0
                if not warmup:
                    times_slavka.append(tslavka / len(bs))
    
                for b in bs:
                    tnsr = torch.randint(c, (b, height, width)).to(device=d)
    
                    t0 = time.time()
                    y = batch_tensor_to_onehot_alpha(tnsr, c)
                    torch.cuda.synchronize()
                    talpha += time.time() - t0
                if not warmup:
                    times_alpha.append(talpha / len(bs))
    
                for b in bs:
                    tnsr = torch.randint(c, (b, height, width)).to(device=d)
    
                    t0 = time.time()
                    y = batch_tensor_to_onehot_jodag(tnsr, c)
                    torch.cuda.synchronize()
                    tjodag += time.time() - t0
                if not warmup:
                    times_jodag.append(tjodag / len(bs))
    
    
                warmup = False
    
            fig = plt.figure()
            ax = fig.subplots()
            ax.plot(num_classes, times_slavka, label='Slavka-cat')
            ax.plot(num_classes, times_alpha, label='Alpha-one_hot')
            ax.plot(num_classes, times_jodag, label='jodag-scatter_')
            ax.set_xlabel('num_classes')
            ax.set_ylabel('time (s)')
            ax.set_title(f'{d} benchmark')
            ax.legend()
            plt.savefig(f'{d}.png')
            plt.show()
    
    
    if __name__ == "__main__":
        main()
    

    【讨论】:

      【解决方案2】:

      您可以使用torch.nn.functional.one_hot

      对于您的情况:

      a = torch.nn.functional.one_hot(tnsr, num_classes=classes)
      out = a.permute(0, 3, 1, 2)
      

      【讨论】:

        猜你喜欢
        • 2022-10-17
        • 2019-07-29
        • 2019-10-24
        • 2020-08-05
        • 2021-11-25
        • 1970-01-01
        • 1970-01-01
        • 2019-05-23
        相关资源
        最近更新 更多