【问题标题】:Map RGB Semantic Maps to One Hot Encodings and vice versa in TensorFlow在 TensorFlow 中将 RGB 语义映射映射到一个热编码,反之亦然
【发布时间】:2017-10-24 06:47:50
【问题描述】:

下图是来自Cityscapes Dataset 的示例语义图。它以 RGB 图像的形式提供,其中每种特定颜色代表一个类。

在一些深度学习任务中,我们希望将其映射为 one hot 编码。例如,如果它有 20 个类,则该图像将从 H x W x 3 映射到 H x W x 20

我们如何在 TensorFlow 中做到这一点?

【问题讨论】:

    标签: tensorflow


    【解决方案1】:

    我的解决方案如下。期待有关如何提高效率的建议,或者更有效的答案。

    import tensorflow as tf
    import numpy as np
    import scipy.misc
    
    img = scipy.misc.imread('aachen_000000_000019_gtFine_color.png', mode = 'RGB')
    palette = np.array(
    [[128,  64, 128],
     [244,  35, 232],
     [ 70,  70,  70],
     [102, 102, 156],
     [190, 153, 153],
     [153, 153, 153],
     [250, 170,  30],
     [220, 220,   0],
     [107, 142,  35],
     [152, 251, 152],
     [ 70, 130, 180],
     [220,  20,  60],
     [255,   0,   0],
     [  0,   0, 142],
     [  0,   0,  70],
     [  0,  60, 100],
     [  0,  80, 100],
     [  0,   0, 230],
     [119,  11,  32],
     [  0,   0,   0],
     [255, 255, 255]], np.uint8)
    
    semantic_map = []
    for colour in palette:
      class_map = tf.reduce_all(tf.equal(img, colour), axis=-1)
      semantic_map.append(class_map)
    semantic_map = tf.stack(semantic_map, axis=-1)
    # NOTE cast to tf.float32 because most neural networks operate in float32.
    semantic_map = tf.cast(semantic_map, tf.float32)
    magic_number = tf.reduce_sum(semantic_map)
    print semantic_map.shape
    
    palette = tf.constant(palette, dtype=tf.uint8)
    class_indexes = tf.argmax(semantic_map, axis=-1)
    # NOTE this operation flattens class_indexes
    class_indexes = tf.reshape(class_indexes, [-1])
    color_image = tf.gather(palette, class_indexes)
    color_image = tf.reshape(color_image, [1024, 2048, 3])
    
    sess = tf.Session()
    # NOTE magic_number checks that there are only 1024*2048 1s in the entire
    # 1024*2048*21 tensor.
    magic_number_val = sess.run(magic_number)
    assert magic_number_val == 1024*2048
    color_image_val = sess.run(color_image)
    scipy.misc.imsave('test.png', color_image_val)
    

    【讨论】:

    • 你可以使用 tf.gather 函数 (tensorflow.org/api_docs/python/tf/gather) 加速你的代码在 'color_image' : tf.gather(palette, class_indexes) 然后像你一样重塑。
    • 我已经进行了更改并对其进行了测试。它确实更快。感谢您的建议!
    • 你能像这样分享你正在使用的图像我可以测试其他东西吗? :)
    • 我不能根据许可共享它。您必须从 Cityscapes 下载它。
    • 那么你真正想做的是,从这张图片 [H, W, 3] 创建一个热矩阵 [H, W, 21],对吧?
    猜你喜欢
    • 2011-06-24
    • 2022-06-11
    • 2021-10-25
    • 2018-02-05
    • 2017-11-17
    • 1970-01-01
    • 1970-01-01
    • 2011-06-15
    相关资源
    最近更新 更多