【发布时间】:2022-12-25 19:01:35
【问题描述】:
:)
我有一个包含 70 种鸟类的 ~16,000 .wav 记录的数据集。 我正在使用 tensorflow 训练一个模型,以使用基于卷积的架构对这些录音的梅尔频谱图进行分类。
使用的架构之一是下面描述的简单多层卷积。 预处理阶段包括:
- 提取梅尔频谱图并转换为 dB 标度
- 将音频分段为 1 秒段(如果残差超过 250 毫秒,则用零或高斯噪声填充,否则丢弃)
- 训练数据的 z-score 归一化 - 减少均值并将结果除以标准差
推理时的预处理:
- 同上
- z-score normalization BY training data - 减少均值(训练)并将结果除以 std(训练数据)
我知道输出层的 sigmoid 激活概率不会累积到 1,但我得到很多 (8-10) 个非常高的预测 (~0.999) 概率。有些恰好是 0.5。
目前的测试集正确分类率是~84%,用10折交叉验证测试,所以看起来网络大部分运行良好。
笔记: 1.我知道不同鸟类的发声有相似的特征,但接收到的概率似乎没有正确反映它们 2. 例如概率 - 自然噪音的记录: 自然噪音:0.999 绿头鸭 - 0.981
我试图了解这些结果的原因,如果它与数据等广泛的错误标记(可能不是)或来自其他来源有关。
任何帮助都感激不尽! :)
编辑:我使用 sigmoid 是因为所有类的概率都是必要的,我不需要它们累加到 1。
def convnet1(input_shape, numClasses, activation='softmax'): # Define the network model = tf.keras.Sequential() model.add(InputLayer(input_shape=input_shape)) # model.add(Augmentations1(p=0.5, freq_type='mel', max_aug=2)) model.add(Conv2D(64, (3, 3), activation='relu', padding='same')) model.add(BatchNormalization()) model.add(MaxPooling2D(pool_size=(2, 1))) model.add(Conv2D(128, (3, 3), activation='relu', padding='same')) model.add(BatchNormalization()) model.add(MaxPooling2D(pool_size=(2, 1))) model.add(Conv2D(128, (5, 5), activation='relu', padding='same')) model.add(BatchNormalization()) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Conv2D(256, (5, 5), activation='relu', padding='same')) model.add(BatchNormalization()) model.add(Flatten()) # model.add(Dense(numClasses, activation='relu')) model.add(Dropout(0.2)) model.add(Dense(numClasses, activation='sigmoid')) model.compile( loss='categorical_crossentropy', metrics=['accuracy'], optimizer=optimizers.Adam(learning_rate=0.001), run_eagerly=False) # this parameter allows to debug and use regular functions inside layers: print(), save() etc.. return model
【问题讨论】:
-
输出端的激活应该是 softmax,而不是 sigmoid。
-
@Dr.Snoopy 谢谢,我使用 sigmoid 激活,因为除了分类之外,我还需要其他类别的概率来理解相似之处。
-
不,这不是你使用的损失期望 softmax 输出的工作原理。
-
@Dr.Snoopy 那么,如果我想对所有其他类别进行相似性估计,我应该使用哪种损失?
-
据我了解,Softmax 提供了这些概率。它们加起来为 1。所以如果你想调查相似性,你应该使用它。
标签: python tensorflow keras deep-learning sigmoid