【发布时间】:2020-09-24 12:37:42
【问题描述】:
我尝试使用 tf.data.Dataset 对数据集进行一些转换。
我发现在每个 epoch 都执行了转换。 map 函数是否可能在第一个 epoch 执行?
【问题讨论】:
标签: python tensorflow machine-learning keras tensorflow2.0
我尝试使用 tf.data.Dataset 对数据集进行一些转换。
我发现在每个 epoch 都执行了转换。 map 函数是否可能在第一个 epoch 执行?
【问题讨论】:
标签: python tensorflow machine-learning keras tensorflow2.0
您可以使用不同的数据集。这在自定义训练循环中很容易。就这样:
def transformation(inputs, labels):
tf.print('With transformation!')
return inputs, labels
def no_transformation(inputs, labels):
tf.print('No transformation!')
return inputs, labels
data_with_transform = data.take(4).map(transformation).batch(4)
data_no_transform = data.take(4).map(no_transformation).batch(4)
然后:
if epoch < 1:
ds = data_with_transform
else:
ds = data_no_transform
for X_train, y_train in ds:
train_step(X_train, y_train)
完整示例:
import tensorflow_datasets as tfds
import tensorflow as tf
data, info = tfds.load('iris', split='train', as_supervised=True,
with_info=True)
def transformation(inputs, labels):
tf.print('With transformation!')
return inputs, labels
def no_transformation(inputs, labels):
tf.print('No transformation!')
return inputs, labels
data_with_transform = data.take(4).map(transformation).batch(4)
data_no_transform = data.take(4).map(no_transformation).batch(4)
model = tf.keras.Sequential([
tf.keras.layers.Dense(8, activation='relu'),
tf.keras.layers.Dense(16, activation='relu'),
tf.keras.layers.Dense(info.features['label'].num_classes)
])
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
train_loss = tf.keras.metrics.Mean()
train_acc = tf.keras.metrics.SparseCategoricalAccuracy()
opt = tf.keras.optimizers.Adam(learning_rate=1e-3)
@tf.function
def train_step(inputs, labels):
with tf.GradientTape() as tape:
logits = model(inputs)
loss = loss_object(labels, logits)
gradients = tape.gradient(loss, model.trainable_variables)
opt.apply_gradients(zip(gradients, model.trainable_variables))
train_loss(loss)
train_acc(labels, logits)
def main(epochs=5):
for epoch in range(epochs):
train_loss.reset_states()
train_acc.reset_states()
if epoch < 1:
ds = data_with_transform
else:
ds = data_no_transform
for X_train, y_train in ds:
train_step(X_train, y_train)
if __name__ == '__main__':
main()
With transformation!
With transformation!
With transformation!
With transformation!
No transformation!
No transformation!
No transformation!
No transformation!
No transformation!
No transformation!
No transformation!
No transformation!
No transformation!
No transformation!
No transformation!
No transformation!
No transformation!
No transformation!
No transformation!
No transformation!
【讨论】: