【问题标题】:Pytorch Operation to detect NaNs检测 NaN 的 Pytorch 操作
【发布时间】:2018-06-17 21:18:15
【问题描述】:

是否有 Pytorch 内部程序来检测张量中的NaNs? Tensorflow 有 tf.is_nantf.check_numerics 操作...... Pytorch 有类似的东西吗?我在文档中找不到类似的东西......

我正在专门寻找 Pytorch 内部例程,因为我希望这在 GPU 和 CPU 上都发生。这不包括基于 numpy 的解决方案(如 np.isnan(sometensor.numpy()).any())...

【问题讨论】:

  • 这可能会有所帮助:x.isnan().any()

标签: pytorch


【解决方案1】:

你总是可以利用nan != nan:

>>> x = torch.tensor([1, 2, np.nan])
tensor([  1.,   2., nan.])
>>> x != x
tensor([ 0,  0,  1], dtype=torch.uint8)

在 pytorch 0.4 中还有 torch.isnan:

>>> torch.isnan(x)
tensor([ 0,  0,  1], dtype=torch.uint8)

【讨论】:

  • 我可以确认它也可以在 GPU 上运行。 .any() 然后将其简化为 Python 布尔值。谢谢:-)
  • 哇!我不知道nan != nan。谢谢!
【解决方案2】:

从 PyTorch 0.4.1 开始,有 detect_anomaly 上下文管理器,它会在反向传播的所有步骤之间自动插入等同于 assert not torch.isnan(grad).any() 的断言。在向后传递过程中出现问题时,它非常有用。

【讨论】:

    【解决方案3】:

    正如@cleros 在对@nemo 答案的评论中所建议的那样,您可以使用any() 运算符将其作为布尔值:

    torch.isnan(your_tensor).any()
    

    【讨论】:

      【解决方案4】:

      如果任何值为 nan,则为真:

      torch.any(tensor.isnan())
      

      如果都是 nan,则为真:

      torch.all(tensor.isnan())
      

      【讨论】:

        【解决方案5】:

        如果你想直接在张量上调用它:

        import torch
        
        x = torch.randn(5, 4)
        print(x.isnan().any())
        

        出来:

        import torch
        x = torch.randn(5, 4)
        print(x.isnan().any())
        tensor(False)
        

        【讨论】:

          猜你喜欢
          • 1970-01-01
          • 2016-07-17
          • 1970-01-01
          • 2019-12-25
          • 1970-01-01
          • 1970-01-01
          • 2021-01-26
          • 1970-01-01
          • 2014-02-09
          相关资源
          最近更新 更多