【问题标题】:Tensorflow tfrecords won't allow setting shape?Tensorflow tfrecords 不允许设置形状?
【发布时间】:2017-07-12 03:49:26
【问题描述】:

情况

我正在尝试将图像数据存储在 tfrecords 中。

详情

图像具有形状 (256,256,4) 和标签 (17)。看来tfrecords保存正确(可以成功解码height和width属性)

问题

当我使用会话测试从 tfrecord 中提取图像和标签时,会引发错误。标签形状似乎有些不对劲

错误信息

INFO:tensorflow:Error 报告给 Coordinator: 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>,>reshape 的输入是一个有 34 个值的张量,但请求的形状有 17 个 [[节点:Reshape_4 = Reshape[T=DT_INT32, Tshape=DT_INT32, >_device="/job:localhost/replica:0/task:0/cpu:0"](DecodeRaw_5, >Reshape_4/shape)]]

代码

注意:我对第一部分非常有信心,因为它是直接从 tensorflow 文档示例中复制而来的

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

"""Converts a dataset to tfrecords."""
# Open files
train_filename = os.path.join('./data/train.tfrecords')
validation_filename = os.path.join('./data/validation.tfrecords')

# Create writers
train_writer = tf.python_io.TFRecordWriter(train_filename)
# validation_writer = tf.python_io.TFRecordWriter(validation_filename)

for i in range(200):
    label = y[i]
    img = io.imread(TRAINING_IMAGES_DIR + '/train_' + str(i) + '.tif')

    example = tf.train.Example(features=tf.train.Features(feature={
        'width': _int64_feature([img.shape[0]]),
        'height': _int64_feature([img.shape[1]]),
        'channels': _int64_feature([img.shape[2]]),
        'label': _bytes_feature(label.tostring()),
        'image': _bytes_feature(img.tostring())
    }))

#     if i in validation_indices:    
#         validation_writer.write(example.SerializeToString())
#     else:
    train_writer.write(example.SerializeToString())

train_writer.close()
# validation_writer.close()

错误部分。请注意,特别奇怪的是,如果我将 reshape 函数更改为 [34],我仍然会得到相同的错误。

data_path = './data/train.tfrecords'

with tf.Session() as sess:
    feature = {'image': tf.FixedLenFeature([], tf.string),
               'label': tf.FixedLenFeature([], tf.string)}

    # Create a list of filenames and pass it to a queue
    filename_queue = tf.train.string_input_producer([data_path], num_epochs=1)

    # Define a reader and read the next record
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)

    # Decode the record read by the reader
    features = tf.parse_single_example(serialized_example, features=feature)

    # Convert the image data from string back to the numbers
    image = tf.decode_raw(features['image'], tf.float32)

    # Cast label data into int32
    label = tf.decode_raw(features['label'], tf.int8)

    # Reshape image data into the original shape
    image = tf.reshape(image, [256, 256, 4])
    label = tf.reshape(label, [17])

    # Any preprocessing here ...

    # Creates batches by randomly shuffling tensors
    images, labels = tf.train.shuffle_batch([image, label], batch_size=1, capacity=20, num_threads=1, min_after_dequeue=10)

    # Initialize all global and local variables
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    sess.run(init_op)

    # Create a coordinator and run all QueueRunner objects
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    img, lbl = sess.run([images, labels])
    img

    # Stop the threads
    coord.request_stop()

    # Wait for threads to stop
    coord.join(threads)

    sess.close()

【问题讨论】:

    标签: python machine-learning tensorflow training-data


    【解决方案1】:

    如果您的标签在以字节为单位保存在 tfrecords 中之前为 tf.int16,则可能会出现此问题。因此,当您阅读时,tf.int8 它的数字是您预期的两倍。因此,您可以确保您的标签正确写入:label = tf.cast(y[i], tf.int8) 在您的 tfrecords 转换代码中。

    【讨论】:

    • 我尝试在保存之前将标签和图像都转换为 np.int32,并将读数都转换为 tf.int32,但没有运气。我仍然收到INFO:tensorflow:Error reported to Coordinator: <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>, Input to reshape is a tensor with 17 values, but the requested shape has 34 [[Node: Reshape_12 = Reshape[T=DT_INT32, Tshape=DT_INT32, _device="/job:localhost/replica:0/task:0/cpu:0"](DecodeRaw_13, Reshape_12/shape)]]
    • 在这种情况下检查 label (= y[i]) 的大小并确认它有 17 个值。
    • y[0].shape 返回(17,)img.shape 返回(256, 256, 4)
    • 有没有可能这个形状有时候不是17?当你写 tfrecords 时,这个大小不是 17 时,你能放一个assert 吗?
    • 感谢所有这些快速跟进 vijay,我在 tfrecrods 创建中添加了assert len(label) == 17,我没有收到任何错误。
    猜你喜欢
    • 1970-01-01
    • 2015-03-25
    • 2021-03-10
    • 1970-01-01
    • 2012-04-05
    • 1970-01-01
    • 2019-05-17
    • 1970-01-01
    • 2018-04-16
    相关资源
    最近更新 更多