【发布时间】:2017-11-27 01:58:26
【问题描述】:
我正在使用以下代码生成 tfrecords 文件。
def generate_tfrecords(data_path, labels, name):
"""Converts a dataset to tfrecords."""
filename = os.path.join(args.tfrecords_path, name + '.tfrecords')
writer = tf.python_io.TFRecordWriter(filename)
for index, data in enumerate(data_path):
with tf.gfile.GFile(data, 'rb') as fid:
encoded_jpg = fid.read()
print(len(encoded_jpg)) # 17904
encoded_jpg_io = io.BytesIO(encoded_jpg)
image = pil.open(encoded_jpg_io)
image = np.asarray(image)
print(image.shape) # 112*112*3
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(int(image.shape[0])),
'width': _int64_feature(int(image.shape[1])),
'depth': _int64_feature(int(3)),
'label': _int64_feature(int(labels[index])),
'image_raw': _bytes_feature(encoded_jpg)}))
writer.write(example.SerializeToString())
writer.close()
在上面的代码中,encoded_jpg 的长度为17904,而图像的形状为112*112*3,这是不一致的。
当我使用以下代码解析 tfrecords 时:
def _parse_function(example_proto):
features = {'height': tf.FixedLenFeature((), tf.int64, default_value=0),
'width': tf.FixedLenFeature((), tf.int64, default_value=0),
'depth': tf.FixedLenFeature((), tf.int64, default_value=0),
'label': tf.FixedLenFeature((), tf.int64, default_value=0),
'image_raw': tf.FixedLenFeature((), tf.string, default_value="")}
parsed_features = tf.parse_single_example(example_proto, features)
height = tf.cast(parsed_features["height"], tf.int32) # 112
width = tf.cast(parsed_features["width"], tf.int32) # 112
depth = tf.cast(parsed_features["depth"], tf.int32) #3
label = parsed_features['label']
img = tf.decode_raw(parsed_features['image_raw'], tf.uint8, little_endian=True)
img = tf.reshape(img, [height, width, depth])
return img, label
当我使用上面的代码时,我得到了以下错误:
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape is a tensor with 17904 values, but the requested shape has 37632
[[Node: Reshape = Reshape[T=DT_UINT8, Tshape=DT_INT32](DecodeRaw, Reshape/shape)]]
[[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?,?,?,?], [?]], output_types=[DT_UINT8, DT_INT64], _device="/job:localhost/replica:0/task:0/device:CPU:0"](Iterator)]]
我该如何解决这个问题。图像类型为png 和37632=112*112*3。谢谢!
【问题讨论】:
-
你确定图片类型是PNG吗?代码提到了
jpg,看起来encoded_jpg是一个原始图像。错误消息表明其中一张原始图像的大小与height和depth字段中编码的值的大小不同...您的数据集中是否可能有不同大小的图像?
标签: tensorflow tensorflow-datasets