【发布时间】:2018-03-10 13:50:04
【问题描述】:
这是我的代码! 我的tensorflow版本是1.6.0,python版本是3.6.4。 如果我直接使用数据集读取csv文件,我可以训练并且没有错。但是我将 csv 文件转换为 tfrecords 文件,这是错误的。我在互联网上搜索它,几乎有人说应该更新 tensorflow,但它对我不起作用。
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.INFO)
feature_names = [
'SepalLength',
'SepalWidth',
'PetalLength',
'PetalWidth'
]
def my_input_fn(is_shuffle=False, repeat_count=1):
dataset = tf.data.TFRecordDataset(['csv.tfrecords']) # filename is a list
def parser(record):
keys_to_features = {
'label': tf.FixedLenFeature((), dtype=tf.int64),
'features': tf.FixedLenFeature(shape=(4,), dtype=tf.float32),
}
parsed = tf.parse_single_example(record, keys_to_features)
return parsed['features'], parsed['label']
dataset = dataset.map(parser)
if is_shuffle:
# Randomizes input using a window of 256 elements (read into memory)
dataset = dataset.shuffle(buffer_size=256)
dataset = dataset.batch(32)
dataset = dataset.repeat(repeat_count)
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels
feature_columns = [tf.feature_column.numeric_column(k) for k in feature_names]
classifier = tf.estimator.DNNClassifier(
feature_columns=feature_columns, # The input features to our model
hidden_units=[10, 10], # Two layers, each with 10 neurons
n_classes=3,
model_dir='iris_model_2') # Path to where checkpoints etc are stored
classifier.train(input_fn=lambda: my_input_fn(is_shuffle=True, repeat_count=100))
它返回这个错误信息!
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': 'iris_model_2', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x1163d9f28>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
Traceback (most recent call last):
File "/Users/huanghelin/Desktop/TFrecord/try2.py", line 45, in <module>
classifier.train(input_fn=lambda: my_input_fn(is_shuffle=True, repeat_count=100))
File "/Users/huanghelin/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 352, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "/Users/huanghelin/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 812, in _train_model
features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
File "/Users/huanghelin/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 793, in _call_model_fn
model_fn_results = self._model_fn(features=features, **kwargs)
File "/Users/huanghelin/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/estimator/canned/dnn.py", line 354, in _model_fn
config=config)
File "/Users/huanghelin/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/estimator/canned/dnn.py", line 161, in _dnn_model_fn
'Given type: {}'.format(type(features)))
ValueError: features should be a dictionary of `Tensor`s. Given type: <class 'tensorflow.python.framework.ops.Tensor'>
【问题讨论】:
-
我遇到了一个非常相似的问题,你发现你的输入功能有什么问题了吗?
-
我已经修好了,非常感谢!
-
你能发布解决方案吗?
-
我发了,请看。
标签: python-3.x tensorflow tfrecord