我也是Estimator API 的新手,但我从 S.O. 学到了很多东西。社区,并将尝试回答您的问题。
首先,我想向您指出这个colab。这是目前我用于Estimators 的约定。
您是正确的,因为TRAIN 和EVAL 模式的input_fn 都是(features, labels) 形式的元组。
那么让我们来解决您的第一个问题:
如果我的初始特征已经是一个张量,比如“大小”,它是一个由 3 个双精度值组成的数组,该怎么办?如何通过 input_fn 输入?
这需要我稍微回溯一下,您的输入:
100 个示例的批次和一个名为“权重”的特征我将在特征字典中创建一个形状为 (100,1) 的张量的条目,
为了确保我理解正确,你是说,如果不是Tensor 形状为[100, 1],你有一个Tensor 或[100, <size>],在这种情况下是3 个双打,所以[100, 3] ?
如果是这样的话,那完全没有问题。在链接的colab 中,输入的单个示例具有形状[20, 7]。所以Tensor 的[3] 是直截了当的。
简短的回答是,无论您指定为元组的features 部分,都将传递给model_fn。所以你想传递一个Tensor 的[batch_size, size] 你返回一个([batch_size, size], labels) 的元组。然而,正如另一位用户在 S.O. 上向我指出的那样。我会给你同样的建议 - 使用字典,例如
my_data = # Tensor with shape [batch_size, size]
features = {'my_data': my_data}
...
return (features, labels)
作为参考,让我们以colab 的input_fn 为例,我按照上面的建议做同样的事情:
def input_fn(filenames:list, params):
mode = params['mode'] if 'mode' in params else 'train'
batch_size = params['batch_size']
shuffle(filenames) # <--- far more efficient than tf dataset shuffle
dataset = tf.data.TFRecordDataset(filenames)
# using fio's SCHEMA fill the TF Feature placeholders with values
dataset = dataset.map(lambda record: fio.from_record(record))
# using fio's SCHEMA restructure and unwrap (if possible) features (because tf records require wrapping everything into a list)
dataset = dataset.map(lambda context, features: fio.reconstitute((context, features)))
# dataset should be a tuple of (features, labels)
dataset = dataset.map(lambda context, features: (
{"input_tensors": features[I_FEATURE]}, # features <--- wrapping it in a dictionary
features[O_FEATURE] # labels
)
)
为简单起见,我假设您使用的是tf.data.Dataset。如果您的数据不是存储为 TF Records,则需要替换第 1 行。
1. dataset = tf.data.TFRecordDataset(filenames)
2. dataset = dataset.map(lambda record: fio.from_record(record))
3. dataset = dataset.map(lambda context, features: fio.reconstitute((context, features)))
无论您如何构建数据集,无论是 FeatureColumn、from_tensor_slices 等,并删除第 2 行和第 3 行,因为您不需要从 TF Records 恢复您的 (Sequence)Example。
现在让我们解决您的第二个问题,可变长度数组。
和上面一样!将其包装在字典中并返回。
这是正确的,除了从 TF Records 恢复您的 SequenceExample 之外,您将需要 VarLenFeature