ImageDataGenerator 是一个高级类,它允许从多个源(来自np arrays、来自目录...)产生数据,并且包括用于执行图像增强等的实用函数。
更新
从keras-preprocessing 1.0.4 开始,ImageDataGenerator 带有一个flow_from_dataframe method,它可以解决您的问题。它需要dataframe 和directory 参数定义如下:
dataframe: Pandas dataframe containing the filenames of the
images in a column and classes in another or column/s
that can be fed as raw target data.
directory: string, path to the target directory that contains all
the images mapped in the dataframe.
所以不再需要自己实现。
原答案如下
在您的情况下,使用您描述的数据框,您还可以编写自己的自定义生成器,利用 prepare_data 函数中的逻辑作为更简约的解决方案。最好使用 Keras 的 Sequence 对象来执行此操作,因为它允许使用多处理(如果您使用的是 GPU,这将有助于避免出现瓶颈)。
您可以查看Sequence 对象上的docs,它包含一个实现示例。最终,您的代码将是这样的(这是样板代码,您必须添加诸如 label2int 函数或图像预处理逻辑之类的细节):
from keras.utils import Sequence
class DataSequence(Sequence):
"""
Keras Sequence object to train a model on larger-than-memory data.
"""
def __init__(self, df, batch_size, mode='train'):
self.df = df # your pandas dataframe
self.bsz = batch_size # batch size
self.mode = mode # shuffle when in train mode
# Take labels and a list of image locations in memory
self.labels = self.df['label'].values
self.im_list = self.df['image_name'].tolist()
def __len__(self):
# compute number of batches to yield
return int(math.ceil(len(self.df) / float(self.bsz)))
def on_epoch_end(self):
# Shuffles indexes after each epoch if in training mode
self.indexes = range(len(self.im_list))
if self.mode == 'train':
self.indexes = random.sample(self.indexes, k=len(self.indexes))
def get_batch_labels(self, idx):
# Fetch a batch of labels
return self.labels[idx * self.bsz: (idx + 1) * self.bsz]
def get_batch_features(self, idx):
# Fetch a batch of inputs
return np.array([imread(im) for im in self.im_list[idx * self.bsz: (1 + idx) * self.bsz]])
def __getitem__(self, idx):
batch_x = self.get_batch_features(idx)
batch_y = self.get_batch_labels(idx)
return batch_x, batch_y
您可以像自定义生成器一样传递此对象来训练您的模型:
sequence = DataSequence(dataframe, batch_size)
model.fit_generator(sequence, epochs=1, use_multiprocessing=True)
如下所述,不需要实现洗牌逻辑。在fit_generator() 调用中将shuffle 参数设置为True 就足够了。来自docs:
随机播放:布尔值。是否打乱批次的顺序
每个时代的开始。仅与 Sequence 的实例一起使用
(keras.utils.Sequence)。当 steps_per_epoch 不是时无效
没有。