【问题标题】:List of tensors and just tensors张量和张量列表
【发布时间】:2020-08-30 14:09:35
【问题描述】:

我正在将代码从 tensorflow 1.x 更新到 2.1.0。

我更改了 tensorflow 1.x 代码

labels = tf.cast(labels, tf.int64)
predict = tf.argmax(input=logits, axis=1)
tf.metrics.accuracy(labels=labels, predictions=predict)

到 tensorflow 2.1.0 代码。

labels = tf.cast(labels, tf.int64)
predict = tf.argmax(input=logits, axis=1)
tf.keras.metrics.Accuracy.update_state(labels, predict) #updated code

但是,当我运行更新后的代码时,出现以下错误。

TypeError: update_state() missing 1 required positional argument: 'y_pred'

所以,我检查了 tensorflow 2.1.0 文档,tf.keras.metrics.Accuracy.update_state() 的参数似乎是一个列表(以 [ , , , ] 的形式)。然后,我搜索了一种将张量转换为列表的方法,即

labels = tf.make_tensor_proto(labels)
labels = tf.make_ndarray(labels)

运行此代码后,出现以下错误。

TypeError: List of Tensors when single Tensor expected

所以,我尝试将张量列表转换为张量

labels = tf.stack(labels)
#or
labels = torch.stack(labels)

tf.stack() 不起作用,因为它给出了相同的初始 TypeError,说更新的代码中缺少“y_pred”。

torch.stack(),然而,给出了以下错误。

TypeError: stack() : argument 'tensors' (position 1) must be tuple of Tensors, not Tensor

所以,我猜torch.stack() 只接受一个元组,不是一个列表。 但是,tf.stack() 似乎接受了一个列表,但它并没有把它变成一个张量?

我的标签和预测是否首先是张量列表?如果是这样,为什么 tf.stack() 不把它们变成张量?如何正确转换标签并进行预测,以便将它们传递到tf.keras.metrics.Accuracy.update_state()

除非绝对必要,否则如果不使用compat.v1.,我将不胜感激。

【问题讨论】:

    标签: python tensorflow keras tensor torch


    【解决方案1】:

    这样试试:

    labels = [0,1]
    logits = np.asarray([[0.9,0.1],[0.1,0.9]])
    
    labels = tf.cast(labels, tf.int64)
    predict = tf.argmax(input=logits, axis=1)
    acc = tf.keras.metrics.Accuracy()
    acc = acc.update_state(y_true=labels, y_pred=predict)
    acc
    

    【讨论】:

    • 上面的代码没有给出任何错误,但是当我运行acc时,它给出了以下结果:<tf.Variable 'UnreadVariable' shape=() dtype=float32, numpy=2.0>
    • 实际上,您的代码有所帮助!在我的原始代码中,我忘记在tf.keras.metrics.Accuracy 之后添加 ()。谢谢!
    • 太棒了!不要忘记投票并接受它作为答案
    • 快速提问。就像我在前面的评论中所说,如果我运行你的代码,acc 会给出<tf.Variable 'UnreadVariable' shape=() dtype=float32, numpy=2.0>。但是,在我的原始代码中,如果我运行相同的代码,它会给出<tf.Variable 'AssignAddVariableOp_1' shape=() dtype=float32>。而且我无法在我的原始代码中得到acc.result().numpy()。你知道为什么acc有区别吗?
    • 在您的原始代码中,您没有更新任何内容...您必须初始化 acc 方法,然后更新,然后在必要时调用
    猜你喜欢
    • 1970-01-01
    • 2016-04-03
    • 2020-08-05
    • 1970-01-01
    • 2019-07-29
    • 1970-01-01
    • 1970-01-01
    • 2020-12-04
    相关资源
    最近更新 更多