【问题标题】:Getting nans for gradient获取渐变的 nans
【发布时间】:2021-09-24 21:52:41
【问题描述】:

我正在尝试创建一个搜索相关性模型,在该模型中获取查询向量和结果文档之间的点积。我在顶部添加了一个位置偏差项,以考虑到位置 1 更有可能被点击的事实。最终(非归一化)对数似然计算如下:

        query = self.query_model(query_input_ids, query_attention_mask)
        docs = self.doc_model(doc_input_ids, doc_attention_mask)
        positional_bias = self.position_model()
        
        if optimizer_idx is not None:
            if optimizer_idx == 0:
                docs = docs.detach()
                positional_bias = positional_bias.clone().detach()
            elif optimizer_idx == 1:
                query = query.detach()
                positional_bias = positional_bias.clone().detach()
            else:
                query = query.detach()
                docs = docs.detach()
                
        similarity = (docs @ query.unsqueeze(-1)).squeeze()

        click_log_lik = (similarity + positional_bias)\
                .reshape(doc_mask.shape)\
                .masked_fill_((1 - doc_mask).bool(), float("-inf"))

查询和文档模型只是一个 distilbert 模型,在 CLS 令牌之上有一个投影层。模型可以在这里看到:https://pastebin.com/g21g9MG3

在检查第一个梯度下降步骤时,它有nans,但仅适用于查询模型而不是文档模型。 我的假设是规范化 doc 和查询模型 (return F.normalize(out, dim=-1)) 的返回值在某种程度上会影响梯度。

有谁知道1.如果我的假设是正确的,更重要的是 2.我怎样才能纠正 nan 梯度?

附加信息:

  • 没有任何损失是 inf 或 nan。
  • 查询是 BS x 768
  • 文档是 BS x DOC_RESULTS x 768
  • positional_bias 是 DOC_RESULTS
  • 在我的情况下,DOC_RESULTS 是 10。
  • 最后一行中的masked_fill 是因为有时我的查询数据点少于 10 个。

更新1

以下更改对 nans 没有影响:

  • masked_fill-inf 更改为 1e5
  • 将投影从 F.normalize(out, dim=-1) 更改为 out / 100
  • 完全消除了位置偏差,再次失败。

【问题讨论】:

    标签: tensorflow deep-learning pytorch huggingface-transformers


    【解决方案1】:

    如果它对任何人有帮助,而你在使用 Transformers 时遇到了这个问题,我就是这样做的:

    所以最后这个错误是由于我掩盖了nan的事实。由于我有一些长度为零的文档,因此转换器的输出为 nan。我希望masked_fill 能解决这个问题,但事实并非如此。在我的情况下,解决方案是只通过转换器放置非零长度序列,然后附加零以填充批量大小。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2012-12-07
      • 2020-01-05
      • 1970-01-01
      • 2011-03-19
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多