【发布时间】:2019-05-29 14:26:22
【问题描述】:
我是 TensorFlow 的新手,并尝试使用 Estimator API 进行一些简单的分类实验。我在libsvm format 中有一个稀疏数据集。以下输入函数适用于小型数据集:
def libsvm_input_function(file):
def input_function():
indexes_raw = []
indicators_raw = []
values_raw = []
labels_raw = []
i=0
for line in open(file, "r"):
data = line.split(" ")
label = int(data[0])
for fea in data[1:]:
id, value = fea.split(":")
indexes_raw.append([i,int(id)])
indicators_raw.append(int(1))
values_raw.append(float(value))
labels_raw.append(label)
i=i+1
indexes = tf.SparseTensor(indices=indexes_raw,
values=indicators_raw,
dense_shape=[i, num_features])
values = tf.SparseTensor(indices=indexes_raw,
values=values_raw,
dense_shape=[i, num_features])
labels = tf.constant(labels_raw, dtype=tf.int32)
return {"indexes": indexes, "values": values}, labels
return input_function
但是,对于几 GB 大小的数据集,我收到以下错误:
ValueError: 无法创建内容大于 2GB 的张量原型。
如何避免此错误?我应该如何编写一个输入函数来读取中型稀疏数据集(libsvm 格式)?
【问题讨论】:
标签: tensorflow