【发布时间】:2018-05-07 20:46:40
【问题描述】:
我有一个从 tensorflow SparseTensorValue 获取批次的方法(如下所示)。但是,这种方法相当慢(批量大小为 32 的批次需要 10-20 秒),这是有问题的,因为它被调用了数千次。
def get_batch(index, tensors, batch_size, nItems):
xs, ys = tensors
begin = (index * batch_size)
end = min((index+1)*batch_size, nItems)
y_b = ys[begin:end]
(inds, vals, dsize) = xs
nInds = [[ind[0] - begin, ind[1]] for ind in inds if begin <= ind[0] < end]
nInds = np.array(nInds)
nVals = vals[:nInds.shape[0]]
nDsize = (end - begin, dsize[1])
x_b = tf.SparseTensorValue(nInds, nVals, nDsize)
return (x_b, y_b)
有没有办法让这种方法更高效?
【问题讨论】:
标签: tensorflow