【问题标题】:index extraction in tensorflow-jstensorflow-js 中的索引提取
【发布时间】:2018-05-10 13:59:41
【问题描述】:

我正在使用tensorflow-js 在浏览器上进行一些图像处理。

我有一个bool 类型的张量,我想从中提取true/1 值的索引。

有没有一种方法可以在不通过Tensor.data() 将整个张量作为数组获取的情况下做到这一点?

目前我正在做这样的事情:

let array = await tensor.data()
for(let i = 0; i <array.length;i++) {
 if (array[i]){
 //do Something
 };
};

但是在 600 毫秒加上 CPU 上的大张量上花费的时间太长。

【问题讨论】:

    标签: javascript tensorflow tensorflow.js


    【解决方案1】:

    这仍然使用tensor.data(),但它为您提供了一个张量,其中包含您想要的所有索引。 (+1 因为否则无论第一个值是否为 0,乘​​法总是导致 0)

    tf.tidy(() => {
    
      const boolTensor = tf.randomUniform([10], 0, 2, "bool");
      boolTensor.print();
    
      const indices = tf.range(1, boolTensor.shape[0] + 1);
      //starting at 1 to prevent 0x0
      indices.print();
    
      const ones = boolTensor.cast("float32").mul(indices);
      ones.print();
    
    
      ind = Array.from(ones.dataSync()).filter(i => i > 0).map(i => i - 1);
      //map to go back to 0-indexed
      console.log(ind);
    });
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.10.0">
    </script>

    【讨论】:

    • 好主意,但 ones.dataSync() 仍然解析整个张量,所以延迟仍然存在
    猜你喜欢
    • 2018-02-19
    • 1970-01-01
    • 2020-03-17
    • 2017-01-15
    • 1970-01-01
    • 1970-01-01
    • 2019-12-18
    • 2021-10-03
    • 1970-01-01
    相关资源
    最近更新 更多