【问题标题】:Tensor Slicing on TPUTPU 上的张量切片
【发布时间】:2021-04-16 12:04:49
【问题描述】:

我想在 TPU(Google Cloud TPU)上运行一个模型。我试图减少到最低限度。我省略了模型代码,因为它不相关,我的问题发生得更早。

这是主要的python文件:

import tensorflow as tf
import os
from Model import Model
from DataGeneratorTPU import load_dataset

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=os.environ['TPU_NAME'])
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)

strategy = tf.distribute.TPUStrategy(resolver)

with strategy.scope():
    model = Model(32768,7)
    model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
    dg = load_dataset('gs://bucket/data.tf','gs://bucket/annotations.tf',32768).batch(32,drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)
    model.fit(dg,epochs=10,verbose=1)
    model.save('test')

这是 DataGeneratorTPU.py:

import tensorflow as tf

def slice(i,data,annotations,lead_length):
    X = data[i:i+lead_length,:]
    y = annotations[i+lead_length,0,:]
    print(X.shape,y.shape) #OUTPUT2
    return X,y

def load_dataset(filename_data,filename_annotations,lead_length,step_size=1):
    data = tf.io.parse_tensor(tf.io.read_file(filename_data), tf.float32)
    annotations = tf.io.parse_tensor(tf.io.read_file(filename_annotations), tf.int32)
    print(data.shape,annotations.shape) #OUTPUT1
    rangeds = tf.data.Dataset.range(0,data.shape[0]-lead_length,step_size)
    def slice_(i):
        return slice(i,data,annotations,tf.constant(lead_length,dtype=tf.int64))

    return rangeds.map(slice_, tf.data.experimental.AUTOTUNE)

您可能注意到我用 OUTPUT1 和 OUTPUT2 标记了两个 print 语句,所以我可以告诉您输出是什么:

OUTPUT1 是(432001, 7) (432001, 7, 3)

OUTPUT2 是(None, 7) (3, )

但是,我认为 OUTPUT2 应该是 (32768, 7) (3, )

事实上,模型随后抱怨(仅来自一层的示例,还有更多,这是来自 conv1d 层):

  (0) Invalid argument: {{function_node __inference_train_function_33579}} Compilation failure: Dynamic Spatial Convolution is not supported: lhs shape is f32[4,1,<=32774,7]     
     [[{{node Model/conv1d/conv1d}}]]
        TPU compilation failed
         [[tpu_compile_succeeded_assert/_12623170171032432447/_5]] 
         [[tpu_compile_succeeded_assert/_12623170171032432447/_5/_303]]

抱怨我们正在谈论的维度(我在映射函数中打印)是动态的,而不是固定在 32768。但是它应该是静态的,因为我对切片使用恒定宽度32768,我什至确保范围不会查看可能出错的最后 32768 个元素。似乎只是能够估计这个小于 32774,我不知道这 6 个额外的元素是从哪里来的……

我做错了什么?我怎样才能得到这个静态?

【问题讨论】:

  • 为什么你认为这个形状应该是32768? Colab/Kaggle TPU 是免费使用的,是否可以先检查它们,如果可能提供可重现的代码?
  • 因为我在切片(第一维)i:i+32768(参数是常数)。但没关系,Lescurel 的回答似乎很到位。

标签: tensorflow tf.keras tpu


【解决方案1】:

似乎有一种情况,使用tf.strided_slice(这是__getitem__方法调用的函数)会丢失传递给它的张量的形状信息。我猜这是因为切片非常灵活,并且允许传递“不可能”的切片大小(例如,end 索引大于数组大小)。这样做会导致最终数据集中的元素形状可变。该函数无法确保数组的最终形状,因此默认为None

您的情况很简单,可以通过调用tf.slice 来代替,通过询问切片的大小来保留形状信息。

将您的 slice 函数替换为以下内容:

def slice(i, data, annotations, lead_length):
    X = tf.slice(data, [i,0], [lead_length, tf.shape(data)[1]])
    # I also used slice for y for the sake of it, but its probably more readable to use
    # y = annotations[i+lead_length,0,:]
    y = tf.squeeze(tf.slice(annotations, [i+lead_length,0,0], [1, 1, tf.shape(annotations)[2]]))
    return X, y

查看数据集的形状可以得到:

>>> ds = rangeds.map(slice_, tf.data.experimental.AUTOTUNE)
>>> ds
<ParallelMapDataset shapes: ((32768, 7), (3,)), types: (tf.float32, tf.float32)>

另一种可能性是在你的张量上调用set_shape如果你知道你可以保证形状是正确的(即i+lead_length永远不会比你的第一个尺寸大方面)。如果不能,将导致难以调试运行时错误。

def slice(i,data,annotations,lead_length):
    X = data[i:i+lead_length,:]
    y = annotations[i+lead_length,0,:]
    X.set_shape((lead_length,7))
    return X,y

我认为在你的情况下,让tf.slice 做这项工作更干净。

【讨论】:

  • 我可以确认我可以使用tf.slice 解决它,这实际上是问题所在。我还没有尝试过set_shape 变体,但我想它可以工作(事实上我可以保证这一点)。这不应该是一个错误吗?在tf.strided_slice 中或者应该在某处对此发出警告,而不是在tf.slice 的文档中建议使用__getitem__ 而不是tf.slice?赞成和接受。非常感谢!
  • 看来我必须等待 38 分钟才能获得赏金。我希望我不会忘记这一点。
  • “这不应该是一个错误吗?”这是 python 切片的默认行为,即list(range(10))[9:12] 是有效的,并且只会返回列表的最后一个元素。
  • 那么至少文档不建议使用它来代替tf.slice。它特别提到 getitem 作为“更pythonic”的替代品。事实上它不是(在 TPU 上),因为它可能会失去张量形状。我认为,有关此的警告将在文档中按顺序排列。
猜你喜欢
  • 2016-07-04
  • 2017-12-01
  • 1970-01-01
  • 2018-03-15
  • 1970-01-01
  • 1970-01-01
  • 2020-03-11
  • 1970-01-01
  • 2019-10-21
相关资源
最近更新 更多