【问题标题】:assign values to an array in a loop in tensorflow在tensorflow的循环中为数组赋值
【发布时间】:2017-08-04 08:54:24
【问题描述】:

我在 tensorflow 中有一个数组,我想根据 for 循环中的另一个数组更新它的值。代码如下:

def get_weights(labels, class_ratio=0.5):
        weights = tf.ones_like(labels, dtype=tf.float64))

        pos_num = class_ratio * 100
        neg_num = 100 - class_ratio * 100
        for i in range(labels.shape[0]):
            if labels[i] == 0:
               weights[i].assign(pos_num/neg_num)
            else:
               weights[i].assign(neg_num)

        return weights

然后我有这段代码来调用上面的函数:

with tf.Graph().as_default():
     labels = tf.placeholder(tf.int32, (5,))
     example_weights = get_weights(labels, class_ratio=0.1)
     with tf.Session() as sess:
          np_labels = np.random.randint(0, 2, 5)
          np_weights = sess.run(example_weights, feed_dict={labels: np_labels})
         print("Labels:  %r" % (np_labels,))
         print("Weights: %r" % (np_weights,))

但是当我运行它时,它给了我这个错误:

ValueError: Sliced assignment is only supported for variables

如何在 tensorflow 中分配/更新数组的值?

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:

    TensorFlow 中的tf.Tensor 是一个只读值——实际上,它是一个用于计算只读值的符号表达式——因此您通常不能为其赋值。 (主要的例外是tf.Variable 对象。)这意味着我们鼓励您使用“函数式”操作来定义您的张量。例如,有几种方法可以函数式地生成 weights 张量:

    • 由于weights 被定义为labels 的元素转换,您可以使用tf.map_fn() 通过映射一个函数来创建一个新张量(包含tf.cond()replace the if statement):

      def get_weights(labels, class_ratio=0.5):
        pos_num = tf.constant(class_ratio * 100)
        neg_num = tf.constant(100 - class_ratio * 100)
      
        def compute_weight(x):
          return tf.cond(tf.equal(x, 0), lambda: pos_num / neg_num, lambda: neg_num)
      
        return tf.map_fn(compute_weight, labels, dtype=tf.float32)
      

      这个版本允许你对labels的每个元素应用任意复杂的函数。

    • 但是,由于该函数简单、计算成本低,并且可以使用简单的 TensorFlow 操作来表示,因此您可以避免使用 tf.map_fn(),而是使用 tf.where()

      def get_weights(labels, class_ratio=0.5):
        pos_num = tf.fill(tf.shape(labels), class_ratio * 100)
        neg_num = tf.fill(tf.shape(labels), 100 - class_ratio * 100)
        return tf.where(tf.equal(labels, 0), pos_num / neg_num, neg_num)
      

      (您也可以在tf.map_fn() 版本中使用tf.where() 代替tf.cond()。)

    【讨论】:

    • 非常感谢您的详细解释!效果很好!
    猜你喜欢
    • 2016-01-17
    • 2017-07-08
    • 1970-01-01
    • 2020-08-15
    • 2017-03-21
    • 2019-06-27
    • 2015-03-01
    • 1970-01-01
    • 2012-11-19
    相关资源
    最近更新 更多