【问题标题】:Balance Dataset for Tensorflow Object Detection用于 TensorFlow 对象检测的平衡数据集
【发布时间】:2018-11-12 12:03:47
【问题描述】:

我目前想使用 Tensorflows 对象检测 API 来解决我的自定义问题。 我已经创建了数据集,但它非常不平衡。 数据集有 3 个类,我的主要问题是,一个类有大约 16k 个样本,而另一个类只有大约 2.5k 个样本。

所以我认为我必须平衡数据集。有人告诉我,有一种叫做样本/类权重的东西(不确定这是否 100% 正确),它可以平衡训练样本,因此最大的类对训练的影响比最小的类要小。

我找不到这种平衡方法。有人可以给我一个提示从哪里开始吗?

谢谢!

【问题讨论】:

标签: tensorflow dataset object-detection


【解决方案1】:

你可以做正常的交叉熵,给你一个? x 1 张量,X 个损失

如果你想让班级数 N 多计算 T 次,你可以这样做

X = X * tf.reduce_sum(tf.multiply(one_hot_label, class_weight), axis = 1)

tf.multiply

按你想要的任何重量缩放标签,

tf.reduce_sum

将标签向量 a 转换为标量,所以你最终得到 a ? x 1 张量填充了类权重。然后,您只需将损失的张量乘以权重的张量即可获得所需的结果。

由于一个类的常见度是另一个类的 6.4 倍,因此我会将权重 1 和 6.4 分别应用于更常见和不太常见的类。这意味着每次出现较不常见的类时,它的影响是较常见的类的 6.4 倍,所以就像从每个类中看到相同数量的样本一样。

您可能需要对其进行修改,以使权重加起来等于类的数量。这匹配默认情况是所有权重都是 1。在这种情况下,我们有 1 /7.4 和 6.4/7.4

【讨论】:

  • 感谢您的回答!据我了解,那是用于编辑网络源文件。你知道在哪里编辑对象检测 API 吗?或者是否有其他方法可以为对象检测 api 执行此操作?
  • 抱歉,不熟悉那个 API。但是,如果您不想在内部进行太多操作,您也可以构建/修改数据集,以便对频率较低的类进行过采样。
猜你喜欢
  • 2019-01-25
  • 2020-06-21
  • 2019-01-13
  • 2019-07-28
  • 1970-01-01
  • 2016-01-06
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多