【问题标题】:Getting an error when using tf.keras.metrics.Mean in functional Keras API在功能性 Keras API 中使用 tf.keras.metrics.Mean 时出错
【发布时间】:2021-07-12 22:09:01
【问题描述】:

我正在尝试向 Keras 函数模型 (Tensorflow 2.5) 添加平均指标,但出现以下错误:

ValueError: Expected a symbolic Tensor for the metric value, received: tf.Tensor(0.0, shape=(), dtype=float32)

代码如下:

x = [1, 2, 3, 4, 5, 6, 7, 8]
y = [5 + i * 3 for i in x]
a = Input(shape=(1,))
output = Dense(1)(a)
model = Model(inputs=a,outputs=output)
model.add_metric(tf.keras.metrics.Mean()(output))
model.compile(loss='mse')
model.fit(x=x, y=y, epochs=100)

如果我删除以下行(引发异常的行):

model.add_metric(tf.keras.metrics.Mean()(output))

代码按预期工作。

我尝试禁用 Eager Execution,但我收到以下错误:

ValueError: Using the result of calling a `Metric` object when calling `add_metric` on a Functional Model is not supported. Please pass the Tensor to monitor directly.

上述用法几乎是从tf.keras.metrics.Mean 文档中复制而来的(请参阅使用 compile() API

【问题讨论】:

    标签: python tensorflow keras deep-learning metrics


    【解决方案1】:

    我找到了一种绕过问题的方法,完全避免使用model.add_metric,并将Metric 对象传递给compile() 方法。
    但是,当传递tf.keras.metrics.Mean 的实例时如下:

    model.compile(loss='mse', metrics=tf.keras.metrics.Mean())
    

    我从compile() 方法得到以下错误:

    TypeError: update_state() got multiple values for argument 'sample_weight'
    

    为了解决这个问题,我不得不扩展 tf.keras.metrics.Mean 并更改 update_state 的签名以匹配预期的签名。
    这是最终(工作)代码:

    class FixedMean(tf.keras.metrics.Mean):
        def update_state(self, y_true, y_pred, sample_weight=None):
            super().update_state(y_pred, sample_weight=sample_weight)
    
    x = [1, 2, 3, 4, 5, 6, 7, 8]
    y = [5 + i * 3 for i in x]
    a = Input(shape=(1,))
    output = Dense(1)(a)
    model = Model(inputs=a,outputs=output)
    model.compile(loss='mse', metrics=FixedMean())
    model.fit(x=x, y=y, epochs=100)
    

    【讨论】:

      猜你喜欢
      • 2020-02-29
      • 2021-01-11
      • 1970-01-01
      • 2021-02-22
      • 2022-01-23
      • 2019-05-07
      • 2018-10-18
      • 1970-01-01
      • 2019-11-03
      相关资源
      最近更新 更多