【问题标题】:Tensorflow 2 displaying a histogram of weightsTensorflow 2 显示权重直方图
【发布时间】:2019-11-20 03:17:03
【问题描述】:

我正在尝试使用 Tensorflow 2 的 LambdaCallback 在 Tensorboard 中的每个时期显示所有网络权重 (CNN) 的直方图,如下所示:

def log_hist_weights(model,writer):
    model = model
    writer = writer
    
    def log_hist_weights(epoch, logs):
        # predict images
        Ws = model.get_weights()
        with writer.as_default():
            tf.summary.histogram("epoch: " + str(epoch), Ws)
    return log_hist_weights

hist_callback = keras.callbacks.LambdaCallback(on_epoch_end=log_hist_weights(baseline_model, file_writer))

但问题是 get_weights() 返回所有网络权重而没有任何名称(例如过滤器权重, BatchNormalization 权重和其他东西)但我实际上只对 CNN 过滤器权重感兴趣。

如果我能在 Tensorflow 2 中实现 this one 之类的东西,那就太好了。

如何使用 Tensorflow 显示过滤器权重的直方图?

【问题讨论】:

    标签: python tensorflow classification conv-neural-network tensorboard


    【解决方案1】:

    对于其他有同样问题的人,这是我最终使用 Tensorflow 2 解决的方法:

    def log_hist_weights(model,writer):
        model = model
        writer = writer
    
        def log_hist(epoch, logs):
            # predict images
            with writer.as_default():
                for tf_var in baseline_model.trainable_weights:
                        tf.summary.histogram(tf_var.name, tf_var.numpy(), step=epoch)
        return log_hist
    
        hist_callback = keras.callbacks.LambdaCallback(on_epoch_end=log_hist_weights(baseline_model, file_writer))
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2018-01-06
      • 2019-03-30
      • 1970-01-01
      • 2015-08-02
      • 2017-07-07
      • 1970-01-01
      • 2016-08-29
      • 2015-10-12
      相关资源
      最近更新 更多