【发布时间】:2018-10-13 01:57:15
【问题描述】:
我正在尝试在自己的图像和标签上训练 Tensorflow 官方 resnet 模型 (link)。
我创建了imagenet_main.py (my_data_main.py) 的副本,我在其中更改了与数据集相关的硬编码值,如下所示(我现在只是想让它适用于很少的图像):
"""Runs a ResNet model on the ImageNet dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import imagenet_preprocessing
from official.resnet import resnet_model
from official.resnet import resnet_run_loop
_DEFAULT_IMAGE_SIZE = 299
_NUM_CHANNELS = 3
_NUM_CLASSES = 9
# TODO: generate dynamically:
_NUM_IMAGES = {
'train': 484,
'validation': 121,
}
_NUM_TRAIN_FILES = 2
_NUM_VAL_FILES = 2
_SHUFFLE_BUFFER = 200
###############################################################################
# Data processing
###############################################################################
def get_filenames(is_training, data_dir):
"""Return filenames for dataset."""
if is_training:
return [
os.path.join(data_dir, 'my_data_train_%05d-of-%05d.tfrecord' % (i, _NUM_TRAIN_FILES))
for i in range(_NUM_TRAIN_FILES)]
else:
return [
os.path.join(data_dir, 'my_data_validation_%05d-of-%05d.tfrecord' % (i, _NUM_VAL_FILES))
for i in range(_NUM_VAL_FILES)]
# rest of program unchanged
为了加载我的数据,我为我添加到 data_dir 目录~/Projects/my_data/data/images/ 的 train 和 eval 创建了 TFRecords。
然后我启动程序:
python3 my_data_main.py \
--data_dir ~/Projects/my_data/data/images/ \
--model_dir /tmp/tests \
--export_dir /tmp/exports \
--train_epochs 10 \
--max_train_steps 200 \
--epochs_between_evals 1 \
--batch_size 256 \
--multi_gpu \
--hooks LoggingTensorHook \
--num_parallel_calls 12 \
--inter_op_parallelism_threads 0 \
--intra_op_parallelism_threads 0 \
--dtype fp32 \
--export_dir /tmp/resnet \
--version 1 \
--resnet_size 18
问题:图像被正确加载以用于训练,但不能用于评估。 def resnet_model_fn开头的resnet_run_loop.py中的以下行将图像加载到Tensorboard中:
# Generate a summary node for the images
tf.summary.image('images', features, max_outputs=6)
我可以看到我的火车运行图像 但不适用于 Tensorboard 中的 eval
我检查的内容:
我检查了我的 TFRecords 是否已成功读取。
我查看了estimator.py 并打印了来自_get_features_and_labels_from_input_fn 的张量形状(在_evaluate_model 中调用)。我找不到任何问题。
我还没有做的事情:
我目前正在下载完整的 imagenet 数据,试图找出他们准备数据的方式的不同之处。
在写这篇文章之前,我已尽力在网上找到答案。感谢大家的时间。
【问题讨论】:
标签: tensorflow resnet