【发布时间】:2021-12-31 04:37:36
【问题描述】:
我想用他们的标签创建一个图像生成器。 首先,从 csv 导入数据,然后使用以下代码映射 43 个类:
label_map = {v:i for i, v in enumerate(classes)}
输出将是这样的:
{'Danger': 4,
'Give Way': 5,
'Hump': 6,
'Left Bend': 7,
'Left Margin': 8,...}
然后将使用以下命令从目录加载图像:
train_images = glob('/Desktop/dataset/resized_train/*')
现在我使用 csv 文件映射标签:
train_labels = df['label'].map(label_map)
现在,当我想用相应的标签显示每张图片时,我不能。
我使用了这个代码:
img = tf.io.read_file(image_path)
img = tf.image.decode_image(img, channels=3)
img.set_shape([None,None,3])
img = tf.image.resize(img, [image_w, image_h])
img = img/255.0
return img
def load_data(image_path, label):
image = read_img(image_path)
return image, label
def data_generator(features,labels):
dataset = tf.data.Dataset.from_tensor_slices((features,labels))
dataset = dataset.shuffle(buffer_size=100)
autotune = tf.data.experimental.AUTOTUNE
dataset = dataset.map(load_data, num_parallel_calls=autotune)
dataset = dataset.batch(batch_size=batch_size)
dataset = dataset.repeat()
dataset = dataset.prefetch(autotune)
return dataset
def show_img(dataset):
plt.figure(figsize=(15,15))
for i in range(8):
for val in dataset.take(1):
img = val[0][i]*255.0
plt.subplot(4,2,i+1)
plt.imshow(tf.cast(img,tf.uint8))
plt.title(val[1][i].numpy())
plt.subplots_adjust(hspace=1)
plt.show()
train_dataset = data_generator(train_images,train_labels)
val_dataset = data_generator(val_images,val_labels)
show_img(train_dataset)
当我运行 show_img 时,它会显示图像,但标签都是 0。
【问题讨论】:
标签: python tensorflow image-processing tensorflow2.0 tensorflow-datasets