【问题标题】:When to use model.predict(x) vs model(x) in tensorflow何时在 tensorflow 中使用 model.predict(x) 与 model(x)
【发布时间】:2020-05-26 07:42:29
【问题描述】:

我有一个keras.models.Model,我用tf.keras.models.load_model 加载。

现在有两个选项可以使用此模型。我可以打电话给model.predict(x) 或者我可以打电话给model(x).numpy()。这两个选项都给了我相同的结果,但 model.predict(x) 的运行时间要长 10 倍以上。

source code 状态的 cmets:

计算是分批完成的。这种方法是为性能而设计的 大规模投入。对于适合一批的少量输入, 建议直接使用__call__ 以加快执行速度,例如, model(x),或model(x, training=False)

我已经用包含 1 的 x 进行了测试; 1,000,000;和 10,000,000 行和 model(x) 仍然表现更好。

输入需要多大才能被归类为大规模输入,model.predict(x) 才能更好地执行?

【问题讨论】:

  • This 可能会有所帮助

标签: tensorflow keras


【解决方案1】:

现有的堆栈溢出答案可能对您有用:https://stackoverflow.com/a/58385156/5666087。我在tensorflow/tensorflow#33340 上找到了它。该答案建议将experimental_run_tf_function=False 传递给model.compile 调用以恢复到模型执行的TF 1.x 版本。您也可以完全省略 model.compile 调用(预测不需要)。

输入需要多大才能被归类为大规模输入,model.predict(x) 才能更好地执行?

这是您可以测试的。正如文档所述,如果您的数据适合一批,model(x) 可能会比model.predict(x) 更快。 model.predict(x) 提供超过 model(x) 的一件事是能够预测多个批次。如果您想使用model(x) 预测多个批次,您必须自己编写循环。 model.predict 还提供其他功能,例如回调。

仅供参考,源代码中的文档是在提交 42f469be0f3e8c36624f0b01c571e7ed15f75faf 中添加的,这是 tensorflow/tensorflow#33340 的结果。

model.predict(x) 的主要行为实现here。它不仅包含模型的正向传递。这可能会导致一些速度差异。

我已经用包含 1 的 x 进行了测试; 1,000,000;和 10,000,000 行和 model(x) 仍然表现更好。

这 10,000,000 行是否适合一个批次...?

【讨论】:

  • 谢谢你。我想我对这一切的了解还不够,无法真正理解正在发生的事情。批次是我可以放入(GPU)内存中的任何数量的数据吗?
  • 批量大小是模型一次迭代中使用的示例数。您可以将批量大小设置为您的 GPU 可以容纳的任何大小,但是批量大小过大会导致过度拟合 (Keskar et al., 2016)。
猜你喜欢
  • 1970-01-01
  • 2020-03-05
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2021-10-31
  • 2020-08-24
  • 2020-05-06
  • 1970-01-01
相关资源
最近更新 更多