【问题标题】:How to get input_shape of each layer in a loop?如何在循环中获取每一层的 input_shape?
【发布时间】:2021-07-29 11:58:49
【问题描述】:

我有以下代码尝试使用 ResNET32 模型对 cifar10 数据集进行预测。但是,我正在检索一个错误。

(X_train, y_train), (X_test, y_test) = cifar10.load_data()
X_train = X_train.astype('float32') 
X_test = X_test.astype('float32') 
X_train = X_train / 255.0 
X_test = X_test / 255.0
y_train = tf.keras.utils.to_categorical(y_train) 
y_test = tf.keras.utils.to_categorical(y_test) 
num_classes = y_test.shape[1]
#Prediction on a single image  
test_image1 =image.load_img('cat.jpg',target_size =(32,32))
test_image =image.img_to_array(test_image1) 
test_image =np.expand_dims(test_image, axis =0) 
def profiler(model, test_input):
    data_input = test_input
    for layer in model.layers:
        im_imput = Input( batch_shape=model.get_layer( layer.name ).get_input_shape_at( 0 ) )#error thrown on this line
        im_out = layer( im_imput )
        new_model = tf.keras.models.Model(inputs=im_imput, outputs=im_out )
        total_time = 0
        for i in range(averaging_steps):
          start = time.time()
          out = new_model.predict(data_input)
          end = time.time() - start
          milliseconds = end * 1000
          total_time += milliseconds
        avg_time = total_time / averaging_steps
        data_input = out
times = profiler(model,test_image)

错误跟踪:

---> 29 times = profiler(model,test_image)
8 frames
/usr/local/lib/python3.7/dist-packages/six.py in raise_from(value, from_value)
TypeError: Dimension value must be integer or None or have an __index__ method, got value '(None, 32, 32, 16)' with type '<class 'tuple'>'

【问题讨论】:

  • 您应该包含完整的回溯,而不是其中的一部分,因为没有完整的回溯就无法解释错误。

标签: machine-learning keras neural-network conv-neural-network


【解决方案1】:

print(new_model.summary()) 将打印输入形状以及其他详细信息。

如果需要每一层的输入形状,可以使用

for layer in model.layers:
    print(layer.output_shape)
    print(layer.input_shape)

【讨论】:

  • 感谢您的回复。我需要将变量中的 input_shape 存储在变量中。
  • 编辑了我的答案@ashgoharam
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 2021-10-08
  • 1970-01-01
  • 1970-01-01
  • 2011-12-04
  • 1970-01-01
  • 2019-12-04
  • 2021-12-25
相关资源
最近更新 更多