【发布时间】:2020-07-31 04:10:52
【问题描述】:
我正在为 Tensorflow 2 中的数据集编写映射函数。 数据集包含几张图像和相应的标签,更具体地说,标签只有三个可能的值,13、17 和 34。 映射函数应该获取标签并将它们转换为分类标签。
可能有更好的方法来实现这个功能(请随意提出建议),但这是我的实现:
def map_labels(dataset):
def convert_labels_to_categorical(image, labels):
labels = [1.0, 0., 0.] if labels == 13 else [0., 1.0, 0.] if labels == 17 else [0., 0., 1.0]
return image, labels
categorical_dataset = dataset.map(convert_labels_to_categorical)
return categorical_dataset
主要问题是我收到以下错误:
OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed: AutoGraph is disabled in this function. Try decorating it directly with @tf.function.
我真的不知道这个错误是什么意思,而且互联网上没有那么多其他来源记录同样的错误。有什么想法吗?
编辑(新的非工作实现):
def map_labels(dataset):
def convert_labels_to_categorical(image, labels):
labels = tf.Variable([1.0, 0., 0.]) if tf.reduce_any(tf.math.equal(labels, tf.constant(0,dtype=tf.int64))) \
else tf.Variable([0., 1.0, 0.]) if tf.reduce_any(tf.math.equal(labels, tf.constant(90,dtype=tf.int64))) \
else tf.Variable([0., 0., 1.0])
return image, labels
categorical_dataset = dataset.map(convert_labels_to_categorical)
return categorical_dataset
【问题讨论】:
标签: tensorflow2.0 tensorflow-datasets tensorflow2.x