【发布时间】:2019-12-26 04:34:32
【问题描述】:
简而言之,我想将标量w_ij 组装成对称矩阵W,如下所示:
W[i, j] = w_ij
W[j, i] = w_ij
在为此苦苦挣扎并在互联网和 SE 上查找材料之后,我找不到从 w_ij 构造矩阵 W 的方法,我不知道如何做到这一点。任何帮助将不胜感激。
详细说明和 MWE 如下。
问题
在我的研究中,我正在尝试训练一个将source 映射到标量w_ij 的网络。其中输出w_ij 旨在表示对称矩阵W 中的元素i,j。
因此,训练损失是通过将许多相同网络的输出(具有共享权重,但每个网络看到不同的输入,并驱动矩阵中的不同元素)组合成矩阵形式,如下所示:
W[i, j] = w_ij
W[j, i] = w_ij
然后以损失形式训练这些多个网络:
L2_loss(f(W) - f(True_W))
(其中f() 是一个运行f(Y) = d' Y d 二次形式的函数——矩阵乘以左右固定向量的乘积。)
我需要通过这个损失对每个网络运行梯度。
我尝试了什么
-
tensorflow不支持简单的张量切片,即,不支持
W[i, j] = w_ij。 使用
tf.scatter_update()不允许通过它运行渐变。-
最后,虽然我已经接近解决方案,但我尝试将
tf.Variable用于矩阵W,如下所示:W_flat = tf.Variable(initial_value=[0] * (2 * 2), dtype='float32')然后通过切片
W_falt[0].assign(w_ij)分配给这个W_flat,但似乎我对这个变量的分配不起作用(参见MWE)。
MWE
Bellow 是一个简短的 MWE,其中W 是一个对角线为零的 2×2 对称矩阵,所以我只有一个网络必须驱动的独立元素(所以这里我只有一个网络),即,我想让W 拥有这些值
W = [[0, w_ij] [w_ij, 0]]
所以我尝试更新:
W_flat[1].assign(w_ij)
W_flat[2].assign(w_ij)
并将其转回矩阵:
W = tf.reshape(W_flat, (2, 2))
最终这个更新没有通过,print 的输出显示W 仍然全为零。
代码
import tensorflow as tf
def train():
with tf.Graph().as_default():
with tf.device('/cpu'):
source = tf.placeholder(tf.float32, shape=(2, 3))
is_training = tf.placeholder(tf.bool, shape=())
w_ij = tf.reduce_sum(source)
W_flat = tf.Variable(initial_value=[0] * (2 * 2), dtype='float32')
W_flat[1].assign(w_ij)
W_flat[2].assign(w_ij)
tf.assign(W_flat[1], w_ij)
tf.assign(W_flat[2], w_ij)
W = tf.reshape(W_flat, (2, 2))
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init, {is_training: True})
ops = {'W_flat': W_flat,
'source' : source,
'w_ij' : w_ij,
'W' : W}
for epoch in range(2):
feed_dict = {ops['source']: [[1,1,1], [7,7,7]]}
res_W_flat, res_wij, res_W = sess.run([ops['W_flat'], ops['w_ij'], ops['W']], feed_dict=feed_dict)
print("epoch:" , epoch)
print("W_flat:", res_W_flat)
print("wij:", res_wij)
print("W:", res_W)
if __name__ == "__main__" :
train()
print() 输出
epoch: 0
W_flat: [0. 0. 0. 0.]
wij: 24.0
W: [[0. 0.]
[0. 0.]]
epoch: 1
W_flat: [0. 0. 0. 0.]
wij: 24.0
W: [[0. 0.]
[0. 0.]]
所以W 和W_flat 不会被w_ij 的值更新,其值为24,但W 和W_flat 保持为零。
【问题讨论】:
-
请编辑您的问题,避免同时提出多个不同的问题。
标签: python-3.x tensorflow matrix deep-learning loss-function