【问题标题】:visualise filters in keras cnn在 keras cnn 中可视化过滤器
【发布时间】:2018-04-11 12:44:04
【问题描述】:
def build_model(network):
    model = Sequential()
    model.add(Conv2D(6, (5,5), padding='valid', activation = 'relu', kernel_initializer='he_normal', input_shape=(32,32,3)))
    print(np.asarray(model.get_weights())[0].shape)
    model.add(MaxPooling2D((2, 2), strides=(2, 2)))
    model.add(Conv2D(16, (5,5), padding='valid', activation = 'relu', kernel_initializer='he_normal'))
    model.add(MaxPooling2D((2, 2), strides=(2, 2)))

    model.add(Flatten())
    model.add(Dense(120, activation = 'relu', kernel_initializer='he_normal'))
    model.add(Dense(84, activation = 'relu', kernel_initializer='he_normal'))
    model.add(Dense(10, activation = 'softmax', kernel_initializer='he_normal'))
    sgd = optimizers.SGD(lr=learning_rate)
    model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
    return model

这是 Keras 中的模型。我想在第一个 conv 层中可视化过滤器。我想绘制过滤器本身,而不是当我们将梯度一直反向传播到图像时出现的模式。

我找到了一种获取权重的方法 - 使用 model.get_weights()

如何绘制这些权重? np.asarray(model.get_weights())[0] 的形状是 (5,5,3,6)。

如何用它制作六个尺寸为 5x5x3 的滤镜?

【问题讨论】:

  • 你打算如何可视化 3 维数组?
  • 也许是 matplotlib ?
  • 那么你在纠结什么?
  • 形状看起来有点不同。 (5,5,3,6)。我不知道我该如何处理。
  • 我不确定,但如果是 (6,5,5,3),那么我可以将图像拼接在一起。

标签: python tensorflow keras


【解决方案1】:

使用 matplotlib 可以绘制切片。

你应该自己将切片分开,例如:

import matplotlib.pyplot as plt

#normalize these filters first, otherwise they won't be in a suitable range for plotting:

maxVal = filters.max()
minVal = filters.min()
absMax = max(abs(minVal),abs(maxVal))

filters = (filters / absMax)*255


for outputChannel in range(6):
    for inputChannel in range(3):
        filt = filters[:,:,inputChannel,outputChannel]

      #a trick to see negatives as blue and positives as red
        imageRed = np.array(filt)
        imageBlue = np.array(filt)
        imageRed[imageRed<0] = 0
        imageBlue[imageBlue>0]= 0

        print(imageRed)
        print(imageBlue)

        finalImage = np.zeros((filt.shape[0],filt.shape[1],3))
        finalImage[:,:,0] = imageRed
        finalImage[:,:,2] = -imageBlue

        #plot image here
        plt.figure()  
        plt.imshow(finalImage)          

【讨论】:

  • filt = filters[:,:,inputChannel,outputChannel] 我猜 outputChannel 应该是 outputFilter ?
  • 得到这个错误:filters = int((filters / absMax) * 255) TypeError: only length-1 arrays can be converted to Python scalars
  • 当然:),对不起。
  • 刚刚删除了“int”部分。也许您需要finalImage.astype(np.uint8) 进行绘图。
  • 我更正了imageBlue 中的错误符号。
猜你喜欢
  • 1970-01-01
  • 2019-07-24
  • 2017-11-08
  • 2018-08-03
  • 2023-03-26
  • 1970-01-01
  • 2023-04-08
  • 2018-02-14
  • 1970-01-01
相关资源
最近更新 更多