【发布时间】:2019-03-12 21:36:35
【问题描述】:
我正在尝试提供从 2 个经过微调的 VGG16(每个都在不同的流上)提取的特征,然后对于 9 个数据对的每个序列,连接它们的 numpy 数组并将 9 个输出序列(连接)提供给Keras 中的双向 LSTM。
问题是我在尝试构建 LSTM 部分时遇到了错误。下面显示了我编写的生成器,用于读取 RGB 和光流,提取特征并连接每一对:
def generate_generator_multiple(generator,dir1, dir2, batch_rgb, batch_flow, img_height,img_width):
print("Processing inside generate multiple")
genX1 = generator.flow_from_directory(dir1,
target_size = (img_height,img_width),
class_mode = 'categorical',
batch_size = batch_rgb,
shuffle=False
)
genX2 = generator.flow_from_directory(dir2,
target_size = (img_height,img_width),
class_mode = 'categorical',
batch_size = batch_flow,
shuffle=False
)
while True:
imgs, labels = next(genX1)
X1i = RGB_model.predict(imgs, verbose=0)
imgs2, labels2 = next(genX2)
X2i = FLOW_model.predict(imgs2,verbose=0)
Xi = []
for i in range(9):
Xi.append(np.concatenate([X1i[i+1],X2i[i]]))
Xi = np.asarray(Xi)
if np.array_equal(labels[1:],labels2)==False:
print("ERROR !! problem of labels matching: RGB and FLOW have different labels")
yield Xi, labels2[2]
我希望生成器产生一个由 9 个数组组成的序列,所以当我强制循环运行两次时 Xi 的形状是:(9, 14, 7, 512)
当我使用 while True(如上面的代码)并尝试调用该方法来检查它返回的内容时,在 3 次迭代后我得到了错误:
ValueError: too many values to unpack (expected 2)
现在,假设生成器没有问题,我尝试将生成器返回的数据提供给双向 LSTM,如下所示:
n_frames = 9
seq = 100
Bi_LSTM = Sequential()
Bi_LSTM.add(Bidirectional(LSTM(seq, return_sequences=True, dropout=0.25, recurrent_dropout=0.1),input_shape=(n_frames,14,7,512)))
Bi_LSTM.add(GlobalMaxPool1D())
Bi_LSTM.add(TimeDistributed(Dense(100, activation="relu")))
Bi_LSTM.add(layers.Dropout(0.25))
Bi_LSTM.add(Dense(4, activation="relu"))
model.compile(Adam(lr=.00001), loss='categorical_crossentropy', metrics=['accuracy'])
但我不断收到以下错误:(错误日志有点长)
InvalidArgumentError: Shape must be rank 4 but is rank 2 for 'bidirectional_2/Tile_1' (op: 'Tile') with input shapes: [?,7,512,1], [2].
好像是这行引起的:
Bi_LSTM.add(Bidirectional(LSTM(seq, return_sequences=True, dropout=0.25, recurrent_dropout=0.1),input_shape=(n_frames,14,7,512)))
我不确定问题是我尝试构建 LSTM 的方式、我从生成器返回数据的方式,还是我定义 LSTM 输入的方式。
非常感谢您提供的任何帮助。
【问题讨论】: