【问题标题】:How to extract data without label from tensorflow dataset如何从张量流数据集中提取没有标签的数据
【发布时间】:2021-05-13 15:17:01
【问题描述】:

我有一个名为 train_ds 的 tf 数据集:

directory = 'Data/dataset_train'

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  directory,
  validation_split=0.2,
  subset="training",
    color_mode='grayscale',
  seed=123,
  image_size=(28, 28),
  batch_size=32)

这个数据集由 20000 张“假”图像和 20000 张“真实”图像组成,我想从这个 tf 数据集中以 numpy 形式提取 X_train 和 y_train,但我只设法用

y_train = np.concatenate([y for x, y in train_ds], axis=0)

我也试过这个,但它似乎没有遍历 20000 张图像:

for images, labels in train_ds.take(-1):  
    X_train = images.numpy()
    y_train = labels.numpy()

我真的想将图像提取到 X_train 并将标签提取到 y_train 但我想不通! 对于我所犯的任何错误,我提前道歉,并感谢我能得到的所有帮助:)

【问题讨论】:

    标签: python tensorflow tensorflow-datasets


    【解决方案1】:

    如果您没有对数据集应用进一步的转换,它将是BatchDataset。您可以创建两个列表来迭代数据集。我总共有 2936 张图片。

    x_train, y_train = [], []
    
    for images, labels in train_ds:
      x_train.append(images.numpy())
      y_train.append(labels.numpy())
    
    np.array(x_train).shape >> (92,)
    

    它正在生成批次。您可以使用np.concatenate 连接它们。

    x_train = np.concatenate(x_train, axis = 0) 
    x_train.shape >> (2936,28,28,3)
    

    或者您可以取消批处理数据集并对其进行迭代:

    for images, labels in train_ds.unbatch():
      x_train.append(images.numpy())
      y_train.append(labels.numpy())
    
    x_train = np.array(x_train)
    x_train.shape >> (2936,28,28,3)
    

    【讨论】:

      【解决方案2】:

      您可以使用 TF Dataset 方法 unbatch() 取消批量数据集,然后您可以轻松地从中检索数据和标签:

      data=[]
      for images, labels in ds.unbatch():
          data.append(images)
      

      【讨论】:

        猜你喜欢
        • 1970-01-01
        • 2020-05-10
        • 1970-01-01
        • 2019-10-07
        • 2022-08-27
        • 2018-09-29
        • 2017-11-02
        • 2021-12-10
        • 2019-03-28
        相关资源
        最近更新 更多