【发布时间】:2020-08-26 12:02:09
【问题描述】:
我想知道我应该如何解释 keras 库的模型摘要的以下结果。 以下结果来自 keras 版本 2.3.1。
在keras中,我们可以设置layer的trainable属性,使其权重在训练过程中不发生变化。
from keras.models import Sequential
from keras.layers import Dense
model = Sequential([
Dense(5, input_dim=3), Dense(1)
])
model.summary()
print("***")
model.layers[0].trainable = False
model.summary()
Model: "sequential_36"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_101 (Dense) (None, 5) 20
_________________________________________________________________
dense_102 (Dense) (None, 1) 6
=================================================================
Total params: 26
Trainable params: 26
Non-trainable params: 0
_________________________________________________________________
***
Model: "sequential_36"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_101 (Dense) (None, 5) 20
_________________________________________________________________
dense_102 (Dense) (None, 1) 6
=================================================================
Total params: 26
Trainable params: 6
Non-trainable params: 20
上面的结果很直观,因为我将第一层设置为不可训练,我们的可训练参数较少。
如果我在更改属性之前在编译模型(这不是标准的,但在某些应用程序中可能会发生),我会得到以下结果。
model = Sequential([
Dense(5, input_dim=3), Dense(1)
])
model.compile(loss="mse", optimizer="adam")
model.summary()
print("***")
model.layers[0].trainable = False
model.summary()
Model: "sequential_38"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_105 (Dense) (None, 5) 20
_________________________________________________________________
dense_106 (Dense) (None, 1) 6
=================================================================
Total params: 26
Trainable params: 26
Non-trainable params: 0
_________________________________________________________________
***
Model: "sequential_38"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_105 (Dense) (None, 5) 20
_________________________________________________________________
dense_106 (Dense) (None, 1) 6
=================================================================
Total params: 46
Trainable params: 26
Non-trainable params: 20
这表示参数比以前更多。有人可以澄清这些数字应该如何解释吗?
[编辑]
从收到的答案来看,这似乎是一个错误功能,其行为取决于包版本。这是我从 tensorflow keras API 获得的另一个示例。与@lukasz-tracewski 的答案不同,我仍然获得相同数量的参数和不同的警告消息。也许版本略有不同?
import tensorflow as tf
print("tensorflow version is", tf.__version__)
print("keras version is", tf.keras.__version__)
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
model = Sequential([
Dense(5, input_dim=3), Dense(1)
])
model.compile(loss="mse", optimizer="adam")
model.summary()
print("***")
model.layers[0].trainable = False
model.summary()
tensorflow version is 2.1.0
keras version is 2.2.4-tf
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 5) 20
_________________________________________________________________
dense_1 (Dense) (None, 1) 6
=================================================================
Total params: 26
Trainable params: 26
Non-trainable params: 0
_________________________________________________________________
***
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 5) 20
_________________________________________________________________
dense_1 (Dense) (None, 1) 6
=================================================================
WARNING:tensorflow:Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?
Total params: 46
Trainable params: 26
Non-trainable params: 20
【问题讨论】:
标签: python tensorflow keras