【发布时间】:2020-05-29 11:04:46
【问题描述】:
有没有办法根据 XLA 编译函数中的随机数生成器对张量进行动态切片?例如:
@tf.function(experimental_compile=True)
def random_slice(input, max_slice_size):
offset = tf.squeeze(tf.random.uniform([1], minval=0, maxval=input.shape[0]-max_slice_size, dtype=tf.int32))
sz = tf.squeeze(tf.random.uniform([1], minval=1, maxval=max_slice_size, dtype=tf.int32))
indices = tf.range(offset, offset+sz) # Non-XLA-able due to non-static bounds
return tf.gather(input, indices)
x = tf.ones([50, 50])
y = random_slice(x, 4)
此代码无法编译,因为 XLA 要求 tf.range 的参数在编译时已知。有推荐的解决方法吗?
【问题讨论】:
标签: python tensorflow tensorflow-xla