【发布时间】:2020-05-24 11:54:16
【问题描述】:
我正在学习 Google 机器学习强化课程。但它使用的是 TensorFlow 1.x 版本,所以我打算更改练习以便能够在 TensorFlow 2.0 中运行它们。但我被困在那个练习中:
具体代码:
def my_input_fn(features, targets, batch_size=1, shuffle=True, num_epochs=None):
"""Trains a linear regression model of one feature.
Args:
features: pandas DataFrame of features
targets: pandas DataFrame of targets
batch_size: Size of batches to be passed to the model
shuffle: True or False. Whether to shuffle the data.
num_epochs: Number of epochs for which data should be repeated. None = repeat indefinitely
Returns:
Tuple of (features, labels) for next data batch
"""
# Convert pandas data into a dict of np arrays.
features = {key:np.array(value) for key,value in dict(features).items()}
# Construct a dataset, and configure batching/repeating.
ds = Dataset.from_tensor_slices((features,targets)) # warning: 2GB limit
ds = ds.batch(batch_size).repeat(num_epochs)
# Shuffle the data, if specified.
if shuffle:
ds = ds.shuffle(buffer_size=10000)
# Return the next batch of data.
features, labels = ds.make_one_shot_iterator().get_next()
return features, labels
我已将features, labels = ds.make_one_shot_iterator().get_next() 替换为features, labels = tf.compat.v1.data.make_one_shot_iterator(ds).get_next()
它似乎可以工作,但 make_one_shot_iterator() 已被弃用,那么,我该如何替换它?
也照https://github.com/tensorflow/tensorflow/issues/29252,我试过了
features, labels = ds.__iter__()
next(ds.__iter__())
return features, labels
但它返回错误__iter __ () is only supported inside of tf.function or when eager execution is enabled.
我在 python 方面非常缺乏经验,并且作为业余爱好者遵循课程。关于如何解决它的任何想法?谢谢。
【问题讨论】:
标签: tensorflow iterator tensorflow-datasets