【发布时间】:2019-12-20 17:52:04
【问题描述】:
我正在尝试使用 keras 拟合多输入模型。数据由数值和分类组成,所以我定义了两个输入分支,分类使用实体嵌入,数值使用 1D CNN,这是我的数据:
dim(Y_train) = c(1000,1)
dim(X_categorical) = c(1000,20)
dim(X_numerical) = c(1000, 21, 50)
这是我的模型:
library(data.table)
library(keras)
unique_cat_num <- vector()
for(cat in ctgrc_cols){
unique_cat_num = append(unique_cat_num,length(unique(myDT[,get(cat)])))
}
embed_input <- list()
for(cat in ctgrc_vars){
embed_input = append(embed_input,layer_input(shape = 1, name = paste0("input_",cat)))
}
embed_dim_vec <- vector()
for(i in 1:length(ctgrc_vars)){
embed_dim_vec = append(embed_dim_vec,min((unique_cat_len[i] + 1) %/% 2, 50))
}
embed_out <- list()
for(i in 1:length(ctgrc_vars)){
layout = layer_embedding(object = embed_input[[i]], input_dim = unique_cat_len[i] + 1, output_dim = embed_dim_vec[i], name = paste0("embedding_",ctgrc_vars[i]))
layout = layer_flatten(layout)
embed_out = append(embed_out, layout)
}
Conv_input <- layer_input(shape = c(21, 50))
conv_flat <- Conv_input %>%
layer_conv_1d(filters = 64, kernel_size = 3, activation = "relu") %>%
layer_conv_1d(filters = 64, kernel_size = 3, activation = "relu") %>%
layer_max_pooling_1d() %>%
layer_conv_1d(filters = 32, kernel_size = 3, activation = "relu") %>%
layer_conv_1d(filters = 32, kernel_size = 3, activation = "relu") %>%
layer_global_max_pooling_1d()%>%
layer_dropout(0.3) %>%
layer_dense(units = 32, activation = "relu") %>%
layer_flatten()
output <- layer_concatenate(list(embed_out, conv_flat)) %>%
layer_dense(units = 32, activation = "relu") %>%
layer_dropout(0.3) %>%
layer_dense(units = out_neuron, activation = "sigmoid")
model <- keras_model(inputs = list(embed_input, Conv_input), outputs = output)
当我尝试尝试该模型时,我收到此值错误:
ValueError: A `Concatenate` layer requires inputs with matching shapes except for the concat axis. Got inputs shapes: [[(None, 14), (None, 2), (None, 6), (None, 3), (None, 2), (None, 2), (None, 17), (None, 4), (None, 5), (None, 5), (None, 5), (None, 5), (None, 5), (None, 5), (None, 5), (None, 5),
Called from: py_call_impl(callable, dots$args, dots$keywords)
有什么建议吗?
【问题讨论】:
标签: r python-3.x keras