【问题标题】:Nested while loop in tensorflow张量流中的嵌套while循环
【发布时间】:2018-05-08 13:50:33
【问题描述】:

我正在尝试在 keras 中实现损失函数,例如以下伪代码

for i in range(N):
    for j in range(N):
        sum += some_calculations

但我读到张量流不支持这种 for 循环,因此我从 here 了解了 while_loop(cond, body, loop_vars) 函数

我在这里了解了 while 循环的基本工作原理,因此我实现了以下代码:

def body1(i):
    global data
    N = len(data)*positive_samples     //Some length
    j = tf.constant(0)    //iterators
    condition2 = lambda j, i :tf.less(j, N)   //one condition only j should be less than N
    tf.add(i, 1)   //increment previous index i
    result = 0

    def body2(j, i):
        global similarity_matrix, U, V
        result = (tf.transpose(U[:, i])*V[:, j])   //U and V are 2-d tensor Variables and here only a column is extracted and their final product is a single value
        return result

    tf.while_loop(condition2, body2, loop_vars=[j, i])
    return result


def loss_function(x):
    global data
    N = len(data)*positive_samples
    i = tf.constant(0)
    condition1 =  lambda i : tf.less(i, N)
    return tf.while_loop(condition1, body1, [i])

但是当我运行这段代码时,我得到了一个错误

ValueError: The two structures don't have the same number of elements. First structure: [<tf.Tensor 'lambda_1/while/while/Identity:0' shape=() dtype=int32>, <tf.Tensor 'lambda_1/while/while/Identity_1:0' shape=() dtype=int32>], second structure: [0]

【问题讨论】:

    标签: tensorflow keras tensor


    【解决方案1】:

    tf.while_loop 可能很难使用,请务必仔细阅读文档。主体的返回值必须与循环变量具有相同的结构,tf.while_loop 操作的返回值是变量的最终值。为了进行计算,您应该传递一个额外的循环变量来存储部分结果。你可以这样做:

    def body1(i, result):
        global data
        N = len(data) * positive_samples
        j = tf.constant(0)
        condition2 = lambda j, i, result: tf.less(j, N)
        result = 0
    
        def body2(j, i, result):
            global similarity_matrix, U, V
            result_j = (tf.transpose(U[:, i]) * V[:, j])
            return j + 1, i, result + result_j
    
        j, i, result = tf.while_loop(condition2, body2, loop_vars=[j, i, result])
        return i + 1, result
    
    def loss_function(x):
        global data
        N = len(data)*positive_samples
        i = tf.constant(0)
        result = tf.constant(0, dtype=tf.float32)
        condition1 = lambda i, result: tf.less(i, N)
        i, result = tf.while_loop(condition1, body1, [i, result])
        return result
    

    从您的代码中不清楚在哪里使用x。但是,在这种情况下,操作的结果应该等于:

    result = tf.reduce_sum(tf.linalg.matmul(U, V, transpose_a=True))
    

    这也会更快。

    【讨论】:

      【解决方案2】:

      这里是另一个使用tf.while_loop 的 TensorFlow 嵌套循环示例。 在这一个中,张量 x 的第 i 个元素由张量 v 的第 i 个元素中给定的次数迭代连接。

      import tensorflow as tf
      x = tf.Variable([[1,1],[2,2],[3,3]])
      v = tf.constant([1,2,3])
      i = tf.constant(0)
      a_combined = tf.zeros([0, 2], dtype=tf.int32)
      
      
      def body(x,v,i,a_combined):
          x_slice = tf.slice(x,[i,0], [1, x.shape[1]])
          v_slice = tf.slice(v,[i],[1])
          j = tf.constant(0)
          b_combined = tf.zeros([0, 2], dtype=tf.int32)
          
          print("i: ", i)
          
          def body_supp(x_slice,v_slice,j, b_combined):
              
              print("j: ", j)
              
              j = tf.add(j,1)
              b_combined = tf.concat([b_combined,x_slice],0)
              return x_slice, v_slice, j, b_combined 
          
          while_condition_supp = lambda x_slice, v_slice, j, b_combined: tf.less(j, v_slice)
          
          x_slice, v_slice, j, b_combined = tf.while_loop(while_condition_supp, body_supp, [x_slice, v_slice, j, b_combined])
          
          i = tf.add(i,1)
      
          a_combined = tf.concat([a_combined,b_combined],0)
          return x, v, i, a_combined
      
      while_condition = lambda x, v, i, a_combined: i < v.shape[0]  
      
      x, v, i, a_combined = tf.while_loop(while_condition, body, [x, v, i, a_combined])
      
      a_combined 
      

      输出如下所示:

      <tf.Tensor: shape=(6, 2), dtype=int32, numpy=
      array([[1, 1],
             [2, 2],
             [2, 2],
             [3, 3],
             [3, 3],
             [3, 3]], dtype=int32)>

      【讨论】:

        猜你喜欢
        • 1970-01-01
        • 1970-01-01
        • 2013-10-26
        • 2011-02-19
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        相关资源
        最近更新 更多