【问题标题】:Can Numba be used with Tensorflow?Numba 可以与 Tensorflow 一起使用吗?
【发布时间】:2017-04-29 04:40:36
【问题描述】:

Numba 可以用来编译与 Tensorflow 接口的 Python 代码吗? IE。 Tensorflow 宇宙之外的计算将使用 Numba 运行以提高速度。我还没有找到有关如何执行此操作的任何资源。

【问题讨论】:

标签: python numpy tensorflow numba


【解决方案1】:

我知道这并不能直接回答您的问题,但它可能是一个不错的选择。 Numba 正在使用即时 (JIT) 编译。因此,您可以按照 TensorFlow 官方文档 here 中的说明,了解如何在 TensorFlow 中使用 JIT(但不在 Numba 生态系统中)。

【讨论】:

    【解决方案2】:

    您可以使用tf.numpy_functiontf.py_func 包装python 函数并将其用作TensorFlow 操作。这是我使用的一个示例:

    @jit
    def dice_coeff_nb(y_true, y_pred):
        "Calculates dice coefficient"
        smooth = np.float32(1)
        y_true_f = np.reshape(y_true, [-1])
        y_pred_f = np.reshape(y_pred, [-1])
        intersection = np.sum(y_true_f * y_pred_f)
        score = (2. * intersection + smooth) / (np.sum(y_true_f) +
                                                np.sum(y_pred_f) + smooth)
        return score
    
    @jit
    def dice_loss_nb(y_true, y_pred):
        "Calculates dice loss"
        loss = 1 - dice_coeff_nb(y_true, y_pred)
        return loss
    
    def bce_dice_loss_nb(y_true, y_pred):
        "Adds dice_loss to categorical_crossentropy"
        loss =  tf.numpy_function(dice_loss_nb, [y_true, y_pred], tf.float64) + \
                tf.keras.losses.categorical_crossentropy(y_true, y_pred)
        return loss
    

    然后我在训练一个 tf.keras 模型时使用了这个损失函数:

    ...
    model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
    model.compile(optimizer='adam', loss=bce_dice_loss_nb)
    

    【讨论】:

      猜你喜欢
      • 2017-08-27
      • 1970-01-01
      • 2016-07-15
      • 2023-04-08
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2011-08-15
      • 2018-02-16
      相关资源
      最近更新 更多