【问题标题】:How to avoid 64-bit data types to run my Python Estimator model natively on tfjs-node?如何避免 64 位数据类型在 tfjs-node 上本地运行我的 Python Estimator 模型?
【发布时间】:2020-02-13 06:18:17
【问题描述】:

我在 Python 中创建了一个 tf.estimator 模型和一个 tf.data 管道,并在 TF 2.1 中将其保存为 tf.saved_model 格式。由于 tfjs-node 不支持 int64 或 float64 类型,因此无法加载模型。

在 Tensorboard 上,我观察到一些输入管道 Python 变量被自动声明为 64 位类型。

例如上面的batch_sizeepochs。如何避免这个问题并在 tfjs-node 中加载tf.estimator 模型而不进行转换?

为了重现,

【问题讨论】:

  • 一种可能的解决方法是使用 keras API 训练模型并将其保存为已保存的模型格式。在 tfjs-node 中加载它没有任何问题

标签: python tensorflow tensorflow.js tfjs-node


【解决方案1】:

由于 tfjs-node 不支持 INT64 和 FLOAT64,如果输入/输出 tensor dtypes 为 INT64/FLOAT64,则无法直接加载和执行模型。一种解决方法是使用张量转换函数包装模型:

  1. 在python中,加载原模型,新建一个输入张量为INT32或FLOAT32的模型
  2. 在新模型中将输入张量转换为 DTYPE INT64/FLOAT64
  3. 将转换后的张量馈送到原始模型,并返回模型输出(如果需要,转换值 dtype)
  4. 将此新模型导出为 TF SavedModel

因此,通过这种包装,tfjs-node 支持模型的输入/输出 dtype,尽管它可能会失去一些准确性。

【讨论】:

  • 谢谢康怡。将尝试此解决方法来解除对自己的阻止。有没有计划在 tfjs-node 的 loadSavedModel API 中自动处理这个问题?
  • 还没有计划,因为它需要节点绑定解析模型输入/输出来决定是否转换数据类型,目前还不支持。从长远来看,添加它是一个很好的特性,因为它显然会解除对模型使用不受支持的数据类型的许多用户的阻止。我们将在规划未来路线图时进行讨论。
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 2014-07-02
  • 1970-01-01
  • 2021-12-06
  • 1970-01-01
  • 1970-01-01
  • 2015-09-04
  • 2019-04-22
相关资源
最近更新 更多