【问题标题】:What's the correct way to use tf.while_loop in a custom loss function?在自定义损失函数中使用 tf.while_loop 的正确方法是什么?
【发布时间】:2021-07-18 19:28:26
【问题描述】:

我打算使用以下函数作为我的训练损失:

import tensorflow as tf

def wrap(dist): 
    return tf.while_loop(
        cond=lambda X: tf.math.abs(X) > 0.5,
        body=lambda X: tf.math.subtract(X, 1.0),
        loop_vars=(dist))


# PBC-aware MSE, period = 1.0 ([0, 1.0])
def custom_loss(y_true, y_pred):
    diff = tf.math.abs(y_true - y_pred)
    diff = tf.nest.flatten(diff)
    diff = tf.vectorized_map(wrap, diff)
    return tf.math.reduce_mean(tf.math.square(diff))

# ...other code for loading data and defining the model

model.compile(optimizer=tf.keras.optimizers.SGD(momentum=0.1),
              loss=custom_loss)

但是我遇到了一堆错误信息。由于日志太长,我将它们放在一个要点中: https://gist.github.com/HanatoK/f75fddd82372f499c37279f1128cad7a

上面代码的等效numpy版本应该是

def wrap_diff2(x, y, period=1.0):
    diff = np.abs(x - y)
    while diff > 0.5 * period:
        diff -= period
    return diff * diff

def custom_loss_numpy(y_true, y_pred):
    diff2 = np.vectorize(wrap_diff2)(y_true, y_pred)
    return np.mean(diff2)

有什么想法吗? 完整的代码示例在 google colab 上共享: https://colab.research.google.com/drive/1ExVHgyKHQfGcpXvo5ZsuBBmzmHzxUekC?usp=sharing

【问题讨论】:

  • 经过一些调试,我发现tf.nest.flatten 只是将张量重新整形为(1,batch_size),但与documentation 一样,它对张量没有任何作用。可能我这里应该使用tf.reshape,但我仍然不知道如何实现该功能。

标签: python tensorflow loss-function


【解决方案1】:

试试这个:

import tensorflow as tf
import numpy as np

def wrap(dist): 
    return tf.while_loop(
        cond=lambda X: tf.math.abs(X) > 0.5,
        body=lambda X: tf.math.subtract(X, 1.0),
        loop_vars=(dist))

def custom_loss(y_true, y_pred):
    diff = tf.math.abs(y_true - y_pred)
    diff = tf.reshape(diff, [-1])
    diff = tf.vectorized_map(wrap, [diff])
    return tf.math.reduce_mean(tf.math.square(diff))

y_true = np.array([[0., 1., 1.0], [0., 0., 0.]])
y_pred = np.array([[1., 1., 1.0], [1., 0., 1.]])
custom_loss(y_true, y_pred).numpy()

【讨论】:

  • 谢谢,您的代码 sn-p 有效,但是当我在训练中尝试它时,我仍然遇到一堆错误。我已经在 google colab 上上传了示例代码:colab.research.google.com/drive/…
  • 您现在面临的错误不是来自这个损失函数,而是来自custom_error。在您的自定义错误函数中更改 diff = tf.reshape(diff, [-1]); diff = tf.vectorized_map(wrap, [diff])
  • 非常感谢!我好奇。您能否详细说明tf.reshapetf.vectorized_map(wrap, [diff]) 的工作原理?
  • 查看文档,我不能说更多。如果您不明白,请发布一个新问题。
猜你喜欢
  • 2019-05-22
  • 1970-01-01
  • 2018-04-27
  • 1970-01-01
  • 2023-04-05
  • 2023-04-10
  • 1970-01-01
  • 2020-08-01
  • 1970-01-01
相关资源
最近更新 更多