【问题标题】:MNIST Tensorflow: How to manipulate a tensor of the form [i] to a tensor of a form [...0,0,0,1,0,0...] where 1 is at ith position?MNIST Tensorflow:如何将 [i] 形式的张量操作为 [...0,0,0,1,0,0...] 形式的张量,其中 1 在第 i 个位置?
【发布时间】:2016-10-19 09:12:46
【问题描述】:

我想转换形式为

的张量(称为 logits)
int32 - [batch_size]

到形式的张量(称为标签)

 [batch_size, 10]

例如对于 batch_size=3

logits=[1,6,9]
labels=[[0,1,0,0,0,0,0,0,0,0],
        [0,0,0,0,0,0,1,0,0,0],
        [0,0,0,0,0,0,0,0,0,1]]

出现这个问题是因为我想在 tensorflow mnist 示例中将成本函数更改为二次函数 (https://github.com/tensorflow/tensorflow/tree/r0.9/tensorflow/examples/tutorials/mnist) 我使用fully_connected_feed.py 和mnist.py。在 mnist.py 我想改变:

    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, labels, name='xentropy')
    loss = tf.reduce_mean(cross_entropy, name='xentropy_mean') 

loss= tf.reduce_sum(tf.squared_difference(logits,labels))

但问题在于:

Logits tensor, float - [batch_size, 10];  
Labels tensor, int64 - [batch_size].

所以我需要“矢量化”标签!? 有谁知道如何做到这一点?

【问题讨论】:

    标签: tensorflow mnist


    【解决方案1】:

    标签“矢量化”称为one-hot encoding。

    您正在寻找tf.one_hot 函数。

    这个函数需要:

    1. 索引列表(您的logits 向量)
    2. depth 参数:这是 one-hot 向量的深度(one-hot 编码标签的长度)
    3. on_value & off_value,您可以根据需要进行更改(但默认值 1 和 0 是您要查找的值)。
    4. dtype 就是张量输出类型。

    因此,您可以使用以下代码对标签进行一次性编码:

    one_hot_labels = tf.one_hot(logits, 10, dtype=tf.uint8)
    

    one_hot_labels 是一个tf.Tensor 对象。

    如果你需要从 python 访问它的内容,记得要 eval(或者运行它)。

    这是一个玩具示例:

    import tensorflow as tf.
    tf.InteractiveSession()
    logits=[1,6,9]
    one_hot_labels = tf.one_hot(logits, 10, dtype=tf.uint8)
    print(one_hot_labels.eval())
    

    输出:

    [[0 1 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 1 0 0 0]
     [0 0 0 0 0 0 0 0 0 1]]
    

    【讨论】:

    • 谢谢你 nessuno,这正是我想要的。但是如果我写 vectorized_labels= tf.one_hot(labels, 10) 我总是得到 TypeError: one_hot() 需要至少 4 个参数(给定 2 个)??
    猜你喜欢
    • 2017-06-06
    • 2017-08-17
    • 2017-04-01
    • 2018-05-15
    • 1970-01-01
    • 2019-03-18
    • 2019-09-28
    • 2017-06-14
    • 2020-09-01
    相关资源
    最近更新 更多