【问题标题】:How to calculate one hot encoding of a floating point value in a Keras graph如何计算 Keras 图中浮点值的一种热编码
【发布时间】:2021-12-28 17:34:20
【问题描述】:

直截了当,这是我的 Keras 模型:

from tensorflow.keras import layers, models, Input, backend as K
import tensorflow as tf

input_layer = Input(shape=(1,), name="input")

x = layers.Dense(128, activation='relu', name="dense_1")(input_layer)
x = layers.Dense(1024, activation='softmax', name="dense_2")(x)

model = models.Model(input_layer, x)
model.summary()

数据集的每个输入都是一个浮点数以及目标。但是该模型会生成一个分类值作为输出,即目标的光栅化。例如,假设目标是 10 到 20 之间的数字。我想通过将范围划分为 6 个类别来使其成为分类值:

[10, 12) -> 0 -> [1, 0, 0, 0, 0, 0]
[12, 14) -> 1 -> [0, 1, 0, 0, 0, 0]
[14, 16) -> 2 -> [0, 0, 1, 0, 0, 0]
[16, 18) -> 3 -> [0, 0, 0, 1, 0, 0]
[18, 20) -> 4 -> [0, 0, 0, 0, 1, 0]
[20, 20] -> 5 -> [0, 0, 0, 0, 0, 1]

我知道我可以预处理数据集并将所有目标更改为 one-hot 向量,然后将它们输入到我的模型中。但出于教育目的,我想将我的目标保持为浮点数,并将它们转换为模型管道中的 one-hot 向量。

为此,我编写了一个自定义损失函数:

def one_hot_loss(x, y):
    min = 10
    max = 20
    steps = 2
    num_classes = (max - min) / steps
    transformed = (y - min) / (max - min) * num_classes
    transformed_int = K.cast(transformed, "uint8")
    one_hot = tf.one_hot(transformed_int, depth=int(num_classes))
    return tf.keras.losses.sparse_categorical_crossentropy(x, one_hot)

model.compile(optimizer='rmsprop', loss=one_hot_loss)

为了测试这一点,我尝试调用fit 方法:

model.fit([[1.0]], [[15.0]])

这面临以下错误:

ValueError: No gradients provided for any variable:
  (['dense_1/kernel:0', 'dense_1/bias:0', 'dense_2/kernel:0', 'dense_2/bias:0'],).
  Provided `grads_and_vars` is
    ((None, <tf.Variable 'dense_1/kernel:0' shape=(1, 128) dtype=float32>),
    (None, <tf.Variable 'dense_1/bias:0' shape=(128,) dtype=float32>),
    (None, <tf.Variable 'dense_2/kernel:0' shape=(128, 1024) dtype=float32>),
    (None, <tf.Variable 'dense_2/bias:0' shape=(1024,) dtype=float32>)).

我的问题是,为什么没有提供渐变?我错过了什么?

【问题讨论】:

    标签: tensorflow keras


    【解决方案1】:

    我认为问题是tf.cast(或K.cast)不可区分。您可以尝试找到一些可微分的圆形函数,如下所示:https://stackoverflow.com/questions/46596636/differentiable-round-function-in-tensorflow."

    【讨论】:

    • 谢谢。您很可能是对的,但问题是 tf.not_hot 只接受 int!
    猜你喜欢
    • 2019-11-03
    • 2020-01-16
    • 1970-01-01
    • 2020-08-28
    • 1970-01-01
    • 2021-04-28
    • 2023-04-10
    • 2021-12-01
    • 2018-12-17
    相关资源
    最近更新 更多