【发布时间】:2022-12-30 02:03:19
【问题描述】:
我正在尝试加载在 tensorflow 2.9.1 中实现的模型的权重,但失败了
我使用model.save_weights("./saved_model/model")保存了模型
并使用加载模型
model = DepthEstimationModel()
model.load_weights(os.path.join("saved_model", "model"))
model.compile(optimizer, loss=cross_entropy)
model.built = True
model.summary()
这表明
Model: "depth_estimation_model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
downscale_block (DownscaleB multiple 0 (unused)
lock)
downscale_block_1 (Downscal multiple 0 (unused)
eBlock)
downscale_block_2 (Downscal multiple 0 (unused)
eBlock)
downscale_block_3 (Downscal multiple 0 (unused)
eBlock)
bottle_neck_block (BottleNe multiple 0 (unused)
ckBlock)
upscale_block (UpscaleBlock multiple 0 (unused)
)
upscale_block_1 (UpscaleBlo multiple 0 (unused)
ck)
upscale_block_2 (UpscaleBlo multiple 0 (unused)
ck)
upscale_block_3 (UpscaleBlo multiple 0 (unused)
ck)
conv2d_18 (Conv2D) multiple 0 (unused)
=================================================================
Total params: 2
Trainable params: 0
Non-trainable params: 2
这表明参数未正确加载。
如果删除model.built = True,它会输出:
This model has not yet been built. Build the model first by calling `build()` or by calling the model on a batch of data.
参考:Tensorflow 2.0 ValueError while Loading weights from .h5 file
【问题讨论】:
-
您是否尝试过以与保存权重相同的方式使用路径加载权重?此外,您是否尝试过运行 model.build() 或使用评估数据集评估模型(编译后)?
-
@LucaKnaack 使用
model.evaluate()确实有效,谢谢
标签: python tensorflow