【问题标题】:model.predict_classes is deprecated - What to use instead?不推荐使用 model.predict_classes - 改用什么?
【发布时间】:2021-08-13 18:17:41
【问题描述】:

我一直在尝试重新访问我的 python 代码以在神经网络上进行预测,运行代码后我意识到 model.predict_classes 自 2021 年 1 月 1 日起已弃用。

请您支持我知道我可以用什么代替我的代码?

代码行是:

y_pred_nn = model.predict_classes(X_test)

问题:

NameError
Traceback (most recent call last)
<ipython-input-11-fc1ddbecb622> in <module>
----> 1 print(y_pred_nn)

NameError: name 'y_pred_nn' is not defined

【问题讨论】:

标签: python tensorflow keras neural-network


【解决方案1】:

如何处理这个问题的最佳解释如下:

https://androidkt.com/get-class-labels-from-predict-method-in-keras/

首先使用model.predict() 提取类概率。然后根据类的数量执行以下操作:

二元分类

使用阈值选择将确定类别 0 或 1 的概率

np.where(y_pred &gt; threshold, 1,0)

例如使用 0.5 的阈值

多类分类

选择概率最高的类

np.argmax(predictions, axis=1)

多标签分类

如果每个示例可以有多个输出类,请使用阈值来选择应用哪些标签。

y_pred = model.predict(x, axis=1)
[i for i,prob in enumerate(y_pred) if prob > 0.5]

【讨论】:

    【解决方案2】:

    如果您的模型执行多类分类(例如,如果它使用 softmax 最后一层激活),请使用: np.argmax(model.predict(x), axis=-1)

    如果您的模型执行二元分类(例如,如果它使用 sigmoid 最后一层激活),请使用:

    (model.predict(X_test) &gt; 0.5).astype("int32")

    【讨论】:

      【解决方案3】:

      model.predict_classes 已弃用,正如@Kaveh 所述,请改用model.predict() 函数。

      np.argmax(model.predict(x_test), axis=-1)

      要了解有关 model.predict() 的更多信息,请查看link

      【讨论】:

      • 这仅涵盖多类分类情况,请参阅我对二元或多标签分类的回答。
      猜你喜欢
      • 1970-01-01
      • 2012-01-11
      • 1970-01-01
      • 1970-01-01
      • 2016-02-23
      • 2017-11-04
      • 2011-10-22
      • 2011-04-11
      • 2021-10-12
      相关资源
      最近更新 更多