【问题标题】:CoreML custom layer: Pixelwise Normalization with Metal ShadersCoreML 自定义层:使用金属着色器的像素标准化
【发布时间】:2019-02-25 08:20:36
【问题描述】:

我正在将 Nvidia 的渐进式 GAN 生成器转换为 coreML。我已经设法将所有内容都转移到了 coreML,但像素标准化 (Lambda) 层除外,我计划在 Swift/Metal 中将其实现为自定义 coreML 层。

在 TensorFlow.Keras 中,我将像素范数实现为

def pixelwise_norm(a):
    return a / tf.sqrt(tf.reduce_mean(a * a, axis=3, keep_dims=True) + 1e-8)

现在,我几乎没有使用过着色器/Metal,但按照此处的说明:http://machinethink.net/blog/coreml-custom-layers/,我设置了一个自定义层以使用 Metal 进行前馈操作。我正在使用 MTLComputePipelineState (调用?编码?)图层操作的以下着色器:

#include <metal_stdlib>
using namespace metal;


kernel void pixelwise_norm(
              texture2d_array<half, access::read> inTexture [[texture(0)]],
              texture2d_array<half, access::write> outTexture [[texture(1)]],
              ushort3 gid [[thread_position_in_grid]])
{
    if (gid.x >= outTexture.get_width() ||
        gid.y >= outTexture.get_height()) {
        return;
    }

    const float4 x = float4(inTexture.read(gid.xy, gid.z));
    const float4 y = 0.0000001f + (x / sqrt(pow(x,2)));
    outTexture.write(half4(y), gid.xy, gid.z);
}

我无法确定“reduce_mean”的金属等价物,现在这个着色器实现了一个 ~tensorflow ~ 操作,如

return a / tf.sqrt((a * a) + 1e-8) 

有没有人指点? 谢谢

【问题讨论】:

    标签: c++ tensorflow keras metal coreml


    【解决方案1】:

    如果我没看错的话,对于特征图中的每个像素,这会将该像素除以该像素通道上的 L2 范数吗?

    在这种情况下,您需要使用 for 循环来读取该像素的通道,将这些数字相加,然后除以通道数。 (如果通道数大于4,则只需要执行此循环。)

    另请注意,您的 1e-8 需要在 sqrt() 内或至少在分母内。

    【讨论】:

    • 是的,就是这个想法,根据层在网络中的位置,将有 3 到 512 个通道。正如现在所写,x 指的是(xy 中的单个像素/位置,以及给定 xy 的通道 z),对吗?还是它指的是由 z 索引的特定 xy 平面?对于基本问题,抱歉,这远远超出了我通常的工作领域
    • 在您提供的代码中,x 包含 4 个通道的数据,因为您正在从单个纹理切片(特别是 gid.z 的切片)中读取数据,这为您提供了float4。因此,不是将像素除以所有通道上的范数,而是将每个像素除以这个纹理切片中的 4 个通道上的范数。
    • 知道了,由于金属的工作原理,我无法将 x 作为所有 512 个通道(至少使用 gid.xyzwabc... 格式)。相反,我必须对 inTexture 做一个循环?
    • 是的,这是正确的。因为像素 x 将有 512 个通道,所以您需要从所有 512 / 4 = 128 个纹理切片中读取该像素的值,将它们平方并求和。
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2021-06-29
    • 1970-01-01
    • 1970-01-01
    • 2017-05-03
    • 2023-02-08
    • 1970-01-01
    相关资源
    最近更新 更多