【问题标题】:tf.keras replace lower layer in pretrained resnet50tf.keras 替换预训练 resnet50 中的较低层
【发布时间】:2020-05-14 00:48:28
【问题描述】:

是否可以删除/替换 tf.keras.applications 中预训练的 ResNet50 模型的 BOTTOM 层?

例如,我尝试过这样做:

import tensorflow as tf
pretrained_resnet = tf.keras.applications.ResNet50(include_top=False, weights='imagenet')
inputs = tf.keras.Input(shape=(256,256,1))
x = tf.keras.layers.ZeroPadding2D()(inputs)
x = tf.keras.layers.Conv2D(filters=64,
                           kernel_size=(7,7),
                           strides=(2,2),
                           padding='same')(x)
outputs = pretrained_resnet.layers[3](x)
test = tf.keras.Model(inputs, pretrained_resnet.output)

但它给出了这个错误:ValueError: Graph disconnected: cannot get value for tensor Tensor("input_2:0", .......

我也尝试过使用 tf.keras Sequential API,但这不起作用,因为 ResNet 不是顺序模型。我基本上只是想用一个新的替换 ResNet50 中的第一个 Conv2D 层。这可能吗?还是我必须重写整个 ResNet 模型?

任何建议将不胜感激!

【问题讨论】:

  • ZeroPadding2DConv2D (7*7, 64, stride 2)2nd3rd 层的 Resnet50 网络。您能否确认一下,您是否只想替换第一层,即输入层?如果是,在答案部分,我已经提供了解决方案。谢谢!

标签: tensorflow keras tensorflow2.0 tf.keras


【解决方案1】:

ZeroPadding2DConv2D (7*7, 64, stride 2)Resnet50 网络的 2nd3rd 层。

因此,此处显示仅替换Resnet50 中的第一层(即输入层)

from tensorflow.keras.applications import ResNet50
import tensorflow as tf

model = ResNet50(include_top = False, weights = 'imagenet')
model.save('model.h5')

res50_model = tf.keras.models.load_model('model.h5')
#res50_model.summary()

要从网络中删除第一层,您可以运行如下代码

 res50_model._layers.pop(0)

Resnet50 expects the input must have 3 channels,因此将输入层形状添加为(256,256,3) 而不是(256,256,1)

要添加新的输入层,您可以运行如下代码

newInput = tf.keras.Input(shape=(256,256,3))
newOutputs = res50_model(newInput)
newModel = tf.keras.Model(newInput, newOutputs)
newModel.summary()

输出:

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 256, 256, 3)]     0         
_________________________________________________________________
resnet50 (Model)             multiple                  23587712  
=================================================================
Total params: 23,587,712
Trainable params: 23,534,592
Non-trainable params: 53,120
_________________________________________________________________

【讨论】:

  • 请问您为什么在更换图层之前保存和加载模型?
猜你喜欢
  • 2018-12-21
  • 1970-01-01
  • 1970-01-01
  • 2021-05-20
  • 1970-01-01
  • 1970-01-01
  • 2019-07-04
  • 2021-08-04
  • 2021-02-11
相关资源
最近更新 更多