【问题标题】:How to access embedding layer's variables in tensorflow?如何在张量流中访问嵌入层的变量?
【发布时间】:2021-02-21 15:58:34
【问题描述】:

假设我有这样的嵌入层e

import tensorflow as tf
e = tf.keras.layers.Embedding(5,3)

如何打印它的 numpy 值?

【问题讨论】:

    标签: tensorflow tensorflow2.0 embedding tensorflow2.x


    【解决方案1】:

    您需要构建嵌入层才能访问嵌入矩阵:

    import tensorflow as tf
    
    emb = tf.keras.layers.Embedding(5, 3)
    emb.build(())
    emb.trainable_variables[0].numpy()
    # array([[-0.00595363,  0.03049802,  0.01821234],
    #        [ 0.01515153, -0.01006874,  0.02568189],
    #        [-0.01845006,  0.02135053, -0.03916124],
    #        [-0.00822829,  0.00922295,  0.00091892],
    #        [-0.00727308, -0.03537174, -0.01419405]], dtype=float32)
    

    【讨论】:

      【解决方案2】:

      感谢@vald 的回答。我认为e.embeddings 更 Pythonic 并且可能更高效。

      import tensorflow as tf
      e = tf.keras.layers.Embedding(5,3)
      
      e.build(()) # You should build it before using.
      
      print(e.embeddings)
      

      >>>
      <tf.Variable 'embeddings:0' shape=(5, 3) dtype=float32, numpy=
      array([[ 0.02099125,  0.01865673,  0.03652272],
             [ 0.02714007, -0.00316695, -0.00252246],
             [-0.02411103,  0.02043924, -0.01297874],
             [ 0.00766286, -0.03511617,  0.03460207],
             [ 0.00256425, -0.03659264, -0.01796588]], dtype=float32)>
      

      【讨论】:

        猜你喜欢
        • 2019-10-17
        • 2023-03-04
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 2019-05-07
        • 1970-01-01
        • 2018-08-14
        • 2018-07-31
        相关资源
        最近更新 更多