【发布时间】:2019-10-13 17:23:53
【问题描述】:
我找到了这段代码,它运行良好。这个想法 - 拆分我的数据并在其上训练 KMeansClustering。所以我创建了 InitHook 和迭代器并将其用于训练。
class _IteratorInitHook(tf.train.SessionRunHook):
"""Hook to initialize data iterator after session is created."""
def __init__(self):
super(_IteratorInitHook, self).__init__()
self.iterator_initializer_fn = None
def after_create_session(self, session, coord):
"""Initialize the iterator after the session has been created."""
del coord
self.iterator_initializer_fn(session)
# Run K-means clustering.
def _get_input_fn():
"""Helper function to create input function and hook for training.
Returns:
input_fn: Input function for k-means Estimator training.
init_hook: Hook used to load data during training.
"""
init_hook = _IteratorInitHook()
def _input_fn():
"""Produces tf.data.Dataset object for k-means training.
Returns:
Tensor with the data for training.
"""
features_placeholder = tf.placeholder(tf.float32,
my_data.shape)
delf_dataset = tf.data.Dataset.from_tensor_slices((features_placeholder))
delf_dataset = delf_dataset.shuffle(1000).batch(
my_data.shape[0])
iterator = delf_dataset.make_initializable_iterator()
def _initializer_fn(sess):
"""Initialize dataset iterator, feed in the data."""
sess.run(
iterator.initializer,
feed_dict={features_placeholder: my_data})
init_hook.iterator_initializer_fn = _initializer_fn
return iterator.get_next()
return _input_fn, init_hook
input_fn, init_hook = _get_input_fn()
output_cluster_dir = 'parameters/clusters'
kmeans = tf.contrib.factorization.KMeansClustering(
num_clusters=1024,
model_dir=output_cluster_dir,
use_mini_batch=False,
)
print('Starting K-means clustering...')
kmeans.train(input_fn, hooks=[init_hook])
但如果我将 num_clusters 更改为 512 或 256,我会收到下一个错误:
InvalidArgumentError:segment_ids[0] = 600 超出范围 [0, 256)
[[节点UnsortedSegmentSum(定义在 /home/mikhail/.conda/envs/tf2/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py:1112) ]] [[节点挤压(定义在 /home/mikhail/.conda/envs/tf2/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py:1112) ]]
看起来我在将数据拆分为批次时遇到了一些问题,或者我的 KMeans 默认使用 1024 个集群,即使我设置了另一个值!
我不知道要进行哪些更改才能使其正常工作。 Traceback 很大,如果需要我可以附加为文件。
【问题讨论】:
-
您是否在使用不同的簇数重新运行之间清理图表?
-
@GPhilo 不!这是否意味着我仍然拥有包含 1024 个集群的相同图表?如何清除?
-
tf.reset_default_graph() -
@GPhilo 我在没有图表的情况下使用它。只需编写有问题的代码并运行。因此,如果我添加该行,则没有任何变化。我也尝试将我的代码添加到
with tf.Graph().as_default():但也不起作用:( -
@GPhilo 关闭 IDE 并打开它(以清除所有数据) - 如果我一开始使用 256 个集群,仍然会出现同样的错误。
标签: python tensorflow batch-processing k-means