【问题标题】:Tensorflow Dataset api issueTensorflow 数据集 api 问题
【发布时间】:2018-03-10 08:51:45
【问题描述】:

这是我的代码 sn-p

features=np.array([1,2,3,4,5,6,7],dtype=float)
labels=np.array([1,2,3,4,5,6,7],dtype=float)
training_data=(features,labels)

train_dataset=tf.data.Dataset.from_tensor_slices(training_data)
train_dataset=train_dataset.batch(1)

iter=train_dataset.make_one_shot_iterator()
batch=iter.get_next()

with tf.Session() as sess:
    x,y=batch
    a=x.eval()
    b=y.eval()   
    print(a,"---------->",b)

输出[1] ---------> [2]

预期输出[1] ---------> [1]

我在这上面花了 6 个小时,当我遇到这个问题时,我正在训练 LSTM 模型。我错过了什么?

【问题讨论】:

    标签: tensorflow tensorflow-datasets


    【解决方案1】:

    问题在于,将batch 分解为x, y 后,您不会得到两个普通张量,而是会得到两个迭代器:

    In [15]: batch
    Out[15]:
    (<tf.Tensor 'IteratorGetNext_1:0' shape=(?,) dtype=float64>,
     <tf.Tensor 'IteratorGetNext_1:1' shape=(?,) dtype=float64>)
    

    因此,x.eval() 将迭代器增加 1,y.eval() 再次增加迭代器,导致您看到值 (1, 2)

    相反,这样做只运行一次迭代器:

    with tf.Session() as sess:
        a, b = sess.run(batch)
        print(a,"---------->",b)
    

    您应该会看到预期的结果。

    【讨论】:

    • 啊!!我的上帝,我没有看到这一点,非常感谢。
    猜你喜欢
    • 1970-01-01
    • 2019-02-08
    • 1970-01-01
    • 2018-08-22
    • 2019-04-29
    • 2019-10-01
    • 2020-06-24
    • 2018-06-05
    • 2021-05-18
    相关资源
    最近更新 更多