【问题标题】:TensorFlow: alternate between datasets of different output shapesTensorFlow:在不同输出形状的数据集之间交替
【发布时间】:2018-08-24 04:18:52
【问题描述】:

我正在尝试将tf.Dataset 用于 3D 图像 CNN,其中从训练集和验证集输入的 3D 图像的形状不同(训练:(64、64、64),验证: (176、176、160))。我什至不知道这是可能的,但我正在根据一篇论文重新创建这个网络,并使用经典的 feed_dict 方法网络确实有效。出于性能原因(并且只是为了学习),我正在尝试将网络切换为使用 tf.Dataset

我有两个数据集和迭代器,如下所示:

def _data_parser(dataset, shape):
        features = {"input": tf.FixedLenFeature((), tf.string),
                    "label": tf.FixedLenFeature((), tf.string)}
        parsed_features = tf.parse_single_example(dataset, features)

        image = tf.decode_raw(parsed_features["input"], tf.float32)
        image = tf.reshape(image, shape + (1,))

        label = tf.decode_raw(parsed_features["label"], tf.float32)
        label = tf.reshape(label, shape + (1,))
        return image, label

train_datasets = ["train.tfrecord"]
train_dataset = tf.data.TFRecordDataset(train_datasets)
train_dataset = train_dataset.map(lambda x: _data_parser(x, (64, 64, 64)))
train_dataset = train_dataset.batch(batch_size) # batch_size = 16
train_iterator = train_dataset.make_initializable_iterator()

val_datasets = ["validation.tfrecord"]
val_dataset = tf.data.TFRecordDataset(val_datasets)
val_dataset = val_dataset.map(lambda x: _data_parser(x, (176, 176, 160)))
val_dataset = val_dataset.batch(1)
val_iterator = val_dataset.make_initializable_iterator()

TensorFlow documentation 有关于使用reinitializable_iteratorfeedable_iterator 在数据集之间切换的示例,但它们都在相同输出形状的迭代器之间切换,这里不是这种情况。

在我的情况下,我应该如何使用tf.Datasettf.data.Iterator 在训练集和验证集之间切换?

【问题讨论】:

    标签: python tensorflow tensorflow-datasets


    【解决方案1】:

    只需为尺寸不匹配的轴上的形状提供未指定的 (None) 值。例如

    import numpy as np
    import tensorflow as tf
    
    training_dataset = tf.data.Dataset.from_tensors(np.zeros((64, 64, 64), np.float32)).repeat().batch(4)
    validation_dataset = tf.data.Dataset.from_tensors(np.zeros((176, 176, 160), np.float32)).repeat().batch(1)
    
    iterator = tf.data.Iterator.from_structure(
        training_dataset.output_types,
        tf.TensorShape([None, None, None, None]))
    next_element = iterator.get_next()
    
    training_init_op = iterator.make_initializer(training_dataset)
    validation_init_op = iterator.make_initializer(validation_dataset)
    
    sess = tf.InteractiveSession()
    sess.run(training_init_op)
    print(sess.run(next_element).shape)
    sess.run(validation_init_op)
    print(sess.run(next_element).shape)
    
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2018-09-19
      • 2022-08-16
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2020-05-29
      相关资源
      最近更新 更多