【发布时间】:2019-05-28 11:48:30
【问题描述】:
我正在将 python 3 与 anaconda 一起使用,并尝试将 tf.contrib 损失函数与 Keras 模型一起使用。
代码如下
from keras.layers import Dense, Flatten
from keras.optimizers import Adam
from keras.models import Sequential
from tensorflow.contrib.losses import metric_learning
model = Sequential()
model.add(Flatten(input_shape=input_shape))
model.add(Dense(50, activation="relu"))
model.compile(loss=metric_learning.triplet_semihard_loss, optimizer=Adam())
我收到以下错误:
文件 "/home/user/.local/lib/python3.6/site-packages/keras/engine/training_utils.py", 第 404 行,加权 score_array = fn(y_true, y_pred) 文件“/home/user/anaconda3/envs/siamese/lib/python3.6/site-packages/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py”, 第 179 行,在 Triplet_semihard_loss 中 断言 lshape.shape == 1 断言错误
当我使用带有 keras 损失函数的同一个网络时,它工作正常,我尝试将 tf 损失函数包装在这样的函数中
def func(y_true, y_pred):
import tensorflow as tf
return tf.contrib.losses.metric_learning.triplet_semihard_loss(y_true, y_pred)
仍然出现同样的错误
我在这里做错了什么?
更新: 当更改 func 以返回以下内容时
return K.categorical_crossentropy(y_true, y_pred)
一切正常! 但我不能让它与特定的 tf 损失函数一起工作......
当我进入 tf.contrib.losses.metric_learning.triplet_semihard_loss 并删除这行代码时:assert lshape.shape == 1 它运行良好
谢谢
【问题讨论】:
-
仍然不清楚到底在哪里你的错误会弹出;是在
fit期间吗?在compile?发布完整的错误跟踪是个好主意... -
@desertnaut 错误在编译函数中。当我进入 tf.contrib.losses.metric_learning.triplet_semihard_loss 并删除这行代码时: assert lshape.shape == 1 它运行良好
-
您好,我也有同样的问题,但解决方案变得如此简单。您只需替换参数。首先设置标签,然后设置嵌入。
-
@thebeancounter 嘿!你能解决吗?我也遇到了同样的问题,不知道怎么办?
标签: python tensorflow machine-learning keras deep-learning