【发布时间】:2021-07-09 07:35:00
【问题描述】:
我正在尝试编写一个在数据集上运行 KMeans 并输出集群质心的函数。我的目标是在自定义的keras 层中使用它,所以我使用TensorFlow 的 KMeans 实现,它将张量作为输入数据集。
但是,我的问题是,即使作为独立功能,我也无法使其工作。问题来自KMeans 接受一个 generator function 提供小批量而不是普通张量的事实,但是当我使用闭包来执行此操作时,我得到一个 graph disconnected 错误:
import tensorflow as tf # version: 2.4.1
from tensorflow.compat.v1.estimator.experimental import KMeans
@tf.function
def KMeansCentroids(inputs, num_clusters, steps, use_mini_batch=False):
# `inputs` is a 2D tensor
def input_fn():
# Each one of the lines below results in the same "Graph Disconnected" error. Tuples don't really needed but just to be consistent with the documentation
return (inputs, None)
return (tf.data.Dataset.from_tensor_slices(inputs), None)
return (tf.convert_to_tensor(inputs), None)
kmeans = KMeans(
num_clusters=num_clusters,
use_mini_batch=use_mini_batch)
kmeans.train(input_fn, steps=steps) # This is where the error happens
return kmeans.cluster_centers()
>>> x = tf.random.uniform((100, 2))
>>> c = KMeansCentroids(x, 5, 10)
确切的错误是:
值错误:
Tensor("strided_slice:0", shape=(), dtype=int32)必须来自同一图表Tensor("Equal:0", shape=(), dtype=bool)(图表为FuncGraph(name=KMeansCentroids, id=..)和<tensorflow.python.framework.ops.Graph object at ...>)。
- 如果我要使用
numpy数据集并在函数内转换为张量,代码就可以正常工作。 - 另外,使
input_fn()直接返回tf.random.uniform((100, 2))(忽略输入参数)将再次起作用。这就是为什么我猜测 tensorflow 不支持闭包,因为它需要在一开始就构建计算图。
但我不知道如何解决这个问题。 由于 KMeans 是compat.v1.experimental模块,会不会是版本错误?
请注意,documentation of KMeans 状态为 input_fn():
该函数应构造并返回以下之一:
- tf.data.Dataset 对象:Dataset 对象的输出必须是具有以下相同约束的元组(特征、标签)。
- 元组(特征、标签):其中 features 是 tf.Tensor 或字符串特征名称到 Tensor 的字典,标签是 Tensor 或字符串标签名称到 Tensor 的字典。特征和标签都由 model_fn 使用。它们应该满足输入对 model_fn 的期望。
【问题讨论】:
标签: tensorflow machine-learning keras k-means