【发布时间】:2019-11-20 03:17:03
【问题描述】:
我正在尝试使用 Tensorflow 2 的 LambdaCallback 在 Tensorboard 中的每个时期显示所有网络权重 (CNN) 的直方图,如下所示:
def log_hist_weights(model,writer):
model = model
writer = writer
def log_hist_weights(epoch, logs):
# predict images
Ws = model.get_weights()
with writer.as_default():
tf.summary.histogram("epoch: " + str(epoch), Ws)
return log_hist_weights
hist_callback = keras.callbacks.LambdaCallback(on_epoch_end=log_hist_weights(baseline_model, file_writer))
但问题是 get_weights() 返回所有网络权重而没有任何名称(例如过滤器权重,
BatchNormalization 权重和其他东西)但我实际上只对 CNN 过滤器权重感兴趣。
如果我能在 Tensorflow 2 中实现 this one 之类的东西,那就太好了。
如何使用 Tensorflow 显示过滤器权重的直方图?
【问题讨论】:
标签: python tensorflow classification conv-neural-network tensorboard